Skip to content

Commit ab946ea

Browse files
committed
fix: kronecker
1 parent 32b27ba commit ab946ea

2 files changed

Lines changed: 42 additions & 45 deletions

File tree

src/GraphBLAS-sharp.Backend/Matrix/CSR/Kronecker.fs

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
namespace GraphBLAS.FSharp.Backend.Matrix.CSR
22

33
open FSharp.Quotations.Evaluator
4+
open FSharpx.Collections
45
open Microsoft.FSharp.Quotations
56
open Brahma.FSharp
67
open GraphBLAS.FSharp.Backend.Quotes
@@ -166,21 +167,18 @@ module internal Kronecker =
166167
let setPositions<'c when 'c: struct> (clContext: ClContext) workGroupSize =
167168

168169
let setPositions =
169-
<@ fun (ndRange: Range1D) rowCount columnCount (nnz: ClCell<int>) (rowOffset: ClCell<int>) (columnOffset: ClCell<int>) (startIndex: ClCell<int>) (bitmap: ClArray<int>) (values: ClArray<'c>) (resultRows: ClArray<int>) (resultColumns: ClArray<int>) (resultValues: ClArray<'c>) ->
170+
<@ fun (ndRange: Range1D) rowCount columnCount startIndex (nnz: ClCell<int>) (rowOffset: ClCell<int>) (columnOffset: ClCell<int>) (bitmap: ClArray<int>) (values: ClArray<'c>) (resultRows: ClArray<int>) (resultColumns: ClArray<int>) (resultValues: ClArray<'c>) ->
170171

171172
let gid = ndRange.GlobalID0
172173

173-
if gid = 0 then
174-
nnz.Value <- nnz.Value + startIndex.Value
175-
176174
if gid < rowCount * columnCount
177175
&& (gid = 0 && bitmap.[gid] = 1
178176
|| gid > 0 && bitmap.[gid - 1] < bitmap.[gid]) then
179177

180178
let columnIndex = gid % columnCount
181179
let rowIndex = gid / columnCount
182180

183-
let index = startIndex.Value + bitmap.[gid] - 1
181+
let index = startIndex + bitmap.[gid] - 1
184182

185183
resultRows.[index] <- rowIndex + rowOffset.Value
186184
resultColumns.[index] <- columnIndex + columnOffset.Value
@@ -191,7 +189,7 @@ module internal Kronecker =
191189
let scan =
192190
PrefixSum.standardIncludeInPlace clContext workGroupSize
193191

194-
fun (processor: MailboxProcessor<_>) rowCount columnCount (rowOffset: int) (columnOffset: int) (startIndex: ClCell<int>) (resultMatrix: COO<'c>) (values: ClArray<'c>) (bitmap: ClArray<int>) ->
192+
fun (processor: MailboxProcessor<_>) rowCount columnCount (rowOffset: int) (columnOffset: int) (startIndex: int) (resultMatrix: COO<'c>) (values: ClArray<'c>) (bitmap: ClArray<int>) ->
195193

196194
let sum = scan processor bitmap
197195

@@ -210,10 +208,10 @@ module internal Kronecker =
210208
ndRange
211209
rowCount
212210
columnCount
211+
startIndex
213212
sum
214213
rowOffset
215214
columnOffset
216-
startIndex
217215
bitmap
218216
values
219217
resultMatrix.Rows
@@ -223,6 +221,8 @@ module internal Kronecker =
223221

224222
processor.Post(Msg.CreateRunMsg<_, _> kernel)
225223

224+
(sum.ToHostAndFree processor) + startIndex
225+
226226
let copyToResult (clContext: ClContext) workGroupSize =
227227

228228
let copyToResult =
@@ -257,12 +257,12 @@ module internal Kronecker =
257257
sourceMatrix.NNZ
258258
rowOffset
259259
columnOffset
260-
resultMatrix.Rows
261-
resultMatrix.Columns
262-
resultMatrix.Values
263260
sourceMatrix.Rows
264261
sourceMatrix.Columns
265-
sourceMatrix.Values)
262+
sourceMatrix.Values
263+
resultMatrix.Rows
264+
resultMatrix.Columns
265+
resultMatrix.Values)
266266
)
267267

268268
processor.Post(Msg.CreateRunMsg<_, _> kernel)
@@ -271,7 +271,7 @@ module internal Kronecker =
271271

272272
let copy = copyToResult clContext workGroupSize
273273

274-
fun queue (startIndex: int) (zeroCounts: int list array) (matrixZero: COO<'c>) (matrixRight: CSR<'b>) resultMatrix ->
274+
fun queue startIndex (zeroCounts: int list array) (matrixZero: COO<'c>) resultMatrix ->
275275

276276
let rowCount = zeroCounts.Length
277277

@@ -282,10 +282,10 @@ module internal Kronecker =
282282
if iter >= count then
283283
()
284284
else
285-
let rowOffset = row * matrixRight.RowCount
285+
let rowOffset = row * matrixZero.RowCount
286286

287287
let columnOffset =
288-
(firstColumn + iter) * matrixRight.ColumnCount
288+
(firstColumn + iter) * matrixZero.ColumnCount
289289

290290
copy queue startIndex rowOffset columnOffset resultMatrix matrixZero
291291

@@ -338,7 +338,7 @@ module internal Kronecker =
338338
let mappedMatrix =
339339
clContext.CreateClArrayWithSpecificAllocationMode<'c>(DeviceOnly, length)
340340

341-
let startIndex = clContext.CreateClCell 0
341+
let mutable startIndex = 0
342342

343343
let rec insertInRowRec row rightEdge index =
344344
if index > rightEdge then
@@ -354,7 +354,12 @@ module internal Kronecker =
354354

355355
value.Free queue
356356

357-
setPositions rowOffset columnOffset startIndex resultMatrix mappedMatrix bitmap
357+
startIndex <-
358+
setPositions rowOffset columnOffset startIndex resultMatrix mappedMatrix bitmap
359+
// printfn $"resultMatrix.Values: %A{resultMatrix.Values.ToHost queue}"
360+
// printfn $"resultMatrix.Rows: %A{resultMatrix.Rows.ToHost queue}"
361+
// printfn $"resultMatrix.Columns: %A{resultMatrix.Columns.ToHost queue}"
362+
// printfn $"startIndex: %A{startIndex.ToHost queue}"
358363

359364
insertInRowRec row rightEdge (index + 1)
360365

@@ -434,12 +439,9 @@ module internal Kronecker =
434439
let startIndex =
435440
insertNonZero queue rowsEdges matrixRight matrixLeft.Values leftColumns resultMatrix
436441

437-
let startIndex = startIndex.ToHostAndFree queue
438-
439442
match matrixZero with
440443
| Some m ->
441-
insertZero queue startIndex zeroCounts m matrixRight resultMatrix
442-
m.Dispose queue
444+
insertZero queue startIndex zeroCounts m resultMatrix
443445
| _ -> ()
444446

445447
resultMatrix
@@ -483,6 +485,10 @@ module internal Kronecker =
483485
let result =
484486
mapAll queue allocationMode size matrixZero matrixLeft matrixRight
485487

488+
match matrixZero with
489+
| Some m -> m.Dispose queue
490+
| _ -> ()
491+
486492
bitonic queue result.Rows result.Columns result.Values
487493

488494
result |> Some

tests/GraphBLAS-sharp.Tests/Backend/Matrix/Kronecker.fs

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,14 @@ open GraphBLAS.FSharp.Objects.MatrixExtensions
1515

1616
let config =
1717
{ Utils.defaultConfig with
18-
endSize = 8
19-
maxTest = 50 }
18+
endSize = 100
19+
maxTest = 20 }
2020

2121
let logger = Log.create "kronecker.Tests"
2222

2323
let workGroupSize = Utils.defaultWorkGroupSize
2424

2525
let makeTest context processor zero isEqual op kroneckerFun (leftMatrix: 'a [,], rightMatrix: 'a [,]) =
26-
let leftMatrix = [[false; false]
27-
[false; false]
28-
[true; false]] |> array2D
29-
30-
let rightMatrix = [[false; false]
31-
[false; true]
32-
[false; false]] |> array2D
33-
3426
let m1 =
3527
Utils.createMatrixFromArray2D CSR leftMatrix (isEqual zero)
3628

@@ -83,22 +75,21 @@ let generalTests (testContext: TestContext) =
8375
let queue = testContext.Queue
8476
queue.Error.Add(fun e -> failwithf "%A" e)
8577

86-
// createGeneralTest context queue false (=) (&&) ArithmeticOperations.boolMulOption "mul"
78+
createGeneralTest context queue false (=) (&&) ArithmeticOperations.boolMulOption "mul"
8779
createGeneralTest context queue false (=) (||) ArithmeticOperations.boolSumOption "sum"
88-
//
89-
// createGeneralTest context queue 0 (=) (*) ArithmeticOperations.intMulOption "mul"
90-
// createGeneralTest context queue 0 (=) (+) ArithmeticOperations.intSumOption "sum"
91-
//
92-
// createGeneralTest context queue 0uy (=) (*) ArithmeticOperations.byteMulOption "mul"
93-
// createGeneralTest context queue 0uy (=) (+) ArithmeticOperations.byteSumOption "sum"
94-
95-
// createGeneralTest context queue 0.0f Utils.float32IsEqual (*) ArithmeticOperations.float32MulOption "mul"
96-
// createGeneralTest context? queue 0.0f Utils.float32IsEqual (+) ArithmeticOperations.float32SumOption "sum"
97-
98-
// if Utils.isFloat64Available context.ClDevice then
99-
// createGeneralTest context queue 0.0 Utils.floatIsEqual (*) ArithmeticOperations.floatMulOption "mul"
100-
// createGeneralTest context queue 0.0 Utils.floatIsEqual (+) ArithmeticOperations.floatSumOption "sum"
101-
]
80+
81+
createGeneralTest context queue 0 (=) (*) ArithmeticOperations.intMulOption "mul"
82+
createGeneralTest context queue 0 (=) (+) ArithmeticOperations.intSumOption "sum"
83+
84+
createGeneralTest context queue 0uy (=) (*) ArithmeticOperations.byteMulOption "mul"
85+
createGeneralTest context queue 0uy (=) (+) ArithmeticOperations.byteSumOption "sum"
86+
87+
createGeneralTest context queue 0.0f Utils.float32IsEqual (*) ArithmeticOperations.float32MulOption "mul"
88+
createGeneralTest context queue 0.0f Utils.float32IsEqual (+) ArithmeticOperations.float32SumOption "sum"
89+
90+
if Utils.isFloat64Available context.ClDevice then
91+
createGeneralTest context queue 0.0 Utils.floatIsEqual (*) ArithmeticOperations.floatMulOption "mul"
92+
createGeneralTest context queue 0.0 Utils.floatIsEqual (+) ArithmeticOperations.floatSumOption "sum" ]
10293

10394
let tests =
10495
gpuTests "Backend.Matrix.kronecker tests" generalTests

0 commit comments

Comments
 (0)