Skip to content

Commit 0d09177

Browse files
committed
SpMSpV bool only
1 parent 38ceee3 commit 0d09177

5 files changed

Lines changed: 409 additions & 0 deletions

File tree

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
<Compile Include="Vector/Sparse/Vector.fs" />
4848
<Compile Include="Vector/SpMV.fs" />
4949
<Compile Include="Vector/Vector.fs" />
50+
<Compile Include="Vector/SpMSpV.fs" />
5051
<Compile Include="Matrix/Common.fs" />
5152
<Compile Include="Matrix/COO/Map.fs" />
5253
<Compile Include="Matrix/COO/Merge.fs" />
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
namespace GraphBLAS.FSharp.Backend.Vector
2+
3+
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend.Common
5+
open GraphBLAS.FSharp.Backend.Quotes
6+
open GraphBLAS.FSharp.Backend.Vector.Sparse
7+
open Microsoft.FSharp.Quotations
8+
open GraphBLAS.FSharp.Backend.Objects
9+
open GraphBLAS.FSharp.Backend.Objects.ClVector
10+
open GraphBLAS.FSharp.Backend.Objects.ClContext
11+
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
12+
open GraphBLAS.FSharp.Backend.Objects.ClCell
13+
14+
module SpMSpV =
15+
16+
//For v in vectorIndices collect R[v] and R[v + 1]
17+
let private collectRows (clContext: ClContext) workGroupSize =
18+
19+
let collectRows =
20+
<@ fun (ndRange: Range1D) inputSize (vectorIndices: ClArray<int>) (rowOffsets: ClArray<int>) (resultArray: ClArray<int>) ->
21+
22+
let i = ndRange.GlobalID0
23+
24+
//resultArray is twice vector size
25+
if i < (inputSize * 2) then
26+
let columnIndex = vectorIndices.[i / 2]
27+
28+
if i % 2 = 0 then
29+
resultArray.[i] <- rowOffsets.[columnIndex]
30+
else
31+
resultArray.[i] <- rowOffsets.[columnIndex + 1]
32+
elif i = inputSize * 2 then
33+
resultArray.[i] <- 0 @>
34+
35+
let collectRows = clContext.Compile collectRows
36+
37+
fun (queue: MailboxProcessor<_>) size (vectorIndices: ClArray<int>) (rowOffsets: ClArray<int>) ->
38+
39+
let ndRange =
40+
Range1D.CreateValid(size * 2 + 1, workGroupSize)
41+
42+
// Last element will contain length of array for gather
43+
let resultRows =
44+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, size * 2 + 1)
45+
46+
let collectRows = collectRows.GetKernel()
47+
48+
queue.Post(
49+
Msg.MsgSetArguments(fun () -> collectRows.KernelFunc ndRange size vectorIndices rowOffsets resultRows)
50+
)
51+
52+
queue.Post(Msg.CreateRunMsg<_, _>(collectRows))
53+
54+
resultRows
55+
56+
//For above array compute result offsets
57+
let private computeOffsetsInplace (clContext: ClContext) workGroupSize =
58+
59+
let prepareOffsets =
60+
<@ fun (ndRange: Range1D) inputSize (inputArray: ClArray<int>) ->
61+
62+
let i = ndRange.GlobalID0
63+
64+
if i < inputSize && i % 2 = 0 then
65+
inputArray.[i + 1] <- inputArray.[i + 1] - inputArray.[i]
66+
inputArray.[i] <- 0 @>
67+
68+
let sum =
69+
PrefixSum.standardExcludeInPlace clContext workGroupSize
70+
71+
let prepareOffsets = clContext.Compile prepareOffsets
72+
73+
fun (queue: MailboxProcessor<_>) size (input: ClArray<int>) ->
74+
75+
let ndRange = Range1D.CreateValid(size, workGroupSize)
76+
77+
let prepareOffsets = prepareOffsets.GetKernel()
78+
79+
queue.Post(Msg.MsgSetArguments(fun () -> prepareOffsets.KernelFunc ndRange size input))
80+
81+
queue.Post(Msg.CreateRunMsg<_, _>(prepareOffsets))
82+
83+
let resSize = (sum queue input).ToHostAndFree queue
84+
85+
resSize
86+
87+
//Gather rows from the matrix that will be used in multiplication
88+
let private gather (clContext: ClContext) workGroupSize =
89+
90+
let gather =
91+
<@ fun (ndRange: Range1D) vectorNNZ (rowOffsets: ClArray<int>) (matrixRowPointers: ClArray<int>) (matrixColumns: ClArray<int>) (matrixValues: ClArray<'a>) (vectorIndices: ClArray<int>) (resultRowsArray: ClArray<int>) (resultIndicesArray: ClArray<int>) (resultValuesArray: ClArray<'a>) ->
92+
93+
//Serial number of row to gather
94+
let row = ndRange.GlobalID0
95+
96+
if row < vectorNNZ then
97+
let offsetIndex = row * 2 + 1
98+
let rowOffset = rowOffsets.[offsetIndex]
99+
100+
//vectorIndices.[row] --- actual number of row in matrix
101+
let actualRow = vectorIndices.[row]
102+
let matrixIndexOffset = matrixRowPointers.[actualRow]
103+
104+
if rowOffset <> rowOffsets.[offsetIndex + 1] then
105+
let rowSize = rowOffsets.[offsetIndex + 1] - rowOffset
106+
107+
for i in 0 .. rowSize - 1 do
108+
resultRowsArray.[i + rowOffset] <- actualRow
109+
resultIndicesArray.[i + rowOffset] <- matrixColumns.[matrixIndexOffset + i]
110+
resultValuesArray.[i + rowOffset] <- matrixValues.[matrixIndexOffset + i] @>
111+
112+
let collectRows = collectRows clContext workGroupSize
113+
114+
let computeOffsetsInplace =
115+
computeOffsetsInplace clContext workGroupSize
116+
117+
let gather = clContext.Compile gather
118+
119+
fun (queue: MailboxProcessor<_>) (matrix: ClMatrix.CSR<'a>) (vector: ClVector.Sparse<'b>) ->
120+
121+
//Collect R[v] and R[v + 1] for each v in vector
122+
let collectedRows =
123+
collectRows queue vector.NNZ vector.Indices matrix.RowPointers
124+
125+
//Place R[v + 1] - R[v] in previous array and do prefix sum, computing offsets for gather array
126+
let gatherArraySize =
127+
computeOffsetsInplace queue (vector.NNZ * 2 + 1) collectedRows
128+
129+
if gatherArraySize = 0 then
130+
let resultRows =
131+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, 1)
132+
133+
let resultValues =
134+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, 1)
135+
136+
let resultColumns =
137+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, 1)
138+
139+
resultRows, resultColumns, resultValues, gatherArraySize
140+
else
141+
let ndRange =
142+
Range1D.CreateValid(vector.NNZ, workGroupSize)
143+
144+
let gather = gather.GetKernel()
145+
146+
let resultRows =
147+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, gatherArraySize)
148+
149+
let resultIndices =
150+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, gatherArraySize)
151+
152+
let resultValues =
153+
clContext.CreateClArrayWithSpecificAllocationMode<'a>(DeviceOnly, gatherArraySize)
154+
155+
if gatherArraySize > 0 then
156+
queue.Post(
157+
Msg.MsgSetArguments
158+
(fun () ->
159+
gather.KernelFunc
160+
ndRange
161+
vector.NNZ
162+
collectedRows
163+
matrix.RowPointers
164+
matrix.Columns
165+
matrix.Values
166+
vector.Indices
167+
resultRows
168+
resultIndices
169+
resultValues)
170+
)
171+
172+
queue.Post(Msg.CreateRunMsg<_, _>(gather))
173+
174+
collectedRows.Free queue
175+
176+
resultRows, resultIndices, resultValues, gatherArraySize
177+
178+
179+
let private multiplyScalar (clContext: ClContext) (mul: Expr<'a option -> 'b option -> 'c option>) workGroupSize =
180+
181+
let multiply =
182+
<@ fun (ndRange: Range1D) resultLength vectorLength (rowIndices: ClArray<int>) (matrixValues: ClArray<'a>) (vectorIndices: ClArray<int>) (vectorValues: ClArray<'b>) (resultValues: ClArray<'c option>) ->
183+
let i = ndRange.GlobalID0
184+
185+
if i < resultLength then
186+
let index = rowIndices.[i]
187+
let matrixValue = matrixValues.[i]
188+
189+
let vectorValue =
190+
(%Search.Bin.byKey) vectorLength index vectorIndices vectorValues
191+
192+
let res = (%mul) (Some matrixValue) vectorValue
193+
resultValues.[i] <- res @>
194+
195+
let multiply = clContext.Compile multiply
196+
197+
fun (queue: MailboxProcessor<_>) (columnIndices: ClArray<int>) (matrixValues: ClArray<'a>) (vector: Sparse<'b>) ->
198+
199+
let resultLength = columnIndices.Length
200+
201+
let result =
202+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, resultLength)
203+
204+
let ndRange =
205+
Range1D.CreateValid(resultLength, workGroupSize)
206+
207+
let multiply = multiply.GetKernel()
208+
209+
queue.Post(
210+
Msg.MsgSetArguments
211+
(fun () ->
212+
multiply.KernelFunc
213+
ndRange
214+
resultLength
215+
vector.NNZ
216+
columnIndices
217+
matrixValues
218+
vector.Indices
219+
vector.Values
220+
result)
221+
)
222+
223+
queue.Post(Msg.CreateRunMsg<_, _>(multiply))
224+
225+
result
226+
227+
let runBoolStandard
228+
(add: Expr<'c option -> 'c option -> 'c option>)
229+
(mul: Expr<'a option -> 'b option -> 'c option>)
230+
(clContext: ClContext)
231+
workGroupSize
232+
=
233+
234+
let gather = gather clContext workGroupSize
235+
236+
let sort =
237+
Sort.Radix.standardRunKeysOnly clContext workGroupSize
238+
239+
let removeDuplicates =
240+
ClArray.removeDuplications clContext workGroupSize
241+
242+
let create = ClArray.create clContext workGroupSize
243+
244+
fun (queue: MailboxProcessor<_>) (matrix: ClMatrix.CSR<'a>) (vector: ClVector.Sparse<'b>) ->
245+
246+
let gatherRows, gatherIndices, gatherValues, gatherLength = gather queue matrix vector
247+
248+
gatherRows.Free queue
249+
gatherValues.Free queue
250+
251+
if gatherLength <= 0 then
252+
{ Context = clContext
253+
Indices = gatherIndices
254+
Values = clContext.CreateClArray [| false |]
255+
Size = matrix.ColumnCount }
256+
else
257+
let sortedIndices = sort queue gatherIndices
258+
259+
let resultIndices = removeDuplicates queue sortedIndices
260+
261+
gatherIndices.Free queue
262+
sortedIndices.Free queue
263+
264+
{ Context = clContext
265+
Indices = resultIndices
266+
Values = create queue DeviceOnly resultIndices.Length true
267+
Size = matrix.ColumnCount }

0 commit comments

Comments
 (0)