Skip to content

Commit 7e09219

Browse files
committed
wip: expand tests passed
1 parent b679650 commit 7e09219

3 files changed

Lines changed: 84 additions & 22 deletions

File tree

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM/Expand.fs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,6 @@ module Expand =
7777

7878
length, segmentsLengths
7979

80-
let
81-
8280
let expand (clContext: ClContext) workGroupSize opMul =
8381

8482
let init = ClArray.init clContext workGroupSize Map.id

tests/GraphBLAS-sharp.Tests/Matrix/SpGeMM.fs

Lines changed: 83 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,20 @@ let processor = Context.defaultContext.Queue
1919

2020
let config = { Utils.defaultConfig with arbitrary = [ typeof<Generators.PairOfMatricesOfCompatibleSize> ] }
2121

22+
let createCSRMatrix array isZero =
23+
Utils.createMatrixFromArray2D CSR array isZero
24+
|> Utils.castMatrixToCSR
25+
2226
let getSegmentsPointers (leftMatrix: Matrix.CSR<'a>) (rightMatrix: Matrix.CSR<'b>) =
2327
Array.map (fun item ->
2428
rightMatrix.RowPointers.[item + 1] - rightMatrix.RowPointers.[item]) leftMatrix.ColumnIndices
2529
|> HostPrimitives.prefixSumExclude
2630

2731
let makeTest isZero testFun (leftArray: 'a [,], rightArray: 'a [,]) =
2832

29-
let leftMatrix =
30-
Utils.createMatrixFromArray2D CSR leftArray isZero
31-
|> Utils.castMatrixToCSR
33+
let leftMatrix = createCSRMatrix leftArray isZero
3234

33-
let rightMatrix =
34-
Utils.createMatrixFromArray2D CSR rightArray isZero
35-
|> Utils.castMatrixToCSR
35+
let rightMatrix = createCSRMatrix rightArray isZero
3636

3737
if leftMatrix.NNZ > 0 && rightMatrix.NNZ > 0 then
3838

@@ -44,6 +44,7 @@ let makeTest isZero testFun (leftArray: 'a [,], rightArray: 'a [,]) =
4444
testFun processor clLeftMatrix clRightMatrix
4545

4646
clLeftMatrix.Dispose processor
47+
clRightMatrix.Dispose processor
4748

4849
let actualPointers = clActual.ToHostAndFree processor
4950

@@ -61,7 +62,7 @@ let createTest<'a when 'a : struct> (isZero: 'a -> bool) testFun =
6162
let testFun = testFun context Utils.defaultWorkGroupSize
6263

6364
makeTest isZero testFun
64-
|> testPropertyWithConfig { config with endSize = 10 } $"test on {typeof<'a>}"
65+
|> testPropertyWithConfig config $"test on {typeof<'a>}"
6566

6667
let getSegmentsTests =
6768
[ createTest ((=) 0) Expand.getSegmentPointers
@@ -71,18 +72,48 @@ let getSegmentsTests =
7172

7273
createTest ((=) 0f) Expand.getSegmentPointers
7374
createTest ((=) false) Expand.getSegmentPointers
74-
createTest ((=) 0u) Expand.getSegmentPointers ]
75+
createTest ((=) 0uy) Expand.getSegmentPointers ]
7576
|> testList "get segment pointers"
7677

77-
let makeExpandTest isZero testFun (leftArray: 'a [,], rightArray: 'a [,]) =
78+
let expand length segmentPointers mulOp (leftMatrix: Matrix.CSR<'a>) (rightMatrix: Matrix.CSR<'b>) =
79+
let extendPointers pointers =
80+
Array.pairwise pointers
81+
|> Array.map (fun (fst, snd) -> snd - fst)
82+
|> Array.mapi (fun index length -> Array.create length index)
83+
|> Array.concat
84+
85+
let segmentsLengths =
86+
Array.append segmentPointers [| length |]
87+
|> Array.pairwise
88+
|> Array.map (fun (fst, snd) -> snd - fst)
89+
90+
let leftMatrixValues, expectedRows =
91+
let tripleFst (fst, _, _) = fst
92+
93+
Array.zip3 segmentsLengths leftMatrix.Values <| extendPointers leftMatrix.RowPointers // TODO(expand row pointers)
94+
// select items each segment length not zero
95+
|> Array.filter (tripleFst >> ((=) 0) >> not)
96+
|> Array.collect (fun (length, value, rowIndex) -> Array.create length (value, rowIndex))
97+
|> Array.unzip
98+
99+
let rightMatrixValues, expectedColumns =
100+
let valuesAndColumns = Array.zip rightMatrix.Values rightMatrix.ColumnIndices
78101

79-
let leftMatrix =
80-
Utils.createMatrixFromArray2D CSR leftArray isZero
81-
|> Utils.castMatrixToCSR
102+
Array.map2 (fun column length ->
103+
let rowStart = rightMatrix.RowPointers.[column]
104+
Array.take length valuesAndColumns.[rowStart..]) leftMatrix.ColumnIndices segmentsLengths
105+
|> Array.concat
106+
|> Array.unzip
82107

83-
let rightMatrix =
84-
Utils.createMatrixFromArray2D CSR rightArray isZero
85-
|> Utils.castMatrixToCSR
108+
let expectedValues = Array.map2 mulOp leftMatrixValues rightMatrixValues
109+
110+
expectedValues, expectedColumns, expectedRows
111+
112+
let makeExpandTest isEqual zero opMul testFun (leftArray: 'a [,], rightArray: 'a [,]) =
113+
114+
let leftMatrix = createCSRMatrix leftArray <| isEqual zero
115+
116+
let rightMatrix = createCSRMatrix rightArray <| isEqual zero
86117

87118
if leftMatrix.NNZ > 0
88119
&& rightMatrix.NNZ > 0 then
@@ -94,10 +125,43 @@ let makeExpandTest isZero testFun (leftArray: 'a [,], rightArray: 'a [,]) =
94125
let clRightMatrix = rightMatrix.ToDevice context
95126
let clSegmentPointers = context.CreateClArray segmentPointers
96127

97-
let (actualValues: ClArray<'a>), (actualColumns: ClArray<int>), (actualRows: ClArray<int>) =
128+
let (clActualValues: ClArray<'a>), (clActualColumns: ClArray<int>), (clActualRows: ClArray<int>) =
98129
testFun processor length clSegmentPointers clLeftMatrix clRightMatrix
99130

100-
clLeftMatrix.Free processor
101-
clRightMatrix. processor
102-
clSegmentPointers
131+
clLeftMatrix.Dispose processor
132+
clRightMatrix.Dispose processor
133+
clSegmentPointers.Free processor
134+
135+
let actualValues = clActualValues.ToHostAndFree processor
136+
let actualColumns = clActualColumns.ToHostAndFree processor
137+
let actualRows = clActualRows.ToHostAndFree processor
138+
139+
let expectedValues, expectedColumns, expectedRows =
140+
expand length segmentPointers opMul leftMatrix rightMatrix
141+
142+
"Values must be the same"
143+
|> Utils.compareArrays isEqual actualValues expectedValues
144+
145+
"Columns must be the same"
146+
|> Utils.compareArrays (=) actualColumns expectedColumns
147+
148+
"Rows must be the same"
149+
|> Utils.compareArrays (=) actualRows expectedRows
150+
151+
let createExpandTest isEqual (zero: 'a) opMul opMulQ testFun =
152+
153+
let testFun = testFun context Utils.defaultWorkGroupSize opMulQ
154+
155+
makeExpandTest isEqual zero opMul testFun
156+
|> testPropertyWithConfig config $"test on %A{typeof<'a>}"
157+
158+
let expandTests =
159+
[ createExpandTest (=) 0 (*) <@ (*) @> Expand.expand
160+
161+
if Utils.isFloat64Available context.ClDevice then
162+
createExpandTest Utils.floatIsEqual 0.0 (*) <@ (*) @> Expand.expand
103163

164+
createExpandTest Utils.float32IsEqual 0f (*) <@ (*) @> Expand.expand
165+
createExpandTest (=) false (&&) <@ (&&) @> Expand.expand
166+
createExpandTest (=) 0uy (*) <@ (*) @> Expand.expand ]
167+
|> testList "Expand.expand"

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ open GraphBLAS.FSharp.Tests.Matrix
9494
let allTests =
9595
testList
9696
"All tests"
97-
[ SpGeMM.getSegmentsTests ]
97+
[ SpGeMM.expandTests ]
9898

9999
|> testSequenced
100100

0 commit comments

Comments
 (0)