@@ -236,36 +236,40 @@ module SpMSpV =
236236 Reduce.ByKey.Option.segmentSequential add clContext workGroupSize
237237
238238 fun ( queue : MailboxProcessor < _ >) ( matrix : ClMatrix.CSR < 'a >) ( vector : ClVector.Sparse < 'b >) ->
239+ let result =
240+ gather queue matrix vector
241+ |> Option.map
242+ ( fun ( gatherRows , gatherIndices , gatherValues ) ->
243+ sort queue gatherIndices gatherRows gatherValues
239244
240- match gather queue matrix vector with
241- | None -> None
242- | Some ( gatherRows, gatherIndices, gatherValues) ->
243- sort queue gatherIndices gatherRows gatherValues
245+ let sortedRows , sortedIndices , sortedValues = gatherRows, gatherIndices, gatherValues
244246
245- let sortedRows , sortedIndices , sortedValues = gatherRows, gatherIndices, gatherValues
247+ let multipliedValues =
248+ multiplyScalar queue sortedRows sortedValues vector
246249
247- let multipliedValues =
248- multiplyScalar queue sortedRows sortedValues vector
250+ sortedValues.Free queue
251+ sortedRows.Free queue
249252
250- sortedValues.Free queue
251- sortedRows.Free queue
253+ let result =
254+ segReduce queue DeviceOnly sortedIndices multipliedValues
255+ |> Option.map
256+ ( fun ( reducedValues , reducedKeys ) ->
252257
253- match segReduce queue DeviceOnly sortedIndices multipliedValues with
254- | Some ( reducedValues , reducedKeys ) ->
255- multipliedValues.Free queue
256- sortedIndices.Free queue
258+ { Context = clContext
259+ Indices = reducedKeys
260+ Values = reducedValues
261+ Size = matrix.ColumnCount })
257262
258- Some (
259- { Context = clContext
260- Indices = reducedKeys
261- Values = reducedValues
262- Size = matrix.ColumnCount }
263- )
264- | None ->
265- multipliedValues.Free queue
266- sortedIndices.Free queue
263+ multipliedValues.Free queue
264+ sortedIndices.Free queue
265+
266+ result )
267+
268+ //Unwrap 't option option to 't option
269+ match result with
270+ | Some result -> result
271+ | None -> None
267272
268- None
269273
270274 let runBoolStandard
271275 ( add : Expr < 'c option -> 'c option -> 'c option >)
@@ -286,22 +290,20 @@ module SpMSpV =
286290
287291 fun ( queue : MailboxProcessor < _ >) ( matrix : ClMatrix.CSR < 'a >) ( vector : ClVector.Sparse < 'b >) ->
288292
289- match gather queue matrix vector with
290- | None -> None
291- | Some ( gatherRows, gatherIndices, gatherValues) ->
292- gatherRows.Free queue
293- gatherValues.Free queue
293+ gather queue matrix vector
294+ |> Option.map
295+ ( fun ( gatherRows , gatherIndices , gatherValues ) ->
296+ gatherRows.Free queue
297+ gatherValues.Free queue
294298
295- let sortedIndices = sort queue gatherIndices
299+ let sortedIndices = sort queue gatherIndices
296300
297- let resultIndices = removeDuplicates queue sortedIndices
301+ let resultIndices = removeDuplicates queue sortedIndices
298302
299- gatherIndices.Free queue
300- sortedIndices.Free queue
303+ gatherIndices.Free queue
304+ sortedIndices.Free queue
301305
302- Some(
303306 { Context = clContext
304307 Indices = resultIndices
305308 Values = create queue DeviceOnly resultIndices.Length true
306- Size = matrix.ColumnCount }
307- )
309+ Size = matrix.ColumnCount })
0 commit comments