Skip to content

Commit 1805e3a

Browse files
authored
Merge pull request #32 from kirillgarbar/net5
CSR.eWiseAdd
2 parents 57c2922 + 44e85b5 commit 1805e3a

4 files changed

Lines changed: 99 additions & 59 deletions

File tree

src/GraphBLAS-sharp.Backend/COOMatrix/COOMatrix.fs

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -399,25 +399,31 @@ module COOMatrix =
399399
nonZeroRowsPointers.[gid]
400400
- nonZeroRowsPointers.[gid - 1] @>
401401

402-
let kernelCHSR =
402+
let expandNnzPerRow =
403+
<@ fun (ndRange: Range1D) totalSum (nnzPerRowSparse: ClArray<'a>) (nonZeroRowsIndices: ClArray<int>) (expandedNnzPerRow: ClArray<'a>) ->
404+
405+
let i = ndRange.GlobalID0
406+
407+
if i < totalSum then
408+
expandedNnzPerRow.[nonZeroRowsIndices.[i] + 1] <- nnzPerRowSparse.[i] @>
409+
410+
let kernelCalcHyperSparseRows =
403411
clContext.CreateClKernel calcHyperSparseRows
404412

405-
let kernelCNPRS =
413+
let kernelCalcNnzPerRowSparse =
406414
clContext.CreateClKernel calcNnzPerRowSparse
407415

416+
let kernelExpandNnzPerRow = clContext.CreateClKernel expandNnzPerRow
417+
408418
let getUniqueBitmap = ClArray.getUniqueBitmap clContext
409419

410420
let posAndTotalSum =
411421
ClArray.prefixSumExclude clContext workGroupSize
412422

413-
let expandSparseNnzPerRow = ClArray.setPositions clContext
414-
415-
fun (processor: MailboxProcessor<_>) (rowIndices: ClArray<int>) ->
416-
417-
let nnz = rowIndices.Length
418-
419-
let ndRangeCHSR = Range1D.CreateValid(nnz, workGroupSize)
423+
let getRowPointers =
424+
ClArray.prefixSumInclude clContext workGroupSize
420425

426+
fun (processor: MailboxProcessor<_>) (rowIndices: ClArray<int>) rowCount ->
421427
let bitmap =
422428
getUniqueBitmap processor workGroupSize rowIndices
423429

@@ -430,17 +436,16 @@ module COOMatrix =
430436

431437
let totalSum = hostTotalSum.[0]
432438

433-
let ndRangeCNPRS =
434-
Range1D.CreateValid(totalSum, workGroupSize)
439+
let nonZeroRowsIndices = clContext.CreateClArray totalSum
440+
let nonZeroRowsPointers = clContext.CreateClArray totalSum
435441

436-
let zeroArray = Array.zeroCreate totalSum
437-
let nonZeroRowsIndices = clContext.CreateClArray zeroArray
438-
let nonZeroRowsPointers = clContext.CreateClArray zeroArray
442+
let nnz = rowIndices.Length
443+
let ndRangeCHSR = Range1D.CreateValid(nnz, workGroupSize)
439444

440445
processor.Post(
441446
Msg.MsgSetArguments
442447
(fun () ->
443-
kernelCHSR.SetArguments
448+
kernelCalcHyperSparseRows.SetArguments
444449
ndRangeCHSR
445450
rowIndices
446451
bitmap
@@ -450,22 +455,43 @@ module COOMatrix =
450455
nnz)
451456
)
452457

453-
processor.Post(Msg.CreateRunMsg<_, _> kernelCHSR)
458+
processor.Post(Msg.CreateRunMsg<_, _> kernelCalcHyperSparseRows)
459+
460+
let nnzPerRowSparse = clContext.CreateClArray totalSum
454461

455-
let nnzPerRowSparse = clContext.CreateClArray zeroArray
462+
let ndRangeCNPRSandENPR =
463+
Range1D.CreateValid(totalSum, workGroupSize)
456464

457465
processor.Post(
458466
Msg.MsgSetArguments
459-
(fun () -> kernelCNPRS.SetArguments ndRangeCNPRS nonZeroRowsPointers nnzPerRowSparse totalSum)
467+
(fun () ->
468+
kernelCalcNnzPerRowSparse.SetArguments
469+
ndRangeCNPRSandENPR
470+
nonZeroRowsPointers
471+
nnzPerRowSparse
472+
totalSum)
460473
)
461474

462-
processor.Post(Msg.CreateRunMsg<_, _> kernelCNPRS)
475+
processor.Post(Msg.CreateRunMsg<_, _> kernelCalcNnzPerRowSparse)
463476

464477
let expandedNnzPerRow =
465-
expandSparseNnzPerRow processor workGroupSize nnzPerRowSparse nonZeroRowsIndices totalSum
478+
clContext.CreateClArray(Array.zeroCreate rowCount)
479+
480+
processor.Post(
481+
Msg.MsgSetArguments
482+
(fun () ->
483+
kernelExpandNnzPerRow.SetArguments
484+
ndRangeCNPRSandENPR
485+
totalSum
486+
nnzPerRowSparse
487+
nonZeroRowsIndices
488+
expandedNnzPerRow)
489+
)
490+
491+
processor.Post(Msg.CreateRunMsg<_, _> kernelExpandNnzPerRow)
466492

467493
let rowPointers, _ =
468-
posAndTotalSum processor expandedNnzPerRow
494+
getRowPointers processor expandedNnzPerRow
469495

470496
rowPointers
471497

@@ -482,7 +508,8 @@ module COOMatrix =
482508
GraphBLAS.FSharp.Backend.ClArray.copy clContext
483509

484510
fun (processor: MailboxProcessor<_>) (matrix: COOMatrix<'a>) ->
485-
let compressedRows = compressRows processor matrix.Rows
511+
let compressedRows =
512+
compressRows processor matrix.Rows matrix.RowCount
486513

487514
let cols =
488515
copy processor workGroupSize matrix.Columns

tests/GraphBLAS-sharp.Tests/BackendCommonTests/COOMatrixEwiseAddTests.fs renamed to tests/GraphBLAS-sharp.Tests/BackendCommonTests/MatrixEwiseAddTests.fs

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module Backend.COOMatrix.EwiseAdd
1+
module Backend.EwiseAdd
22

33
open FsCheck
44
open Expecto
@@ -10,7 +10,7 @@ open GraphBLAS.FSharp
1010
open GraphBLAS.FSharp.Tests.Generators
1111
open GraphBLAS.FSharp.Tests.Utils
1212

13-
let logger = Log.create "COOMatrix.EwiseAdd.Tests"
13+
let logger = Log.create "EwiseAdd.Tests"
1414

1515
let context =
1616
let deviceType = ClDeviceType.Default
@@ -41,39 +41,40 @@ let checkResult op zero (baseMtx1: 'a [,]) (baseMtx2: 'a [,]) (actual: Matrix<'a
4141
let actual2D =
4242
Array2D.create actual.RowCount actual.ColumnCount zero
4343

44-
match actual with
45-
| MatrixCOO actual ->
46-
for i in 0 .. actual.Rows.Length - 1 do
47-
actual2D.[actual.Rows.[i], actual.Columns.[i]] <- actual.Values.[i]
44+
let actual2D =
45+
match actual with
46+
| MatrixCOO actual ->
47+
for i in 0 .. actual.Rows.Length - 1 do
48+
actual2D.[actual.Rows.[i], actual.Columns.[i]] <- actual.Values.[i]
49+
50+
actual2D
51+
| MatrixCSR actual ->
52+
let rowIndices =
53+
Array.create actual.ColumnIndices.Length 0
4854

49-
for i in 0 .. rows - 1 do
50-
for j in 0 .. columns - 1 do
51-
Expect.equal actual2D.[i, j] expected.[i, j] "Elements of matrices should be equals."
52-
| MatrixCSR actual ->
53-
let rowIndices =
54-
Array.create actual.ColumnIndices.Length 0
55+
for i in 0 .. actual.RowCount - 1 do
56+
if i < actual.RowCount - 1 then
57+
let rowStart = actual.RowPointers.[i]
58+
let rowEnd = actual.RowPointers.[i + 1]
59+
let rowLength = rowEnd - rowStart
5560

56-
for i in 0 .. actual.RowCount - 1 do
57-
if i < actual.RowCount - 1 then
58-
let rowStart = actual.RowPointers.[i]
59-
let rowEnd = actual.RowPointers.[i + 1]
60-
let rowLength = rowEnd - rowStart
61+
for j in 0 .. rowLength - 1 do
62+
rowIndices.[rowStart + j] <- i
63+
else
64+
let rowStart = actual.RowPointers.[actual.RowCount - 1]
65+
let rowLength = rowIndices.Length - rowStart
6166

62-
for j in 0 .. rowLength - 1 do
63-
rowIndices.[rowStart + j] <- i
64-
else
65-
let rowStart = actual.RowPointers.[actual.RowCount - 1]
66-
let rowLength = rowIndices.Length - rowStart
67+
for j in 0 .. rowLength - 1 do
68+
rowIndices.[rowStart + j] <- i
6769

68-
for j in 0 .. rowLength - 1 do
69-
rowIndices.[rowStart + j] <- i
70+
for i in 0 .. rowIndices.Length - 1 do
71+
actual2D.[rowIndices.[i], actual.ColumnIndices.[i]] <- actual.Values.[i]
7072

71-
for i in 0 .. rowIndices.Length - 1 do
72-
actual2D.[rowIndices.[i], actual.ColumnIndices.[i]] <- actual.Values.[i]
73+
actual2D
7374

74-
for i in 0 .. rows - 1 do
75-
for j in 0 .. columns - 1 do
76-
Expect.equal actual2D.[i, j] expected.[i, j] "Elements of matrices should be equals."
75+
for i in 0 .. rows - 1 do
76+
for j in 0 .. columns - 1 do
77+
Expect.equal actual2D.[i, j] expected.[i, j] "Elements of matrices should be equals."
7778

7879
let testCases =
7980
let q = context.Provider.CommandQueue
@@ -217,17 +218,29 @@ let testCases =
217218

218219
| _ -> failwith "No other types of matrices tested yet."
219220

220-
[ testProperty "Correctness test on random int arrays"
221+
[ testProperty "Correctness test on random int matrices COO"
221222
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.IntType()) size COO (+) <@ (+) @> 0)
222223

223-
testProperty "Correctness test on random bool arrays"
224+
testProperty "Correctness test on random bool matrices COO"
224225
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.BoolType()) size COO (||) <@ (||) @> false)
225226

226-
testProperty "Correctness test on random float arrays"
227+
testProperty "Correctness test on random float matrices COO"
227228
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.FloatType()) size COO (+) <@ (+) @> 0.0)
228229

229-
testProperty "Correctness test on random byte arrays"
230-
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.ByteType()) size COO (+) <@ (+) @> 0uy) ]
230+
testProperty "Correctness test on random byte matrices COO"
231+
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.ByteType()) size COO (+) <@ (+) @> 0uy)
232+
233+
testProperty "Correctness test on random int matrices CSR"
234+
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.IntType()) size CSR (+) <@ (+) @> 0)
235+
236+
testProperty "Correctness test on random bool matrices CSR"
237+
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.BoolType()) size CSR (||) <@ (||) @> false)
238+
239+
testProperty "Correctness test on random float matrices CSR"
240+
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.FloatType()) size CSR (+) <@ (+) @> 0.0)
241+
242+
testProperty "Correctness test on random byte matrices CSR"
243+
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.ByteType()) size CSR (+) <@ (+) @> 0uy) ]
231244

232245
let tests =
233-
testCases |> testList "COOMatrix.EwiseAdd tests"
246+
testCases |> testList "Backend.EwiseAdd tests"

tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
<Compile Include="BackendCommonTests/RemoveDuplicatesTests.fs" />
1717
<Compile Include="BackendCommonTests/CopyTests.fs" />
1818
<Compile Include="BackendCommonTests/ReplicateTests.fs" />
19-
<Compile Include="BackendCommonTests/COOMatrixEwiseAddTests.fs" />
19+
<Compile Include="BackendCommonTests/MatrixEwiseAddTests.fs" />
2020
<!--Compile Include="MatrixOperationsTests/GetTuplesTests.fs" /-->
2121
<!--Compile Include="MatrixOperationsTests/MxmTests.fs" /-->
2222
<!--Compile Include="MatrixOperationsTests/MxvTests.fs" /-->

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ let allTests =
1616
Backend.RemoveDuplicates.tests
1717
Backend.Copy.tests
1818
Backend.Replicate.tests
19-
Backend.COOMatrix.EwiseAdd.tests
19+
Backend.EwiseAdd.tests
2020
//Matrix.EWiseAdd.tests
2121
//Matrix.GetTuples.tests
2222
//Matrix.Mxv.tests

0 commit comments

Comments
 (0)