Skip to content

Commit 146b17f

Browse files
committed
add: Matrix.map
1 parent 5409b2e commit 146b17f

9 files changed

Lines changed: 503 additions & 0 deletions

File tree

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@
3737
<Compile Include="Matrix/Common.fs" />
3838
<Compile Include="Matrix/COOMatrix/Map2.fs" />
3939
<Compile Include="Matrix/COOMatrix/Map2AtLeastOne.fs" />
40+
<Compile Include="Matrix/COOMatrix/Map.fs" />
4041
<Compile Include="Matrix/COOMatrix/Matrix.fs" />
4142
<Compile Include="Matrix/CSRMatrix/Map2AtLeastOne.fs" />
4243
<Compile Include="Matrix/CSRMatrix/SpGEMM.fs" />
44+
<Compile Include="Matrix/CSRMatrix/Map.fs" />
4345
<Compile Include="Matrix/CSRMatrix/Matrix.fs" />
4446
<Compile Include="Matrix/Matrix.fs" />
4547
<Compile Include="Vector/SparseVector/Common.fs" />
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
namespace GraphBLAS.FSharp.Backend.Matrix.COO
2+
3+
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend.Matrix
5+
open Microsoft.FSharp.Quotations
6+
open GraphBLAS.FSharp.Backend.Objects
7+
open GraphBLAS.FSharp.Backend
8+
open GraphBLAS.FSharp.Backend.Objects.ClMatrix
9+
open GraphBLAS.FSharp.Backend.Objects.ClContext
10+
11+
12+
module Map =
13+
let preparePositions<'a, 'b> (clContext: ClContext) workGroupSize opAdd =
14+
15+
let preparePositions (op: Expr<'a option -> 'b option>) =
16+
<@ fun (ndRange: Range1D) rowCount columnCount valuesLength (values: ClArray<'a>) (rowPointers: ClArray<int>) (columns: ClArray<int>) (resultBitmap: ClArray<int>) (resultValues: ClArray<'b>) (resultRows: ClArray<int>) (resultColumns: ClArray<int>) ->
17+
18+
let gid = ndRange.GlobalID0
19+
20+
if gid < rowCount * columnCount then
21+
22+
let columnIndex = gid % columnCount
23+
let rowIndex = gid / columnCount
24+
25+
let index =
26+
(uint64 rowIndex <<< 32) ||| (uint64 columnIndex)
27+
28+
let value =
29+
(%Map2.binSearch) valuesLength index rowPointers columns values
30+
31+
match (%op) value with
32+
| Some resultValue ->
33+
resultValues.[gid] <- resultValue
34+
resultRows.[gid] <- rowIndex
35+
resultColumns.[gid] <- columnIndex
36+
37+
resultBitmap.[gid] <- 1
38+
| None -> resultBitmap.[gid] <- 0 @>
39+
40+
41+
let kernel =
42+
clContext.Compile <| preparePositions opAdd
43+
44+
fun (processor: MailboxProcessor<_>) rowCount columnCount (values: ClArray<'a>) (rowPointers: ClArray<int>) (columns: ClArray<int>) ->
45+
46+
let (resultLength: int) = columnCount * rowCount
47+
48+
let resultBitmap =
49+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, resultLength)
50+
51+
let resultRows =
52+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, resultLength)
53+
54+
let resultColumns =
55+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, resultLength)
56+
57+
let resultValues =
58+
clContext.CreateClArrayWithSpecificAllocationMode<'b>(DeviceOnly, resultLength)
59+
60+
let ndRange =
61+
Range1D.CreateValid(resultLength, workGroupSize)
62+
63+
let kernel = kernel.GetKernel()
64+
65+
processor.Post(
66+
Msg.MsgSetArguments
67+
(fun () ->
68+
kernel.KernelFunc
69+
ndRange
70+
rowCount
71+
columnCount
72+
values.Length
73+
values
74+
rowPointers
75+
columns
76+
resultBitmap
77+
resultValues
78+
resultRows
79+
resultColumns)
80+
)
81+
82+
processor.Post(Msg.CreateRunMsg<_, _> kernel)
83+
84+
resultBitmap, resultValues, resultRows, resultColumns
85+
86+
87+
let run<'a, 'b when 'a: struct and 'b: struct and 'b: equality>
88+
(clContext: ClContext)
89+
(opAdd: Expr<'a option -> 'b option>)
90+
workGroupSize
91+
=
92+
93+
let map =
94+
preparePositions clContext workGroupSize opAdd
95+
96+
let setPositions =
97+
Common.setPositions<'b> clContext workGroupSize
98+
99+
fun (queue: MailboxProcessor<_>) allocationMode (matrix: ClMatrix.COO<'a>) ->
100+
101+
let bitmap, values, rows, columns =
102+
map queue matrix.RowCount matrix.ColumnCount matrix.Values matrix.Rows matrix.Columns
103+
104+
let resultRows, resultColumns, resultValues, _ =
105+
setPositions queue allocationMode rows columns values bitmap
106+
107+
queue.Post(Msg.CreateFreeMsg<_>(bitmap))
108+
queue.Post(Msg.CreateFreeMsg<_>(values))
109+
queue.Post(Msg.CreateFreeMsg<_>(rows))
110+
queue.Post(Msg.CreateFreeMsg<_>(columns))
111+
112+
{ Context = clContext
113+
RowCount = matrix.RowCount
114+
ColumnCount = matrix.ColumnCount
115+
Rows = resultRows
116+
Columns = resultColumns
117+
Values = resultValues }

src/GraphBLAS-sharp.Backend/Matrix/COOMatrix/Matrix.fs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ open GraphBLAS.FSharp.Backend.Objects
88
open GraphBLAS.FSharp.Backend.Objects.ClMatrix
99

1010
module Matrix =
11+
let map = Map.run
12+
1113
let map2 = Map2.run
1214

1315
///<param name="clContext">.</param>
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
namespace GraphBLAS.FSharp.Backend.Matrix.CSR
2+
3+
open Brahma.FSharp
4+
open FSharp.Quotations
5+
open GraphBLAS.FSharp.Backend
6+
open GraphBLAS.FSharp.Backend.Matrix
7+
open GraphBLAS.FSharp.Backend.Matrix.COO
8+
open GraphBLAS.FSharp.Backend.Objects
9+
open GraphBLAS.FSharp.Backend.Objects.ClMatrix
10+
open GraphBLAS.FSharp.Backend.Objects.ClContext
11+
12+
module Map =
13+
let binSearch<'a> =
14+
<@ fun startIndex nnzInRow sourceColumn (columnIndices: ClArray<int>) (values: ClArray<'a>) ->
15+
16+
let mutable leftEdge = startIndex
17+
let mutable rightEdge = startIndex + nnzInRow - 1
18+
19+
let mutable result = None
20+
21+
while leftEdge <= rightEdge do
22+
let middleIdx = (leftEdge + rightEdge) / 2
23+
24+
let currentColumn = columnIndices.[middleIdx]
25+
26+
if sourceColumn = currentColumn then
27+
result <- Some values.[middleIdx]
28+
29+
rightEdge <- -1 // TODO() break
30+
elif sourceColumn < currentColumn then
31+
rightEdge <- middleIdx - 1
32+
else
33+
leftEdge <- middleIdx + 1
34+
35+
result @>
36+
37+
let preparePositions<'a, 'b> (clContext: ClContext) workGroupSize opAdd =
38+
39+
let preparePositions (op: Expr<'a option -> 'b option>) =
40+
<@ fun (ndRange: Range1D) rowCount columnCount (values: ClArray<'a>) (rowPointers: ClArray<int>) (columns: ClArray<int>) (resultBitmap: ClArray<int>) (resultValues: ClArray<'b>) (resultRows: ClArray<int>) (resultColumns: ClArray<int>) ->
41+
42+
let gid = ndRange.GlobalID0
43+
44+
if gid < rowCount * columnCount then
45+
46+
let columnIndex = gid % columnCount
47+
let rowIndex = gid / columnCount
48+
49+
let nnzInRow =
50+
rowPointers.[rowIndex + 1]
51+
- rowPointers.[rowIndex]
52+
53+
let value =
54+
(%binSearch) rowPointers.[rowIndex] nnzInRow columnIndex columns values
55+
56+
match (%op) value with
57+
| Some resultValue ->
58+
resultValues.[gid] <- resultValue
59+
resultRows.[gid] <- rowIndex
60+
resultColumns.[gid] <- columnIndex
61+
62+
resultBitmap.[gid] <- 1
63+
| None -> resultBitmap.[gid] <- 0 @>
64+
65+
let kernel =
66+
clContext.Compile <| preparePositions opAdd
67+
68+
fun (processor: MailboxProcessor<_>) rowCount columnCount (values: ClArray<'a>) (rowPointers: ClArray<int>) (columns: ClArray<int>) ->
69+
70+
let (resultLength: int) = columnCount * rowCount
71+
72+
let resultBitmap =
73+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, resultLength)
74+
75+
let resultRows =
76+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, resultLength)
77+
78+
let resultColumns =
79+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, resultLength)
80+
81+
let resultValues =
82+
clContext.CreateClArrayWithSpecificAllocationMode<'b>(DeviceOnly, resultLength)
83+
84+
let ndRange =
85+
Range1D.CreateValid(resultLength, workGroupSize)
86+
87+
let kernel = kernel.GetKernel()
88+
89+
processor.Post(
90+
Msg.MsgSetArguments
91+
(fun () ->
92+
kernel.KernelFunc
93+
ndRange
94+
rowCount
95+
columnCount
96+
values
97+
rowPointers
98+
columns
99+
resultBitmap
100+
resultValues
101+
resultRows
102+
resultColumns)
103+
)
104+
105+
processor.Post(Msg.CreateRunMsg<_, _> kernel)
106+
107+
resultBitmap, resultValues, resultRows, resultColumns
108+
109+
110+
let runToCOO<'a, 'b when 'a: struct and 'b: struct and 'b: equality>
111+
(clContext: ClContext)
112+
(opAdd: Expr<'a option -> 'b option>)
113+
workGroupSize
114+
=
115+
116+
let map =
117+
preparePositions clContext workGroupSize opAdd
118+
119+
let setPositions =
120+
Common.setPositions<'b> clContext workGroupSize
121+
122+
fun (queue: MailboxProcessor<_>) allocationMode (matrix: ClMatrix.CSR<'a>) ->
123+
124+
let bitmap, values, rows, columns =
125+
map queue matrix.RowCount matrix.ColumnCount matrix.Values matrix.RowPointers matrix.Columns
126+
127+
let resultRows, resultColumns, resultValues, _ =
128+
setPositions queue allocationMode rows columns values bitmap
129+
130+
queue.Post(Msg.CreateFreeMsg<_>(bitmap))
131+
queue.Post(Msg.CreateFreeMsg<_>(values))
132+
queue.Post(Msg.CreateFreeMsg<_>(rows))
133+
queue.Post(Msg.CreateFreeMsg<_>(columns))
134+
135+
{ Context = clContext
136+
RowCount = matrix.RowCount
137+
ColumnCount = matrix.ColumnCount
138+
Rows = resultRows
139+
Columns = resultColumns
140+
Values = resultValues }
141+
142+
let run<'a, 'b when 'a: struct and 'b: struct and 'b: equality>
143+
(clContext: ClContext)
144+
(opAdd: Expr<'a option -> 'b option>)
145+
workGroupSize
146+
=
147+
148+
let elementwiseToCOO = runToCOO clContext opAdd workGroupSize
149+
150+
let toCSRInplace =
151+
Matrix.toCSRInplace clContext workGroupSize
152+
153+
fun (queue: MailboxProcessor<_>) allocationMode (matrix: ClMatrix.CSR<'a>) ->
154+
elementwiseToCOO queue allocationMode matrix
155+
|> toCSRInplace queue allocationMode

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/Matrix.fs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ module Matrix =
9090
Columns = matrix.Columns
9191
Values = matrix.Values }
9292

93+
let map = CSR.Map.run
94+
9395
let map2<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
9496
(clContext: ClContext)
9597
(opAdd: Expr<'a option -> 'b option -> 'c option>)

src/GraphBLAS-sharp.Backend/Matrix/Matrix.fs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,21 @@ module Matrix =
191191
.ToCSC
192192
|> ClMatrix.CSC
193193

194+
let map (clContext: ClContext) (opAdd: Expr<'a option -> 'b option>) workGroupSize =
195+
let mapCOO =
196+
COO.Matrix.map clContext opAdd workGroupSize
197+
198+
let mapCSR =
199+
CSR.Matrix.map clContext opAdd workGroupSize
200+
201+
fun (processor: MailboxProcessor<_>) allocationMode matrix ->
202+
match matrix with
203+
| ClMatrix.COO m -> mapCOO processor allocationMode m |> ClMatrix.COO
204+
| ClMatrix.CSR m -> mapCSR processor allocationMode m |> ClMatrix.CSR
205+
| ClMatrix.CSC m ->
206+
(mapCSR processor allocationMode m.ToCSR).ToCSC
207+
|> ClMatrix.CSC
208+
194209
let map2 (clContext: ClContext) (opAdd: Expr<'a option -> 'b option -> 'c option>) workGroupSize = // TODO()
195210
let map2COO =
196211
COO.Matrix.map2 clContext opAdd workGroupSize

tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
<Compile Include="Matrix/Map2.fs" />
4646
<Compile Include="Matrix/Mxm.fs" />
4747
<Compile Include="Matrix/Transpose.fs" />
48+
<Compile Include="Matrix/Map.fs" />
4849
<Compile Include="Program.fs" />
4950
</ItemGroup>
5051
<Import Project="..\..\.paket\Paket.Restore.targets" />

0 commit comments

Comments
 (0)