Skip to content

Commit 3d4d102

Browse files
authored
Merge pull request #39 from kirillgarbar/net5
MatrixEwiseAddTests reworked
2 parents 1c013e9 + 0cab5ce commit 3d4d102

6 files changed

Lines changed: 246 additions & 211 deletions

File tree

src/GraphBLAS-sharp.Backend/Matrices.fs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ type Matrix<'a when 'a: struct> =
2323
| MatrixCSR matrix -> matrix.ColumnCount
2424
| MatrixCOO matrix -> matrix.ColumnCount
2525

26+
member this.Dispose() =
27+
match this with
28+
| MatrixCSR matrix -> (matrix :> IDeviceMemObject).Dispose()
29+
| MatrixCOO matrix -> (matrix :> IDeviceMemObject).Dispose()
30+
2631
and CSRMatrix<'elem when 'elem: struct> =
2732
{ Context: ClContext
2833
RowCount: int

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/CSRMatrix.fs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,19 @@ open Microsoft.FSharp.Quotations
77
module CSRMatrix =
88
let private expandRows (clContext: ClContext) =
99
let expandRows =
10-
<@ fun (range: Range1D) workGroupSize (rowPointers: ClArray<int>) (rowIndices: ClArray<int>) ->
10+
<@ fun (range: Range1D) workGroupSize (rowPointers: ClArray<int>) (rowIndices: ClArray<int>) rowCount nnz ->
1111

1212
let lid = range.LocalID0
1313
let groupId = range.GlobalID0 / workGroupSize
1414

1515
let rowStart = rowPointers.[groupId]
16-
let rowEnd = rowPointers.[groupId + 1]
16+
17+
let rowEnd =
18+
if groupId <> rowCount - 1 then
19+
rowPointers.[groupId + 1]
20+
else
21+
nnz
22+
1723
let rowLength = rowEnd - rowStart
1824

1925
let mutable i = lid
@@ -36,7 +42,8 @@ module CSRMatrix =
3642
)
3743

3844
processor.Post(
39-
Msg.MsgSetArguments(fun () -> kernel.SetArguments ndRange workGroupSize rowPointers rowIndices)
45+
Msg.MsgSetArguments
46+
(fun () -> kernel.SetArguments ndRange workGroupSize rowPointers rowIndices rowCount nnz)
4047
)
4148

4249
processor.Post(Msg.CreateRunMsg<_, _> kernel)

src/GraphBLAS-sharp.Backend/Matrix/Matrix.fs

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,53 @@ open Brahma.FSharp.OpenCL
44
open Microsoft.FSharp.Quotations
55

66
module Matrix =
7+
let copy (clContext: ClContext) =
8+
let copy =
9+
GraphBLAS.FSharp.Backend.ClArray.copy clContext
10+
11+
let copyData =
12+
GraphBLAS.FSharp.Backend.ClArray.copy clContext
13+
14+
fun (processor: MailboxProcessor<_>) workGroupSize (matrix: Matrix<'a>) ->
15+
match matrix with
16+
| MatrixCOO m ->
17+
let res =
18+
{ Context = clContext
19+
RowCount = m.RowCount
20+
ColumnCount = m.ColumnCount
21+
Rows = copy processor workGroupSize m.Rows
22+
Columns = copy processor workGroupSize m.Columns
23+
Values = copyData processor workGroupSize m.Values }
24+
25+
MatrixCOO res
26+
| MatrixCSR m ->
27+
let res =
28+
{ Context = clContext
29+
RowCount = m.RowCount
30+
ColumnCount = m.ColumnCount
31+
RowPointers = copy processor workGroupSize m.RowPointers
32+
Columns = copy processor workGroupSize m.Columns
33+
Values = copyData processor workGroupSize m.Values }
34+
35+
MatrixCSR res
36+
737
let toCSR (clContext: ClContext) workGroupSize =
838
let toCSR = COOMatrix.toCSR clContext workGroupSize
39+
let copy = copy clContext
940

1041
fun (processor: MailboxProcessor<_>) (matrix: Matrix<'a>) ->
1142
match matrix with
1243
| MatrixCOO m -> toCSR processor m |> MatrixCSR
13-
| MatrixCSR _ -> matrix
44+
| MatrixCSR _ -> copy processor workGroupSize matrix
1445

1546
let toCOO (clContext: ClContext) workGroupSize =
16-
let toCOO = CSRMatrix.toCOO clContext
47+
let toCOO = CSRMatrix.toCOO clContext workGroupSize
48+
let copy = copy clContext
1749

1850
fun (processor: MailboxProcessor<_>) (matrix: Matrix<'a>) ->
1951
match matrix with
20-
| MatrixCOO _ -> matrix
21-
| MatrixCSR m -> toCOO workGroupSize processor m |> MatrixCOO
52+
| MatrixCOO _ -> copy processor workGroupSize matrix
53+
| MatrixCSR m -> toCOO processor m |> MatrixCOO
2254

2355
let eWiseAdd (clContext: ClContext) (opAdd: Expr<'a -> 'a -> 'a>) workGroupSize =
2456
let COOeWiseAdd =

src/GraphBLAS-sharp/Objects/Matrix.fs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
namespace GraphBLAS.FSharp
22

3+
open Brahma.FSharp.OpenCL
4+
open GraphBLAS.FSharp.Backend
5+
36
type MatrixFromat =
47
| CSR
58
| COO
@@ -18,6 +21,89 @@ type Matrix<'a when 'a: struct> =
1821
| MatrixCSR matrix -> matrix.ColumnCount
1922
| MatrixCOO matrix -> matrix.ColumnCount
2023

24+
member this.NNZCount =
25+
match this with
26+
| MatrixCOO m -> m.Values.Length
27+
| MatrixCSR m -> m.Values.Length
28+
29+
member this.ToBackend(context: ClContext) =
30+
match this with
31+
| MatrixCOO m ->
32+
let rows = context.CreateClArray m.Rows
33+
let columns = context.CreateClArray m.Columns
34+
let values = context.CreateClArray m.Values
35+
36+
let result =
37+
{ Backend.COOMatrix.Context = context
38+
RowCount = m.RowCount
39+
ColumnCount = m.ColumnCount
40+
Rows = rows
41+
Columns = columns
42+
Values = values }
43+
44+
Backend.MatrixCOO result
45+
| MatrixCSR m ->
46+
let rows = context.CreateClArray m.RowPointers
47+
let columns = context.CreateClArray m.ColumnIndices
48+
let values = context.CreateClArray m.Values
49+
50+
let result =
51+
{ Backend.CSRMatrix.Context = context
52+
RowCount = m.RowCount
53+
ColumnCount = m.ColumnCount
54+
RowPointers = rows
55+
Columns = columns
56+
Values = values }
57+
58+
Backend.MatrixCSR result
59+
60+
static member FromBackend (q: MailboxProcessor<_>) matrix =
61+
match matrix with
62+
| Backend.MatrixCOO m ->
63+
let rows = Array.zeroCreate m.Rows.Length
64+
let columns = Array.zeroCreate m.Columns.Length
65+
let values = Array.zeroCreate m.Values.Length
66+
67+
let _ =
68+
q.Post(Msg.CreateToHostMsg(m.Rows, rows))
69+
70+
let _ =
71+
q.Post(Msg.CreateToHostMsg(m.Columns, columns))
72+
73+
let _ =
74+
q.PostAndReply(fun ch -> Msg.CreateToHostMsg(m.Values, values, ch))
75+
76+
let result =
77+
{ RowCount = m.RowCount
78+
ColumnCount = m.ColumnCount
79+
Rows = rows
80+
Columns = columns
81+
Values = values }
82+
83+
MatrixCOO result
84+
| Backend.MatrixCSR m ->
85+
let rows = Array.zeroCreate m.RowPointers.Length
86+
let columns = Array.zeroCreate m.Columns.Length
87+
let values = Array.zeroCreate m.Values.Length
88+
89+
let _ =
90+
q.Post(Msg.CreateToHostMsg(m.RowPointers, rows))
91+
92+
let _ =
93+
q.Post(Msg.CreateToHostMsg(m.Columns, columns))
94+
95+
let _ =
96+
q.PostAndReply(fun ch -> Msg.CreateToHostMsg(m.Values, values, ch))
97+
98+
let result =
99+
{ RowCount = m.RowCount
100+
ColumnCount = m.ColumnCount
101+
RowPointers = rows
102+
ColumnIndices = columns
103+
Values = values }
104+
105+
MatrixCSR result
106+
21107
and CSRMatrix<'a> =
22108
{ RowCount: int
23109
ColumnCount: int

0 commit comments

Comments
 (0)