Skip to content

Commit d3bc08c

Browse files
committed
None on empty vectors
1 parent 7c22108 commit d3bc08c

2 files changed

Lines changed: 33 additions & 53 deletions

File tree

  • src/GraphBLAS-sharp.Backend/Vector
  • tests/GraphBLAS-sharp.Tests/Backend/Vector

src/GraphBLAS-sharp.Backend/Vector/SpMSpV.fs

Lines changed: 23 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,7 @@ module SpMSpV =
127127
computeOffsetsInplace queue (vector.NNZ * 2 + 1) collectedRows
128128

129129
if gatherArraySize = 0 then
130-
let resultRows =
131-
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, 1)
132-
133-
let resultValues =
134-
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, 1)
135-
136-
let resultColumns =
137-
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, 1)
138-
139-
resultRows, resultColumns, resultValues, gatherArraySize
130+
None
140131
else
141132
let ndRange =
142133
Range1D.CreateValid(vector.NNZ, workGroupSize)
@@ -173,7 +164,7 @@ module SpMSpV =
173164

174165
collectedRows.Free queue
175166

176-
resultRows, resultIndices, resultValues, gatherArraySize
167+
Some(resultRows, resultIndices, resultValues)
177168

178169

179170
let private multiplyScalar (clContext: ClContext) (mul: Expr<'a option -> 'b option -> 'c option>) workGroupSize =
@@ -246,17 +237,9 @@ module SpMSpV =
246237

247238
fun (queue: MailboxProcessor<_>) (matrix: ClMatrix.CSR<'a>) (vector: ClVector.Sparse<'b>) ->
248239

249-
let gatherRows, gatherIndices, gatherValues, gatherLength = gather queue matrix vector
250-
251-
if gatherLength <= 0 then
252-
gatherRows.Free queue
253-
gatherValues.Free queue
254-
255-
{ Context = clContext
256-
Indices = gatherIndices
257-
Values = clContext.CreateClArray 0
258-
Size = matrix.ColumnCount }
259-
else
240+
match gather queue matrix vector with
241+
| None -> None
242+
| Some (gatherRows, gatherIndices, gatherValues) ->
260243
sort queue gatherIndices gatherRows gatherValues
261244

262245
let sortedRows, sortedIndices, sortedValues = gatherRows, gatherIndices, gatherValues
@@ -272,18 +255,17 @@ module SpMSpV =
272255
multipliedValues.Free queue
273256
sortedIndices.Free queue
274257

275-
{ Context = clContext
276-
Indices = reducedKeys
277-
Values = reducedValues
278-
Size = matrix.ColumnCount }
258+
Some(
259+
{ Context = clContext
260+
Indices = reducedKeys
261+
Values = reducedValues
262+
Size = matrix.ColumnCount }
263+
)
279264
| None ->
280265
multipliedValues.Free queue
281266
sortedIndices.Free queue
282267

283-
{ Context = clContext
284-
Indices = clContext.CreateClArray 0
285-
Values = clContext.CreateClArray 0
286-
Size = matrix.ColumnCount }
268+
None
287269

288270
let runBoolStandard
289271
(add: Expr<'c option -> 'c option -> 'c option>)
@@ -304,25 +286,22 @@ module SpMSpV =
304286

305287
fun (queue: MailboxProcessor<_>) (matrix: ClMatrix.CSR<'a>) (vector: ClVector.Sparse<'b>) ->
306288

307-
let gatherRows, gatherIndices, gatherValues, gatherLength = gather queue matrix vector
308-
309-
gatherRows.Free queue
310-
gatherValues.Free queue
289+
match gather queue matrix vector with
290+
| None -> None
291+
| Some (gatherRows, gatherIndices, gatherValues) ->
292+
gatherRows.Free queue
293+
gatherValues.Free queue
311294

312-
if gatherLength <= 0 then
313-
{ Context = clContext
314-
Indices = gatherIndices
315-
Values = clContext.CreateClArray [| false |]
316-
Size = matrix.ColumnCount }
317-
else
318295
let sortedIndices = sort queue gatherIndices
319296

320297
let resultIndices = removeDuplicates queue sortedIndices
321298

322299
gatherIndices.Free queue
323300
sortedIndices.Free queue
324301

325-
{ Context = clContext
326-
Indices = resultIndices
327-
Values = create queue DeviceOnly resultIndices.Length true
328-
Size = matrix.ColumnCount }
302+
Some(
303+
{ Context = clContext
304+
Indices = resultIndices
305+
Values = create queue DeviceOnly resultIndices.Length true
306+
Size = matrix.ColumnCount }
307+
)

tests/GraphBLAS-sharp.Tests/Backend/Vector/SpMSpV.fs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ let correctnessGenericTest
6363
some
6464
sumOp
6565
mulOp
66-
(spMV: MailboxProcessor<_> -> ClMatrix.CSR<'a> -> ClVector.Sparse<'a> -> ClVector.Sparse<'a>)
66+
(spMV: MailboxProcessor<_> -> ClMatrix.CSR<'a> -> ClVector.Sparse<'a> -> ClVector.Sparse<'a> option)
6767
(isEqual: 'a -> 'a -> bool)
6868
q
6969
(testContext: TestContext)
@@ -89,15 +89,16 @@ let correctnessGenericTest
8989
| Vector.Sparse vtr, ClMatrix.CSR m ->
9090
let v = vtr.ToDevice testContext.ClContext
9191

92-
let res = spMV testContext.Queue m v
92+
match spMV testContext.Queue m v with
93+
| Some res ->
94+
(ClMatrix.CSR m).Dispose q
95+
v.Dispose q
96+
let hostResIndices = res.Indices.ToHost q
97+
let hostResValues = res.Values.ToHost q
98+
res.Dispose q
9399

94-
(ClMatrix.CSR m).Dispose q
95-
v.Dispose q
96-
let hostResIndices = res.Indices.ToHost q
97-
let hostResValues = res.Values.ToHost q
98-
res.Dispose q
99-
100-
checkResult sumOp mulOp zero matrix vector hostResIndices hostResValues
100+
checkResult sumOp mulOp zero matrix vector hostResIndices hostResValues
101+
| None -> failwith "Result should not be empty while standard operations are tested"
101102
| _ -> failwith "Impossible"
102103
with
103104
| ex when ex.Message = "InvalidBufferSize" -> ()

0 commit comments

Comments
 (0)