@@ -2,6 +2,7 @@ module GraphBLAS.FSharp.Tests.Matrix.SpGeMM
22
33open Expecto
44open GraphBLAS.FSharp .Backend .Matrix .CSR .SpGeMM
5+ open GraphBLAS.FSharp .Test
56open Microsoft.FSharp .Collections
67open GraphBLAS.FSharp .Backend
78open GraphBLAS.FSharp .Backend .Matrix
@@ -16,29 +17,14 @@ let context = Context.defaultContext.ClContext
1617
1718let processor = Context.defaultContext.Queue
1819
19- let getSegmentsPointers ( leftMatrix : Matrix.CSR < 'a >) ( rightMatrix : Matrix.CSR < 'b >) =
20- printfn $" all: %A {rightMatrix.RowPointers}"
21-
22- let firstRowPointers =
23- rightMatrix.RowPointers.[.. rightMatrix.RowPointers.Length - 2 ]
24-
25- printfn $" first pointers: %A {firstRowPointers}"
26-
27- let lastRowPointers = rightMatrix.RowPointers.[ 1 ..]
28-
29- printfn $" last pointers: %A {lastRowPointers}"
30-
31- let rowsLengths = Array.map2 (-) lastRowPointers firstRowPointers
32-
33- printfn $" all row lengths %A {rowsLengths}"
34-
35- let neededLengths = Array.init leftMatrix.ColumnIndices.Length ( fun index -> Array.item index rowsLengths)
36-
37- printfn $" needed lengths %A {neededLengths}"
20+ let config = { Utils.defaultConfig with arbitrary = [ typeof< Generators.PairOfMatricesOfCompatibleSize> ] }
3821
39- HostPrimitives.prefixSumExclude neededLengths
22+ let getSegmentsPointers ( leftMatrix : Matrix.CSR < 'a >) ( rightMatrix : Matrix.CSR < 'b >) =
23+ Array.map ( fun item ->
24+ rightMatrix.RowPointers.[ item + 1 ] - rightMatrix.RowPointers.[ item]) leftMatrix.ColumnIndices
25+ |> HostPrimitives.prefixSumExclude
4026
41- let makeTest isZero testFun ( leftArray : 'a [,], rightArray : 'a [,], _ : bool [,] ) =
27+ let makeTest isZero testFun ( leftArray : 'a [,], rightArray : 'a [,]) =
4228
4329 let leftMatrix =
4430 Utils.createMatrixFromArray2D CSR leftArray isZero
@@ -57,6 +43,8 @@ let makeTest isZero testFun (leftArray: 'a [,], rightArray: 'a [,], _: bool [,])
5743 let actualLength , ( clActual : ClArray < int >) =
5844 testFun processor clLeftMatrix clRightMatrix
5945
46+ clLeftMatrix.Dispose processor
47+
6048 let actualPointers = clActual.ToHostAndFree processor
6149
6250 let expectedPointers , expectedLength =
@@ -73,7 +61,7 @@ let createTest<'a when 'a : struct> (isZero: 'a -> bool) testFun =
7361 let testFun = testFun context Utils.defaultWorkGroupSize
7462
7563 makeTest isZero testFun
76- |> testPropertyWithConfig { Utils.defaultConfig with endSize = 10 } $" test on {typeof<'a>}"
64+ |> testPropertyWithConfig { config with endSize = 10 } $" test on {typeof<'a>}"
7765
7866let getSegmentsTests =
7967 [ createTest ((=) 0 ) Expand.getSegmentPointers
@@ -86,4 +74,30 @@ let getSegmentsTests =
8674 createTest ((=) 0 u) Expand.getSegmentPointers ]
8775 |> testList " get segment pointers"
8876
77+ let makeExpandTest isZero testFun ( leftArray : 'a [,], rightArray : 'a [,]) =
78+
79+ let leftMatrix =
80+ Utils.createMatrixFromArray2D CSR leftArray isZero
81+ |> Utils.castMatrixToCSR
82+
83+ let rightMatrix =
84+ Utils.createMatrixFromArray2D CSR rightArray isZero
85+ |> Utils.castMatrixToCSR
86+
87+ if leftMatrix.NNZ > 0
88+ && rightMatrix.NNZ > 0 then
89+
90+ let segmentPointers , length =
91+ getSegmentsPointers leftMatrix rightMatrix
92+
93+ let clLeftMatrix = leftMatrix.ToDevice context
94+ let clRightMatrix = rightMatrix.ToDevice context
95+ let clSegmentPointers = context.CreateClArray segmentPointers
96+
97+ let ( actualValues : ClArray < 'a >), ( actualColumns : ClArray < int >), ( actualRows : ClArray < int >) =
98+ testFun processor length clSegmentPointers clLeftMatrix clRightMatrix
99+
100+ clLeftMatrix.Free processor
101+ clRightMatrix. processor
102+ clSegmentPointers
89103
0 commit comments