Skip to content

Commit 7c22108

Browse files
committed
SpMSpV general
1 parent a87949e commit 7c22108

2 files changed

Lines changed: 96 additions & 3 deletions

File tree

  • src/GraphBLAS-sharp.Backend/Vector
  • tests/GraphBLAS-sharp.Tests/Backend/Vector

src/GraphBLAS-sharp.Backend/Vector/SpMSpV.fs

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,67 @@ module SpMSpV =
224224

225225
result
226226

227+
let run
228+
(add: Expr<'c option -> 'c option -> 'c option>)
229+
(mul: Expr<'a option -> 'b option -> 'c option>)
230+
(clContext: ClContext)
231+
workGroupSize
232+
=
233+
234+
//TODO: Common.Gather?
235+
let gather = gather clContext workGroupSize
236+
237+
//TODO: Radix sort
238+
let sort =
239+
Sort.Bitonic.sortKeyValuesInplace clContext workGroupSize
240+
241+
let multiplyScalar =
242+
multiplyScalar clContext mul workGroupSize
243+
244+
let segReduce =
245+
Reduce.ByKey.Option.segmentSequential add clContext workGroupSize
246+
247+
fun (queue: MailboxProcessor<_>) (matrix: ClMatrix.CSR<'a>) (vector: ClVector.Sparse<'b>) ->
248+
249+
let gatherRows, gatherIndices, gatherValues, gatherLength = gather queue matrix vector
250+
251+
if gatherLength <= 0 then
252+
gatherRows.Free queue
253+
gatherValues.Free queue
254+
255+
{ Context = clContext
256+
Indices = gatherIndices
257+
Values = clContext.CreateClArray 0
258+
Size = matrix.ColumnCount }
259+
else
260+
sort queue gatherIndices gatherRows gatherValues
261+
262+
let sortedRows, sortedIndices, sortedValues = gatherRows, gatherIndices, gatherValues
263+
264+
let multipliedValues =
265+
multiplyScalar queue sortedRows sortedValues vector
266+
267+
sortedValues.Free queue
268+
sortedRows.Free queue
269+
270+
match segReduce queue DeviceOnly sortedIndices multipliedValues with
271+
| Some (reducedValues, reducedKeys) ->
272+
multipliedValues.Free queue
273+
sortedIndices.Free queue
274+
275+
{ Context = clContext
276+
Indices = reducedKeys
277+
Values = reducedValues
278+
Size = matrix.ColumnCount }
279+
| None ->
280+
multipliedValues.Free queue
281+
sortedIndices.Free queue
282+
283+
{ Context = clContext
284+
Indices = clContext.CreateClArray 0
285+
Values = clContext.CreateClArray 0
286+
Size = matrix.ColumnCount }
287+
227288
let runBoolStandard
228289
(add: Expr<'c option -> 'c option -> 'c option>)
229290
(mul: Expr<'a option -> 'b option -> 'c option>)

tests/GraphBLAS-sharp.Tests/Backend/Vector/SpMSpV.fs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module GraphBLAS.FSharp.Tests.Backend.Vector.SpMSpV
22

3+
open System
34
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
45
open Expecto
56
open Brahma.FSharp
@@ -131,9 +132,40 @@ let testFixturesSpMSpV (testContext: TestContext) =
131132
(&&)
132133
ArithmeticOperations.boolSumOption
133134
ArithmeticOperations.boolMulOption
134-
//createTest testContext 0 1 (=) (+) (*) ArithmeticOperations.intSum ArithmeticOperations.intMul
135-
//createTest testContext 0.0f 1f (=) (+) (*) ArithmeticOperations.float32Sum ArithmeticOperations.float32Mul
136-
]
135+
136+
createTest
137+
SpMSpV.run
138+
testContext
139+
0
140+
1
141+
(=)
142+
(+)
143+
(*)
144+
ArithmeticOperations.intSumOption
145+
ArithmeticOperations.intMulOption
146+
147+
createTest
148+
SpMSpV.run
149+
testContext
150+
0.0f
151+
1f
152+
(=)
153+
(+)
154+
(*)
155+
ArithmeticOperations.float32SumOption
156+
ArithmeticOperations.float32MulOption
157+
158+
if Utils.isFloat64Available context.ClDevice then
159+
createTest
160+
SpMSpV.run
161+
testContext
162+
0.0
163+
1
164+
(=)
165+
(+)
166+
(*)
167+
ArithmeticOperations.floatSumOption
168+
ArithmeticOperations.floatMulOption ]
137169

138170
let tests =
139171
gpuTests "Backend.Vector.SpMSpV tests" testFixturesSpMSpV

0 commit comments

Comments
 (0)