Skip to content

Commit fc55f85

Browse files
committed
perf: kronecker
1 parent ab946ea commit fc55f85

3 files changed

Lines changed: 60 additions & 97 deletions

File tree

src/GraphBLAS-sharp.Backend/Matrix/CSR/Kronecker.fs

Lines changed: 54 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -15,66 +15,63 @@ open GraphBLAS.FSharp.Backend.Objects.ClContext
1515
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
1616

1717
module internal Kronecker =
18-
let private getBitmap (clContext: ClContext) workGroupSize op =
18+
let private updateBitmap (clContext: ClContext) workGroupSize op =
1919

20-
let getBitmap (op: Expr<'a option -> 'b option -> 'c option>) =
21-
<@ fun (ndRange: Range1D) (prevSum: ClCell<int>) (operand: ClCell<'a>) valuesLength numberOfZeros (values: ClArray<'b>) (resultBitmap: ClArray<int>) ->
20+
let updateBitmap (op: Expr<'a option -> 'b option -> 'c option>) =
21+
<@ fun (ndRange: Range1D) (operand: ClCell<'a>) valuesLength zeroCount (values: ClArray<'b>) (resultBitmap: ClArray<int>) ->
2222

2323
let gid = ndRange.GlobalID0
2424

2525
if gid = 0 then
2626

2727
match (%op) (Some operand.Value) None with
28-
| Some _ -> resultBitmap.[0] <- prevSum.Value + numberOfZeros
29-
| _ -> resultBitmap.[0] <- prevSum.Value
28+
| Some _ -> resultBitmap.[0] <- resultBitmap.[0] + zeroCount
29+
| _ -> ()
3030

3131
else if (gid - 1) < valuesLength then
3232

3333
match (%op) (Some operand.Value) (Some values.[gid - 1]) with
34-
| Some _ -> resultBitmap.[gid] <- 1
35-
| _ -> resultBitmap.[gid] <- 0 @>
34+
| Some _ -> resultBitmap.[gid] <- resultBitmap.[gid] + 1
35+
| _ -> () @>
3636

37-
let getBitmap = clContext.Compile <| getBitmap op
37+
let updateBitmap = clContext.Compile <| updateBitmap op
3838

39-
fun (processor: MailboxProcessor<_>) (prevSum: ClCell<int>) (operand: ClCell<'a>) (matrixRight: ClMatrix.CSR<'b>) (bitmap: ClArray<int>) ->
39+
fun (processor: MailboxProcessor<_>) (operand: ClCell<'a>) (matrixRight: ClMatrix.CSR<'b>) (bitmap: ClArray<int>) ->
4040

4141
let resultLength = matrixRight.NNZ + 1
4242

4343
let ndRange =
4444
Range1D.CreateValid(resultLength, workGroupSize)
4545

46-
let getBitmap = getBitmap.GetKernel()
46+
let updateBitmap = updateBitmap.GetKernel()
4747

4848
let numberOfZeros =
49-
matrixRight.ColumnCount * matrixRight.RowCount - matrixRight.NNZ
49+
matrixRight.ColumnCount * matrixRight.RowCount
50+
- matrixRight.NNZ
5051

5152
processor.Post(
5253
Msg.MsgSetArguments
5354
(fun () ->
54-
getBitmap.KernelFunc
55-
ndRange
56-
prevSum
57-
operand
58-
matrixRight.NNZ
59-
numberOfZeros
60-
matrixRight.Values
61-
bitmap)
55+
updateBitmap.KernelFunc ndRange operand matrixRight.NNZ numberOfZeros matrixRight.Values bitmap)
6256
)
6357

64-
processor.Post(Msg.CreateRunMsg<_, _> getBitmap)
58+
processor.Post(Msg.CreateRunMsg<_, _> updateBitmap)
6559

6660
let private getAllocationSize (clContext: ClContext) workGroupSize op =
6761

68-
let getBitmap = getBitmap clContext workGroupSize op
62+
let updateBitmap = updateBitmap clContext workGroupSize op
6963

7064
let sum =
7165
Reduce.sum <@ fun x y -> x + y @> 0 clContext workGroupSize
7266

7367
let item = ClArray.item clContext workGroupSize
7468

69+
let createClArray =
70+
ClArray.zeroCreate clContext workGroupSize
71+
7572
let opOnHost = QuotationEvaluator.Evaluate op
7673

77-
fun (queue: MailboxProcessor<_>) (matrixLeft: ClMatrix.CSR<'a>) (matrixRight: ClMatrix.CSR<'b>) ->
74+
fun (queue: MailboxProcessor<_>) (matrixZero: COO<'c> option) (matrixLeft: CSR<'a>) (matrixRight: CSR<'b>) ->
7875

7976
let nnz =
8077
match opOnHost None None with
@@ -89,28 +86,30 @@ module internal Kronecker =
8986

9087
leftZeroCount * rightZeroCount
9188
| _ -> 0
92-
|> clContext.CreateClCell
9389

9490
let bitmap =
95-
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, matrixRight.NNZ + 1)
91+
createClArray queue DeviceOnly (matrixRight.NNZ + 1)
9692

97-
let nnz =
98-
{ 0 .. matrixLeft.NNZ - 1 }
99-
|> Seq.fold
100-
(fun acc index ->
101-
let value = item queue index matrixLeft.Values
93+
for index in 0 .. matrixLeft.NNZ - 1 do
94+
let value = item queue index matrixLeft.Values
10295

103-
getBitmap queue acc value matrixRight bitmap
96+
updateBitmap queue value matrixRight bitmap
10497

105-
let nnz = sum queue bitmap
98+
value.Free queue
10699

107-
acc.Free queue
108-
value.Free queue
100+
let bitmapSum = sum queue bitmap
109101

110-
nnz)
111-
nnz
102+
bitmap.Free queue
103+
104+
let leftZeroCount =
105+
matrixLeft.ColumnCount * matrixLeft.RowCount
106+
- matrixLeft.NNZ
112107

113-
nnz.ToHostAndFree queue
108+
match matrixZero with
109+
| Some m -> m.NNZ * leftZeroCount
110+
| _ -> 0
111+
+ nnz
112+
+ bitmapSum.ToHostAndFree queue
114113

115114
let private preparePositions<'a, 'b, 'c when 'b: struct> (clContext: ClContext) workGroupSize op =
116115

@@ -241,7 +240,8 @@ module internal Kronecker =
241240

242241
fun (processor: MailboxProcessor<_>) startIndex (rowOffset: int) (columnOffset: int) (resultMatrix: COO<'c>) (sourceMatrix: COO<'c>) ->
243242

244-
let ndRange = Range1D.CreateValid(sourceMatrix.NNZ, workGroupSize)
243+
let ndRange =
244+
Range1D.CreateValid(sourceMatrix.NNZ, workGroupSize)
245245

246246
let kernel = kernel.GetKernel()
247247

@@ -278,22 +278,15 @@ module internal Kronecker =
278278
let mutable startIndex = startIndex
279279

280280
let insertMany row firstColumn count =
281-
let rec insertManyRec iter =
282-
if iter >= count then
283-
()
284-
else
285-
let rowOffset = row * matrixZero.RowCount
281+
for i in 0 .. count - 1 do
282+
let rowOffset = row * matrixZero.RowCount
286283

287-
let columnOffset =
288-
(firstColumn + iter) * matrixZero.ColumnCount
284+
let columnOffset =
285+
(firstColumn + i) * matrixZero.ColumnCount
289286

290-
copy queue startIndex rowOffset columnOffset resultMatrix matrixZero
287+
copy queue startIndex rowOffset columnOffset resultMatrix matrixZero
291288

292-
startIndex <- startIndex + matrixZero.NNZ
293-
294-
insertManyRec (iter + 1)
295-
296-
insertManyRec 0
289+
startIndex <- startIndex + matrixZero.NNZ
297290

298291
let rec insertInRowRec zeroCounts row column =
299292
match zeroCounts with
@@ -303,15 +296,8 @@ module internal Kronecker =
303296

304297
insertInRowRec tl row (h + column + 1)
305298

306-
let rec insertZeroRec row =
307-
if row >= rowCount then
308-
()
309-
else
310-
insertInRowRec zeroCounts.[row] row 0
311-
312-
insertZeroRec (row + 1)
313-
314-
insertZeroRec 0
299+
for row in 0 .. rowCount - 1 do
300+
insertInRowRec zeroCounts.[row] row 0
315301

316302
let insertNonZero (clContext: ClContext) workGroupSize op =
317303

@@ -340,12 +326,12 @@ module internal Kronecker =
340326

341327
let mutable startIndex = 0
342328

343-
let rec insertInRowRec row rightEdge index =
344-
if index > rightEdge then
345-
()
346-
else
347-
let value = item queue index leftValues
348-
let column = leftColsHost.[index]
329+
for row in 0 .. rowCount - 1 do
330+
let leftEdge, rightEdge = rowsEdges.[row]
331+
332+
for i in leftEdge .. rightEdge do
333+
let value = item queue i leftValues
334+
let column = leftColsHost.[i]
349335

350336
let rowOffset = row * matrixRight.RowCount
351337
let columnOffset = column * matrixRight.ColumnCount
@@ -354,26 +340,7 @@ module internal Kronecker =
354340

355341
value.Free queue
356342

357-
startIndex <-
358-
setPositions rowOffset columnOffset startIndex resultMatrix mappedMatrix bitmap
359-
// printfn $"resultMatrix.Values: %A{resultMatrix.Values.ToHost queue}"
360-
// printfn $"resultMatrix.Rows: %A{resultMatrix.Rows.ToHost queue}"
361-
// printfn $"resultMatrix.Columns: %A{resultMatrix.Columns.ToHost queue}"
362-
// printfn $"startIndex: %A{startIndex.ToHost queue}"
363-
364-
insertInRowRec row rightEdge (index + 1)
365-
366-
let rec insertNonZeroRec row =
367-
if row >= rowCount then
368-
()
369-
else
370-
let leftEdge, rightEdge = rowsEdges.[row]
371-
372-
insertInRowRec row rightEdge leftEdge
373-
374-
insertNonZeroRec (row + 1)
375-
376-
insertNonZeroRec 0
343+
startIndex <- setPositions rowOffset columnOffset startIndex resultMatrix mappedMatrix bitmap
377344

378345
bitmap.Free queue
379346
mappedMatrix.Free queue
@@ -440,8 +407,7 @@ module internal Kronecker =
440407
insertNonZero queue rowsEdges matrixRight matrixLeft.Values leftColumns resultMatrix
441408

442409
match matrixZero with
443-
| Some m ->
444-
insertZero queue startIndex zeroCounts m resultMatrix
410+
| Some m -> insertZero queue startIndex zeroCounts m resultMatrix
445411
| _ -> ()
446412

447413
resultMatrix
@@ -468,16 +434,8 @@ module internal Kronecker =
468434
let matrixZero =
469435
mapWithValue queue allocationMode None matrixRight
470436

471-
let size = getSize queue matrixLeft matrixRight
472-
473-
let leftZeroCount =
474-
matrixLeft.ColumnCount * matrixLeft.RowCount
475-
- matrixLeft.NNZ
476-
477437
let size =
478-
match matrixZero with
479-
| Some m -> size + m.NNZ * leftZeroCount
480-
| _ -> size
438+
getSize queue matrixZero matrixLeft matrixRight
481439

482440
if size = 0 then
483441
None

tests/GraphBLAS-sharp.Tests/Backend/Matrix/Kronecker.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ open GraphBLAS.FSharp.Objects.MatrixExtensions
1515

1616
let config =
1717
{ Utils.defaultConfig with
18-
endSize = 100
18+
endSize = 30
1919
maxTest = 20 }
2020

2121
let logger = Log.create "kronecker.Tests"

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ open GraphBLAS.FSharp.Tests
106106
// Host.Matrix.Convert.tests
107107
// Host.IO.MtxReader.test ]
108108
// |> testSequenced
109+
//
110+
// [<Tests>]
111+
// let allTests =
112+
// testList "All" [ deviceTests; hostTests ]
113+
// |> testSequenced
109114

110115
[<Tests>]
111116
let allTests =

0 commit comments

Comments
 (0)