11namespace GraphBLAS.FSharp.Backend.Matrix.CSR
22
33open FSharp.Quotations .Evaluator
4+ open FSharpx.Collections
45open Microsoft.FSharp .Quotations
56open Brahma.FSharp
67open GraphBLAS.FSharp .Backend .Quotes
@@ -166,21 +167,18 @@ module internal Kronecker =
166167 let setPositions < 'c when 'c : struct > ( clContext : ClContext ) workGroupSize =
167168
168169 let setPositions =
169- <@ fun ( ndRange : Range1D ) rowCount columnCount ( nnz : ClCell < int >) ( rowOffset : ClCell < int >) ( columnOffset : ClCell < int >) ( startIndex : ClCell < int >) ( bitmap : ClArray < int >) ( values : ClArray < 'c >) ( resultRows : ClArray < int >) ( resultColumns : ClArray < int >) ( resultValues : ClArray < 'c >) ->
170+ <@ fun ( ndRange : Range1D ) rowCount columnCount startIndex ( nnz : ClCell < int >) ( rowOffset : ClCell < int >) ( columnOffset : ClCell < int >) ( bitmap : ClArray < int >) ( values : ClArray < 'c >) ( resultRows : ClArray < int >) ( resultColumns : ClArray < int >) ( resultValues : ClArray < 'c >) ->
170171
171172 let gid = ndRange.GlobalID0
172173
173- if gid = 0 then
174- nnz.Value <- nnz.Value + startIndex.Value
175-
176174 if gid < rowCount * columnCount
177175 && ( gid = 0 && bitmap.[ gid] = 1
178176 || gid > 0 && bitmap.[ gid - 1 ] < bitmap.[ gid]) then
179177
180178 let columnIndex = gid % columnCount
181179 let rowIndex = gid / columnCount
182180
183- let index = startIndex.Value + bitmap.[ gid] - 1
181+ let index = startIndex + bitmap.[ gid] - 1
184182
185183 resultRows.[ index] <- rowIndex + rowOffset.Value
186184 resultColumns.[ index] <- columnIndex + columnOffset.Value
@@ -191,7 +189,7 @@ module internal Kronecker =
191189 let scan =
192190 PrefixSum.standardIncludeInPlace clContext workGroupSize
193191
194- fun ( processor : MailboxProcessor < _ >) rowCount columnCount ( rowOffset : int ) ( columnOffset : int ) ( startIndex : ClCell < int > ) ( resultMatrix : COO < 'c >) ( values : ClArray < 'c >) ( bitmap : ClArray < int >) ->
192+ fun ( processor : MailboxProcessor < _ >) rowCount columnCount ( rowOffset : int ) ( columnOffset : int ) ( startIndex : int ) ( resultMatrix : COO < 'c >) ( values : ClArray < 'c >) ( bitmap : ClArray < int >) ->
195193
196194 let sum = scan processor bitmap
197195
@@ -210,10 +208,10 @@ module internal Kronecker =
210208 ndRange
211209 rowCount
212210 columnCount
211+ startIndex
213212 sum
214213 rowOffset
215214 columnOffset
216- startIndex
217215 bitmap
218216 values
219217 resultMatrix.Rows
@@ -223,6 +221,8 @@ module internal Kronecker =
223221
224222 processor.Post( Msg.CreateRunMsg<_, _> kernel)
225223
224+ ( sum.ToHostAndFree processor) + startIndex
225+
226226 let copyToResult ( clContext : ClContext ) workGroupSize =
227227
228228 let copyToResult =
@@ -257,12 +257,12 @@ module internal Kronecker =
257257 sourceMatrix.NNZ
258258 rowOffset
259259 columnOffset
260- resultMatrix.Rows
261- resultMatrix.Columns
262- resultMatrix.Values
263260 sourceMatrix.Rows
264261 sourceMatrix.Columns
265- sourceMatrix.Values)
262+ sourceMatrix.Values
263+ resultMatrix.Rows
264+ resultMatrix.Columns
265+ resultMatrix.Values)
266266 )
267267
268268 processor.Post( Msg.CreateRunMsg<_, _> kernel)
@@ -271,7 +271,7 @@ module internal Kronecker =
271271
272272 let copy = copyToResult clContext workGroupSize
273273
274- fun queue ( startIndex : int ) ( zeroCounts : int list array ) ( matrixZero : COO < 'c >) ( matrixRight : CSR < 'b >) resultMatrix ->
274+ fun queue startIndex ( zeroCounts : int list array ) ( matrixZero : COO < 'c >) resultMatrix ->
275275
276276 let rowCount = zeroCounts.Length
277277
@@ -282,10 +282,10 @@ module internal Kronecker =
282282 if iter >= count then
283283 ()
284284 else
285- let rowOffset = row * matrixRight .RowCount
285+ let rowOffset = row * matrixZero .RowCount
286286
287287 let columnOffset =
288- ( firstColumn + iter) * matrixRight .ColumnCount
288+ ( firstColumn + iter) * matrixZero .ColumnCount
289289
290290 copy queue startIndex rowOffset columnOffset resultMatrix matrixZero
291291
@@ -338,7 +338,7 @@ module internal Kronecker =
338338 let mappedMatrix =
339339 clContext.CreateClArrayWithSpecificAllocationMode< 'c>( DeviceOnly, length)
340340
341- let startIndex = clContext.CreateClCell 0
341+ let mutable startIndex = 0
342342
343343 let rec insertInRowRec row rightEdge index =
344344 if index > rightEdge then
@@ -354,7 +354,12 @@ module internal Kronecker =
354354
355355 value.Free queue
356356
357- setPositions rowOffset columnOffset startIndex resultMatrix mappedMatrix bitmap
357+ startIndex <-
358+ setPositions rowOffset columnOffset startIndex resultMatrix mappedMatrix bitmap
359+ // printfn $"resultMatrix.Values: %A{resultMatrix.Values.ToHost queue}"
360+ // printfn $"resultMatrix.Rows: %A{resultMatrix.Rows.ToHost queue}"
361+ // printfn $"resultMatrix.Columns: %A{resultMatrix.Columns.ToHost queue}"
362+ // printfn $"startIndex: %A{startIndex.ToHost queue}"
358363
359364 insertInRowRec row rightEdge ( index + 1 )
360365
@@ -434,12 +439,9 @@ module internal Kronecker =
434439 let startIndex =
435440 insertNonZero queue rowsEdges matrixRight matrixLeft.Values leftColumns resultMatrix
436441
437- let startIndex = startIndex.ToHostAndFree queue
438-
439442 match matrixZero with
440443 | Some m ->
441- insertZero queue startIndex zeroCounts m matrixRight resultMatrix
442- m.Dispose queue
444+ insertZero queue startIndex zeroCounts m resultMatrix
443445 | _ -> ()
444446
445447 resultMatrix
@@ -483,6 +485,10 @@ module internal Kronecker =
483485 let result =
484486 mapAll queue allocationMode size matrixZero matrixLeft matrixRight
485487
488+ match matrixZero with
489+ | Some m -> m.Dispose queue
490+ | _ -> ()
491+
486492 bitonic queue result.Rows result.Columns result.Values
487493
488494 result |> Some
0 commit comments