@@ -2,6 +2,7 @@ module GraphBLAS.FSharp.Tests.Matrix.SpGeMM
22
33open Expecto
44open GraphBLAS.FSharp .Backend .Matrix .CSR .SpGeMM
5+ open GraphBLAS.FSharp .Backend .Quotes
56open GraphBLAS.FSharp .Test
67open Microsoft.FSharp .Collections
78open GraphBLAS.FSharp .Backend
@@ -76,7 +77,7 @@ let getSegmentsTests =
7677 createTest ((=) 0 uy) Expand.getSegmentPointers ]
7778 |> testList " get segment pointers"
7879
79- let expand length segmentPointers mulOp ( leftMatrix : Matrix.CSR < 'a >) ( rightMatrix : Matrix.CSR < 'b >) =
80+ let expand length segmentPointers ( leftMatrix : Matrix.CSR < 'a >) ( rightMatrix : Matrix.CSR < 'b >) =
8081 let extendPointers pointers =
8182 Array.pairwise pointers
8283 |> Array.map ( fun ( fst , snd ) -> snd - fst)
@@ -106,11 +107,9 @@ let expand length segmentPointers mulOp (leftMatrix: Matrix.CSR<'a>) (rightMatri
106107 |> Array.concat
107108 |> Array.unzip
108109
109- let expectedValues = Array.map2 mulOp leftMatrixValues rightMatrixValues
110+ leftMatrixValues, rightMatrixValues, expectedColumns , expectedRows
110111
111- expectedValues, expectedColumns, expectedRows
112-
113- let makeExpandTest isEqual zero opMul testFun ( leftArray : 'a [,], rightArray : 'a [,]) =
112+ let makeExpandTest isEqual zero testFun ( leftArray : 'a [,], rightArray : 'a [,]) =
114113
115114 let leftMatrix = createCSRMatrix leftArray <| isEqual zero
116115
@@ -126,51 +125,55 @@ let makeExpandTest isEqual zero opMul testFun (leftArray: 'a [,], rightArray: 'a
126125 let clRightMatrix = rightMatrix.ToDevice context
127126 let clSegmentPointers = context.CreateClArray segmentPointers
128127
129- let ( clActualValues : ClArray < 'a >), ( clActualColumns : ClArray < int >), ( clActualRows : ClArray < int >) =
128+ let ( clActualLeftValues : ClArray < 'a >), ( clActualRightValues : ClArray < 'a >), ( clActualColumns : ClArray < int >), ( clActualRows : ClArray < int >) =
130129 testFun processor length clSegmentPointers clLeftMatrix clRightMatrix
131130
132131 clLeftMatrix.Dispose processor
133132 clRightMatrix.Dispose processor
134133 clSegmentPointers.Free processor
135134
136- let actualValues = clActualValues.ToHostAndFree processor
135+ let actualLeftValues = clActualLeftValues.ToHostAndFree processor
136+ let actualRightValues = clActualRightValues.ToHostAndFree processor
137137 let actualColumns = clActualColumns.ToHostAndFree processor
138138 let actualRows = clActualRows.ToHostAndFree processor
139139
140- let expectedValues , expectedColumns , expectedRows =
141- expand length segmentPointers opMul leftMatrix rightMatrix
140+ let expectedLeftMatrixValues , expectedRightMatrixValues , expectedColumns , expectedRows =
141+ expand length segmentPointers leftMatrix rightMatrix
142+
143+ " Left values must be the same"
144+ |> Utils.compareArrays isEqual actualLeftValues expectedLeftMatrixValues
142145
143- " Values must be the same"
144- |> Utils.compareArrays isEqual actualValues expectedValues
146+ " Right values must be the same"
147+ |> Utils.compareArrays isEqual actualRightValues expectedRightMatrixValues
145148
146149 " Columns must be the same"
147150 |> Utils.compareArrays (=) actualColumns expectedColumns
148151
149152 " Rows must be the same"
150153 |> Utils.compareArrays (=) actualRows expectedRows
151154
152- let createExpandTest isEqual ( zero : 'a ) opMul opMulQ testFun =
155+ let createExpandTest isEqual ( zero : 'a ) testFun =
153156
154- let testFun = testFun context Utils.defaultWorkGroupSize opMulQ
157+ let testFun = testFun context Utils.defaultWorkGroupSize
155158
156- makeExpandTest isEqual zero opMul testFun
159+ makeExpandTest isEqual zero testFun
157160 |> testPropertyWithConfig config $" test on %A {typeof<'a>}"
158161
159162let expandTests =
160- [ createExpandTest (=) 0 (*) <@ (*) @> Expand.expand
163+ [ createExpandTest (=) 0 Expand.expand
161164
162165 if Utils.isFloat64Available context.ClDevice then
163- createExpandTest Utils.floatIsEqual 0.0 (*) <@ (*) @> Expand.expand
166+ createExpandTest Utils.floatIsEqual 0.0 Expand.expand
164167
165- createExpandTest Utils.float32IsEqual 0 f (*) <@ (*) @> Expand.expand
166- createExpandTest (=) false (&&) <@ (&&) @> Expand.expand
167- createExpandTest (=) 0 uy (*) <@ (*) @> Expand.expand ]
168+ createExpandTest Utils.float32IsEqual 0 f Expand.expand
169+ createExpandTest (=) false Expand.expand
170+ createExpandTest (=) 0 uy Expand.expand ]
168171 |> testList " Expand.expand"
169172
170173let checkGeneralResult zero isEqual actualValues actualColumns actualRows mul add ( leftArray : 'a [,]) ( rightArray : 'a [,]) =
171174
172175 let expected =
173- HostPrimitives.array2DMultiplication mul add leftArray rightArray
176+ HostPrimitives.array2DMultiplication zero mul add leftArray rightArray
174177 |> fun array -> Utils.createMatrixFromArray2D COO array ( isEqual zero)
175178 |> function Matrix.COO matrix -> matrix | _ -> failwith " format miss"
176179
@@ -217,15 +220,15 @@ let makeGeneralTest zero isEqual opMul opAdd testFun (leftArray: 'a [,], rightAr
217220 checkGeneralResult zero isEqual actualValues actualColumns actualRows opMul opAdd leftArray rightArray
218221 with
219222 | ex when ex.Message = " InvalidBufferSize" -> ()
220- | ex -> raise ex
223+ | _ -> reraise ()
221224
222- let createGeneralTest ( zero : 'a ) isEqual opAdd opAddQ opMul opMulQ testFun =
225+ let createGeneralTest ( zero : 'a ) isEqual opAddQ opAdd ( opMulQ , opMul ) testFun =
223226
224227 let testFun = testFun context Utils.defaultWorkGroupSize opAddQ opMulQ
225228
226229 makeGeneralTest zero isEqual opMul opAdd testFun
227230 |> testPropertyWithConfig { config with endSize = 10 ; maxTest = 1000 } $" test on %A {typeof<'a>}"
228231
229232let generalTests =
230- [ createGeneralTest 0 (=) (+) <@ (+) @> (*) <@ (*) @> Expand.run ]
233+ [ createGeneralTest 0 (=) <@ (+) @> (+) ArithmeticOperations.intMul Expand.run ]
231234 |> testList " general"
0 commit comments