@@ -19,20 +19,20 @@ let processor = Context.defaultContext.Queue
1919
2020let config = { Utils.defaultConfig with arbitrary = [ typeof< Generators.PairOfMatricesOfCompatibleSize> ] }
2121
22+ let createCSRMatrix array isZero =
23+ Utils.createMatrixFromArray2D CSR array isZero
24+ |> Utils.castMatrixToCSR
25+
2226let getSegmentsPointers ( leftMatrix : Matrix.CSR < 'a >) ( rightMatrix : Matrix.CSR < 'b >) =
2327 Array.map ( fun item ->
2428 rightMatrix.RowPointers.[ item + 1 ] - rightMatrix.RowPointers.[ item]) leftMatrix.ColumnIndices
2529 |> HostPrimitives.prefixSumExclude
2630
2731let makeTest isZero testFun ( leftArray : 'a [,], rightArray : 'a [,]) =
2832
29- let leftMatrix =
30- Utils.createMatrixFromArray2D CSR leftArray isZero
31- |> Utils.castMatrixToCSR
33+ let leftMatrix = createCSRMatrix leftArray isZero
3234
33- let rightMatrix =
34- Utils.createMatrixFromArray2D CSR rightArray isZero
35- |> Utils.castMatrixToCSR
35+ let rightMatrix = createCSRMatrix rightArray isZero
3636
3737 if leftMatrix.NNZ > 0 && rightMatrix.NNZ > 0 then
3838
@@ -44,6 +44,7 @@ let makeTest isZero testFun (leftArray: 'a [,], rightArray: 'a [,]) =
4444 testFun processor clLeftMatrix clRightMatrix
4545
4646 clLeftMatrix.Dispose processor
47+ clRightMatrix.Dispose processor
4748
4849 let actualPointers = clActual.ToHostAndFree processor
4950
@@ -61,7 +62,7 @@ let createTest<'a when 'a : struct> (isZero: 'a -> bool) testFun =
6162 let testFun = testFun context Utils.defaultWorkGroupSize
6263
6364 makeTest isZero testFun
64- |> testPropertyWithConfig { config with endSize = 10 } $" test on {typeof<'a>}"
65+ |> testPropertyWithConfig config $" test on {typeof<'a>}"
6566
6667let getSegmentsTests =
6768 [ createTest ((=) 0 ) Expand.getSegmentPointers
@@ -71,18 +72,48 @@ let getSegmentsTests =
7172
7273 createTest ((=) 0 f) Expand.getSegmentPointers
7374 createTest ((=) false ) Expand.getSegmentPointers
74- createTest ((=) 0 u ) Expand.getSegmentPointers ]
75+ createTest ((=) 0 uy ) Expand.getSegmentPointers ]
7576 |> testList " get segment pointers"
7677
77- let makeExpandTest isZero testFun ( leftArray : 'a [,], rightArray : 'a [,]) =
78+ let expand length segmentPointers mulOp ( leftMatrix : Matrix.CSR < 'a >) ( rightMatrix : Matrix.CSR < 'b >) =
79+ let extendPointers pointers =
80+ Array.pairwise pointers
81+ |> Array.map ( fun ( fst , snd ) -> snd - fst)
82+ |> Array.mapi ( fun index length -> Array.create length index)
83+ |> Array.concat
84+
85+ let segmentsLengths =
86+ Array.append segmentPointers [| length |]
87+ |> Array.pairwise
88+ |> Array.map ( fun ( fst , snd ) -> snd - fst)
89+
90+ let leftMatrixValues , expectedRows =
91+ let tripleFst ( fst , _ , _ ) = fst
92+
93+ Array.zip3 segmentsLengths leftMatrix.Values <| extendPointers leftMatrix.RowPointers // TODO(expand row pointers)
94+ // select items each segment length not zero
95+ |> Array.filter ( tripleFst >> ((=) 0 ) >> not )
96+ |> Array.collect ( fun ( length , value , rowIndex ) -> Array.create length ( value, rowIndex))
97+ |> Array.unzip
98+
99+ let rightMatrixValues , expectedColumns =
100+ let valuesAndColumns = Array.zip rightMatrix.Values rightMatrix.ColumnIndices
78101
79- let leftMatrix =
80- Utils.createMatrixFromArray2D CSR leftArray isZero
81- |> Utils.castMatrixToCSR
102+ Array.map2 ( fun column length ->
103+ let rowStart = rightMatrix.RowPointers.[ column]
104+ Array.take length valuesAndColumns.[ rowStart..]) leftMatrix.ColumnIndices segmentsLengths
105+ |> Array.concat
106+ |> Array.unzip
82107
83- let rightMatrix =
84- Utils.createMatrixFromArray2D CSR rightArray isZero
85- |> Utils.castMatrixToCSR
108+ let expectedValues = Array.map2 mulOp leftMatrixValues rightMatrixValues
109+
110+ expectedValues, expectedColumns, expectedRows
111+
112+ let makeExpandTest isEqual zero opMul testFun ( leftArray : 'a [,], rightArray : 'a [,]) =
113+
114+ let leftMatrix = createCSRMatrix leftArray <| isEqual zero
115+
116+ let rightMatrix = createCSRMatrix rightArray <| isEqual zero
86117
87118 if leftMatrix.NNZ > 0
88119 && rightMatrix.NNZ > 0 then
@@ -94,10 +125,43 @@ let makeExpandTest isZero testFun (leftArray: 'a [,], rightArray: 'a [,]) =
94125 let clRightMatrix = rightMatrix.ToDevice context
95126 let clSegmentPointers = context.CreateClArray segmentPointers
96127
97- let ( actualValues : ClArray < 'a >), ( actualColumns : ClArray < int >), ( actualRows : ClArray < int >) =
128+ let ( clActualValues : ClArray < 'a >), ( clActualColumns : ClArray < int >), ( clActualRows : ClArray < int >) =
98129 testFun processor length clSegmentPointers clLeftMatrix clRightMatrix
99130
100- clLeftMatrix.Free processor
101- clRightMatrix. processor
102- clSegmentPointers
131+ clLeftMatrix.Dispose processor
132+ clRightMatrix.Dispose processor
133+ clSegmentPointers.Free processor
134+
135+ let actualValues = clActualValues.ToHostAndFree processor
136+ let actualColumns = clActualColumns.ToHostAndFree processor
137+ let actualRows = clActualRows.ToHostAndFree processor
138+
139+ let expectedValues , expectedColumns , expectedRows =
140+ expand length segmentPointers opMul leftMatrix rightMatrix
141+
142+ " Values must be the same"
143+ |> Utils.compareArrays isEqual actualValues expectedValues
144+
145+ " Columns must be the same"
146+ |> Utils.compareArrays (=) actualColumns expectedColumns
147+
148+ " Rows must be the same"
149+ |> Utils.compareArrays (=) actualRows expectedRows
150+
151+ let createExpandTest isEqual ( zero : 'a ) opMul opMulQ testFun =
152+
153+ let testFun = testFun context Utils.defaultWorkGroupSize opMulQ
154+
155+ makeExpandTest isEqual zero opMul testFun
156+ |> testPropertyWithConfig config $" test on %A {typeof<'a>}"
157+
158+ let expandTests =
159+ [ createExpandTest (=) 0 (*) <@ (*) @> Expand.expand
160+
161+ if Utils.isFloat64Available context.ClDevice then
162+ createExpandTest Utils.floatIsEqual 0.0 (*) <@ (*) @> Expand.expand
103163
164+ createExpandTest Utils.float32IsEqual 0 f (*) <@ (*) @> Expand.expand
165+ createExpandTest (=) false (&&) <@ (&&) @> Expand.expand
166+ createExpandTest (=) 0 uy (*) <@ (*) @> Expand.expand ]
167+ |> testList " Expand.expand"
0 commit comments