Skip to content

Commit d22621e

Browse files
committed
wip: module Expand test
1 parent f77b4d2 commit d22621e

4 files changed

Lines changed: 102 additions & 16 deletions

File tree

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM/Expand.fs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,19 +266,19 @@ module Expand =
266266

267267
// since prefix sum include
268268
// positions in global array for right matrix
269-
let globalRightMatrixValuesRawsStartPositions = requiredRawsLengths
269+
let globalRightMatrixRawsStartPositions = requiredRawsLengths
270270

271271
// pointers to required raws in right matrix values
272272
let requiredRightMatrixValuesPointers =
273273
getRequiredRightMatrixValuesPointers processor leftMatrix.Columns rightMatrix.RowPointers
274274

275275
// bitmap to distinguish different raws in a general array
276276
let globalPositions =
277-
getGlobalPositions processor globalLength globalRightMatrixValuesRawsStartPositions
277+
getGlobalPositions processor globalLength globalRightMatrixRawsStartPositions
278278

279279
// extended pointers to all required right matrix numbers
280280
let globalRightMatrixValuesPointers =
281-
getRightMatrixValuesPointers processor globalLength globalPositions globalRightMatrixValuesRawsStartPositions requiredRightMatrixValuesPointers
281+
getRightMatrixValuesPointers processor globalLength globalPositions globalRightMatrixRawsStartPositions requiredRightMatrixValuesPointers
282282

283283
// gather all required right matrix values
284284
let extendedRightMatrixValues =
@@ -300,6 +300,6 @@ module Expand =
300300
map2 processor DeviceOnly extendedLeftMatrixValues extendedRightMatrixValues
301301

302302
let rowPointers =
303-
getRawPointers processor leftMatrix.RowPointers globalRightMatrixValuesRawsStartPositions
303+
getRawPointers processor leftMatrix.RowPointers globalRightMatrixRawsStartPositions
304304

305305
multiplicationResult, extendedRightMatrixColumns, rowPointers

src/GraphBLAS-sharp.Backend/Objects/ArraysExtentions.fs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ module ArraysExtensions =
1313
let dst = Array.zeroCreate this.Length
1414
q.PostAndReply(fun ch -> Msg.CreateToHostMsg(this, dst, ch))
1515

16+
member this.ToHostAndFree(q: MailboxProcessor<Msg>) =
17+
let result = this.ToHost q
18+
this.Dispose q
19+
20+
result
21+
1622
member this.Size = this.Length
1723

1824
type 'a ``[]`` with

tests/GraphBLAS-sharp.Tests/Matrix/SpGEMM/Expand.fs

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@ open GraphBLAS.FSharp.Backend.Matrix.CSR.SpGEMM
55
open GraphBLAS.FSharp.Tests
66
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
77
open Expecto
8+
open GraphBLAS.FSharp.Backend.Common
9+
open GraphBLAS.FSharp.Backend.Predefined
10+
open GraphBLAS.FSharp.Backend.Objects.ClCell
811

912
let context = Context.defaultContext
1013

14+
let clContext = context.ClContext
15+
let processor = context.Queue
16+
1117
/// <remarks>
1218
/// Left matrix
1319
/// </remarks>
@@ -42,25 +48,94 @@ let rightMatrix =
4248
ColumnIndices = [| 1; 4; 6; 2; 5; 1; 5; 6; 4; 6 |]
4349
Values = [| 3; 4; 4; 2; 2; 5; 9; 1; 1; 8 |] }
4450

45-
let requiredRowLength =
51+
let deviceLeftMatrix = leftMatrix.ToDevice clContext
52+
let deviceRightMatrix = rightMatrix.ToDevice clContext
53+
54+
let requiredRawsLengths () =
55+
let getRequiredRawsLengths =
56+
Expand.processLeftMatrixColumnsAndRightMatrixRawPointers clContext Utils.defaultWorkGroupSize Expand.requiredRawsLengths
57+
58+
getRequiredRawsLengths processor deviceLeftMatrix.Columns deviceRightMatrix.RowPointers
59+
60+
let requiredRowLengthTest =
4661
testCase "requiredRowLength"
4762
<| fun () ->
48-
let clContext = context.ClContext
49-
let processor = context.Queue
63+
let actual = requiredRawsLengths().ToHostAndFree processor
64+
65+
"Results must be the same"
66+
|> Expect.equal actual [| 2; 3; 3; 3; 2; 2; 0; 3 |]
67+
68+
let globalLength =
69+
let prefixSumExclude =
70+
PrefixSum.standardExcludeInplace clContext Utils.defaultWorkGroupSize
5071

51-
let deviceLeftMatrix = leftMatrix.ToDevice clContext
52-
let deviceRightMatrix = rightMatrix.ToDevice clContext
72+
let requiredRawsLengths = requiredRawsLengths ()
5373

54-
let getRequiredRawsLengths =
55-
Expand.processLeftMatrixColumnsAndRightMatrixRawPointers clContext Utils.defaultWorkGroupSize Expand.requiredRawsLengths
74+
(prefixSumExclude processor requiredRawsLengths).ToHostAndFree processor
5675

57-
let requiredRawsLengths =
58-
getRequiredRawsLengths processor deviceLeftMatrix.Columns deviceRightMatrix.RowPointers
76+
let globalLengthTest =
77+
testCase "global length test"
78+
<| fun () -> Expect.equal globalLength 18 "Results must be the same"
5979

60-
let requiredRawsLengthsHost = requiredRawsLengths.ToHost processor
80+
let getGlobalRightMatrixRawsStartPositions () =
81+
let prefixSumExclude =
82+
PrefixSum.standardExcludeInplace clContext Utils.defaultWorkGroupSize
83+
84+
let requiredRawsLengths = requiredRawsLengths ()
85+
86+
(prefixSumExclude processor requiredRawsLengths).Free processor
87+
88+
requiredRawsLengths
89+
90+
let globalRightMatrixRawsStartPositionsTest =
91+
testCase "global right matrix raws start positions"
92+
<| fun () ->
93+
let result = (getGlobalRightMatrixRawsStartPositions ()).ToHostAndFree processor
6194

6295
"Results must be the same"
63-
|> Expect.equal requiredRawsLengthsHost [| 2; 3; 3; 3; 2; 2; 0; 3 |]
96+
|> Expect.equal result [| 0; 2; 5; 8; 11; 13; 15; 15; |]
97+
98+
let getRequiredRightMatrixValuesPointers () =
99+
let getRequiredRightMatrixValuesPointers =
100+
Expand.processLeftMatrixColumnsAndRightMatrixRawPointers clContext Utils.defaultWorkGroupSize Expand.requiredRawPointers
64101

102+
getRequiredRightMatrixValuesPointers processor deviceLeftMatrix.Columns deviceRightMatrix.RowPointers
65103

104+
let getRequiredRightMatrixValuesPointersTest =
105+
testCase "get required right matrix values pointers"
106+
<| fun () ->
107+
let result = (getRequiredRightMatrixValuesPointers ()).ToHostAndFree processor
108+
109+
"Result must be the same"
110+
|> Expect.equal result [| 3; 5; 0; 5; 8; 3; 0; 0; |]
111+
112+
let getGlobalPositions () =
113+
let getGlobalPositions = Expand.getGlobalPositions clContext Utils.defaultWorkGroupSize
114+
115+
getGlobalPositions processor globalLength (getGlobalRightMatrixRawsStartPositions ())
116+
117+
let getGlobalPositionsTest =
118+
testCase "getGlobalPositions test"
119+
<| fun () ->
120+
let result = (getGlobalPositions ()).ToHostAndFree processor
121+
122+
"Result must be the same"
123+
|> Expect.equal result [| 1; 1; 2; 2; 2; 3; 3; 3; 4; 4; 4; 5; 5; 6; 6; 7; 7; 7; |]
124+
125+
let getRightMatrixValuesPointers () =
126+
let getRightMatrixValuesPointers =
127+
Expand.getRightMatrixPointers clContext Utils.defaultWorkGroupSize
128+
129+
let globalPositions = getGlobalPositions ()
130+
let globalRightMatrixRawsStartPositions = getGlobalRightMatrixRawsStartPositions ()
131+
let requiredRightMatrixValuesPointers = getRequiredRightMatrixValuesPointers ()
132+
133+
getRightMatrixValuesPointers processor globalLength globalPositions globalRightMatrixRawsStartPositions requiredRightMatrixValuesPointers
134+
135+
let rightMatrixValuesPointersTest =
136+
testCase "RightMatrixValuesPointers"
137+
<| fun () ->
138+
let result = (getRightMatrixValuesPointers ()).ToHostAndFree processor
66139

140+
"Result must be the same"
141+
|> Expect.equal result [| 3; 4; 5; 6; 7; 0; 1; 2; 5; 6; 7; 8; 9; 3; 4; 0; 1; 2; |]

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,12 @@ open GraphBLAS.FSharp.Tests.Backend
6464
let allTests =
6565
testList
6666
"All tests"
67-
[ Matrix.SpGEMM.Expand.requiredRowLength ]
67+
[ Matrix.SpGEMM.Expand.requiredRowLengthTest
68+
Matrix.SpGEMM.Expand.globalLengthTest
69+
Matrix.SpGEMM.Expand.globalRightMatrixRawsStartPositionsTest
70+
Matrix.SpGEMM.Expand.getRequiredRightMatrixValuesPointersTest
71+
Matrix.SpGEMM.Expand.getGlobalPositionsTest
72+
Matrix.SpGEMM.Expand.rightMatrixValuesPointersTest ]
6873
|> testSequenced
6974

7075
[<EntryPoint>]

0 commit comments

Comments
 (0)