@@ -15,66 +15,63 @@ open GraphBLAS.FSharp.Backend.Objects.ClContext
1515open GraphBLAS.FSharp .Backend .Objects .ArraysExtensions
1616
1717module internal Kronecker =
18- let private getBitmap ( clContext : ClContext ) workGroupSize op =
18+ let private updateBitmap ( clContext : ClContext ) workGroupSize op =
1919
20- let getBitmap ( op : Expr < 'a option -> 'b option -> 'c option >) =
21- <@ fun ( ndRange : Range1D ) ( prevSum : ClCell < int >) ( operand : ClCell < 'a >) valuesLength numberOfZeros ( values : ClArray < 'b >) ( resultBitmap : ClArray < int >) ->
20+ let updateBitmap ( op : Expr < 'a option -> 'b option -> 'c option >) =
21+ <@ fun ( ndRange : Range1D ) ( operand : ClCell < 'a >) valuesLength zeroCount ( values : ClArray < 'b >) ( resultBitmap : ClArray < int >) ->
2222
2323 let gid = ndRange.GlobalID0
2424
2525 if gid = 0 then
2626
2727 match (% op) ( Some operand.Value) None with
28- | Some _ -> resultBitmap.[ 0 ] <- prevSum.Value + numberOfZeros
29- | _ -> resultBitmap .[ 0 ] <- prevSum.Value
28+ | Some _ -> resultBitmap.[ 0 ] <- resultBitmap .[ 0 ] + zeroCount
29+ | _ -> ()
3030
3131 else if ( gid - 1 ) < valuesLength then
3232
3333 match (% op) ( Some operand.Value) ( Some values.[ gid - 1 ]) with
34- | Some _ -> resultBitmap.[ gid] <- 1
35- | _ -> resultBitmap .[ gid ] <- 0 @>
34+ | Some _ -> resultBitmap.[ gid] <- resultBitmap .[ gid ] + 1
35+ | _ -> () @>
3636
37- let getBitmap = clContext.Compile <| getBitmap op
37+ let updateBitmap = clContext.Compile <| updateBitmap op
3838
39- fun ( processor : MailboxProcessor < _ >) ( prevSum : ClCell < int >) ( operand : ClCell < 'a >) ( matrixRight : ClMatrix.CSR < 'b >) ( bitmap : ClArray < int >) ->
39+ fun ( processor : MailboxProcessor < _ >) ( operand : ClCell < 'a >) ( matrixRight : ClMatrix.CSR < 'b >) ( bitmap : ClArray < int >) ->
4040
4141 let resultLength = matrixRight.NNZ + 1
4242
4343 let ndRange =
4444 Range1D.CreateValid( resultLength, workGroupSize)
4545
46- let getBitmap = getBitmap .GetKernel()
46+ let updateBitmap = updateBitmap .GetKernel()
4747
4848 let numberOfZeros =
49- matrixRight.ColumnCount * matrixRight.RowCount - matrixRight.NNZ
49+ matrixRight.ColumnCount * matrixRight.RowCount
50+ - matrixRight.NNZ
5051
5152 processor.Post(
5253 Msg.MsgSetArguments
5354 ( fun () ->
54- getBitmap.KernelFunc
55- ndRange
56- prevSum
57- operand
58- matrixRight.NNZ
59- numberOfZeros
60- matrixRight.Values
61- bitmap)
55+ updateBitmap.KernelFunc ndRange operand matrixRight.NNZ numberOfZeros matrixRight.Values bitmap)
6256 )
6357
64- processor.Post( Msg.CreateRunMsg<_, _> getBitmap )
58+ processor.Post( Msg.CreateRunMsg<_, _> updateBitmap )
6559
6660 let private getAllocationSize ( clContext : ClContext ) workGroupSize op =
6761
68- let getBitmap = getBitmap clContext workGroupSize op
62+ let updateBitmap = updateBitmap clContext workGroupSize op
6963
7064 let sum =
7165 Reduce.sum <@ fun x y -> x + y @> 0 clContext workGroupSize
7266
7367 let item = ClArray.item clContext workGroupSize
7468
69+ let createClArray =
70+ ClArray.zeroCreate clContext workGroupSize
71+
7572 let opOnHost = QuotationEvaluator.Evaluate op
7673
77- fun ( queue : MailboxProcessor < _ >) ( matrixLeft : ClMatrix. CSR < 'a >) ( matrixRight : ClMatrix. CSR < 'b >) ->
74+ fun ( queue : MailboxProcessor < _ >) ( matrixZero : COO < 'c > option ) ( matrixLeft : CSR < 'a >) ( matrixRight : CSR < 'b >) ->
7875
7976 let nnz =
8077 match opOnHost None None with
@@ -89,28 +86,30 @@ module internal Kronecker =
8986
9087 leftZeroCount * rightZeroCount
9188 | _ -> 0
92- |> clContext.CreateClCell
9389
9490 let bitmap =
95- clContext.CreateClArrayWithSpecificAllocationMode < int >( DeviceOnly, matrixRight.NNZ + 1 )
91+ createClArray queue DeviceOnly ( matrixRight.NNZ + 1 )
9692
97- let nnz =
98- { 0 .. matrixLeft.NNZ - 1 }
99- |> Seq.fold
100- ( fun acc index ->
101- let value = item queue index matrixLeft.Values
93+ for index in 0 .. matrixLeft.NNZ - 1 do
94+ let value = item queue index matrixLeft.Values
10295
103- getBitmap queue acc value matrixRight bitmap
96+ updateBitmap queue value matrixRight bitmap
10497
105- let nnz = sum queue bitmap
98+ value.Free queue
10699
107- acc.Free queue
108- value.Free queue
100+ let bitmapSum = sum queue bitmap
109101
110- nnz)
111- nnz
102+ bitmap.Free queue
103+
104+ let leftZeroCount =
105+ matrixLeft.ColumnCount * matrixLeft.RowCount
106+ - matrixLeft.NNZ
112107
113- nnz.ToHostAndFree queue
108+ match matrixZero with
109+ | Some m -> m.NNZ * leftZeroCount
110+ | _ -> 0
111+ + nnz
112+ + bitmapSum.ToHostAndFree queue
114113
115114 let private preparePositions < 'a , 'b , 'c when 'b : struct > ( clContext : ClContext ) workGroupSize op =
116115
@@ -241,7 +240,8 @@ module internal Kronecker =
241240
242241 fun ( processor : MailboxProcessor < _ >) startIndex ( rowOffset : int ) ( columnOffset : int ) ( resultMatrix : COO < 'c >) ( sourceMatrix : COO < 'c >) ->
243242
244- let ndRange = Range1D.CreateValid( sourceMatrix.NNZ, workGroupSize)
243+ let ndRange =
244+ Range1D.CreateValid( sourceMatrix.NNZ, workGroupSize)
245245
246246 let kernel = kernel.GetKernel()
247247
@@ -278,22 +278,15 @@ module internal Kronecker =
278278 let mutable startIndex = startIndex
279279
280280 let insertMany row firstColumn count =
281- let rec insertManyRec iter =
282- if iter >= count then
283- ()
284- else
285- let rowOffset = row * matrixZero.RowCount
281+ for i in 0 .. count - 1 do
282+ let rowOffset = row * matrixZero.RowCount
286283
287- let columnOffset =
288- ( firstColumn + iter ) * matrixZero.ColumnCount
284+ let columnOffset =
285+ ( firstColumn + i ) * matrixZero.ColumnCount
289286
290- copy queue startIndex rowOffset columnOffset resultMatrix matrixZero
287+ copy queue startIndex rowOffset columnOffset resultMatrix matrixZero
291288
292- startIndex <- startIndex + matrixZero.NNZ
293-
294- insertManyRec ( iter + 1 )
295-
296- insertManyRec 0
289+ startIndex <- startIndex + matrixZero.NNZ
297290
298291 let rec insertInRowRec zeroCounts row column =
299292 match zeroCounts with
@@ -303,15 +296,8 @@ module internal Kronecker =
303296
304297 insertInRowRec tl row ( h + column + 1 )
305298
306- let rec insertZeroRec row =
307- if row >= rowCount then
308- ()
309- else
310- insertInRowRec zeroCounts.[ row] row 0
311-
312- insertZeroRec ( row + 1 )
313-
314- insertZeroRec 0
299+ for row in 0 .. rowCount - 1 do
300+ insertInRowRec zeroCounts.[ row] row 0
315301
316302 let insertNonZero ( clContext : ClContext ) workGroupSize op =
317303
@@ -340,12 +326,12 @@ module internal Kronecker =
340326
341327 let mutable startIndex = 0
342328
343- let rec insertInRowRec row rightEdge index =
344- if index > rightEdge then
345- ()
346- else
347- let value = item queue index leftValues
348- let column = leftColsHost.[ index ]
329+ for row in 0 .. rowCount - 1 do
330+ let leftEdge , rightEdge = rowsEdges .[ row ]
331+
332+ for i in leftEdge .. rightEdge do
333+ let value = item queue i leftValues
334+ let column = leftColsHost.[ i ]
349335
350336 let rowOffset = row * matrixRight.RowCount
351337 let columnOffset = column * matrixRight.ColumnCount
@@ -354,26 +340,7 @@ module internal Kronecker =
354340
355341 value.Free queue
356342
357- startIndex <-
358- setPositions rowOffset columnOffset startIndex resultMatrix mappedMatrix bitmap
359- // printfn $"resultMatrix.Values: %A{resultMatrix.Values.ToHost queue}"
360- // printfn $"resultMatrix.Rows: %A{resultMatrix.Rows.ToHost queue}"
361- // printfn $"resultMatrix.Columns: %A{resultMatrix.Columns.ToHost queue}"
362- // printfn $"startIndex: %A{startIndex.ToHost queue}"
363-
364- insertInRowRec row rightEdge ( index + 1 )
365-
366- let rec insertNonZeroRec row =
367- if row >= rowCount then
368- ()
369- else
370- let leftEdge , rightEdge = rowsEdges.[ row]
371-
372- insertInRowRec row rightEdge leftEdge
373-
374- insertNonZeroRec ( row + 1 )
375-
376- insertNonZeroRec 0
343+ startIndex <- setPositions rowOffset columnOffset startIndex resultMatrix mappedMatrix bitmap
377344
378345 bitmap.Free queue
379346 mappedMatrix.Free queue
@@ -440,8 +407,7 @@ module internal Kronecker =
440407 insertNonZero queue rowsEdges matrixRight matrixLeft.Values leftColumns resultMatrix
441408
442409 match matrixZero with
443- | Some m ->
444- insertZero queue startIndex zeroCounts m resultMatrix
410+ | Some m -> insertZero queue startIndex zeroCounts m resultMatrix
445411 | _ -> ()
446412
447413 resultMatrix
@@ -468,16 +434,8 @@ module internal Kronecker =
468434 let matrixZero =
469435 mapWithValue queue allocationMode None matrixRight
470436
471- let size = getSize queue matrixLeft matrixRight
472-
473- let leftZeroCount =
474- matrixLeft.ColumnCount * matrixLeft.RowCount
475- - matrixLeft.NNZ
476-
477437 let size =
478- match matrixZero with
479- | Some m -> size + m.NNZ * leftZeroCount
480- | _ -> size
438+ getSize queue matrixZero matrixLeft matrixRight
481439
482440 if size = 0 then
483441 None
0 commit comments