Skip to content

Commit 251b611

Browse files
committed
refactor: Reduce.fs
1 parent a0c25a4 commit 251b611

4 files changed

Lines changed: 66 additions & 21 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/CommonQuotes.fs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,23 @@ module SubSum =
3636

3737
let treeSum<'a> opAdd = sumGeneral<'a> <| treeAccess<'a> opAdd
3838

39+
module SubReduce =
40+
let run opAdd =
41+
<@ fun length wgSize gid lid (localValues: 'a []) ->
42+
let mutable step = 2
43+
44+
while step <= wgSize do
45+
if (gid + wgSize / step) < length
46+
&& lid < wgSize / step then
47+
let firstValue = localValues.[lid]
48+
let secondValue = localValues.[lid + wgSize / step]
49+
50+
localValues.[lid] <- (%opAdd) firstValue secondValue
51+
52+
step <- step <<< 1
53+
54+
barrierLocal () @>
55+
3956
module PreparePositions =
4057
let both<'c> =
4158
<@ fun index (result: 'c option) (rawPositionsBuffer: ClArray<int>) (allValuesBuffer: ClArray<'c>) ->

src/GraphBLAS-sharp.Backend/Common/Reduce.fs

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
namespace GraphBLAS.FSharp.Backend.Common
22

33
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend
45
open Microsoft.FSharp.Control
56
open Microsoft.FSharp.Quotations
67

@@ -20,20 +21,9 @@ module Reduce =
2021

2122
barrierLocal ()
2223

23-
let mutable step = 2
24-
2524
if gid < length then
26-
while step <= workGroupSize do
27-
if (gid + workGroupSize / step) < length
28-
&& lid < workGroupSize / step then
29-
let firstValue = localValues.[lid]
30-
let secondValue = localValues.[lid + workGroupSize / step]
31-
32-
localValues.[lid] <- (%opAdd) firstValue secondValue
33-
34-
step <- step <<< 1
3525

36-
barrierLocal ()
26+
(%SubReduce.run opAdd) length workGroupSize gid lid localValues
3727

3828
if lid = 0 then
3929
resultArray.[gid / workGroupSize] <- localValues.[0] @>
@@ -53,10 +43,53 @@ module Reduce =
5343

5444
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
5545

46+
let private scanToCell<'a when 'a: struct>
47+
(clContext: ClContext)
48+
(workGroupSize: int)
49+
(opAdd: Expr<'a -> 'a -> 'a>)
50+
=
51+
52+
let scan =
53+
<@ fun (ndRange: Range1D) length (inputArray: ClArray<'a>) (resultValue: ClCell<'a>) ->
54+
55+
let gid = ndRange.GlobalID0
56+
let lid = ndRange.LocalID0
57+
58+
let localValues = localArray<'a> workGroupSize
59+
60+
if gid < length then
61+
localValues.[lid] <- inputArray.[gid]
62+
63+
barrierLocal ()
64+
65+
if gid < length then
66+
67+
(%SubReduce.run opAdd) length workGroupSize gid lid localValues
68+
69+
if lid = 0 then
70+
resultValue.Value <- localValues.[0] @>
71+
72+
let kernel = clContext.Compile(scan)
73+
74+
fun (processor: MailboxProcessor<_>) (valuesArray: ClArray<'a>) valuesLength (resultValue: ClCell<'a>) ->
75+
76+
let ndRange =
77+
Range1D.CreateValid(valuesArray.Length, workGroupSize)
78+
79+
let kernel = kernel.GetKernel()
80+
81+
processor.Post(
82+
Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange valuesLength valuesArray resultValue)
83+
)
84+
85+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
86+
5687
let run<'a when 'a: struct> (clContext: ClContext) (workGroupSize: int) (opAdd: Expr<'a -> 'a -> 'a>) =
5788

5889
let scan = scan clContext workGroupSize opAdd
5990

91+
let scanToCell = scanToCell clContext workGroupSize opAdd
92+
6093
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) ->
6194

6295
let scan = scan processor
@@ -101,14 +134,9 @@ module Reduce =
101134
let fstVertices = fst verticesArrays
102135

103136
let result =
104-
clContext.CreateClArray(
105-
1,
106-
hostAccessMode = HostAccessMode.NotAccessible,
107-
deviceAccessMode = DeviceAccessMode.ReadWrite,
108-
allocationMode = AllocationMode.Default
109-
)
137+
clContext.CreateClCell Unchecked.defaultof<'a>
110138

111-
scan fstVertices verticesLength result
139+
scanToCell processor fstVertices verticesLength result
112140

113141
processor.Post(Msg.CreateFreeMsg(firstVerticesArray))
114142
processor.Post(Msg.CreateFreeMsg(secondVerticesArray))

tests/GraphBLAS-sharp.Tests/BackendCommonTests/ReduceTests.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ let context = Context.defaultContext.ClContext
1414

1515
let makeTest
1616
(q: MailboxProcessor<_>)
17-
(reduce: MailboxProcessor<_> -> ClArray<'a> -> ClArray<'a>)
17+
(reduce: MailboxProcessor<_> -> ClArray<'a> -> ClCell<'a>)
1818
plus
1919
zero
2020
(filter: 'a [] -> 'a [])

tests/GraphBLAS-sharp.Tests/Vector/Reduce.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ let correctnessGenericTest
2727
zero
2828
op
2929
opQ
30-
(reduce: Expr<'a -> 'a -> 'a> -> MailboxProcessor<_> -> ClVector<'a> -> ClArray<'a>)
30+
(reduce: Expr<'a -> 'a -> 'a> -> MailboxProcessor<_> -> ClVector<'a> -> ClCell<'a>)
3131
filter
3232
case
3333
(array: 'a [])

0 commit comments

Comments
 (0)