@@ -5,6 +5,8 @@ open GraphBLAS.FSharp.Backend.Quotes
55open Microsoft.FSharp .Control
66open Microsoft.FSharp .Quotations
77open GraphBLAS.FSharp .Backend .Objects .ClContext
8+ open GraphBLAS.FSharp .Backend .Objects .ClCell
9+ open GraphBLAS.FSharp .Backend .Objects .ArraysExtensions
810
911module Reduce =
1012 let private runGeneral ( clContext : ClContext ) workGroupSize scan scanToCell =
@@ -235,3 +237,166 @@ module Reduce =
235237 runGeneral clContext workGroupSize scan scanToCell
236238
237239 fun ( processor : MailboxProcessor < _ >) ( array : ClArray < 'a >) -> run processor array
240+
241+ module ByKey =
242+ let sequential ( clContext : ClContext ) workGroupSize ( reduceOp : Expr < 'a -> 'a -> 'a >) =
243+
244+ let kernel =
245+ <@ fun ( ndRange : Range1D ) length ( keys : ClArray < int >) ( values : ClArray < 'a >) ( reducedValues : ClArray < 'a >) ( reducedKeys : ClArray < int >) ->
246+
247+ let gid = ndRange.GlobalID0
248+
249+ if gid = 0 then
250+ let mutable currentKey = keys.[ gid]
251+ let mutable segmentResult = values.[ gid]
252+ let mutable segmentCount = 0
253+
254+ for i in 1 .. length - 1 do
255+ if currentKey = keys.[ i] then
256+ segmentResult <- (% reduceOp) segmentResult values.[ i]
257+ else
258+ reducedValues.[ segmentCount] <- segmentResult
259+ reducedKeys.[ segmentCount] <- currentKey
260+
261+ segmentCount <- segmentCount + 1
262+ currentKey <- keys.[ i]
263+ segmentResult <- values.[ i]
264+
265+ reducedKeys.[ segmentCount] <- currentKey
266+ reducedValues.[ segmentCount] <- segmentResult @>
267+
268+ let kernel = clContext.Compile kernel
269+
270+ fun ( processor : MailboxProcessor < _ >) allocationMode ( resultLength : int ) ( keys : ClArray < int >) ( values : ClArray < 'a >) ->
271+
272+ let reducedValues = clContext.CreateClArrayWithSpecificAllocationMode( allocationMode, resultLength)
273+
274+ let reducedKeys = clContext.CreateClArrayWithSpecificAllocationMode( allocationMode, resultLength)
275+
276+ let ndRange = Range1D.CreateValid( resultLength, workGroupSize)
277+
278+ let kernel = kernel.GetKernel()
279+
280+ processor.Post( Msg.MsgSetArguments( fun () -> kernel.KernelFunc ndRange resultLength keys values reducedValues reducedKeys))
281+
282+ processor.Post( Msg.CreateRunMsg<_, _>( kernel))
283+
284+ let segmentSequential ( clContext : ClContext ) workGroupSize ( reduceOp : Expr < 'a -> 'a -> 'a >) =
285+
286+ let kernel =
287+ <@ fun ( ndRange : Range1D ) uniqueKeyCount ( offsets : ClArray < int >) ( keys : ClArray < int >) ( values : ClArray < 'a >) ( reducedValues : ClArray < 'a >) ( reducedKeys : ClArray < int >) ->
288+
289+ let gid = ndRange.GlobalID0
290+
291+ if gid < uniqueKeyCount then
292+ let startPosition = offsets.[ gid]
293+ let sourceKey = keys.[ startPosition]
294+
295+ let mutable nextPosition = startPosition + 1 // TODO()
296+ let mutable nextKey = keys.[ nextPosition]
297+ let mutable sum = values.[ startPosition]
298+
299+ while nextKey = sourceKey do
300+ sum <- (% reduceOp) sum values.[ nextPosition]
301+
302+ nextPosition <- nextPosition + 1
303+ nextKey <- keys.[ nextPosition]
304+
305+ reducedValues.[ gid] <- sum
306+ reducedKeys.[ gid] <- sourceKey @>
307+
308+ let kernel = clContext.Compile kernel
309+
310+ let getUniqueBitmap = ClArray.getUniqueBitmap clContext workGroupSize
311+
312+ let prefixSum = PrefixSum.runExcludeInplace <@ (+) @> clContext workGroupSize
313+
314+ let removeDuplicates = ClArray.removeDuplications clContext workGroupSize
315+
316+ fun ( processor : MailboxProcessor < _ >) allocationMode ( keys : ClArray < int >) ( values : ClArray < 'a >) ->
317+
318+ let bitmap = getUniqueBitmap processor DeviceOnly keys
319+
320+ let resultLength = ( prefixSum processor bitmap 0 ) .ToHostAndFree processor
321+
322+ let offsets = removeDuplicates processor bitmap
323+
324+ bitmap.Free processor
325+
326+ let reducedValues = clContext.CreateClArrayWithSpecificAllocationMode( allocationMode, resultLength)
327+
328+ let reducedKeys = clContext.CreateClArrayWithSpecificAllocationMode( allocationMode, resultLength)
329+
330+ let ndRange = Range1D.CreateValid( resultLength, workGroupSize)
331+
332+ let kernel = kernel.GetKernel()
333+
334+ processor.Post( Msg.MsgSetArguments( fun () -> kernel.KernelFunc ndRange resultLength offsets keys values reducedValues reducedKeys))
335+
336+ processor.Post( Msg.CreateRunMsg<_, _>( kernel))
337+
338+ let oneWorkGroupSegments ( clContext : ClContext ) workGroupSize ( reduceOp : Expr < 'a -> 'a -> 'a >) =
339+
340+ let kernel =
341+ <@ fun ( ndRange : Range1D ) length ( keys : ClArray < int >) ( values : ClArray < 'a >) ( reducedValues : ClArray < 'a >) ( reducedKeys : ClArray < int >) ->
342+
343+ let lid = ndRange.GlobalID0
344+
345+ // load values to local memory (may be without it)
346+ let localValues = localArray< 'a> length
347+ if lid < length then localValues.[ lid] <- values.[ lid]
348+
349+ // load keys to local memory (mb without it)
350+ let localKeys = localArray< int> length
351+ if lid < length then localKeys.[ lid] <- keys.[ lid]
352+
353+ // get unique keys bitmap
354+ let localBitmap = localArray< int> length
355+ (% PreparePositions.getUniqueBitmapLocal< int>) localKeys length lid localBitmap
356+
357+ // get positions from bitmap by prefix sum
358+ // ??? get bitmap by prefix sum in another kernel ???
359+ (% SubSum.localIntPrefixSum) lid workGroupSize localBitmap
360+ let localPositions = localBitmap
361+
362+ let uniqueKeysCount = localPositions.[ length - 1 ]
363+
364+ if lid < uniqueKeysCount then
365+ let itemKeyId = lid + 1
366+ // we can count start position by itemKeyId
367+ // but loose coalesced memory read pattern
368+
369+ let startKeyIndex =
370+ (% Search.Bin.lowerPosition) length itemKeyId localPositions
371+
372+ match startKeyIndex with
373+ | Some startPosition ->
374+ let sourcePosition = localPositions.[ startPosition]
375+ let mutable currentSum = localValues.[ startPosition]
376+ let mutable currentIndex = startPosition + 1
377+
378+ while currentIndex < length
379+ && localPositions.[ currentIndex] = sourcePosition do
380+
381+ currentSum <- (% reduceOp) currentSum localValues.[ currentIndex]
382+ currentIndex <- currentIndex + 1
383+
384+ reducedKeys.[ lid] <- localKeys.[ startPosition]
385+ reducedValues.[ lid] <- currentSum
386+ | None -> () @>
387+
388+ let kernel = clContext.Compile kernel
389+
390+ fun ( processor : MailboxProcessor < _ >) allocationMode ( resultLength : int ) ( keys : ClArray < int >) ( values : ClArray < 'a >) ->
391+
392+ let reducedValues = clContext.CreateClArrayWithSpecificAllocationMode( allocationMode, resultLength)
393+
394+ let reducedKeys = clContext.CreateClArrayWithSpecificAllocationMode( allocationMode, resultLength)
395+
396+ let ndRange = Range1D.CreateValid( resultLength, workGroupSize)
397+
398+ let kernel = kernel.GetKernel()
399+
400+ processor.Post( Msg.MsgSetArguments( fun () -> kernel.KernelFunc ndRange resultLength keys values reducedValues reducedKeys))
401+
402+ processor.Post( Msg.CreateRunMsg<_, _>( kernel))
0 commit comments