Skip to content

Commit 2907167

Browse files
committed
refactor: kronecker
1 parent fc55f85 commit 2907167

3 files changed

Lines changed: 123 additions & 124 deletions

File tree

src/GraphBLAS-sharp.Backend/Matrix/CSR/Kronecker.fs

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ open Microsoft.FSharp.Quotations
66
open Brahma.FSharp
77
open GraphBLAS.FSharp.Backend.Quotes
88
open GraphBLAS.FSharp.Backend.Common
9-
open GraphBLAS.FSharp.Backend.Objects
109
open GraphBLAS.FSharp.Backend.Matrix.COO
1110
open GraphBLAS.FSharp.Backend.Matrix.CSR
1211
open GraphBLAS.FSharp.Backend.Objects.ClCell
@@ -24,19 +23,25 @@ module internal Kronecker =
2423

2524
if gid = 0 then
2625

26+
let item = resultBitmap.[0]
27+
let newItem = item + zeroCount
28+
2729
match (%op) (Some operand.Value) None with
28-
| Some _ -> resultBitmap.[0] <- resultBitmap.[0] + zeroCount
30+
| Some _ -> resultBitmap.[0] <- newItem
2931
| _ -> ()
3032

3133
else if (gid - 1) < valuesLength then
3234

35+
let item = resultBitmap.[gid]
36+
let newItem = item + 1
37+
3338
match (%op) (Some operand.Value) (Some values.[gid - 1]) with
34-
| Some _ -> resultBitmap.[gid] <- resultBitmap.[gid] + 1
39+
| Some _ -> resultBitmap.[gid] <- newItem
3540
| _ -> () @>
3641

3742
let updateBitmap = clContext.Compile <| updateBitmap op
3843

39-
fun (processor: MailboxProcessor<_>) (operand: ClCell<'a>) (matrixRight: ClMatrix.CSR<'b>) (bitmap: ClArray<int>) ->
44+
fun (processor: MailboxProcessor<_>) (operand: ClCell<'a>) (matrixRight: CSR<'b>) (bitmap: ClArray<int>) ->
4045

4146
let resultLength = matrixRight.NNZ + 1
4247

@@ -137,7 +142,7 @@ module internal Kronecker =
137142

138143
let kernel = clContext.Compile <| preparePositions op
139144

140-
fun (processor: MailboxProcessor<_>) (operand: ClCell<'a>) (matrix: ClMatrix.CSR<'b>) (resultDenseMatrix: ClArray<'c>) (resultBitmap: ClArray<int>) ->
145+
fun (processor: MailboxProcessor<_>) (operand: ClCell<'a>) (matrix: CSR<'b>) (resultDenseMatrix: ClArray<'c>) (resultBitmap: ClArray<int>) ->
141146

142147
let resultLength = matrix.RowCount * matrix.ColumnCount
143148

@@ -163,10 +168,10 @@ module internal Kronecker =
163168

164169
processor.Post(Msg.CreateRunMsg<_, _> kernel)
165170

166-
let setPositions<'c when 'c: struct> (clContext: ClContext) workGroupSize =
171+
let private setPositions<'c when 'c: struct> (clContext: ClContext) workGroupSize =
167172

168173
let setPositions =
169-
<@ fun (ndRange: Range1D) rowCount columnCount startIndex (nnz: ClCell<int>) (rowOffset: ClCell<int>) (columnOffset: ClCell<int>) (bitmap: ClArray<int>) (values: ClArray<'c>) (resultRows: ClArray<int>) (resultColumns: ClArray<int>) (resultValues: ClArray<'c>) ->
174+
<@ fun (ndRange: Range1D) rowCount columnCount startIndex (rowOffset: ClCell<int>) (columnOffset: ClCell<int>) (bitmap: ClArray<int>) (values: ClArray<'c>) (resultRows: ClArray<int>) (resultColumns: ClArray<int>) (resultValues: ClArray<'c>) ->
170175

171176
let gid = ndRange.GlobalID0
172177

@@ -208,7 +213,6 @@ module internal Kronecker =
208213
rowCount
209214
columnCount
210215
startIndex
211-
sum
212216
rowOffset
213217
columnOffset
214218
bitmap
@@ -222,7 +226,7 @@ module internal Kronecker =
222226

223227
(sum.ToHostAndFree processor) + startIndex
224228

225-
let copyToResult (clContext: ClContext) workGroupSize =
229+
let private copyToResult (clContext: ClContext) workGroupSize =
226230

227231
let copyToResult =
228232
<@ fun (ndRange: Range1D) startIndex sourceLength (rowOffset: ClCell<int>) (columnOffset: ClCell<int>) (sourceRows: ClArray<int>) (sourceColumns: ClArray<int>) (sourceValues: ClArray<'c>) (resultRows: ClArray<int>) (resultColumns: ClArray<int>) (resultValues: ClArray<'c>) ->
@@ -267,7 +271,7 @@ module internal Kronecker =
267271

268272
processor.Post(Msg.CreateRunMsg<_, _> kernel)
269273

270-
let insertZero (clContext: ClContext) workGroupSize =
274+
let private insertZero (clContext: ClContext) workGroupSize =
271275

272276
let copy = copyToResult clContext workGroupSize
273277

@@ -299,7 +303,7 @@ module internal Kronecker =
299303
for row in 0 .. rowCount - 1 do
300304
insertInRowRec zeroCounts.[row] row 0
301305

302-
let insertNonZero (clContext: ClContext) workGroupSize op =
306+
let private insertNonZero (clContext: ClContext) workGroupSize op =
303307

304308
let item = ClArray.item clContext workGroupSize
305309

@@ -347,7 +351,7 @@ module internal Kronecker =
347351

348352
startIndex
349353

350-
let mapAll<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
354+
let private mapAll<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
351355
(clContext: ClContext)
352356
workGroupSize
353357
(op: Expr<'a option -> 'b option -> 'c option>)
@@ -429,7 +433,7 @@ module internal Kronecker =
429433
let bitonic =
430434
Sort.Bitonic.sortKeyValuesInplace clContext workGroupSize
431435

432-
fun (queue: MailboxProcessor<_>) allocationMode (matrixLeft: ClMatrix.CSR<'a>) (matrixRight: ClMatrix.CSR<'b>) ->
436+
fun (queue: MailboxProcessor<_>) allocationMode (matrixLeft: CSR<'a>) (matrixRight: CSR<'b>) ->
433437

434438
let matrixZero =
435439
mapWithValue queue allocationMode None matrixRight

tests/GraphBLAS-sharp.Tests/Backend/Matrix/Kronecker.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ open GraphBLAS.FSharp.Objects.MatrixExtensions
1515

1616
let config =
1717
{ Utils.defaultConfig with
18-
endSize = 30
18+
endSize = 100
1919
maxTest = 20 }
2020

2121
let logger = Log.create "kronecker.Tests"

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 105 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -2,119 +2,114 @@ open Expecto
22
open GraphBLAS.FSharp.Tests.Backend
33
open GraphBLAS.FSharp.Tests
44

5-
// let matrixTests =
6-
// testList
7-
// "Matrix"
8-
// [ Matrix.Convert.tests
9-
// Matrix.Map2.allTests
10-
// Matrix.Map.allTests
11-
// Matrix.Merge.allTests
12-
// Matrix.Transpose.tests
13-
// Matrix.RowsLengths.tests
14-
// Matrix.ByRows.tests
15-
// Matrix.ExpandRows.tests
16-
// Matrix.SubRows.tests
17-
// Matrix.Kronecker.tests
18-
//
19-
// Matrix.SpGeMM.Expand.generalTests
20-
// Matrix.SpGeMM.Masked.tests ]
21-
// |> testSequenced
22-
//
23-
// let commonTests =
24-
// let scanTests =
25-
// testList
26-
// "Scan"
27-
// [ Common.Scan.ByKey.sequentialSegmentsTests
28-
// Common.Scan.PrefixSum.tests ]
29-
//
30-
// let reduceTests =
31-
// testList
32-
// "Reduce"
33-
// [ Common.Reduce.ByKey.allTests
34-
// Common.Reduce.Reduce.tests
35-
// Common.Reduce.Sum.tests ]
36-
//
37-
// let clArrayTests =
38-
// testList
39-
// "ClArray"
40-
// [ Common.ClArray.RemoveDuplicates.tests
41-
// Common.ClArray.Copy.tests
42-
// Common.ClArray.Replicate.tests
43-
// Common.ClArray.Exists.tests
44-
// Common.ClArray.Map.tests
45-
// Common.ClArray.Map2.addTests
46-
// Common.ClArray.Map2.mulTests
47-
// Common.ClArray.Choose.allTests
48-
// Common.ClArray.ChunkBySize.allTests
49-
// Common.ClArray.Blit.tests
50-
// Common.ClArray.Concat.tests
51-
// Common.ClArray.Fill.tests
52-
// Common.ClArray.Pairwise.tests
53-
// Common.ClArray.UpperBound.tests
54-
// Common.ClArray.Set.tests
55-
// Common.ClArray.Item.tests ]
56-
//
57-
// let sortTests =
58-
// testList
59-
// "Sort"
60-
// [ Common.Sort.Bitonic.tests
61-
// Common.Sort.Radix.allTests ]
62-
//
63-
// testList
64-
// "Common"
65-
// [ Common.Scatter.allTests
66-
// Common.Gather.allTests
67-
// Common.Merge.tests
68-
// clArrayTests
69-
// sortTests
70-
// reduceTests
71-
// scanTests ]
72-
// |> testSequenced
73-
//
74-
// let vectorTests =
75-
// testList
76-
// "Vector"
77-
// [ Vector.SpMV.tests
78-
// Vector.ZeroCreate.tests
79-
// Vector.OfList.tests
80-
// Vector.Copy.tests
81-
// Vector.Convert.tests
82-
// Vector.Map2.allTests
83-
// Vector.AssignByMask.tests
84-
// Vector.AssignByMask.complementedTests
85-
// Vector.Reduce.tests
86-
// Vector.Merge.tests ]
87-
// |> testSequenced
88-
//
89-
// let algorithmsTests =
90-
// testList "Algorithms tests" [ Algorithms.BFS.tests ]
91-
// |> testSequenced
92-
//
93-
// let deviceTests =
94-
// testList
95-
// "Device"
96-
// [ matrixTests
97-
// commonTests
98-
// vectorTests
99-
// algorithmsTests ]
100-
// |> testSequenced
101-
//
102-
// let hostTests =
103-
// testList
104-
// "Host"
105-
// [ Host.Matrix.FromArray2D.tests
106-
// Host.Matrix.Convert.tests
107-
// Host.IO.MtxReader.test ]
108-
// |> testSequenced
109-
//
110-
// [<Tests>]
111-
// let allTests =
112-
// testList "All" [ deviceTests; hostTests ]
113-
// |> testSequenced
5+
let matrixTests =
6+
testList
7+
"Matrix"
8+
[ Matrix.Convert.tests
9+
Matrix.Map2.allTests
10+
Matrix.Map.allTests
11+
Matrix.Merge.allTests
12+
Matrix.Transpose.tests
13+
Matrix.RowsLengths.tests
14+
Matrix.ByRows.tests
15+
Matrix.ExpandRows.tests
16+
Matrix.SubRows.tests
17+
Matrix.Kronecker.tests
18+
19+
Matrix.SpGeMM.Expand.generalTests
20+
Matrix.SpGeMM.Masked.tests ]
21+
|> testSequenced
22+
23+
let commonTests =
24+
let scanTests =
25+
testList
26+
"Scan"
27+
[ Common.Scan.ByKey.sequentialSegmentsTests
28+
Common.Scan.PrefixSum.tests ]
29+
30+
let reduceTests =
31+
testList
32+
"Reduce"
33+
[ Common.Reduce.ByKey.allTests
34+
Common.Reduce.Reduce.tests
35+
Common.Reduce.Sum.tests ]
36+
37+
let clArrayTests =
38+
testList
39+
"ClArray"
40+
[ Common.ClArray.RemoveDuplicates.tests
41+
Common.ClArray.Copy.tests
42+
Common.ClArray.Replicate.tests
43+
Common.ClArray.Exists.tests
44+
Common.ClArray.Map.tests
45+
Common.ClArray.Map2.addTests
46+
Common.ClArray.Map2.mulTests
47+
Common.ClArray.Choose.allTests
48+
Common.ClArray.ChunkBySize.allTests
49+
Common.ClArray.Blit.tests
50+
Common.ClArray.Concat.tests
51+
Common.ClArray.Fill.tests
52+
Common.ClArray.Pairwise.tests
53+
Common.ClArray.UpperBound.tests
54+
Common.ClArray.Set.tests
55+
Common.ClArray.Item.tests ]
56+
57+
let sortTests =
58+
testList
59+
"Sort"
60+
[ Common.Sort.Bitonic.tests
61+
Common.Sort.Radix.allTests ]
62+
63+
testList
64+
"Common"
65+
[ Common.Scatter.allTests
66+
Common.Gather.allTests
67+
Common.Merge.tests
68+
clArrayTests
69+
sortTests
70+
reduceTests
71+
scanTests ]
72+
|> testSequenced
73+
74+
let vectorTests =
75+
testList
76+
"Vector"
77+
[ Vector.SpMV.tests
78+
Vector.ZeroCreate.tests
79+
Vector.OfList.tests
80+
Vector.Copy.tests
81+
Vector.Convert.tests
82+
Vector.Map2.allTests
83+
Vector.AssignByMask.tests
84+
Vector.AssignByMask.complementedTests
85+
Vector.Reduce.tests
86+
Vector.Merge.tests ]
87+
|> testSequenced
88+
89+
let algorithmsTests =
90+
testList "Algorithms tests" [ Algorithms.BFS.tests ]
91+
|> testSequenced
92+
93+
let deviceTests =
94+
testList
95+
"Device"
96+
[ matrixTests
97+
commonTests
98+
vectorTests
99+
algorithmsTests ]
100+
|> testSequenced
101+
102+
let hostTests =
103+
testList
104+
"Host"
105+
[ Host.Matrix.FromArray2D.tests
106+
Host.Matrix.Convert.tests
107+
Host.IO.MtxReader.test ]
108+
|> testSequenced
114109

115110
[<Tests>]
116111
let allTests =
117-
testList "All" [ Matrix.Kronecker.tests ]
112+
testList "All" [ deviceTests; hostTests ]
118113
|> testSequenced
119114

120115
[<EntryPoint>]

0 commit comments

Comments
 (0)