Skip to content

Commit a97631a

Browse files
committed
Choose with keys
1 parent 0d09177 commit a97631a

1 file changed

Lines changed: 85 additions & 0 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,91 @@ module ClArray =
437437

438438
Some result
439439

440+
let assignWithKeysOption (op: Expr<'a -> 'b option>) (clContext: ClContext) workGroupSize =
441+
442+
let assign =
443+
<@ fun (ndRange: Range1D) length (keys: ClArray<int>) (values: ClArray<'a>) (positions: ClArray<int>) (resultKeys: ClArray<int>) (resultValues: ClArray<'b>) resultLength ->
444+
445+
let gid = ndRange.GlobalID0
446+
447+
if gid < length then
448+
let position = positions.[gid]
449+
let value = values.[gid]
450+
let key = keys.[gid]
451+
452+
// seems like scatter (option scatter) ???
453+
if 0 <= position && position < resultLength then
454+
match (%op) value with
455+
| Some value ->
456+
resultValues.[position] <- value
457+
resultKeys.[position] <- key
458+
459+
| None -> () @>
460+
461+
let kernel = clContext.Compile assign
462+
463+
fun (processor: MailboxProcessor<_>) (keys: ClArray<int>) (values: ClArray<'a>) (positions: ClArray<int>) (resultKeys: ClArray<int>) (resultValues: ClArray<'b>) ->
464+
465+
if values.Length <> positions.Length then
466+
failwith "lengths must be the same"
467+
468+
let ndRange =
469+
Range1D.CreateValid(values.Length, workGroupSize)
470+
471+
let kernel = kernel.GetKernel()
472+
473+
processor.Post(
474+
Msg.MsgSetArguments
475+
(fun () ->
476+
kernel.KernelFunc
477+
ndRange
478+
values.Length
479+
keys
480+
values
481+
positions
482+
resultKeys
483+
resultValues
484+
resultValues.Length)
485+
)
486+
487+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
488+
489+
let chooseWithKeys<'a, 'b> (predicate: Expr<'a -> 'b option>) (clContext: ClContext) workGroupSize =
490+
let getBitmap =
491+
map<'a, int> (Map.chooseBitmap predicate) clContext workGroupSize
492+
493+
let prefixSum =
494+
PrefixSum.standardExcludeInPlace clContext workGroupSize
495+
496+
let assignValues =
497+
assignWithKeysOption predicate clContext workGroupSize
498+
499+
fun (processor: MailboxProcessor<_>) allocationMode (sourceKeys: ClArray<int>) (sourceValues: ClArray<'a>) ->
500+
501+
let positions =
502+
getBitmap processor DeviceOnly sourceValues
503+
504+
let resultLength =
505+
(prefixSum processor positions)
506+
.ToHostAndFree(processor)
507+
508+
if resultLength = 0 then
509+
positions.Free processor
510+
511+
None
512+
else
513+
let resultKeys =
514+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
515+
516+
let resultValues =
517+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
518+
519+
assignValues processor sourceKeys sourceValues positions resultKeys resultValues
520+
521+
positions.Free processor
522+
523+
Some(resultKeys, resultValues)
524+
440525
let assignOption2 (op: Expr<'a -> 'b -> 'c option>) (clContext: ClContext) workGroupSize =
441526

442527
let assign =

0 commit comments

Comments
 (0)