Skip to content

Commit 1ebf84b

Browse files
committed
refactor: CSR.expandRowsPointers
1 parent 269fc21 commit 1ebf84b

9 files changed

Lines changed: 383 additions & 669 deletions

File tree

src/GraphBLAS-sharp.Backend/Matrix/CSR/Matrix.fs

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,54 @@ open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
1313

1414

1515
module Matrix =
16+
let expandRowPointers (clContext: ClContext) workGroupSize =
17+
18+
let kernel =
19+
<@ fun (ndRange: Range1D) columnsLength pointersLength (pointers: ClArray<int>) (results: ClArray<int>) ->
20+
21+
let gid = ndRange.GlobalID0
22+
23+
if gid < columnsLength then
24+
let result =
25+
(%Search.Bin.lowerBound 0) pointersLength gid pointers
26+
27+
results.[gid] <- result - 1 @>
28+
29+
let program = clContext.Compile kernel
30+
31+
fun (processor: MailboxProcessor<_>) allocationMode (matrix: ClMatrix.CSR<'a>) ->
32+
33+
let rows =
34+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, matrix.Columns.Length)
35+
36+
let kernel = program.GetKernel()
37+
38+
let ndRange =
39+
Range1D.CreateValid(matrix.Columns.Length, workGroupSize)
40+
41+
processor.Post(Msg.MsgSetArguments(
42+
fun () ->
43+
kernel.KernelFunc
44+
ndRange
45+
matrix.Columns.Length
46+
matrix.RowPointers.Length
47+
matrix.RowPointers
48+
rows))
49+
50+
processor.Post(Msg.CreateRunMsg<_, _> kernel)
51+
52+
rows
53+
1654
let toCOO (clContext: ClContext) workGroupSize =
17-
let prepare =
18-
Common.expandRowPointers clContext workGroupSize
55+
let prepare = expandRowPointers clContext workGroupSize
1956

2057
let copy = ClArray.copy clContext workGroupSize
2158

2259
let copyData = ClArray.copy clContext workGroupSize
2360

2461
fun (processor: MailboxProcessor<_>) allocationMode (matrix: ClMatrix.CSR<'a>) ->
2562
let rows =
26-
prepare processor allocationMode matrix.RowPointers matrix.Columns.Length matrix.RowCount
63+
prepare processor allocationMode matrix
2764

2865
let cols =
2966
copy processor allocationMode matrix.Columns
@@ -39,12 +76,11 @@ module Matrix =
3976
Values = values }
4077

4178
let toCOOInPlace (clContext: ClContext) workGroupSize =
42-
let prepare =
43-
Common.expandRowPointers clContext workGroupSize
79+
let prepare = expandRowPointers clContext workGroupSize
4480

4581
fun (processor: MailboxProcessor<_>) allocationMode (matrix: ClMatrix.CSR<'a>) ->
4682
let rows =
47-
prepare processor allocationMode matrix.RowPointers matrix.Columns.Length matrix.RowCount
83+
prepare processor allocationMode matrix
4884

4985
processor.Post(Msg.CreateFreeMsg(matrix.RowPointers))
5086

@@ -92,7 +128,6 @@ module Matrix =
92128
let toCSRInPlace =
93129
COO.Matrix.toCSRInPlace clContext workGroupSize
94130

95-
96131
fun (queue: MailboxProcessor<_>) allocationMode (matrix: ClMatrix.CSR<'a>) ->
97132
toCOO queue allocationMode matrix
98133
|> transposeInPlace queue

src/GraphBLAS-sharp.Backend/Matrix/Common.fs

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ open Brahma.FSharp
44
open GraphBLAS.FSharp.Backend.Common
55
open GraphBLAS.FSharp.Backend.Objects.ClContext
66
open GraphBLAS.FSharp.Backend.Objects.ClCell
7+
open GraphBLAS.FSharp.Backend.Objects
8+
open GraphBLAS.FSharp.Backend.Quotes
79

8-
module Common =
10+
module internal Common =
911
///<param name="clContext">.</param>
1012
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
1113
let setPositions<'a when 'a: struct> (clContext: ClContext) workGroupSize =
@@ -40,40 +42,3 @@ module Common =
4042
valuesScatter processor positions allValues resultValues
4143

4244
resultRows, resultColumns, resultValues, resultLength
43-
44-
let expandRowPointers (clContext: ClContext) workGroupSize =
45-
46-
let expandRowPointers =
47-
<@ fun (ndRange: Range1D) (rowPointers: ClArray<int>) (rowCount: int) (rows: ClArray<int>) ->
48-
49-
let i = ndRange.GlobalID0
50-
51-
if i < rowCount then
52-
let rowPointer = rowPointers.[i]
53-
54-
if rowPointer <> rowPointers.[i + 1] then
55-
rows.[rowPointer] <- i @>
56-
57-
let program = clContext.Compile expandRowPointers
58-
59-
let create =
60-
ClArray.zeroCreate clContext workGroupSize
61-
62-
let scan =
63-
PrefixSum.runIncludeInPlace <@ max @> clContext workGroupSize
64-
65-
fun (processor: MailboxProcessor<_>) allocationMode (rowPointers: ClArray<int>) nnz rowCount ->
66-
67-
let rows = create processor allocationMode nnz
68-
69-
let kernel = program.GetKernel()
70-
71-
let ndRange =
72-
Range1D.CreateValid(rowCount, workGroupSize)
73-
74-
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange rowPointers rowCount rows))
75-
processor.Post(Msg.CreateRunMsg<_, _> kernel)
76-
77-
(scan processor rows 0).Free processor
78-
79-
rows

src/GraphBLAS-sharp.Backend/Matrix/Matrix.fs

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -405,19 +405,19 @@ module Matrix =
405405
| ClMatrix.CSR m1, ClMatrix.CSC m2, ClMatrix.COO mask -> runCSRnCSC queue m1 m2 mask |> ClMatrix.COO
406406
| _ -> failwith "Matrix formats are not matching"
407407

408-
let expand
409-
(opAdd: Expr<'c -> 'c -> 'c option>)
410-
(opMul: Expr<'a -> 'b -> 'c option>)
411-
(clContext: ClContext)
412-
workGroupSize
413-
=
414-
415-
let run =
416-
SpGeMM.Expand.run clContext workGroupSize opAdd opMul
417-
418-
fun (processor: MailboxProcessor<_>) allocationMode (leftMatrix: ClMatrix<'a>) (rightMatrix: ClMatrix<'b>) ->
419-
match leftMatrix, rightMatrix with
420-
| ClMatrix.CSR leftMatrix, ClMatrix.CSR rightMatrix ->
421-
ClMatrix.LIL
422-
<| run processor allocationMode leftMatrix rightMatrix
423-
| _ -> failwith "Matrix formats are not matching"
408+
// let expand
409+
// (opAdd: Expr<'c -> 'c -> 'c option>)
410+
// (opMul: Expr<'a -> 'b -> 'c option>)
411+
// (clContext: ClContext)
412+
// workGroupSize
413+
// =
414+
//
415+
// let run =
416+
// SpGeMM.Expand.run clContext workGroupSize opAdd opMul
417+
//
418+
// fun (processor: MailboxProcessor<_>) allocationMode (leftMatrix: ClMatrix<'a>) (rightMatrix: ClMatrix<'b>) ->
419+
// match leftMatrix, rightMatrix with
420+
// | ClMatrix.CSR leftMatrix, ClMatrix.CSR rightMatrix ->
421+
// ClMatrix.LIL
422+
// <| run processor allocationMode leftMatrix rightMatrix
423+
// | _ -> failwith "Matrix formats are not matching"

0 commit comments

Comments
 (0)