Skip to content

Commit 972b392

Browse files
committed
refactor: init in spgemm
1 parent 29c564c commit 972b392

3 files changed

Lines changed: 11 additions & 19 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/Scatter.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ namespace GraphBLAS.FSharp.Backend.Common
33
open Brahma.FSharp
44

55
module internal Scatter =
6-
let firstOccurencePredicate () =
6+
let private firstOccurencePredicate () =
77
<@ fun gid _ (positions: ClArray<int>) ->
88
// first occurrence condition
99
(gid = 0 || positions.[gid - 1] <> positions.[gid]) @>
1010

11-
let lastOccurrencePredicate () =
11+
let private lastOccurrencePredicate () =
1212
<@ fun gid positionsLength (positions: ClArray<int>) ->
1313
// last occurrence condition
1414
(gid = positionsLength - 1 || positions.[gid] <> positions.[gid + 1]) @>

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM/Expand.fs

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ module Expand =
6969

7070
let init = ClArray.init clContext workGroupSize Map.id
7171

72+
let idScatter = Scatter.initLastOccurrence Map.id clContext workGroupSize
73+
7274
let scatter = Scatter.lastOccurrence clContext workGroupSize
7375

7476
let zeroCreate = ClArray.zeroCreate clContext workGroupSize
@@ -94,18 +96,14 @@ module Expand =
9496
fun (processor: MailboxProcessor<_>) lengths (segmentsPointers: Indices) (leftMatrix: ClMatrix.CSR<'a>) (rightMatrix: ClMatrix.CSR<'b>) ->
9597

9698
// Compute A positions
97-
let sequence = init processor DeviceOnly segmentsPointers.Length // TODO(fuse)
98-
9999
let APositions = zeroCreate processor DeviceOnly lengths
100100

101-
scatter processor segmentsPointers sequence APositions
102-
103-
sequence.Free processor
101+
idScatter processor segmentsPointers APositions
104102

105103
(maxPrefixSum processor APositions 0).Free processor
106104

107105
// Compute B positions
108-
let BPositions = create processor DeviceOnly lengths 1 // TODO(fuse)
106+
let BPositions = create processor DeviceOnly lengths 1
109107

110108
let requiredBPointers = zeroCreate processor DeviceOnly leftMatrix.Columns.Length
111109

@@ -200,9 +198,7 @@ module Expand =
200198

201199
let prefixSum = PrefixSum.standardExcludeInplace clContext workGroupSize
202200

203-
let init = ClArray.init clContext workGroupSize Map.id // TODO(fuse)
204-
205-
let scatter = Scatter.firstOccurrence clContext workGroupSize
201+
let idScatter = Scatter.initFirsOccurrence Map.id clContext workGroupSize
206202

207203
fun (processor: MailboxProcessor<_>) allocationMode (values: ClArray<'a>) (columns: Indices) (rows: Indices) ->
208204

@@ -214,18 +210,13 @@ module Expand =
214210

215211
printfn $"key bitmap after prefix sum: %A{bitmap.ToHost processor}"
216212

217-
let positions = init processor DeviceOnly bitmap.Length
218-
219-
printfn $"positions: %A{positions.ToHost processor}"
220-
221213
let offsets = clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, uniqueKeysCount)
222214

223-
scatter processor bitmap positions offsets
215+
idScatter processor bitmap offsets
224216

225217
printfn $"offsets: %A{offsets.ToHost processor}"
226218

227219
bitmap.Free processor
228-
positions.Free processor
229220

230221
let reducedColumns, reducedRows, reducedValues = // by size variance TODO()
231222
reduce processor allocationMode uniqueKeysCount offsets columns rows values

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,10 @@ open GraphBLAS.FSharp.Tests.Matrix
9494
let allTests =
9595
testList
9696
"All tests"
97-
[ // SpGeMM.getSegmentsTests
97+
[ SpGeMM.generalTests
9898
// Common.Gather.initTests
99-
Common.Scatter.allTests ]
99+
//Common.Scatter.allTests ]
100+
]
100101

101102
|> testSequenced
102103

0 commit comments

Comments
 (0)