Skip to content

Commit f77b4d2

Browse files
committed
add: requiredRawsLengths test
1 parent a136bd3 commit f77b4d2

7 files changed

Lines changed: 183 additions & 74 deletions

File tree

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
<Compile Include="Matrix/Common.fs" />
3939
<Compile Include="Matrix/COOMatrix/COOMatrix.fs" />
4040
<Compile Include="Matrix/CSRMatrix/Map2.fs" />
41-
<Compile Include="Matrix/CSRMatrix/SpGEMM.fs" />
41+
<Compile Include="Matrix\CSRMatrix\SpGEMMMasked.fs" />
4242
<Compile Include="Matrix/CSRMatrix/CSRMatrix.fs" />
4343
<Compile Include="Matrix\CSRMatrix\SpGEMM\Expand.fs" />
4444
<Compile Include="Matrix/Matrix.fs" />

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/CSRMatrix.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ module CSRMatrix =
250250
=
251251

252252
let run =
253-
SpGEMM.run clContext workGroupSize opAdd opMul
253+
SpGEMMMasked.run clContext workGroupSize opAdd opMul
254254

255255
fun (queue: MailboxProcessor<_>) (matrixLeft: ClMatrix.CSR<'a>) (matrixRight: ClMatrix.CSC<'b>) (mask: ClMatrix.COO<_>) ->
256256

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM/Expand.fs

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
namespace GraphBLAS.FSharp.Backend.Matrix.CSRMatrix.SpGEMM
1+
namespace GraphBLAS.FSharp.Backend.Matrix.CSR.SpGEMM
22

33
open Brahma.FSharp
44
open GraphBLAS.FSharp.Backend.Common
55
open GraphBLAS.FSharp.Backend.Predefined
66
open GraphBLAS.FSharp.Backend.Objects.ClContext
77
open GraphBLAS.FSharp.Backend.Objects
88
open GraphBLAS.FSharp.Backend.Objects.ClCell
9+
open FSharp.Quotations
910

1011
type Indices = ClArray<int>
1112

@@ -143,7 +144,6 @@ module Expand =
143144
)
144145

145146
processor.Post <| Msg.CreateRunMsg<_, _> kernel
146-
processor.Post <| Msg.CreateFreeMsg globalPositions
147147

148148
globalRightMatrixValuesPointers
149149

@@ -157,7 +157,7 @@ module Expand =
157157
if gid < globalLength then
158158
let valuePosition = globalPositions.[gid] - 1
159159

160-
result.[gid] <- rightMatrixValues.[valuePosition]@>
160+
result.[gid] <- rightMatrixValues.[valuePosition] @>
161161

162162
let kernel = clContext.Compile kernel
163163

@@ -184,11 +184,51 @@ module Expand =
184184
)
185185

186186
processor.Post <| Msg.CreateRunMsg<_, _> kernel
187-
processor.Post <| Msg.CreateFreeMsg globalPositions
188187

189188
resultLeftMatrixValues
190189

191-
let run (clContext: ClContext) workGroupSize multiplication =
190+
let getResultRowPointers (clContext: ClContext) workGroupSize =
191+
192+
let kernel =
193+
<@ fun (ndRange: Range1D) length (leftMatrixRowPointers: Indices) (globalArrayRightMatrixRawPointers: Indices) (result: Indices) ->
194+
195+
let gid = ndRange.GlobalID0
196+
197+
if gid < length then
198+
let rowPointer = leftMatrixRowPointers.[gid]
199+
let globalPointer = globalArrayRightMatrixRawPointers.[rowPointer]
200+
201+
result.[gid] <- globalPointer
202+
@>
203+
204+
let kernel = clContext.Compile kernel
205+
206+
fun (processor: MailboxProcessor<_>) (leftMatrixRowPointers: Indices) (globalArrayRightMatrixRawPointers: Indices) ->
207+
208+
let result =
209+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, leftMatrixRowPointers.Length)
210+
211+
let kernel = kernel.GetKernel()
212+
213+
let ndRange =
214+
Range1D.CreateValid( leftMatrixRowPointers.Length, workGroupSize)
215+
216+
processor.Post(
217+
Msg.MsgSetArguments
218+
(fun () ->
219+
kernel.KernelFunc
220+
ndRange
221+
leftMatrixRowPointers.Length
222+
leftMatrixRowPointers
223+
globalArrayRightMatrixRawPointers
224+
result)
225+
)
226+
227+
processor.Post <| Msg.CreateRunMsg<_, _> kernel
228+
229+
result
230+
231+
let run (clContext: ClContext) workGroupSize (multiplication: Expr<'a -> 'b -> 'c>) =
192232

193233
let getRequiredRawsLengths =
194234
processLeftMatrixColumnsAndRightMatrixRawPointers clContext workGroupSize requiredRawsLengths
@@ -199,11 +239,11 @@ module Expand =
199239
let getRequiredRightMatrixValuesPointers =
200240
processLeftMatrixColumnsAndRightMatrixRawPointers clContext workGroupSize requiredRawPointers
201241

242+
let getGlobalPositions = getGlobalPositions clContext workGroupSize
243+
202244
let getRightMatrixValuesPointers =
203245
getRightMatrixPointers clContext workGroupSize
204246

205-
let getGlobalPositions = getGlobalPositions clContext workGroupSize
206-
207247
let gatherRightMatrixData = Gather.run clContext workGroupSize
208248

209249
let gatherIndices = Gather.run clContext workGroupSize
@@ -213,6 +253,8 @@ module Expand =
213253

214254
let map2 = ClArray.map2 clContext workGroupSize multiplication
215255

256+
let getRawPointers = getResultRowPointers clContext workGroupSize
257+
216258
fun (processor: MailboxProcessor<_>) (leftMatrix: ClMatrix.CSR<'a>) (rightMatrix: ClMatrix.CSR<'b>) ->
217259

218260
let requiredRawsLengths =
@@ -252,9 +294,12 @@ module Expand =
252294

253295
// left matrix values correspondingly to right matrix values
254296
let extendedLeftMatrixValues =
255-
getLeftMatrixValues processor globalLength globalPositions rightMatrix.Values
297+
getLeftMatrixValues processor globalLength globalPositions leftMatrix.Values
256298

257299
let multiplicationResult =
258300
map2 processor DeviceOnly extendedLeftMatrixValues extendedRightMatrixValues
259301

260-
multiplicationResult, extendedRightMatrixColumns
302+
let rowPointers =
303+
getRawPointers processor leftMatrix.RowPointers globalRightMatrixValuesRawsStartPositions
304+
305+
multiplicationResult, extendedRightMatrixColumns, rowPointers

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM.fs renamed to src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMMMasked.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ open GraphBLAS.FSharp.Backend.Objects.ClMatrix
99
open GraphBLAS.FSharp.Backend.Objects.ClContext
1010
open GraphBLAS.FSharp.Backend.Objects.ClCell
1111

12-
module internal SpGEMM =
12+
module internal SpGEMMMasked =
1313
let private calculate
1414
(context: ClContext)
1515
workGroupSize

tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
<Compile Include="Matrix/Map2.fs" />
4646
<Compile Include="Matrix/Mxm.fs" />
4747
<Compile Include="Matrix/Transpose.fs" />
48+
<Compile Include="Matrix\SpGEMM\Expand.fs" />
4849
<Compile Include="Program.fs" />
4950
</ItemGroup>
5051
<Import Project="..\..\.paket\Paket.Restore.targets" />
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
module GraphBLAS.FSharp.Tests.Backend.Matrix.SpGEMM.Expand
2+
3+
open GraphBLAS.FSharp.Objects.Matrix
4+
open GraphBLAS.FSharp.Backend.Matrix.CSR.SpGEMM
5+
open GraphBLAS.FSharp.Tests
6+
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
7+
open Expecto
8+
9+
let context = Context.defaultContext
10+
11+
/// <remarks>
12+
/// Left matrix
13+
/// </remarks>
14+
/// <code>
15+
/// [ 0 0 2 3 0
16+
/// 0 0 0 0 0
17+
/// 0 8 0 5 4
18+
/// 0 0 2 0 0
19+
/// 1 7 0 0 0 ]
20+
/// </code>
21+
let leftMatrix =
22+
{ RowCount = 5
23+
ColumnCount = 5
24+
RowPointers = [| 0; 2; 2; 5; 6; 8 |]
25+
ColumnIndices = [| 2; 3; 1; 3; 4; 2; 0; 1|]
26+
Values = [| 2; 3; 8; 5; 4; 2; 1; 7 |] }
27+
28+
/// <remarks>
29+
/// Right matrix
30+
/// </remarks>
31+
/// <code>
32+
/// [ 0 0 0 0 0 0 0
33+
/// 0 3 0 0 4 0 4
34+
/// 0 0 2 0 0 2 0
35+
/// 0 5 0 0 0 9 1
36+
/// 0 0 0 0 1 0 8 ]
37+
/// </code>
38+
let rightMatrix =
39+
{ RowCount = 5
40+
ColumnCount = 7
41+
RowPointers = [| 0; 0; 3; 5; 8; 10 |]
42+
ColumnIndices = [| 1; 4; 6; 2; 5; 1; 5; 6; 4; 6 |]
43+
Values = [| 3; 4; 4; 2; 2; 5; 9; 1; 1; 8 |] }
44+
45+
let requiredRowLength =
46+
testCase "requiredRowLength"
47+
<| fun () ->
48+
let clContext = context.ClContext
49+
let processor = context.Queue
50+
51+
let deviceLeftMatrix = leftMatrix.ToDevice clContext
52+
let deviceRightMatrix = rightMatrix.ToDevice clContext
53+
54+
let getRequiredRawsLengths =
55+
Expand.processLeftMatrixColumnsAndRightMatrixRawPointers clContext Utils.defaultWorkGroupSize Expand.requiredRawsLengths
56+
57+
let requiredRawsLengths =
58+
getRequiredRawsLengths processor deviceLeftMatrix.Columns deviceRightMatrix.RowPointers
59+
60+
let requiredRawsLengthsHost = requiredRawsLengths.ToHost processor
61+
62+
"Results must be the same"
63+
|> Expect.equal requiredRawsLengthsHost [| 2; 3; 3; 3; 2; 2; 0; 3 |]
64+
65+
66+

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 59 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,70 @@
11
open Expecto
22
open GraphBLAS.FSharp.Tests.Backend
33

4-
let matrixTests =
5-
testList
6-
"Matrix tests"
7-
[ Matrix.Convert.tests
8-
Matrix.Map2.addTests
9-
Matrix.Map2.addAtLeastOneTests
10-
Matrix.Map2.mulAtLeastOneTests
11-
Matrix.Map2.addAtLeastOneToCOOTests
12-
Matrix.Mxm.tests
13-
Matrix.Transpose.tests ]
14-
|> testSequenced
15-
16-
let commonTests =
17-
let clArrayTests =
18-
testList
19-
"ClArray"
20-
[ Common.ClArray.PrefixSum.tests
21-
Common.ClArray.RemoveDuplicates.tests
22-
Common.ClArray.Copy.tests
23-
Common.ClArray.Replicate.tests
24-
Common.ClArray.Exists.tests
25-
Common.ClArray.Map.tests
26-
Common.ClArray.Map2.addTests
27-
Common.ClArray.Map2.mulTests
28-
Common.ClArray.Choose.tests ]
29-
30-
testList
31-
"Common tests"
32-
[ clArrayTests
33-
Common.BitonicSort.tests
34-
Common.Scatter.tests
35-
Common.Reduce.tests
36-
Common.Sum.tests ]
37-
|> testSequenced
38-
39-
let vectorTests =
40-
testList
41-
"Vector tests"
42-
[ Vector.SpMV.tests
43-
Vector.ZeroCreate.tests
44-
Vector.OfList.tests
45-
Vector.Copy.tests
46-
Vector.Convert.tests
47-
Vector.Map2.addTests
48-
Vector.Map2.mulTests
49-
Vector.Map2.addAtLeastOneTests
50-
Vector.Map2.mulAtLeastOneTests
51-
Vector.Map2.addGeneralTests
52-
Vector.Map2.mulGeneralTests
53-
Vector.Map2.complementedGeneralTests
54-
Vector.AssignByMask.tests
55-
Vector.AssignByMask.complementedTests
56-
Vector.Reduce.tests ]
57-
|> testSequenced
58-
59-
let algorithmsTests =
60-
testList "Algorithms tests" [ Algorithms.BFS.tests ]
61-
|> testSequenced
4+
// let matrixTests =
5+
// testList
6+
// "Matrix tests"
7+
// [ Matrix.Convert.tests
8+
// Matrix.Map2.addTests
9+
// Matrix.Map2.addAtLeastOneTests
10+
// Matrix.Map2.mulAtLeastOneTests
11+
// Matrix.Map2.addAtLeastOneToCOOTests
12+
// Matrix.Mxm.tests
13+
// Matrix.Transpose.tests ]
14+
// |> testSequenced
15+
//
16+
// let commonTests =
17+
// let clArrayTests =
18+
// testList
19+
// "ClArray"
20+
// [ Common.ClArray.PrefixSum.tests
21+
// Common.ClArray.RemoveDuplicates.tests
22+
// Common.ClArray.Copy.tests
23+
// Common.ClArray.Replicate.tests
24+
// Common.ClArray.Exists.tests
25+
// Common.ClArray.Map.tests
26+
// Common.ClArray.Map2.addTests
27+
// Common.ClArray.Map2.mulTests
28+
// Common.ClArray.Choose.tests ]
29+
//
30+
// testList
31+
// "Common tests"
32+
// [ clArrayTests
33+
// Common.BitonicSort.tests
34+
// Common.Scatter.tests
35+
// Common.Reduce.tests
36+
// Common.Sum.tests ]
37+
// |> testSequenced
38+
//
39+
// let vectorTests =
40+
// testList
41+
// "Vector tests"
42+
// [ Vector.SpMV.tests
43+
// Vector.ZeroCreate.tests
44+
// Vector.OfList.tests
45+
// Vector.Copy.tests
46+
// Vector.Convert.tests
47+
// Vector.Map2.addTests
48+
// Vector.Map2.mulTests
49+
// Vector.Map2.addAtLeastOneTests
50+
// Vector.Map2.mulAtLeastOneTests
51+
// Vector.Map2.addGeneralTests
52+
// Vector.Map2.mulGeneralTests
53+
// Vector.Map2.complementedGeneralTests
54+
// Vector.AssignByMask.tests
55+
// Vector.AssignByMask.complementedTests
56+
// Vector.Reduce.tests ]
57+
// |> testSequenced
58+
//
59+
// let algorithmsTests =
60+
// testList "Algorithms tests" [ Algorithms.BFS.tests ]
61+
// |> testSequenced
6262

6363
[<Tests>]
6464
let allTests =
6565
testList
6666
"All tests"
67-
[ commonTests
68-
matrixTests
69-
vectorTests
70-
algorithmsTests ]
67+
[ Matrix.SpGEMM.Expand.requiredRowLength ]
7168
|> testSequenced
7269

7370
[<EntryPoint>]

0 commit comments

Comments
 (0)