Skip to content

Commit 946be1b

Browse files
committed
refactor: matrix as mask in Matrix.SpGeMM
1 parent c83259a commit 946be1b

5 files changed

Lines changed: 26 additions & 11 deletions

File tree

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/CSRMatrix.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,6 @@ module CSRMatrix =
253253
let run =
254254
SpGEMM.run clContext workGroupSize opAdd opMul
255255

256-
fun (queue: MailboxProcessor<_>) (matrixLeft: ClMatrix.CSR<'a>) (matrixRight: ClMatrix.CSC<'b>) (mask: ClMask2D) ->
256+
fun (queue: MailboxProcessor<_>) (matrixLeft: ClMatrix.CSR<'a>) (matrixRight: ClMatrix.CSC<'b>) (mask: ClMatrix.COO<_>) ->
257257

258258
run queue matrixLeft matrixRight mask

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ module internal SpGEMM =
107107

108108
let program = context.Compile(run)
109109

110-
fun (queue: MailboxProcessor<_>) (matrixLeft: ClMatrix.CSR<'a>) (matrixRight: ClMatrix.CSC<'b>) (mask: ClMask2D) ->
110+
fun (queue: MailboxProcessor<_>) (matrixLeft: ClMatrix.CSR<'a>) (matrixRight: ClMatrix.CSC<'b>) (mask: ClMatrix.COO<_>) ->
111111

112112
let values =
113113
context.CreateClArrayWithSpecificAllocationMode<'c>(DeviceOnly, mask.NNZ)
@@ -157,7 +157,7 @@ module internal SpGEMM =
157157
let scanInplace =
158158
PrefixSum.standardExcludeInplace context workGroupSize
159159

160-
fun (queue: MailboxProcessor<_>) (matrixLeft: ClMatrix.CSR<'a>) (matrixRight: ClMatrix.CSC<'b>) (mask: ClMask2D) ->
160+
fun (queue: MailboxProcessor<_>) (matrixLeft: ClMatrix.CSR<'a>) (matrixRight: ClMatrix.CSC<'b>) (mask: ClMatrix.COO<_>) ->
161161

162162
let values, bitmap =
163163
calculate queue matrixLeft matrixRight mask

src/GraphBLAS-sharp.Backend/Matrix/Matrix.fs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ module Matrix =
504504
let runCSRnCSC =
505505
CSRMatrix.spgemmCSC clContext workGroupSize opAdd opMul
506506

507-
fun (queue: MailboxProcessor<_>) (matrix1: ClMatrix<'a>) (matrix2: ClMatrix<'b>) (mask: ClMask2D) ->
508-
match matrix1, matrix2, mask.IsComplemented with
509-
| ClMatrix.CSR m1, ClMatrix.CSC m2, false -> runCSRnCSC queue m1 m2 mask |> ClMatrix.COO
507+
fun (queue: MailboxProcessor<_>) (matrix1: ClMatrix<'a>) (matrix2: ClMatrix<'b>) (mask: ClMatrix<_>) ->
508+
match matrix1, matrix2, mask with
509+
| ClMatrix.CSR m1, ClMatrix.CSC m2, ClMatrix.COO mask -> runCSRnCSC queue m1 m2 mask |> ClMatrix.COO
510510
| _ -> failwith "Matrix formats are not matching"

src/GraphBLAS-sharp.Backend/Objects/Matrix.fs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ module ClMatrix =
2323
q.Post(Msg.CreateFreeMsg<_>(this.RowPointers))
2424
q.PostAndReply(Msg.MsgNotifyMe)
2525

26+
member this.NNZ = this.Values.Length
27+
2628
type COO<'elem when 'elem: struct> =
2729
{ Context: ClContext
2830
RowCount: int
@@ -38,6 +40,8 @@ module ClMatrix =
3840
q.Post(Msg.CreateFreeMsg<_>(this.Rows))
3941
q.PostAndReply(Msg.MsgNotifyMe)
4042

43+
member this.NNZ = this.Values.Length
44+
4145
type CSC<'elem when 'elem: struct> =
4246
{ Context: ClContext
4347
RowCount: int
@@ -53,6 +57,8 @@ module ClMatrix =
5357
q.Post(Msg.CreateFreeMsg<_>(this.ColumnPointers))
5458
q.PostAndReply(Msg.MsgNotifyMe)
5559

60+
member this.NNZ = this.Values.Length
61+
5662
type Tuple<'elem when 'elem: struct> =
5763
{ Context: ClContext
5864
RowIndices: ClArray<int>
@@ -66,6 +72,8 @@ module ClMatrix =
6672
q.Post(Msg.CreateFreeMsg<_>(this.Values))
6773
q.PostAndReply(Msg.MsgNotifyMe)
6874

75+
member this.NNZ = this.Values.Length
76+
6977
[<RequireQualifiedAccess>]
7078
type ClMatrix<'a when 'a: struct> =
7179
| CSR of ClMatrix.CSR<'a>
@@ -89,3 +97,9 @@ type ClMatrix<'a when 'a: struct> =
8997
| ClMatrix.CSR matrix -> (matrix :> IDeviceMemObject).Dispose q
9098
| ClMatrix.COO matrix -> (matrix :> IDeviceMemObject).Dispose q
9199
| ClMatrix.CSC matrix -> (matrix :> IDeviceMemObject).Dispose q
100+
101+
member this.NNZ =
102+
match this with
103+
| ClMatrix.CSR matrix -> matrix.NNZ
104+
| ClMatrix.COO matrix -> matrix.NNZ
105+
| ClMatrix.CSC matrix -> matrix.NNZ

tests/GraphBLAS-sharp.Tests/Matrix/Mxm.fs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ let makeTest context q zero isEqual plus mul mxmFun (leftMatrix: 'a [,], rightMa
2424
let m2 =
2525
createMatrixFromArray2D CSC rightMatrix (isEqual zero)
2626

27+
let matrixMask =
28+
createMatrixFromArray2D COO mask ((=) false)
29+
2730
if m1.NNZ > 0 && m2.NNZ > 0 then
2831
let expected =
2932
Array2D.init
@@ -43,16 +46,14 @@ let makeTest context q zero isEqual plus mul mxmFun (leftMatrix: 'a [,], rightMa
4346
if expected.NNZ > 0 then
4447
let m1 = m1.ToDevice context
4548
let m2 = m2.ToDevice context
49+
let matrixMask = matrixMask.ToDevice context
4650

47-
let mask =
48-
Mask2D.FromArray2D(mask, not).ToBackend context
49-
50-
let (result: ClMatrix<'a>) = mxmFun q m1 m2 mask
51+
let (result: ClMatrix<'a>) = mxmFun q m1 m2 matrixMask
5152
let actual = result.ToHost q
5253

5354
m1.Dispose q
5455
m2.Dispose q
55-
mask.Dispose q
56+
matrixMask.Dispose q
5657
result.Dispose q
5758

5859
// Check result

0 commit comments

Comments
 (0)