Skip to content

Commit b679650

Browse files
committed
wip: segments computing tests
1 parent 91a72e2 commit b679650

7 files changed

Lines changed: 51 additions & 43 deletions

File tree

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM/Expand.fs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,30 +45,22 @@ module Expand =
4545

4646
positions.Free processor
4747

48-
printfn $"first pointers gpu: %A{firstPointers.ToHost processor}"
49-
5048
// extract last rightMatrix.RowPointers.Lengths - 1 indices from rightMatrix.RowPointers
5149
// (right matrix row pointers without first item)
5250
let shiftedPositions = // TODO(fuse)
5351
createShifted processor DeviceOnly positionsLength
5452

55-
printfn "shifted positions gpu: %A" <| shiftedPositions.ToHost processor
56-
5753
let lastPointers =
5854
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, positionsLength)
5955

6056
gather processor shiftedPositions rightMatrix.RowPointers lastPointers
6157

62-
printfn $"last pointers gpu: %A{lastPointers.ToHost processor}"
63-
6458
shiftedPositions.Free processor
6559

6660
// subtract
6761
let rightMatrixRowsLengths =
6862
subtract processor DeviceOnly lastPointers firstPointers
6963

70-
printfn $"subtract result gpu: %A{rightMatrixRowsLengths.ToHost processor}"
71-
7264
firstPointers.Free processor
7365
lastPointers.Free processor
7466

@@ -78,17 +70,15 @@ module Expand =
7870
// extract needed lengths by left matrix nnz
7971
gather processor leftMatrix.Columns rightMatrixRowsLengths segmentsLengths
8072

81-
printfn $"subtract after gather result gpu: %A{segmentsLengths.ToHost processor}"
82-
8373
rightMatrixRowsLengths.Free processor
8474

8575
// compute pointers
8676
let length = (prefixSum processor segmentsLengths).ToHostAndFree processor
8777

88-
printfn $"subtract after prefix sum gpu: %A{segmentsLengths.ToHost processor}"
89-
9078
length, segmentsLengths
9179

80+
let
81+
9282
let expand (clContext: ClContext) workGroupSize opMul =
9383

9484
let init = ClArray.init clContext workGroupSize Map.id

src/GraphBLAS-sharp.Backend/Objects/Matrix.fs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ module ClMatrix =
2323
q.Post(Msg.CreateFreeMsg<_>(this.RowPointers))
2424
q.PostAndReply(Msg.MsgNotifyMe)
2525

26+
member this.Dispose q = (this :> IDeviceMemObject).Dispose q
27+
2628
member this.NNZ = this.Values.Length
2729

2830
member this.ToCSC =
@@ -48,6 +50,8 @@ module ClMatrix =
4850
q.Post(Msg.CreateFreeMsg<_>(this.ColumnPointers))
4951
q.PostAndReply(Msg.MsgNotifyMe)
5052

53+
member this.Dispose q = (this :> IDeviceMemObject).Dispose q
54+
5155
member this.NNZ = this.Values.Length
5256

5357
member this.ToCSR =
@@ -73,6 +77,8 @@ module ClMatrix =
7377
q.Post(Msg.CreateFreeMsg<_>(this.Rows))
7478
q.PostAndReply(Msg.MsgNotifyMe)
7579

80+
member this.Dispose q = (this :> IDeviceMemObject).Dispose q
81+
7682
member this.NNZ = this.Values.Length
7783

7884
type Tuple<'elem when 'elem: struct> =
@@ -88,6 +94,8 @@ module ClMatrix =
8894
q.Post(Msg.CreateFreeMsg<_>(this.Values))
8995
q.PostAndReply(Msg.MsgNotifyMe)
9096

97+
member this.Dispose q = (this :> IDeviceMemObject).Dispose q
98+
9199
member this.NNZ = this.Values.Length
92100

93101
[<RequireQualifiedAccess>]
@@ -110,9 +118,9 @@ type ClMatrix<'a when 'a: struct> =
110118

111119
member this.Dispose q =
112120
match this with
113-
| ClMatrix.CSR matrix -> (matrix :> IDeviceMemObject).Dispose q
114-
| ClMatrix.COO matrix -> (matrix :> IDeviceMemObject).Dispose q
115-
| ClMatrix.CSC matrix -> (matrix :> IDeviceMemObject).Dispose q
121+
| ClMatrix.CSR matrix -> matrix.Dispose q
122+
| ClMatrix.COO matrix -> matrix.Dispose q
123+
| ClMatrix.CSC matrix -> matrix.Dispose q
116124

117125
member this.NNZ =
118126
match this with

tests/GraphBLAS-sharp.Tests/Common/Gather.fs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ let processor = Context.defaultContext.Queue
1414
let check isEqual actual positions values target =
1515

1616
HostPrimitives.gather positions values target
17+
|> ignore
1718

1819
"Results must be the same"
1920
|> Utils.compareArrays isEqual actual target

tests/GraphBLAS-sharp.Tests/Generators.fs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,6 @@ module Generators =
389389
valuesGenerator
390390
|> Gen.array2DOfDim (nColsA, nColsB)
391391

392-
printf $"left matrix column count: %A{Array2D.length1 matrixA}"
393-
printf $"right matrix row count: %A{Array2D.length2 matrixA}"
394-
395392
return (matrixA, matrixB)
396393
}
397394

tests/GraphBLAS-sharp.Tests/Helpers.fs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ module HostPrimitives =
225225

226226
result
227227

228-
229228
module Context =
230229
type TestContext =
231230
{ ClContext: ClContext

tests/GraphBLAS-sharp.Tests/Matrix/SpGeMM.fs

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module GraphBLAS.FSharp.Tests.Matrix.SpGeMM
22

33
open Expecto
44
open GraphBLAS.FSharp.Backend.Matrix.CSR.SpGeMM
5+
open GraphBLAS.FSharp.Test
56
open Microsoft.FSharp.Collections
67
open GraphBLAS.FSharp.Backend
78
open GraphBLAS.FSharp.Backend.Matrix
@@ -16,29 +17,14 @@ let context = Context.defaultContext.ClContext
1617

1718
let processor = Context.defaultContext.Queue
1819

19-
let getSegmentsPointers (leftMatrix: Matrix.CSR<'a>) (rightMatrix: Matrix.CSR<'b>) =
20-
printfn $"all: %A{rightMatrix.RowPointers}"
21-
22-
let firstRowPointers =
23-
rightMatrix.RowPointers.[..rightMatrix.RowPointers.Length - 2]
24-
25-
printfn $"first pointers: %A{firstRowPointers}"
26-
27-
let lastRowPointers = rightMatrix.RowPointers.[1..]
28-
29-
printfn $"last pointers: %A{lastRowPointers}"
30-
31-
let rowsLengths = Array.map2 (-) lastRowPointers firstRowPointers
32-
33-
printfn $"all row lengths %A{rowsLengths}"
34-
35-
let neededLengths = Array.init leftMatrix.ColumnIndices.Length (fun index -> Array.item index rowsLengths)
36-
37-
printfn $"needed lengths %A{neededLengths}"
20+
let config = { Utils.defaultConfig with arbitrary = [ typeof<Generators.PairOfMatricesOfCompatibleSize> ] }
3821

39-
HostPrimitives.prefixSumExclude neededLengths
22+
let getSegmentsPointers (leftMatrix: Matrix.CSR<'a>) (rightMatrix: Matrix.CSR<'b>) =
23+
Array.map (fun item ->
24+
rightMatrix.RowPointers.[item + 1] - rightMatrix.RowPointers.[item]) leftMatrix.ColumnIndices
25+
|> HostPrimitives.prefixSumExclude
4026

41-
let makeTest isZero testFun (leftArray: 'a [,], rightArray: 'a [,], _: bool [,]) =
27+
let makeTest isZero testFun (leftArray: 'a [,], rightArray: 'a [,]) =
4228

4329
let leftMatrix =
4430
Utils.createMatrixFromArray2D CSR leftArray isZero
@@ -57,6 +43,8 @@ let makeTest isZero testFun (leftArray: 'a [,], rightArray: 'a [,], _: bool [,])
5743
let actualLength, (clActual: ClArray<int>) =
5844
testFun processor clLeftMatrix clRightMatrix
5945

46+
clLeftMatrix.Dispose processor
47+
6048
let actualPointers = clActual.ToHostAndFree processor
6149

6250
let expectedPointers, expectedLength =
@@ -73,7 +61,7 @@ let createTest<'a when 'a : struct> (isZero: 'a -> bool) testFun =
7361
let testFun = testFun context Utils.defaultWorkGroupSize
7462

7563
makeTest isZero testFun
76-
|> testPropertyWithConfig { Utils.defaultConfig with endSize = 10 } $"test on {typeof<'a>}"
64+
|> testPropertyWithConfig { config with endSize = 10 } $"test on {typeof<'a>}"
7765

7866
let getSegmentsTests =
7967
[ createTest ((=) 0) Expand.getSegmentPointers
@@ -86,4 +74,30 @@ let getSegmentsTests =
8674
createTest ((=) 0u) Expand.getSegmentPointers ]
8775
|> testList "get segment pointers"
8876

77+
let makeExpandTest isZero testFun (leftArray: 'a [,], rightArray: 'a [,]) =
78+
79+
let leftMatrix =
80+
Utils.createMatrixFromArray2D CSR leftArray isZero
81+
|> Utils.castMatrixToCSR
82+
83+
let rightMatrix =
84+
Utils.createMatrixFromArray2D CSR rightArray isZero
85+
|> Utils.castMatrixToCSR
86+
87+
if leftMatrix.NNZ > 0
88+
&& rightMatrix.NNZ > 0 then
89+
90+
let segmentPointers, length =
91+
getSegmentsPointers leftMatrix rightMatrix
92+
93+
let clLeftMatrix = leftMatrix.ToDevice context
94+
let clRightMatrix = rightMatrix.ToDevice context
95+
let clSegmentPointers = context.CreateClArray segmentPointers
96+
97+
let (actualValues: ClArray<'a>), (actualColumns: ClArray<int>), (actualRows: ClArray<int>) =
98+
testFun processor length clSegmentPointers clLeftMatrix clRightMatrix
99+
100+
clLeftMatrix.Free processor
101+
clRightMatrix. processor
102+
clSegmentPointers
89103

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,7 @@ open GraphBLAS.FSharp.Tests.Matrix
9494
let allTests =
9595
testList
9696
"All tests"
97-
[ Common.Scatter.tests
98-
Common.Gather.tests ]
97+
[ SpGeMM.getSegmentsTests ]
9998

10099
|> testSequenced
101100

0 commit comments

Comments
 (0)