Skip to content

Commit 69be680

Browse files
committed
add: reduce by key option
1 parent 751ee68 commit 69be680

3 files changed

Lines changed: 222 additions & 2 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/Sum.fs

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ open Microsoft.FSharp.Control
66
open Microsoft.FSharp.Quotations
77
open GraphBLAS.FSharp.Backend.Objects.ClContext
88
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
9+
open GraphBLAS.FSharp.Backend.Objects.ClCell
910

1011
module Reduce =
1112
/// <summary>
@@ -616,3 +617,127 @@ module Reduce =
616617
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
617618

618619
firstReducedKeys, secondReducedKeys, reducedValues
620+
621+
/// <summary>
622+
/// Reduces values by key. Each segment is reduced by one work item.
623+
/// </summary>
624+
/// <param name="clContext">ClContext.</param>
625+
/// <param name="workGroupSize">Work group size.</param>
626+
/// <param name="reduceOp">Operation for reducing values.</param>
627+
/// <remarks>
628+
/// The length of the result must be calculated in advance.
629+
/// </remarks>
630+
let segmentSequentialOption<'a> (clContext: ClContext) workGroupSize (reduceOp: Expr<'a -> 'a -> 'a option>) =
631+
632+
let kernel =
633+
<@ fun (ndRange: Range1D) uniqueKeyCount keysLength (offsets: ClArray<int>) (firstKeys: ClArray<int>) (secondKeys: ClArray<int>) (values: ClArray<'a>) (reducedValues: ClArray<'a>) (firstReducedKeys: ClArray<int>) (secondReducedKeys: ClArray<int>) (resultPositions: ClArray<int>) ->
634+
635+
let gid = ndRange.GlobalID0
636+
637+
if gid < uniqueKeyCount then
638+
let startPosition = offsets.[gid]
639+
640+
let firstSourceKey = firstKeys.[startPosition]
641+
let secondSourceKey = secondKeys.[startPosition]
642+
643+
let mutable sum = Some values.[startPosition]
644+
645+
let mutable currentPosition = startPosition + 1
646+
647+
while currentPosition < keysLength
648+
&& firstSourceKey = firstKeys.[currentPosition]
649+
&& secondSourceKey = secondKeys.[currentPosition] do
650+
651+
match sum with
652+
| Some value ->
653+
let result = ((%reduceOp) value values.[currentPosition]) // brahma error
654+
655+
sum <- result
656+
| None ->
657+
sum <- Some values.[currentPosition]
658+
659+
currentPosition <- currentPosition + 1
660+
661+
match sum with
662+
| Some value ->
663+
reducedValues.[gid] <- value
664+
resultPositions.[gid] <- 1
665+
| None ->
666+
resultPositions.[gid] <- 0
667+
668+
firstReducedKeys.[gid] <- firstSourceKey
669+
secondReducedKeys.[gid] <- secondSourceKey @>
670+
671+
let kernel = clContext.Compile kernel
672+
673+
let scatterData = Scatter.lastOccurrence clContext workGroupSize
674+
675+
let scatterIndices = Scatter.lastOccurrence clContext workGroupSize
676+
677+
let prefixSum = PrefixSum.standardExcludeInplace clContext workGroupSize
678+
679+
fun (processor: MailboxProcessor<_>) allocationMode (resultLength: int) (offsets: ClArray<int>) (firstKeys: ClArray<int>) (secondKeys: ClArray<int>) (values: ClArray<'a>) ->
680+
681+
let reducedValues =
682+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
683+
684+
let firstReducedKeys =
685+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
686+
687+
let secondReducedKeys =
688+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
689+
690+
let resultPositions =
691+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, resultLength)
692+
693+
let ndRange =
694+
Range1D.CreateValid(resultLength, workGroupSize)
695+
696+
let kernel = kernel.GetKernel()
697+
698+
processor.Post(
699+
Msg.MsgSetArguments
700+
(fun () ->
701+
kernel.KernelFunc
702+
ndRange
703+
resultLength
704+
firstKeys.Length
705+
offsets
706+
firstKeys
707+
secondKeys
708+
values
709+
reducedValues
710+
firstReducedKeys
711+
secondReducedKeys
712+
resultPositions)
713+
)
714+
715+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
716+
717+
let resultLength =
718+
(prefixSum processor resultPositions).ToHostAndFree processor
719+
720+
let resultValues =
721+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
722+
723+
scatterData processor resultPositions reducedValues resultValues
724+
725+
reducedValues.Free processor
726+
727+
let resultFirstKeys =
728+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
729+
730+
scatterIndices processor resultPositions firstReducedKeys resultFirstKeys
731+
732+
firstReducedKeys.Free processor
733+
734+
let resultSecondKeys =
735+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
736+
737+
scatterIndices processor resultPositions secondReducedKeys resultSecondKeys
738+
739+
secondReducedKeys.Free processor
740+
741+
resultPositions.Free processor
742+
743+
resultFirstKeys, resultSecondKeys, resultValues

tests/GraphBLAS-sharp.Tests/Common/Reduce/ReduceByKey.fs

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module GraphBLAS.FSharp.Tests.Backend.Common.Reduce.ByKey
22

33
open Expecto
44
open GraphBLAS.FSharp.Backend.Common
5+
open GraphBLAS.FSharp.Backend.Quotes
56
open GraphBLAS.FSharp.Test
67
open GraphBLAS.FSharp.Tests
78
open GraphBLAS.FSharp.Backend.Objects.ClContext
@@ -14,6 +15,16 @@ let processor = Context.defaultContext.Queue
1415

1516
let config = Utils.defaultConfig
1617

18+
let getOffsets array =
19+
Array.map fst array
20+
|> HostPrimitives.getUniqueBitmapFirstOccurrence
21+
|> HostPrimitives.getBitPositions
22+
23+
let getOffsets2D array =
24+
Array.map (fun (fst, snd, _) -> fst, snd) array
25+
|> HostPrimitives.getUniqueBitmapFirstOccurrence
26+
|> HostPrimitives.getBitPositions
27+
1728
let checkResult isEqual actualKeys actualValues keys values reduceOp =
1829

1930
let expectedKeys, expectedValues =
@@ -336,3 +347,87 @@ let sequentialSegmentTests2D =
336347
createTestSequentialSegments2D<bool> (=) (&&) <@ (&&) @> ]
337348

338349
testList "Sequential segments 2D" [ addTests; mulTests ]
350+
351+
let checkResult2DOption isEqual firstActualKeys secondActualKeys actualValues firstKeys secondKeys values reduceOp =
352+
353+
let reduceOp left right =
354+
match left, right with
355+
| Some left, Some right ->
356+
reduceOp left right
357+
| Some value, None
358+
| None, Some value -> Some value
359+
| _ -> None
360+
361+
let expectedFirstKeys, expectedSecondKeys, expectedValues =
362+
let keys = Array.zip firstKeys secondKeys
363+
364+
Array.zip keys values
365+
|> Array.groupBy fst
366+
|> Array.map (fun (key, array) -> key, Array.map snd array)
367+
|> Array.map (fun (key, array) ->
368+
Array.map Some array
369+
|> Array.reduce reduceOp
370+
|> fun result -> key, result)
371+
|> Array.choose (fun ((fstKey, sndKey), value) ->
372+
match value with
373+
| Some value -> Some (fstKey, sndKey, value)
374+
| _ -> None )
375+
|> Array.unzip3
376+
377+
"First keys must be the same"
378+
|> Utils.compareArrays (=) firstActualKeys expectedFirstKeys
379+
380+
"Second keys must be the same"
381+
|> Utils.compareArrays (=) secondActualKeys expectedSecondKeys
382+
383+
"Values must the same"
384+
|> Utils.compareArrays isEqual actualValues expectedValues
385+
386+
let test2DOption<'a> isEqual reduce reduceOp (array: (int * int * 'a) []) =
387+
if array.Length > 0 then
388+
let array = Array.sortBy (fun (fst, snd, _) -> fst, snd) array
389+
390+
let offsets = getOffsets2D array
391+
392+
let firstKeys, secondKeys, values = Array.unzip3 array
393+
394+
let clOffsets =
395+
context.CreateClArrayWithSpecificAllocationMode(HostInterop, offsets)
396+
397+
let clFirstKeys =
398+
context.CreateClArrayWithSpecificAllocationMode(DeviceOnly, firstKeys)
399+
400+
let clSecondKeys =
401+
context.CreateClArrayWithSpecificAllocationMode(DeviceOnly, secondKeys)
402+
403+
let clValues =
404+
context.CreateClArrayWithSpecificAllocationMode(DeviceOnly, values)
405+
406+
let clFirstActualKeys, clSecondActualKeys, clReducedValues: ClArray<int> * ClArray<int> * ClArray<'a> =
407+
reduce processor DeviceOnly offsets.Length clOffsets clFirstKeys clSecondKeys clValues
408+
409+
let reducedFirsKeys = clFirstActualKeys.ToHostAndFree processor
410+
let reducesSecondKeys = clSecondActualKeys.ToHostAndFree processor
411+
let reducedValues = clReducedValues.ToHostAndFree processor
412+
413+
checkResult2DOption isEqual reducedFirsKeys reducesSecondKeys reducedValues firstKeys secondKeys values reduceOp
414+
415+
let createTest2DOption (isEqual: 'a -> 'a -> bool) (reduceOpQ, reduceOp) =
416+
let reduce =
417+
Reduce.ByKey2D.segmentSequentialOption context Utils.defaultWorkGroupSize reduceOpQ
418+
419+
test2DOption<'a> isEqual reduce reduceOp
420+
|> testPropertyWithConfig { config with arbitrary = [ typeof<Generators.ArrayOfDistinctKeys> ] } $"test on {typeof<'a>}"
421+
422+
let testsByKey2DSegmentsSequential =
423+
[ createTest2DOption (=) ArithmeticOperations.intAdd
424+
425+
if Utils.isFloat64Available context.ClDevice then
426+
createTest2DOption Utils.floatIsEqual ArithmeticOperations.floatAdd
427+
428+
createTest2DOption Utils.float32IsEqual ArithmeticOperations.float32Add
429+
createTest2DOption (=) ArithmeticOperations.boolAdd ]
430+
|> testList "2D option"
431+
432+
433+

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ let allTests =
9595
testList
9696
"All tests"
9797
[ // SpGeMM.expandTests
98-
SpGeMM.generalTests
98+
// SpGeMM.generalTests
9999
// Common.Gather.initTests
100100
// Common.ClArray.Choose.tests2 ]
101-
]
101+
Common.Reduce.ByKey.testsByKey2DSegmentsSequential ]
102102
|> testSequenced
103103

104104
[<EntryPoint>]

0 commit comments

Comments
 (0)