@@ -428,7 +428,7 @@ module Reduce =
428428 let itemKeyId = lid + 1
429429
430430 let startKeyIndex =
431- (% Search.Bin.lowerPosition ) length itemKeyId localBitmap
431+ (% Search.Bin.lowerPositionLocal ) length itemKeyId localBitmap
432432
433433 match startKeyIndex with
434434 | Some startPosition ->
@@ -473,16 +473,144 @@ module Reduce =
473473 reducedValues, reducedKeys
474474
475475 module Option =
476+ /// <summary >
477+ /// Reduces values by key. Each segment is reduced by one work item.
478+ /// </summary >
479+ /// <param name =" clContext " >ClContext.</param >
480+ /// <param name =" workGroupSize " >Work group size.</param >
481+ /// <param name =" reduceOp " >Operation for reducing values.</param >
482+ let segmentSequential < 'a >
483+ ( reduceOp : Expr < 'a option -> 'a option -> 'a option >)
484+ ( clContext : ClContext )
485+ workGroupSize
486+ =
487+
488+ let kernel =
489+ <@ fun ( ndRange : Range1D ) uniqueKeyCount keysLength ( offsets : ClArray < int >) ( keys : ClArray < int >) ( values : ClArray < 'a option >) ( reducedValues : ClArray < 'a >) ( firstReducedKeys : ClArray < int >) ( resultPositions : ClArray < int >) ->
490+
491+ let gid = ndRange.GlobalID0
492+
493+ if gid < uniqueKeyCount then
494+ let startPosition =
495+ (% Search.Bin.lowerPosition) keysLength gid offsets
496+
497+ match startPosition with
498+ | Some startPosition ->
499+ let firstSourceKey = keys.[ startPosition]
500+
501+ let mutable sum = None
502+
503+ let mutable currentPosition = startPosition
504+
505+ while currentPosition < keysLength
506+ && firstSourceKey = keys.[ currentPosition] do
507+ let result = (% reduceOp) sum values.[ currentPosition] // brahma error
508+ sum <- result
509+ currentPosition <- currentPosition + 1
510+
511+ match sum with
512+ | Some value ->
513+ reducedValues.[ gid] <- value
514+ resultPositions.[ gid] <- 1
515+ | None -> resultPositions.[ gid] <- 0
516+
517+ firstReducedKeys.[ gid] <- firstSourceKey
518+ | None -> () @> // not possible if done correctly
519+
520+ let kernel = clContext.Compile kernel
521+
522+ let getUniqueBitmap =
523+ ClArray.Bitmap.lastOccurrence clContext workGroupSize
524+
525+ let scatterData =
526+ Scatter.lastOccurrence clContext workGroupSize
527+
528+ let scatterIndices =
529+ Scatter.lastOccurrence clContext workGroupSize
530+
531+ let prefixSum =
532+ PrefixSum.standardExcludeInPlace clContext workGroupSize
533+
534+ fun ( processor : MailboxProcessor < _ >) allocationMode ( keys : ClArray < int >) ( values : ClArray < 'a option >) ->
535+
536+ let offsets =
537+ getUniqueBitmap processor DeviceOnly keys
538+
539+ let uniqueKeysCount =
540+ ( prefixSum processor offsets)
541+ .ToHostAndFree processor
542+
543+ let reducedValues =
544+ clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, uniqueKeysCount)
545+
546+ let reducedKeys =
547+ clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, uniqueKeysCount)
548+
549+ let resultPositions =
550+ clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, uniqueKeysCount)
551+
552+ let ndRange =
553+ Range1D.CreateValid( uniqueKeysCount, workGroupSize)
554+
555+ let kernel = kernel.GetKernel()
556+
557+ processor.Post(
558+ Msg.MsgSetArguments
559+ ( fun () ->
560+ kernel.KernelFunc
561+ ndRange
562+ uniqueKeysCount
563+ keys.Length
564+ offsets
565+ keys
566+ values
567+ reducedValues
568+ reducedKeys
569+ resultPositions)
570+ )
571+
572+ processor.Post( Msg.CreateRunMsg<_, _>( kernel))
573+
574+ let resultLength =
575+ ( prefixSum processor resultPositions)
576+ .ToHostAndFree processor
577+
578+ if resultLength = 0 then
579+ None
580+ else
581+ // write values
582+ let resultValues =
583+ clContext.CreateClArrayWithSpecificAllocationMode( allocationMode, resultLength)
584+
585+ scatterData processor resultPositions reducedValues resultValues
586+
587+ reducedValues.Free processor
588+
589+ // write keys
590+ let resultKeys =
591+ clContext.CreateClArrayWithSpecificAllocationMode( allocationMode, resultLength)
592+
593+ scatterIndices processor resultPositions reducedKeys resultKeys
594+
595+ reducedKeys.Free processor
596+ resultPositions.Free processor
597+
598+ Some( resultValues, resultKeys)
599+
476600 /// <summary >
477601 /// Reduces values by key. Each segment is reduced by one work item.
478602 /// </summary >
479603 /// <param name =" clContext " >ClContext.</param >
480604 /// <param name =" workGroupSize " >Work group size.</param >
481605 /// <param name =" reduceOp " >Operation for reducing values.</param >
482606 /// <remarks >
483- /// The length of the result must be calculated in advance.
607+ /// The length of the result and offsets for each segment must be calculated in advance.
484608 /// </remarks >
485- let segmentSequential < 'a > ( reduceOp : Expr < 'a -> 'a -> 'a option >) ( clContext : ClContext ) workGroupSize =
609+ let segmentSequentialByOffsets < 'a >
610+ ( reduceOp : Expr < 'a -> 'a -> 'a option >)
611+ ( clContext : ClContext )
612+ workGroupSize
613+ =
486614
487615 let kernel =
488616 <@ fun ( ndRange : Range1D ) uniqueKeyCount keysLength ( offsets : ClArray < int >) ( keys : ClArray < int >) ( values : ClArray < 'a >) ( reducedValues : ClArray < 'a >) ( firstReducedKeys : ClArray < int >) ( resultPositions : ClArray < int >) ->
0 commit comments