Skip to content

Commit a1f38e0

Browse files
committed
refactor: gen Sum and Reduce
1 parent 591310a commit a1f38e0

11 files changed

Lines changed: 232 additions & 215 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/Reduce.fs

Lines changed: 0 additions & 144 deletions
This file was deleted.

src/GraphBLAS-sharp.Backend/Common/Sum.fs

Lines changed: 146 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,65 @@
11
namespace GraphBLAS.FSharp.Backend.Common
22

33
open Brahma.FSharp
4-
open GraphBLAS.FSharp.Backend
54
open GraphBLAS.FSharp.Backend.Quotes
5+
open Microsoft.FSharp.Control
66
open Microsoft.FSharp.Quotations
77

8-
module internal Sum =
9-
let private scan (clContext: ClContext) (workGroupSize: int) (opAdd: Expr<'a -> 'a -> 'a>) zero =
8+
module Fold =
9+
let private runGeneral (clContext: ClContext) (workGroupSize: int) scan scanToCell =
10+
11+
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) ->
12+
13+
let scan = scan processor
14+
15+
let firstLength =
16+
(inputArray.Length - 1) / workGroupSize + 1
17+
18+
let firstVerticesArray =
19+
clContext.CreateClArray(
20+
firstLength,
21+
hostAccessMode = HostAccessMode.NotAccessible,
22+
deviceAccessMode = DeviceAccessMode.ReadWrite,
23+
allocationMode = AllocationMode.Default
24+
)
25+
26+
let secondLength = (firstLength - 1) / workGroupSize + 1
27+
28+
let secondVerticesArray =
29+
clContext.CreateClArray(
30+
secondLength,
31+
hostAccessMode = HostAccessMode.NotAccessible,
32+
deviceAccessMode = DeviceAccessMode.ReadWrite,
33+
allocationMode = AllocationMode.Default
34+
)
35+
36+
let mutable verticesArrays = firstVerticesArray, secondVerticesArray
37+
let swap (a, b) = (b, a)
38+
39+
scan inputArray inputArray.Length (fst verticesArrays)
40+
41+
let mutable verticesLength = firstLength
42+
43+
while verticesLength > workGroupSize do
44+
let fstVertices = fst verticesArrays
45+
let sndVertices = snd verticesArrays
46+
47+
scan fstVertices verticesLength sndVertices
48+
49+
verticesArrays <- swap verticesArrays
50+
verticesLength <- (verticesLength - 1) / workGroupSize + 1
51+
52+
let fstVertices = fst verticesArrays
53+
54+
let result =
55+
scanToCell processor fstVertices verticesLength
56+
57+
processor.Post(Msg.CreateFreeMsg(firstVerticesArray))
58+
processor.Post(Msg.CreateFreeMsg(secondVerticesArray))
59+
60+
result
61+
62+
let scanSum (clContext: ClContext) (workGroupSize: int) (opAdd: Expr<'a -> 'a -> 'a>) zero =
1063

1164
let subSum = SubSum.sequentialSum opAdd
1265

@@ -43,7 +96,7 @@ module internal Sum =
4396

4497
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
4598

46-
let private scanToCell (clContext: ClContext) (workGroupSize: int) (opAdd: Expr<'a -> 'a -> 'a>) zero =
99+
let scanToCellSum (clContext: ClContext) (workGroupSize: int) (opAdd: Expr<'a -> 'a -> 'a>) zero =
47100

48101
let subSum = SubSum.sequentialSum opAdd
49102

@@ -83,60 +136,111 @@ module internal Sum =
83136

84137
resultCell
85138

86-
let run (clContext: ClContext) (workGroupSize: int) (opAdd: Expr<'a -> 'a -> 'a>) (zero: 'a) =
139+
let sum (clContext: ClContext) workGroupSize op zero =
87140

88-
let scan = scan clContext workGroupSize opAdd zero
141+
let scan = scanSum clContext workGroupSize op zero
89142

90143
let scanToCell =
91-
scanToCell clContext workGroupSize opAdd zero
144+
scanToCellSum clContext workGroupSize op zero
92145

93-
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) ->
146+
let run =
147+
runGeneral clContext workGroupSize scan scanToCell
94148

95-
let scan = scan processor
149+
fun (processor: MailboxProcessor<_>) (array: ClArray<'a>) -> run processor array
96150

97-
let firstLength =
98-
(inputArray.Length - 1) / workGroupSize + 1
151+
let private scanReduce<'a when 'a: struct>
152+
(clContext: ClContext)
153+
(workGroupSize: int)
154+
(opAdd: Expr<'a -> 'a -> 'a>)
155+
=
99156

100-
let firstVerticesArray =
101-
clContext.CreateClArray(
102-
firstLength,
103-
hostAccessMode = HostAccessMode.NotAccessible,
104-
deviceAccessMode = DeviceAccessMode.ReadWrite,
105-
allocationMode = AllocationMode.Default
106-
)
157+
let scan =
158+
<@ fun (ndRange: Range1D) length (inputArray: ClArray<'a>) (resultArray: ClArray<'a>) ->
107159

108-
let secondLength = (firstLength - 1) / workGroupSize + 1
160+
let gid = ndRange.GlobalID0
161+
let lid = ndRange.LocalID0
109162

110-
let secondVerticesArray =
111-
clContext.CreateClArray(
112-
secondLength,
113-
hostAccessMode = HostAccessMode.NotAccessible,
114-
deviceAccessMode = DeviceAccessMode.ReadWrite,
115-
allocationMode = AllocationMode.Default
116-
)
163+
let localValues = localArray<'a> workGroupSize
117164

118-
let mutable verticesArrays = firstVerticesArray, secondVerticesArray
119-
let swap (a, b) = (b, a)
165+
if gid < length then
166+
localValues.[lid] <- inputArray.[gid]
120167

121-
scan inputArray inputArray.Length (fst verticesArrays)
168+
barrierLocal ()
122169

123-
let mutable verticesLength = firstLength
170+
if gid < length then
124171

125-
while verticesLength > workGroupSize do
126-
let fstVertices = fst verticesArrays
127-
let sndVertices = snd verticesArrays
172+
(%SubReduce.run opAdd) length workGroupSize gid lid localValues
128173

129-
scan fstVertices verticesLength sndVertices
174+
if lid = 0 then
175+
resultArray.[gid / workGroupSize] <- localValues.[0] @>
130176

131-
verticesArrays <- swap verticesArrays
132-
verticesLength <- (verticesLength - 1) / workGroupSize + 1
177+
let kernel = clContext.Compile(scan)
133178

134-
let fstVertices = fst verticesArrays
179+
fun (processor: MailboxProcessor<_>) (valuesArray: ClArray<'a>) valuesLength (resultArray: ClArray<'a>) ->
135180

136-
let result =
137-
scanToCell processor fstVertices verticesLength
181+
let ndRange =
182+
Range1D.CreateValid(valuesArray.Length, workGroupSize)
138183

139-
processor.Post(Msg.CreateFreeMsg(firstVerticesArray))
140-
processor.Post(Msg.CreateFreeMsg(secondVerticesArray))
184+
let kernel = kernel.GetKernel()
141185

142-
result
186+
processor.Post(
187+
Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange valuesLength valuesArray resultArray)
188+
)
189+
190+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
191+
192+
let rec private scanToCellReduce<'a when 'a: struct>
193+
(clContext: ClContext)
194+
(workGroupSize: int)
195+
(opAdd: Expr<'a -> 'a -> 'a>)
196+
=
197+
198+
let scan =
199+
<@ fun (ndRange: Range1D) length (inputArray: ClArray<'a>) (resultValue: ClCell<'a>) ->
200+
201+
let gid = ndRange.GlobalID0
202+
let lid = ndRange.LocalID0
203+
204+
let localValues = localArray<'a> workGroupSize
205+
206+
if gid < length then
207+
localValues.[lid] <- inputArray.[gid]
208+
209+
barrierLocal ()
210+
211+
if gid < length then
212+
213+
(%SubReduce.run opAdd) length workGroupSize gid lid localValues
214+
215+
if lid = 0 then
216+
resultValue.Value <- localValues.[0] @>
217+
218+
let kernel = clContext.Compile(scan)
219+
220+
fun (processor: MailboxProcessor<_>) (valuesArray: ClArray<'a>) valuesLength ->
221+
222+
let ndRange =
223+
Range1D.CreateValid(valuesArray.Length, workGroupSize)
224+
225+
let resultCell =
226+
clContext.CreateClCell Unchecked.defaultof<'a>
227+
228+
let kernel = kernel.GetKernel()
229+
230+
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange valuesLength valuesArray resultCell))
231+
232+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
233+
234+
resultCell
235+
236+
let reduce (clContext: ClContext) workGroupSize op =
237+
238+
let scan = scanReduce clContext workGroupSize op
239+
240+
let scanToCell =
241+
scanToCellReduce clContext workGroupSize op
242+
243+
let run =
244+
runGeneral clContext workGroupSize scan scanToCell
245+
246+
fun (processor: MailboxProcessor<_>) (array: ClArray<'a>) -> run processor array

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
<Compile Include="Common/PrefixSum.fs" />
3131
<Compile Include="Common/ClArray.fs" />
3232
<Compile Include="Common/BitonicSort.fs" />
33-
<Compile Include="Common/Reduce.fs" />
3433
<Compile Include="Predefined/PrefixSum.fs" />
3534
<!--Compile Include="Matrices.fs" /-->
3635
<Compile Include="Matrix/Common.fs" />

src/GraphBLAS-sharp.Backend/Quotes/Predicates.fs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ namespace GraphBLAS.FSharp.Backend.Quotes
33
module Predicates =
44
let containsNonZero<'a> =
55
<@ fun (item: 'a option) ->
6-
match item with
7-
| Some _ -> true
8-
| _ -> false @>
6+
match item with
7+
| Some _ -> true
8+
| _ -> false @>

0 commit comments

Comments
 (0)