Skip to content

Commit 96c0c08

Browse files
committed
refactor: spgemm
1 parent 3f7c0bf commit 96c0c08

24 files changed

Lines changed: 746 additions & 593 deletions

File tree

benchmarks/GraphBLAS-sharp.Benchmarks/BenchmarksMxm.fs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ module Operations =
229229
type MxmBenchmarks4Float32MultiplicationOnly() =
230230

231231
inherit MxmBenchmarksMultiplicationOnly<float32>(
232-
(Matrix.mxm Operations.add Operations.mult),
232+
(Matrix.SpGeMM.masked Operations.add Operations.mult),
233233
float32,
234234
(fun _ -> Utils.nextSingle (System.Random())),
235235
(fun context matrix -> ClMatrix.CSR (Matrix.ToBackendCSR context matrix))
@@ -241,7 +241,7 @@ type MxmBenchmarks4Float32MultiplicationOnly() =
241241
type MxmBenchmarks4Float32WithTransposing() =
242242

243243
inherit MxmBenchmarksWithTransposing<float32>(
244-
(Matrix.mxm Operations.add Operations.mult),
244+
(Matrix.SpGeMM.masked Operations.add Operations.mult),
245245
float32,
246246
(fun _ -> Utils.nextSingle (System.Random())),
247247
(fun context matrix -> ClMatrix.CSR (Matrix.ToBackendCSR context matrix))
@@ -253,7 +253,7 @@ type MxmBenchmarks4Float32WithTransposing() =
253253
type MxmBenchmarks4BoolMultiplicationOnly() =
254254

255255
inherit MxmBenchmarksMultiplicationOnly<bool>(
256-
(Matrix.mxm Operations.logicalOr Operations.logicalAnd),
256+
(Matrix.SpGeMM.masked Operations.logicalOr Operations.logicalAnd),
257257
(fun _ -> true),
258258
(fun _ -> true),
259259
(fun context matrix -> ClMatrix.CSR (Matrix.ToBackendCSR context matrix))
@@ -265,7 +265,7 @@ type MxmBenchmarks4BoolMultiplicationOnly() =
265265
type MxmBenchmarks4BoolWithTransposing() =
266266

267267
inherit MxmBenchmarksWithTransposing<bool>(
268-
(Matrix.mxm Operations.logicalOr Operations.logicalAnd),
268+
(Matrix.SpGeMM.masked Operations.logicalOr Operations.logicalAnd),
269269
(fun _ -> true),
270270
(fun _ -> true),
271271
(fun context matrix -> ClMatrix.CSR (Matrix.ToBackendCSR context matrix))
@@ -277,7 +277,7 @@ type MxmBenchmarks4BoolWithTransposing() =
277277
type MxmBenchmarks4Float32MultiplicationOnlyWithZerosFilter() =
278278

279279
inherit MxmBenchmarksMultiplicationOnly<float32>(
280-
(Matrix.mxm Operations.addWithFilter Operations.mult),
280+
(Matrix.SpGeMM.masked Operations.addWithFilter Operations.mult),
281281
float32,
282282
(fun _ -> Utils.nextSingle (System.Random())),
283283
(fun context matrix -> ClMatrix.CSR (Matrix.ToBackendCSR context matrix))
@@ -289,7 +289,7 @@ type MxmBenchmarks4Float32MultiplicationOnlyWithZerosFilter() =
289289
type MxmBenchmarks4Float32WithTransposingWithZerosFilter() =
290290

291291
inherit MxmBenchmarksWithTransposing<float32>(
292-
(Matrix.mxm Operations.addWithFilter Operations.mult),
292+
(Matrix.SpGeMM.masked Operations.addWithFilter Operations.mult),
293293
float32,
294294
(fun _ -> Utils.nextSingle (System.Random())),
295295
(fun context matrix -> ClMatrix.CSR (Matrix.ToBackendCSR context matrix))

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,15 @@ module ClArray =
168168
let getUniqueBitmapFirstOccurrence clContext =
169169
getUniqueBitmapGeneral
170170
<| <@ fun (gid: int) (_: int) (inputArray: ClArray<'a>) ->
171-
gid = 0 || inputArray.[gid - 1] <> inputArray.[gid] @>
171+
gid = 0
172+
|| inputArray.[gid - 1] <> inputArray.[gid] @>
172173
<| clContext
173174

174175
let getUniqueBitmapLastOccurrence clContext =
175176
getUniqueBitmapGeneral
176177
<| <@ fun (gid: int) (length: int) (inputArray: ClArray<'a>) ->
177-
gid = length - 1 || inputArray.[gid] <> inputArray.[gid + 1] @>
178+
gid = length - 1
179+
|| inputArray.[gid] <> inputArray.[gid + 1] @>
178180
<| clContext
179181

180182
///<description>Remove duplicates form the given array.</description>
@@ -186,7 +188,8 @@ module ClArray =
186188
let scatter =
187189
Scatter.lastOccurrence clContext workGroupSize
188190

189-
let getUniqueBitmap = getUniqueBitmapLastOccurrence clContext workGroupSize
191+
let getUniqueBitmap =
192+
getUniqueBitmapLastOccurrence clContext workGroupSize
190193

191194
let prefixSumExclude =
192195
PrefixSum.runExcludeInplace <@ (+) @> clContext workGroupSize
@@ -308,16 +311,20 @@ module ClArray =
308311

309312
let getUniqueBitmap2General<'a when 'a: equality> getUniqueBitmap (clContext: ClContext) workGroupSize =
310313

311-
let map = map2 clContext workGroupSize <@ fun x y -> x ||| y @>
314+
let map =
315+
map2 clContext workGroupSize <@ fun x y -> x ||| y @>
312316

313317
let firstGetBitmap = getUniqueBitmap clContext workGroupSize
314318

315319
fun (processor: MailboxProcessor<_>) allocationMode (firstArray: ClArray<'a>) (secondArray: ClArray<'a>) ->
316-
let firstBitmap = firstGetBitmap processor DeviceOnly firstArray
320+
let firstBitmap =
321+
firstGetBitmap processor DeviceOnly firstArray
317322

318-
let secondBitmap = firstGetBitmap processor DeviceOnly secondArray
323+
let secondBitmap =
324+
firstGetBitmap processor DeviceOnly secondArray
319325

320-
let result = map processor allocationMode firstBitmap secondBitmap
326+
let result =
327+
map processor allocationMode firstBitmap secondBitmap
321328

322329
firstBitmap.Free processor
323330
secondBitmap.Free processor
@@ -344,15 +351,15 @@ module ClArray =
344351
// seems like scatter (option scatter) ???
345352
if 0 <= position && position < resultLength then
346353
match (%op) value with
347-
| Some value ->
348-
result.[position] <- value
354+
| Some value -> result.[position] <- value
349355
| None -> () @>
350356

351357
let kernel = clContext.Compile assign
352358

353359
fun (processor: MailboxProcessor<_>) (values: ClArray<'a>) (positions: ClArray<int>) (result: ClArray<'b>) ->
354360

355-
if values.Length <> positions.Length then failwith "lengths must be the same"
361+
if values.Length <> positions.Length then
362+
failwith "lengths must be the same"
356363

357364
let ndRange =
358365
Range1D.CreateValid(values.Length, workGroupSize)
@@ -371,19 +378,23 @@ module ClArray =
371378
map<'a, int> clContext workGroupSize
372379
<| Map.chooseBitmap predicate
373380

374-
let prefixSum = PrefixSum.standardExcludeInplace clContext workGroupSize
381+
let prefixSum =
382+
PrefixSum.standardExcludeInplace clContext workGroupSize
375383

376-
let assignValues = assignOption clContext workGroupSize predicate
384+
let assignValues =
385+
assignOption clContext workGroupSize predicate
377386

378387
fun (processor: MailboxProcessor<_>) allocationMode (sourceValues: ClArray<'a>) ->
379388

380-
let positions = getBitmap processor DeviceOnly sourceValues
389+
let positions =
390+
getBitmap processor DeviceOnly sourceValues
381391

382392
let resultLength =
383393
(prefixSum processor positions)
384394
.ToHostAndFree(processor)
385395

386-
let result = clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
396+
let result =
397+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
387398

388399
assignValues processor sourceValues positions result
389400

@@ -405,17 +416,16 @@ module ClArray =
405416
// seems like scatter2 (option scatter2) ???
406417
if 0 <= position && position < resultLength then
407418
match (%op) leftValue rightValue with
408-
| Some value ->
409-
result.[position] <- value
419+
| Some value -> result.[position] <- value
410420
| None -> () @>
411421

412422
let kernel = clContext.Compile assign
413423

414424
fun (processor: MailboxProcessor<_>) (firstValues: ClArray<'a>) (secondValues: ClArray<'b>) (positions: ClArray<int>) (result: ClArray<'c>) ->
415425

416426
if firstValues.Length <> secondValues.Length
417-
|| secondValues.Length <> positions.Length then
418-
failwith "lengths must be the same"
427+
|| secondValues.Length <> positions.Length then
428+
failwith "lengths must be the same"
419429

420430
let ndRange =
421431
Range1D.CreateValid(firstValues.Length, workGroupSize)
@@ -424,7 +434,15 @@ module ClArray =
424434

425435
processor.Post(
426436
Msg.MsgSetArguments
427-
(fun () -> kernel.KernelFunc ndRange firstValues.Length firstValues secondValues positions result result.Length)
437+
(fun () ->
438+
kernel.KernelFunc
439+
ndRange
440+
firstValues.Length
441+
firstValues
442+
secondValues
443+
positions
444+
result
445+
result.Length)
428446
)
429447

430448
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
@@ -434,19 +452,23 @@ module ClArray =
434452
map2<'a, 'b, int> clContext workGroupSize
435453
<| Map.chooseBitmap2 predicate
436454

437-
let prefixSum = PrefixSum.standardExcludeInplace clContext workGroupSize
455+
let prefixSum =
456+
PrefixSum.standardExcludeInplace clContext workGroupSize
438457

439-
let assignValues = assignOption2 clContext workGroupSize predicate
458+
let assignValues =
459+
assignOption2 clContext workGroupSize predicate
440460

441461
fun (processor: MailboxProcessor<_>) allocationMode (firstValues: ClArray<'a>) (secondValues: ClArray<'b>) ->
442462

443-
let positions = getBitmap processor DeviceOnly firstValues secondValues
463+
let positions =
464+
getBitmap processor DeviceOnly firstValues secondValues
444465

445466
let resultLength =
446467
(prefixSum processor positions)
447468
.ToHostAndFree(processor)
448469

449-
let result = clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
470+
let result =
471+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
450472

451473
assignValues processor firstValues secondValues positions result
452474

src/GraphBLAS-sharp.Backend/Common/Gather.fs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@ module internal Gather =
2222

2323
let kernel = program.GetKernel()
2424

25-
let ndRange = Range1D.CreateValid(outputArray.Length, workGroupSize)
25+
let ndRange =
26+
Range1D.CreateValid(outputArray.Length, workGroupSize)
2627

27-
processor.Post(
28-
Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange values.Length values outputArray)
29-
)
28+
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange values.Length values outputArray))
3029

3130
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
3231

@@ -59,14 +58,17 @@ module internal Gather =
5958

6059
fun (processor: MailboxProcessor<_>) (positions: ClArray<int>) (values: ClArray<'a>) (outputArray: ClArray<'a>) ->
6160

62-
if positions.Length <> outputArray.Length then failwith "Lengths must be the same"
61+
if positions.Length <> outputArray.Length then
62+
failwith "Lengths must be the same"
6363

6464
let kernel = program.GetKernel()
6565

66-
let ndRange = Range1D.CreateValid(positions.Length, workGroupSize)
66+
let ndRange =
67+
Range1D.CreateValid(positions.Length, workGroupSize)
6768

6869
processor.Post(
69-
Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange positions.Length values.Length positions values outputArray)
70+
Msg.MsgSetArguments
71+
(fun () -> kernel.KernelFunc ndRange positions.Length values.Length positions values outputArray)
7072
)
7173

7274
processor.Post(Msg.CreateRunMsg<_, _>(kernel))

src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,4 +341,3 @@ module PrefixSum =
341341
/// </example>
342342
let sequentialInclude clContext =
343343
sequentialSegments (Map.snd ()) clContext
344-

src/GraphBLAS-sharp.Backend/Common/Scatter.fs

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ open Brahma.FSharp
55
module internal Scatter =
66
let private firstOccurencePredicate () =
77
<@ fun gid _ (positions: ClArray<int>) ->
8-
// first occurrence condition
9-
(gid = 0 || positions.[gid - 1] <> positions.[gid]) @>
8+
// first occurrence condition
9+
(gid = 0 || positions.[gid - 1] <> positions.[gid]) @>
1010

1111
let private lastOccurrencePredicate () =
1212
<@ fun gid positionsLength (positions: ClArray<int>) ->
13-
// last occurrence condition
14-
(gid = positionsLength - 1 || positions.[gid] <> positions.[gid + 1]) @>
13+
// last occurrence condition
14+
(gid = positionsLength - 1
15+
|| positions.[gid] <> positions.[gid + 1]) @>
1516

1617

1718
let private general<'a> predicate (clContext: ClContext) workGroupSize =
@@ -23,19 +24,23 @@ module internal Scatter =
2324

2425
if gid < positionsLength then
2526
// positions lengths == values length
26-
let predicateResult = (%predicate) gid positionsLength positions
27+
let predicateResult =
28+
(%predicate) gid positionsLength positions
29+
2730
let position = positions.[gid]
2831

2932
if predicateResult
30-
&& 0 <= position && position < resultLength then
33+
&& 0 <= position
34+
&& position < resultLength then
3135

3236
result.[positions.[gid]] <- values.[gid] @>
3337

3438
let program = clContext.Compile(run)
3539

3640
fun (processor: MailboxProcessor<_>) (positions: ClArray<int>) (values: ClArray<'a>) (result: ClArray<'a>) ->
3741

38-
if positions.Length <> values.Length then failwith "Lengths must be the same"
42+
if positions.Length <> values.Length then
43+
failwith "Lengths must be the same"
3944

4045
let positionsLength = positions.Length
4146

@@ -70,9 +75,7 @@ module internal Scatter =
7075
/// </code>
7176
/// </example>
7277
let firstOccurrence clContext =
73-
general
74-
<| firstOccurencePredicate ()
75-
<| clContext
78+
general <| firstOccurencePredicate () <| clContext
7679

7780
/// <summary>
7881
/// Creates a new array from the given one where it is indicated by the array of positions at which position in the new array
@@ -93,9 +96,7 @@ module internal Scatter =
9396
/// </code>
9497
/// </example>
9598
let lastOccurrence clContext =
96-
general
97-
<| lastOccurrencePredicate ()
98-
<| clContext
99+
general <| lastOccurrencePredicate () <| clContext
99100

100101
let private generalInit<'a> predicate valueMap (clContext: ClContext) workGroupSize =
101102

@@ -106,12 +107,14 @@ module internal Scatter =
106107

107108
if gid < positionsLength then
108109
// positions lengths == values length
109-
let predicateResult = (%predicate) gid positionsLength positions
110+
let predicateResult =
111+
(%predicate) gid positionsLength positions
110112

111113
let position = positions.[gid]
112114

113115
if predicateResult
114-
&& 0 <= position && position < resultLength then
116+
&& 0 <= position
117+
&& position < resultLength then
115118

116119
result.[positions.[gid]] <- (%valueMap) gid @>
117120

@@ -127,8 +130,7 @@ module internal Scatter =
127130
let kernel = program.GetKernel()
128131

129132
processor.Post(
130-
Msg.MsgSetArguments
131-
(fun () -> kernel.KernelFunc ndRange positions positionsLength result result.Length)
133+
Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange positions positionsLength result result.Length)
132134
)
133135

134136
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
@@ -152,7 +154,10 @@ module internal Scatter =
152154
/// </code>
153155
/// </example>
154156
/// <param name="valueMap">Maps global id to a value</param>
155-
let initFirsOccurrence<'a> valueMap = generalInit<'a> <| firstOccurencePredicate () <| valueMap
157+
let initFirsOccurrence<'a> valueMap =
158+
generalInit<'a>
159+
<| firstOccurencePredicate ()
160+
<| valueMap
156161

157162
/// <summary>
158163
/// Creates a new array from the given one where it is indicated by the array of positions at which position in the new array
@@ -173,4 +178,7 @@ module internal Scatter =
173178
/// </code>
174179
/// </example>
175180
/// <param name="valueMap">Maps global id to a value</param>
176-
let initLastOccurrence<'a> valueMap = generalInit<'a> <| lastOccurrencePredicate () <| valueMap
181+
let initLastOccurrence<'a> valueMap =
182+
generalInit<'a>
183+
<| lastOccurrencePredicate ()
184+
<| valueMap

0 commit comments

Comments
 (0)