Skip to content

Commit 364651a

Browse files
committed
Make tests more generic; add checking matrix equality by indices
1 parent a317ef3 commit 364651a

1 file changed

Lines changed: 31 additions & 78 deletions

File tree

tests/GraphBLAS-sharp.Tests/OperationsTests/EWiseAddTests.fs

Lines changed: 31 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ let createMatrix<'a when 'a : struct and 'a : equality> matrixFormat args =
7474

7575
let logger = Log.create "Sample"
7676

77-
let correctnessOnNumbers<'a when 'a : struct and 'a : equality>
77+
let checkCorrectnessGeneric<'a when 'a : struct and 'a : equality>
7878
(sum: 'a -> 'a -> 'a)
7979
(diff: 'a -> 'a -> 'a)
8080
(isZero: 'a -> bool)
@@ -87,12 +87,25 @@ let correctnessOnNumbers<'a when 'a : struct and 'a : equality>
8787
let right = matrixB |> Seq.cast<'a>
8888

8989
(left, right)
90-
||> Seq.map2
91-
(fun x y ->
90+
||> Seq.mapi2
91+
(fun idx x y ->
92+
let i = idx / Array2D.length2 matrixA
93+
let j = idx % Array2D.length2 matrixA
94+
9295
if isZero x && isZero y then None
93-
else Some <| sum x y
96+
else Some (i, j, sum x y)
9497
)
9598
|> Seq.choose id
99+
|> Array.ofSeq
100+
|> Array.unzip3
101+
|>
102+
(fun (rows, cols, vals) ->
103+
{
104+
RowIndices = rows
105+
ColumnIndices = cols
106+
Values = vals
107+
}
108+
)
96109

97110
let eWiseAddGB (matrixA: 'a[,]) (matrixB: 'a[,]) =
98111
try
@@ -115,8 +128,6 @@ let correctnessOnNumbers<'a when 'a : struct and 'a : equality>
115128
return! tuples.ToHost()
116129
}
117130
|> oclContext.RunSync
118-
|> (fun tuples -> tuples.Values)
119-
|> Seq.ofArray
120131

121132
finally
122133
oclContext.Provider.CloseAllBuffers()
@@ -126,87 +137,29 @@ let correctnessOnNumbers<'a when 'a : struct and 'a : equality>
126137

127138
logger.debug (
128139
eventX "Expected result is {matrix}"
129-
>> setField "matrix" (sprintf "%A" <| List.ofSeq expected)
140+
>> setField "matrix" (sprintf "%A" expected.Values)
130141
)
131142

132143
logger.debug (
133144
eventX "Actual result is {matrix}"
134-
>> setField "matrix" (sprintf "%A" <| List.ofSeq actual)
145+
>> setField "matrix" (sprintf "%A" actual.Values)
135146
)
136147

137-
"Length of expected and result seq should be equal"
138-
|> Expect.hasLength actual (Seq.length expected)
139-
140-
let difference =
141-
(expected, actual)
142-
||> Seq.map2 diff
143-
144-
"There should be no difference between expected and received values"
145-
|> Expect.all difference isZero
146-
147-
let correctnessOnBool (case: OperationCase) (matrixA: bool[,], matrixB: bool[,]) =
148-
let eWiseAddNaive (matrixA: bool[,]) (matrixB: bool[,]) =
149-
let left = matrixA |> Seq.cast<bool>
150-
let right = matrixB |> Seq.cast<bool>
151-
152-
(left, right)
153-
||> Seq.map2 (||)
154-
|> Seq.filter id
155-
156-
let eWiseAddGB (matrixA: bool[,]) (matrixB: bool[,]) =
157-
try
158-
let left = createMatrix<bool> case.MatrixCase [|matrixA; not|]
159-
let right = createMatrix<bool> case.MatrixCase [|matrixB; not|]
160-
161-
logger.debug (
162-
eventX "Left matrix is \n{matrix}"
163-
>> setField "matrix" left
164-
)
148+
let actualIndices = Seq.zip actual.RowIndices actual.ColumnIndices
149+
let expectedIndices = Seq.zip expected.RowIndices expected.ColumnIndices
165150

166-
logger.debug (
167-
eventX "Right matrix is \n{matrix}"
168-
>> setField "matrix" right
169-
)
170-
171-
opencl {
172-
let! result = left.EWiseAdd right None AnyAll.bool
173-
let! tuples = result.GetTuples()
174-
return! tuples.ToHost()
175-
}
176-
|> oclContext.RunSync
177-
|> (fun tuples -> tuples.Values)
178-
|> Seq.ofArray
179-
180-
finally
181-
oclContext.Provider.CloseAllBuffers()
182-
183-
let expected = eWiseAddNaive matrixA matrixB
184-
let actual = eWiseAddGB matrixA matrixB
185-
186-
logger.debug (
187-
eventX "Expected result is {matrix}"
188-
>> setField "matrix" (sprintf "%A" <| List.ofSeq expected)
189-
)
190-
191-
logger.debug (
192-
eventX "Actual result is {matrix}"
193-
>> setField "matrix" (sprintf "%A" <| List.ofSeq actual)
194-
)
195-
196-
"Length of expected and result seq should be equal"
197-
|> Expect.hasLength actual (Seq.length expected)
151+
"Indices of expected and result matrix must be the same"
152+
|> Expect.sequenceEqual actualIndices expectedIndices
198153

199154
let difference =
200-
(expected, actual)
201-
||> Seq.map2 (<>)
155+
(expected.Values, actual.Values)
156+
||> Seq.map2 diff
202157

203-
logger.debug (
204-
eventX "Difference result is {matrix}"
205-
>> setField "matrix" (sprintf "%A" <| List.ofSeq difference)
206-
)
158+
"Length of expected and result values should be equal"
159+
|> Expect.hasLength actual.Values (Seq.length expected.Values)
207160

208161
"There should be no difference between expected and received values"
209-
|> Expect.all difference not
162+
|> Expect.all difference isZero
210163

211164
let config = {
212165
FsCheckConfig.defaultConfig with
@@ -218,15 +171,15 @@ let config = {
218171
// https://docs.microsoft.com/ru-ru/dotnet/csharp/language-reference/language-specification/types#value-types
219172
let testFixtures case = [
220173
case
221-
|> correctnessOnNumbers<int> (+) (-) ((=) 0) AddMult.int
174+
|> checkCorrectnessGeneric<int> (+) (-) ((=) 0) AddMult.int
222175
|> testPropertyWithConfig config (sprintf "Correctness on int, %A, %A" case.MatrixCase case.MaskCase)
223176

224177
case
225-
|> correctnessOnNumbers<float> (+) (-) (fun x -> abs x < Accuracy.medium.absolute) AddMult.float
178+
|> checkCorrectnessGeneric<float> (+) (-) (fun x -> abs x < Accuracy.medium.absolute) AddMult.float
226179
|> testPropertyWithConfig config (sprintf "Correctness on float, %A, %A" case.MatrixCase case.MaskCase)
227180

228181
case
229-
|> correctnessOnBool
182+
|> checkCorrectnessGeneric<bool> (||) (<>) not AnyAll.bool
230183
|> testPropertyWithConfig config (sprintf "Correctness on bool, %A, %A" case.MatrixCase case.MaskCase)
231184
]
232185

0 commit comments

Comments
 (0)