Skip to content

Commit 95fea31

Browse files
committed
refactor: duplication, allTests in Reduce
1 parent 96c0c08 commit 95fea31

6 files changed

Lines changed: 39 additions & 34 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ open GraphBLAS.FSharp.Backend.Objects.ClContext
66
open GraphBLAS.FSharp.Backend.Objects.ClCell
77
open GraphBLAS.FSharp.Backend.Quotes
88
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
9+
open GraphBLAS.FSharp.Backend.Quotes
910

1011
module ClArray =
1112
let init (clContext: ClContext) workGroupSize (initializer: Expr<int -> 'a>) =
@@ -167,16 +168,12 @@ module ClArray =
167168

168169
let getUniqueBitmapFirstOccurrence clContext =
169170
getUniqueBitmapGeneral
170-
<| <@ fun (gid: int) (_: int) (inputArray: ClArray<'a>) ->
171-
gid = 0
172-
|| inputArray.[gid - 1] <> inputArray.[gid] @>
171+
<| Predicates.firstOccurrence ()
173172
<| clContext
174173

175174
let getUniqueBitmapLastOccurrence clContext =
176175
getUniqueBitmapGeneral
177-
<| <@ fun (gid: int) (length: int) (inputArray: ClArray<'a>) ->
178-
gid = length - 1
179-
|| inputArray.[gid] <> inputArray.[gid + 1] @>
176+
<| Predicates.lastOccurrence ()
180177
<| clContext
181178

182179
///<description>Remove duplicates form the given array.</description>

src/GraphBLAS-sharp.Backend/Common/Scatter.fs

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,9 @@
11
namespace GraphBLAS.FSharp.Backend.Common
22

33
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend.Quotes
45

56
module internal Scatter =
6-
let private firstOccurencePredicate () =
7-
<@ fun gid _ (positions: ClArray<int>) ->
8-
// first occurrence condition
9-
(gid = 0 || positions.[gid - 1] <> positions.[gid]) @>
10-
11-
let private lastOccurrencePredicate () =
12-
<@ fun gid positionsLength (positions: ClArray<int>) ->
13-
// last occurrence condition
14-
(gid = positionsLength - 1
15-
|| positions.[gid] <> positions.[gid + 1]) @>
16-
17-
187
let private general<'a> predicate (clContext: ClContext) workGroupSize =
198

209
let run =
@@ -75,7 +64,7 @@ module internal Scatter =
7564
/// </code>
7665
/// </example>
7766
let firstOccurrence clContext =
78-
general <| firstOccurencePredicate () <| clContext
67+
general <| Predicates.firstOccurrence () <| clContext
7968

8069
/// <summary>
8170
/// Creates a new array from the given one where it is indicated by the array of positions at which position in the new array
@@ -96,7 +85,7 @@ module internal Scatter =
9685
/// </code>
9786
/// </example>
9887
let lastOccurrence clContext =
99-
general <| lastOccurrencePredicate () <| clContext
88+
general <| Predicates.lastOccurrence () <| clContext
10089

10190
let private generalInit<'a> predicate valueMap (clContext: ClContext) workGroupSize =
10291

@@ -156,7 +145,7 @@ module internal Scatter =
156145
/// <param name="valueMap">Maps global id to a value</param>
157146
let initFirsOccurrence<'a> valueMap =
158147
generalInit<'a>
159-
<| firstOccurencePredicate ()
148+
<| Predicates.firstOccurrence ()
160149
<| valueMap
161150

162151
/// <summary>
@@ -180,5 +169,5 @@ module internal Scatter =
180169
/// <param name="valueMap">Maps global id to a value</param>
181170
let initLastOccurrence<'a> valueMap =
182171
generalInit<'a>
183-
<| lastOccurrencePredicate ()
172+
<| Predicates.lastOccurrence ()
184173
<| valueMap

src/GraphBLAS-sharp.Backend/Quotes/Arithmetic.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ module ArithmeticOperations =
159159
// multiplication
160160
let intMul = createPair 0 (*) <@ (*) @>
161161

162-
let boolMul = createPair false (&&) <@ (&&) @>
162+
let boolMul = createPair true (&&) <@ (&&) @>
163163

164164
let floatMul = createPair 0.0 (*) <@ (*) @>
165165

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
11
namespace GraphBLAS.FSharp.Backend.Quotes
22

3+
open Brahma.FSharp
4+
35
module Predicates =
46
let isSome<'a> =
57
<@ fun (item: 'a option) ->
68
match item with
79
| Some _ -> true
810
| _ -> false @>
11+
12+
let inline lastOccurrence () =
13+
<@ fun (gid: int) (length: int) (inputArray: ClArray<'a>) ->
14+
gid = length - 1
15+
|| inputArray.[gid] <> inputArray.[gid + 1] @>
16+
17+
let inline firstOccurrence () =
18+
<@ fun (gid: int) (_: int) (inputArray: ClArray<'a>) ->
19+
gid = 0
20+
|| inputArray.[gid - 1] <> inputArray.[gid] @>

tests/GraphBLAS-sharp.Tests/Common/Reduce/ReduceByKey.fs

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@ let processor = Context.defaultContext.Queue
1515

1616
let config = Utils.defaultConfig
1717

18-
let getOffsets array =
18+
let private getOffsets array =
1919
Array.map fst array
2020
|> HostPrimitives.getUniqueBitmapFirstOccurrence
2121
|> HostPrimitives.getBitPositions
2222

23-
let getOffsets2D array =
23+
let private getOffsets2D array =
2424
Array.map (fun (fst, snd, _) -> fst, snd) array
2525
|> HostPrimitives.getUniqueBitmapFirstOccurrence
2626
|> HostPrimitives.getBitPositions
2727

28-
let checkResult isEqual actualKeys actualValues keys values reduceOp =
28+
let private checkResult isEqual actualKeys actualValues keys values reduceOp =
2929

3030
let expectedKeys, expectedValues =
3131
HostPrimitives.reduceByKey keys values reduceOp
@@ -36,7 +36,7 @@ let checkResult isEqual actualKeys actualValues keys values reduceOp =
3636
"Values must the same"
3737
|> Utils.compareArrays isEqual actualValues expectedValues
3838

39-
let makeTest isEqual reduce reduceOp (arrayAndKeys: (int * 'a) []) =
39+
let private makeTest isEqual reduce reduceOp (arrayAndKeys: (int * 'a) []) =
4040
let keys, values =
4141
Array.sortBy fst arrayAndKeys |> Array.unzip
4242

@@ -60,7 +60,7 @@ let makeTest isEqual reduce reduceOp (arrayAndKeys: (int * 'a) []) =
6060

6161
checkResult isEqual actualKeys actualValues keys values reduceOp
6262

63-
let createTestSequential<'a> (isEqual: 'a -> 'a -> bool) reduceOp reduceOpQ =
63+
let private createTestSequential<'a> (isEqual: 'a -> 'a -> bool) reduceOp reduceOpQ =
6464

6565
let reduce =
6666
Reduce.ByKey.sequential context Utils.defaultWorkGroupSize reduceOpQ
@@ -339,7 +339,7 @@ let createTestSequentialSegments2D<'a> (isEqual: 'a -> 'a -> bool) reduceOp redu
339339
arbitrary = [ typeof<Generators.ArrayOfDistinctKeys> ] }
340340
$"test on {typeof<'a>}"
341341

342-
let sequentialSegmentTests2D =
342+
let sequentialSegment2DTests =
343343
let addTests =
344344
testList
345345
"add tests"
@@ -446,7 +446,7 @@ let createTest2DOption (isEqual: 'a -> 'a -> bool) (reduceOpQ, reduceOp) =
446446
arbitrary = [ typeof<Generators.ArrayOfDistinctKeys> ] }
447447
$"test on {typeof<'a>}"
448448

449-
let testsByKey2DSegmentsSequentialOption =
449+
let testsSegmentsSequential2DOption =
450450
[ createTest2DOption (=) ArithmeticOperations.intAdd
451451

452452
if Utils.isFloat64Available context.ClDevice then
@@ -455,3 +455,13 @@ let testsByKey2DSegmentsSequentialOption =
455455
createTest2DOption Utils.float32IsEqual ArithmeticOperations.float32Add
456456
createTest2DOption (=) ArithmeticOperations.boolAdd ]
457457
|> testList "2D option"
458+
459+
let allTests =
460+
testList
461+
"Reduce.ByKey"
462+
[ sequentialTest
463+
oneWorkGroupTest
464+
sequentialSegmentTests
465+
sequential2DTest
466+
sequentialSegment2DTests
467+
testsSegmentsSequential2DOption ]

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,7 @@ let commonTests =
2323
let reduceTests =
2424
testList
2525
"Reduce"
26-
[ Common.Reduce.ByKey.sequentialTest
27-
Common.Reduce.ByKey.sequentialSegmentTests
28-
Common.Reduce.ByKey.oneWorkGroupTest
29-
Common.Reduce.ByKey.testsByKey2DSegmentsSequentialOption
26+
[ Common.Reduce.ByKey.allTests
3027
Common.Reduce.Reduce.tests
3128
Common.Reduce.Sum.tests ]
3229

0 commit comments

Comments
 (0)