Skip to content

Commit 3703609

Browse files
committed
merge: dev
2 parents 048477d + e850a49 commit 3703609

7 files changed

Lines changed: 199 additions & 67 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs

Lines changed: 26 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ namespace GraphBLAS.FSharp.Backend.Common
33
open Brahma.FSharp
44
open FSharp.Quotations
55
open GraphBLAS.FSharp.Backend.Quotes
6+
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
7+
open GraphBLAS.FSharp.Backend.Objects.ClCell
68

79
module PrefixSum =
810
let private update (opAdd: Expr<'a -> 'a -> 'a>) (clContext: ClContext) workGroupSize =
@@ -38,7 +40,7 @@ module PrefixSum =
3840
)
3941

4042
processor.Post(Msg.CreateRunMsg<_, _> kernel)
41-
processor.Post(Msg.CreateFreeMsg(mirror))
43+
mirror.Free processor
4244

4345
let private scanGeneral
4446
beforeLocalSumClear
@@ -48,10 +50,8 @@ module PrefixSum =
4850
workGroupSize
4951
=
5052

51-
let subSum = SubSum.treeSum opAdd
52-
5353
let scan =
54-
<@ fun (ndRange: Range1D) inputArrayLength verticesLength (resultBuffer: ClArray<'a>) (verticesBuffer: ClArray<'a>) (totalSumBuffer: ClCell<'a>) (zero: ClCell<'a>) (mirror: ClCell<bool>) ->
54+
<@ fun (ndRange: Range1D) inputArrayLength verticesLength (inputArray: ClArray<'a>) (verticesBuffer: ClArray<'a>) (totalSumBuffer: ClCell<'a>) (zero: ClCell<'a>) (mirror: ClCell<bool>) ->
5555

5656
let mirror = mirror.Value
5757

@@ -62,46 +62,34 @@ module PrefixSum =
6262
if mirror then
6363
i <- inputArrayLength - 1 - i
6464

65-
let localID = ndRange.LocalID0
65+
let lid = ndRange.LocalID0
6666

6767
let zero = zero.Value
6868

6969
if gid < inputArrayLength then
70-
resultLocalBuffer.[localID] <- resultBuffer.[i]
70+
resultLocalBuffer.[lid] <- inputArray.[i]
7171
else
72-
resultLocalBuffer.[localID] <- zero
72+
resultLocalBuffer.[lid] <- zero
7373

7474
barrierLocal ()
7575

76-
(%subSum) workGroupSize localID resultLocalBuffer
77-
78-
if localID = workGroupSize - 1 then
79-
if verticesLength <= 1 && localID = gid then
80-
totalSumBuffer.Value <- resultLocalBuffer.[localID]
81-
82-
verticesBuffer.[gid / workGroupSize] <- resultLocalBuffer.[localID]
83-
(%beforeLocalSumClear) resultBuffer resultLocalBuffer.[localID] inputArrayLength gid i
84-
resultLocalBuffer.[localID] <- zero
85-
86-
let mutable step = workGroupSize
76+
// Local tree reduce
77+
(%SubSum.upSweep opAdd) workGroupSize lid resultLocalBuffer
8778

88-
while step > 1 do
89-
barrierLocal ()
79+
if lid = workGroupSize - 1 then
80+
// if last iteration
81+
if verticesLength <= 1 && lid = gid then
82+
totalSumBuffer.Value <- resultLocalBuffer.[lid]
9083

91-
if localID < workGroupSize / step then
92-
let i = step * (localID + 1) - 1
93-
let j = i - (step >>> 1)
84+
verticesBuffer.[gid / workGroupSize] <- resultLocalBuffer.[lid]
85+
(%beforeLocalSumClear) inputArray resultLocalBuffer.[lid] inputArrayLength gid i
86+
resultLocalBuffer.[lid] <- zero
9487

95-
let tmp = resultLocalBuffer.[i]
96-
let buff = (%opAdd) tmp resultLocalBuffer.[j]
97-
resultLocalBuffer.[i] <- buff
98-
resultLocalBuffer.[j] <- tmp
99-
100-
step <- step >>> 1
88+
(%SubSum.downSweep opAdd) workGroupSize lid resultLocalBuffer
10189

10290
barrierLocal ()
10391

104-
(%writeData) resultBuffer resultLocalBuffer inputArrayLength workGroupSize gid i localID @>
92+
(%writeData) inputArray resultLocalBuffer inputArrayLength workGroupSize gid i lid @>
10593

10694
let program = clContext.Compile(scan)
10795

@@ -132,13 +120,14 @@ module PrefixSum =
132120
)
133121

134122
processor.Post(Msg.CreateRunMsg<_, _> kernel)
135-
processor.Post(Msg.CreateFreeMsg(zero))
136-
processor.Post(Msg.CreateFreeMsg(mirror))
123+
124+
zero.Free processor
125+
mirror.Free processor
137126

138127
let private scanExclusive<'a when 'a: struct> =
139128
scanGeneral
140129
<@ fun (_: ClArray<'a>) (_: 'a) (_: int) (_: int) (_: int) -> () @>
141-
<@ fun (resultBuffer: ClArray<'a>) (resultLocalBuffer: 'a []) (inputArrayLength: int) (smth: int) (gid: int) (i: int) (localID: int) ->
130+
<@ fun (resultBuffer: ClArray<'a>) (resultLocalBuffer: 'a []) (inputArrayLength: int) (_: int) (gid: int) (i: int) (localID: int) ->
142131

143132
if gid < inputArrayLength then
144133
resultBuffer.[i] <- resultLocalBuffer.[localID] @>
@@ -206,8 +195,8 @@ module PrefixSum =
206195
verticesArrays <- swap verticesArrays
207196
verticesLength <- (verticesLength - 1) / workGroupSize + 1
208197

209-
processor.Post(Msg.CreateFreeMsg(firstVertices))
210-
processor.Post(Msg.CreateFreeMsg(secondVertices))
198+
firstVertices.Free processor
199+
secondVertices.Free processor
211200

212201
totalSum
213202

@@ -226,7 +215,7 @@ module PrefixSum =
226215
/// <code>
227216
/// let arr = [| 1; 1; 1; 1 |]
228217
/// let sum = [| 0 |]
229-
/// runExcludeInplace clContext workGroupSize processor arr sum <@ (+) @> 0
218+
/// runExcludeInplace clContext workGroupSize processor arr sum (+) 0
230219
/// |> ignore
231220
/// ...
232221
/// > val arr = [| 0; 1; 2; 3 |]
@@ -252,7 +241,7 @@ module PrefixSum =
252241
/// <code>
253242
/// let arr = [| 1; 1; 1; 1 |]
254243
/// let sum = [| 0 |]
255-
/// runExcludeInplace clContext workGroupSize processor arr sum <@ (+) @> 0
244+
/// runExcludeInplace clContext workGroupSize processor arr sum (+) 0
256245
/// |> ignore
257246
/// ...
258247
/// > val arr = [| 1; 2; 3; 4 |]
@@ -271,7 +260,6 @@ module PrefixSum =
271260

272261
scan processor inputArray 0
273262

274-
275263
module ByKey =
276264
let private sequentialSegments opWrite (clContext: ClContext) workGroupSize opAdd zero =
277265

src/GraphBLAS-sharp.Backend/Quotes/SubSum.fs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,30 @@ module SubSum =
3131

3232
barrierLocal () @>
3333

34-
let sequentialSum<'a> opAdd =
35-
sumGeneral<'a> <| sequentialAccess<'a> opAdd
34+
let sequentialSum<'a> = sumGeneral<'a> << sequentialAccess<'a>
3635

37-
let treeSum<'a> opAdd = sumGeneral<'a> <| treeAccess<'a> opAdd
36+
let upSweep<'a> = sumGeneral<'a> << treeAccess<'a>
37+
38+
let downSweep opAdd =
39+
<@ fun wgSize lid (localBuffer: 'a []) ->
40+
let mutable step = wgSize
41+
42+
while step > 1 do
43+
barrierLocal ()
44+
45+
if lid < wgSize / step then
46+
let i = step * (lid + 1) - 1
47+
let j = i - (step >>> 1)
48+
49+
let tmp = localBuffer.[i]
50+
51+
let operand = localBuffer.[j] // brahma error
52+
let buff = (%opAdd) tmp operand
53+
54+
localBuffer.[i] <- buff
55+
localBuffer.[j] <- tmp
56+
57+
step <- step >>> 1 @>
3858

3959
let localPrefixSum opAdd =
4060
<@ fun (lid: int) (workGroupSize: int) (array: 'a []) ->
@@ -52,4 +72,6 @@ module SubSum =
5272
barrierLocal ()
5373
array.[lid] <- value @>
5474

75+
76+
5577
let localIntPrefixSum = localPrefixSum <@ (+) @>
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
module GraphBLAS.FSharp.Tests.Backend.Common.Scan.ByKey
2+
3+
open GraphBLAS.FSharp.Backend.Common
4+
open GraphBLAS.FSharp.Backend.Objects.ClContext
5+
open Expecto
6+
open GraphBLAS.FSharp.Tests
7+
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
8+
9+
let context = Context.defaultContext.ClContext
10+
11+
let processor = Context.defaultContext.Queue
12+
13+
let checkResult isEqual keysAndValues actual hostScan =
14+
15+
let expected =
16+
HostPrimitives.scanByKey hostScan keysAndValues
17+
18+
"Results must be the same"
19+
|> Utils.compareArrays isEqual actual expected
20+
21+
let makeTestSequentialSegments isEqual scanHost scanDevice (keysAndValues: (int * 'a) []) =
22+
if keysAndValues.Length > 0 then
23+
let keys, values =
24+
Array.sortBy fst keysAndValues |> Array.unzip
25+
26+
let offsets =
27+
HostPrimitives.getUniqueBitmapFirstOccurrence keys
28+
|> HostPrimitives.getBitPositions
29+
30+
let uniqueKeysCount = Array.distinct keys |> Array.length
31+
32+
let clKeys =
33+
context.CreateClArrayWithSpecificAllocationMode(HostInterop, keys)
34+
35+
let clValues =
36+
context.CreateClArrayWithSpecificAllocationMode(HostInterop, values)
37+
38+
let clOffsets =
39+
context.CreateClArrayWithSpecificAllocationMode(HostInterop, offsets)
40+
41+
scanDevice processor uniqueKeysCount clValues clKeys clOffsets
42+
43+
let actual = clValues.ToHostAndFree processor
44+
clKeys.Free processor
45+
clOffsets.Free processor
46+
47+
let keysAndValues = Array.zip keys values
48+
49+
checkResult isEqual keysAndValues actual scanHost
50+
51+
let createTest (zero: 'a) opAddQ opAdd isEqual deviceScan hostScan =
52+
53+
let hostScan = hostScan zero opAdd
54+
55+
let deviceScan =
56+
deviceScan context Utils.defaultWorkGroupSize opAddQ zero
57+
58+
makeTestSequentialSegments isEqual hostScan deviceScan
59+
|> testPropertyWithConfig Utils.defaultConfig $"test on {typeof<'a>}"
60+
61+
let sequentialSegmentsTests =
62+
let excludeTests =
63+
[ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
64+
65+
if Utils.isFloat64Available context.ClDevice then
66+
createTest
67+
0.0
68+
<@ (+) @>
69+
(+)
70+
Utils.floatIsEqual
71+
PrefixSum.ByKey.sequentialExclude
72+
HostPrimitives.prefixSumExclude
73+
74+
createTest
75+
0.0f
76+
<@ (+) @>
77+
(+)
78+
Utils.float32IsEqual
79+
PrefixSum.ByKey.sequentialExclude
80+
HostPrimitives.prefixSumExclude
81+
82+
createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude
83+
createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialExclude HostPrimitives.prefixSumExclude ]
84+
|> testList "exclude"
85+
86+
let includeTests =
87+
[ createTest 0 <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
88+
89+
if Utils.isFloat64Available context.ClDevice then
90+
createTest
91+
0.0
92+
<@ (+) @>
93+
(+)
94+
Utils.floatIsEqual
95+
PrefixSum.ByKey.sequentialInclude
96+
HostPrimitives.prefixSumInclude
97+
98+
createTest
99+
0.0f
100+
<@ (+) @>
101+
(+)
102+
Utils.float32IsEqual
103+
PrefixSum.ByKey.sequentialInclude
104+
HostPrimitives.prefixSumInclude
105+
106+
createTest false <@ (||) @> (||) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude
107+
createTest 0u <@ (+) @> (+) (=) PrefixSum.ByKey.sequentialInclude HostPrimitives.prefixSumInclude ]
108+
109+
|> testList "include"
110+
111+
testList "Sequential segments" [ excludeTests; includeTests ]

tests/GraphBLAS-sharp.Tests/Common/ClArray/PrefixSum.fs renamed to tests/GraphBLAS-sharp.Tests/Common/Scan/PrefixSum.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module GraphBLAS.FSharp.Tests.Backend.Common.ClArray.PrefixSum
1+
module GraphBLAS.FSharp.Tests.Backend.Common.Scan.PrefixSum
22

33
open Expecto
44
open Expecto.Logging
@@ -63,7 +63,7 @@ let makeTest plus zero isEqual scan (array: 'a []) =
6363
let testFixtures plus plusQ zero isEqual name =
6464
PrefixSum.runIncludeInplace plusQ context wgSize
6565
|> makeTest plus zero isEqual
66-
|> testPropertyWithConfig config (sprintf "Correctness on %s" name)
66+
|> testPropertyWithConfig config $"Correctness on %s{name}"
6767

6868
let tests =
6969
q.Error.Add(fun e -> failwithf "%A" e)

tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,17 @@
2525
<Compile Include="Common/ClArray/RemoveDuplicates.fs" />
2626
<Compile Include="Common/ClArray/Copy.fs" />
2727
<Compile Include="Common/ClArray/Replicate.fs" />
28-
<Compile Include="Common/ClArray/PrefixSum.fs" />
2928
<Compile Include="Common/Sort/Bitonic.fs" />
3029
<Compile Include="Common/Sort/Radix.fs" />
3130
<Compile Include="Common/Reduce/Sum.fs" />
3231
<Compile Include="Common/Reduce/Reduce.fs" />
3332
<Compile Include="Common/Reduce/ReduceByKey.fs" />
33+
<Compile Include="Common/Scan/PrefixSum.fs" />
34+
<Compile Include="Common/Scan/ByKey.fs" />
35+
<!--Compile Include="MatrixOperationsTests/GetTuplesTests.fs" /-->
36+
<!--Compile Include="MatrixOperationsTests/MxvTests.fs" /-->
37+
<!--Compile Include="MatrixOperationsTests/VxmTests.fs" /-->
38+
<!--Compile Include="AlgorithmsTests/BfsTests.fs" /-->
3439
<Compile Include="Vector/ZeroCreate.fs" />
3540
<Compile Include="Vector/OfList.fs" />
3641
<Compile Include="Vector/Copy.fs" />

tests/GraphBLAS-sharp.Tests/Helpers.fs

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,13 @@ module Utils =
146146
| _ -> failwith "matrix format must be CSR"
147147

148148
module HostPrimitives =
149-
let prefixSumInclude array =
150-
Array.scan (+) 0 array
151-
|> fun scanned -> scanned.[1..]
149+
let prefixSumInclude zero add array =
150+
Array.scan add zero array
151+
|> fun scanned -> scanned.[1..], Array.last scanned
152152

153-
let prefixSumExclude sourceArray =
154-
prefixSumInclude sourceArray
155-
|> Array.insertAt 0 0
153+
let prefixSumExclude zero add sourceArray =
154+
prefixSumInclude zero add sourceArray
155+
|> (fst >> Array.insertAt 0 zero)
156156
|> fun array -> Array.take sourceArray.Length array, Array.last array
157157

158158
let getUniqueBitmapLastOccurrence array =
@@ -181,18 +181,14 @@ module HostPrimitives =
181181
|> Array.mapi (fun index bit -> if bit = 1 then Some index else None)
182182
|> Array.choose id
183183

184-
let reduceByKey keys values reduceOp =
185-
let zipped = Array.zip keys values
186-
187-
Array.distinct keys
184+
let reduceByKey keys value reduceOp =
185+
Array.zip keys value
186+
|> Array.groupBy fst
188187
|> Array.map
189-
(fun key ->
190-
// extract elements corresponding to key
191-
(key,
192-
Array.map snd
193-
<| Array.filter ((=) key << fst) zipped))
194-
// reduce elements
195-
|> Array.map (fun (key, values) -> key, Array.reduce reduceOp values)
188+
(fun (key, array) ->
189+
Array.map snd array
190+
|> Array.reduce reduceOp
191+
|> fun value -> key, value)
196192
|> Array.unzip
197193

198194
let reduceByKey2D firstKeys secondKeys values reduceOp =
@@ -262,6 +258,11 @@ module HostPrimitives =
262258
| Some value -> value
263259
| None -> zero
264260

261+
let scanByKey scan keysAndValues =
262+
Array.groupBy fst keysAndValues
263+
|> Array.map (fun (_, array) -> Array.map snd array |> scan |> fst)
264+
|> Array.concat
265+
265266
module Context =
266267
type TestContext =
267268
{ ClContext: ClContext

0 commit comments

Comments
 (0)