Skip to content

Commit 7eead12

Browse files
committed
refactor: Common.setPosition in Matrix
1 parent 372fed4 commit 7eead12

5 files changed

Lines changed: 74 additions & 121 deletions

File tree

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
<Compile Include="Objects/Vector.fs" />
2828
<Compile Include="Objects/Matrix.fs" />
2929
<Compile Include="Objects\Masks.fs" />
30+
<Compile Include="Matrix\Common.fs" />
3031
<Compile Include="Matrix/COOMatrix/COOMatrix.fs" />
3132
<Compile Include="Matrix/CSRMatrix/Elementwise.fs" />
3233
<Compile Include="Matrix/CSRMatrix/SpGEMM.fs" />

src/GraphBLAS-sharp.Backend/Matrix/COOMatrix/COOMatrix.fs

Lines changed: 2 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,71 +2,11 @@ namespace GraphBLAS.FSharp.Backend.Matrix.COO
22

33
open Brahma.FSharp
44
open GraphBLAS.FSharp.Backend.Common
5-
open GraphBLAS.FSharp.Backend.Predefined
65
open Microsoft.FSharp.Quotations
76
open GraphBLAS.FSharp.Backend.Objects
7+
open GraphBLAS.FSharp.Backend
88

99
module COOMatrix =
10-
///<param name="clContext">.</param>
11-
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
12-
let private setPositions<'a when 'a: struct> (clContext: ClContext) workGroupSize =
13-
14-
let indicesScatter =
15-
Scatter.runInplace clContext workGroupSize
16-
17-
let valuesScatter =
18-
Scatter.runInplace clContext workGroupSize
19-
20-
let sum =
21-
PrefixSum.standardExcludeInplace clContext workGroupSize
22-
23-
let resultLength = Array.zeroCreate<int> 1
24-
25-
fun (processor: MailboxProcessor<_>) (allRows: ClArray<int>) (allColumns: ClArray<int>) (allValues: ClArray<'a>) (positions: ClArray<int>) ->
26-
let resultLengthGpu = clContext.CreateClCell 0
27-
28-
let _, r = sum processor positions resultLengthGpu
29-
30-
let resultLength =
31-
let res =
32-
processor.PostAndReply(fun ch -> Msg.CreateToHostMsg<_>(r, resultLength, ch))
33-
34-
processor.Post(Msg.CreateFreeMsg<_>(r))
35-
36-
res.[0]
37-
38-
let resultRows =
39-
clContext.CreateClArray<int>(
40-
resultLength,
41-
hostAccessMode = HostAccessMode.NotAccessible,
42-
deviceAccessMode = DeviceAccessMode.WriteOnly,
43-
allocationMode = AllocationMode.Default
44-
)
45-
46-
let resultColumns =
47-
clContext.CreateClArray<int>(
48-
resultLength,
49-
hostAccessMode = HostAccessMode.NotAccessible,
50-
deviceAccessMode = DeviceAccessMode.WriteOnly,
51-
allocationMode = AllocationMode.Default
52-
)
53-
54-
let resultValues =
55-
clContext.CreateClArray(
56-
resultLength,
57-
hostAccessMode = HostAccessMode.NotAccessible,
58-
deviceAccessMode = DeviceAccessMode.WriteOnly,
59-
allocationMode = AllocationMode.Default
60-
)
61-
62-
indicesScatter processor positions allRows resultRows
63-
64-
indicesScatter processor positions allColumns resultColumns
65-
66-
valuesScatter processor positions allValues resultValues
67-
68-
resultRows, resultColumns, resultValues, resultLength
69-
7010
let private preparePositions<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
7111
(clContext: ClContext)
7212
(opAdd: Expr<'a option -> 'b option -> 'c option>)
@@ -368,7 +308,7 @@ module COOMatrix =
368308
let preparePositions =
369309
preparePositions clContext opAdd workGroupSize
370310

371-
let setPositions = setPositions<'c> clContext workGroupSize
311+
let setPositions = Matrix.Common.setPositions<'c> clContext workGroupSize
372312

373313
fun (queue: MailboxProcessor<_>) (matrixLeft: ClCOOMatrix<'a>) (matrixRight: ClCOOMatrix<'b>) ->
374314

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/CSRMatrix.fs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ module CSRMatrix =
4242
processor.Post(Msg.CreateRunMsg<_, _> kernel)
4343

4444
let total = clContext.CreateClCell()
45-
let _ = scan processor rows total 0
45+
ignore <| scan processor rows total 0
4646
processor.Post(Msg.CreateFreeMsg(total))
4747

4848
rows
@@ -172,7 +172,7 @@ module CSRMatrix =
172172
preparePositions clContext opAdd Utils.defaultWorkGroupSize
173173

174174
let setPositions =
175-
setPositions<'c> clContext Utils.defaultWorkGroupSize
175+
Matrix.Common.setPositions<'c> clContext Utils.defaultWorkGroupSize
176176

177177
fun (queue: MailboxProcessor<_>) (matrixLeft: ClCSRMatrix<'a>) (matrixRight: ClCSRMatrix<'b>) ->
178178

@@ -192,7 +192,7 @@ module CSRMatrix =
192192
queue.Post(Msg.CreateFreeMsg<_>(leftMergedValues))
193193
queue.Post(Msg.CreateFreeMsg<_>(rightMergedValues))
194194

195-
let resultRows, resultColumns, resultValues, positions, positionsSum =
195+
let resultRows, resultColumns, resultValues, _ =
196196
setPositions queue allRows allColumns allValues positions
197197

198198
queue.Post(Msg.CreateFreeMsg<_>(allRows))

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/Elementwise.fs

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -80,62 +80,6 @@ module internal Elementwise =
8080
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
8181
rowPositions, allValues
8282

83-
let setPositions<'a when 'a: struct> (clContext: ClContext) workGroupSize =
84-
85-
let sum =
86-
PrefixSum.standardExcludeInplace clContext workGroupSize
87-
88-
let indicesScatter =
89-
Scatter.runInplace clContext workGroupSize
90-
91-
let valuesScatter =
92-
Scatter.runInplace clContext workGroupSize
93-
94-
fun (processor: MailboxProcessor<_>) (allRows: ClArray<int>) (allColumns: ClArray<int>) (allValues: ClArray<'a>) (positions: ClArray<int>) ->
95-
96-
let resultLength = Array.zeroCreate<int> 1
97-
let prefixSumArrayLength = positions.Length
98-
99-
let resultLengthGpu = clContext.CreateClCell 0
100-
101-
let _, r = sum processor positions resultLengthGpu
102-
103-
processor.PostAndReply(fun ch -> Msg.CreateToHostMsg<_>(r, resultLength, ch))
104-
let resultLength = resultLength.[0]
105-
processor.Post(Msg.CreateFreeMsg<_>(r))
106-
107-
let resultRows =
108-
clContext.CreateClArray<int>(
109-
resultLength,
110-
hostAccessMode = HostAccessMode.NotAccessible,
111-
deviceAccessMode = DeviceAccessMode.WriteOnly,
112-
allocationMode = AllocationMode.Default
113-
)
114-
115-
let resultColumns =
116-
clContext.CreateClArray<int>(
117-
resultLength,
118-
hostAccessMode = HostAccessMode.NotAccessible,
119-
deviceAccessMode = DeviceAccessMode.WriteOnly,
120-
allocationMode = AllocationMode.Default
121-
)
122-
123-
let resultValues =
124-
clContext.CreateClArray(
125-
resultLength,
126-
hostAccessMode = HostAccessMode.NotAccessible,
127-
deviceAccessMode = DeviceAccessMode.WriteOnly,
128-
allocationMode = AllocationMode.Default
129-
)
130-
131-
indicesScatter processor positions allRows resultRows
132-
133-
indicesScatter processor positions allColumns resultColumns
134-
135-
valuesScatter processor positions allValues resultValues
136-
137-
resultRows, resultColumns, resultValues, positions, resultLength
138-
13983
let merge<'a, 'b when 'a: struct and 'b: struct> (clContext: ClContext) workGroupSize =
14084
let localArraySize = workGroupSize + 2
14185

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
module GraphBLAS.FSharp.Backend.Matrix
2+
3+
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend.Common
5+
open GraphBLAS.FSharp.Backend.Predefined
6+
7+
module Common =
8+
///<param name="clContext">.</param>
9+
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
10+
let setPositions<'a when 'a: struct> (clContext: ClContext) workGroupSize =
11+
12+
let indicesScatter =
13+
Scatter.runInplace clContext workGroupSize
14+
15+
let valuesScatter =
16+
Scatter.runInplace clContext workGroupSize
17+
18+
let sum =
19+
PrefixSum.standardExcludeInplace clContext workGroupSize
20+
21+
let resultLength = Array.zeroCreate<int> 1
22+
23+
fun (processor: MailboxProcessor<_>) (allRows: ClArray<int>) (allColumns: ClArray<int>) (allValues: ClArray<'a>) (positions: ClArray<int>) ->
24+
let resultLengthGpu = clContext.CreateClCell 0
25+
26+
let _, r = sum processor positions resultLengthGpu
27+
28+
let resultLength =
29+
let res =
30+
processor.PostAndReply(fun ch -> Msg.CreateToHostMsg<_>(r, resultLength, ch))
31+
32+
processor.Post(Msg.CreateFreeMsg<_>(r))
33+
34+
res.[0]
35+
36+
let resultRows =
37+
clContext.CreateClArray<int>(
38+
resultLength,
39+
hostAccessMode = HostAccessMode.NotAccessible,
40+
deviceAccessMode = DeviceAccessMode.WriteOnly,
41+
allocationMode = AllocationMode.Default
42+
)
43+
44+
let resultColumns =
45+
clContext.CreateClArray<int>(
46+
resultLength,
47+
hostAccessMode = HostAccessMode.NotAccessible,
48+
deviceAccessMode = DeviceAccessMode.WriteOnly,
49+
allocationMode = AllocationMode.Default
50+
)
51+
52+
let resultValues =
53+
clContext.CreateClArray(
54+
resultLength,
55+
hostAccessMode = HostAccessMode.NotAccessible,
56+
deviceAccessMode = DeviceAccessMode.WriteOnly,
57+
allocationMode = AllocationMode.Default
58+
)
59+
60+
indicesScatter processor positions allRows resultRows
61+
62+
indicesScatter processor positions allColumns resultColumns
63+
64+
valuesScatter processor positions allValues resultValues
65+
66+
resultRows, resultColumns, resultValues, resultLength
67+
68+

0 commit comments

Comments
 (0)