Skip to content

Commit 5cd6610

Browse files
committed
refactor: PrefixSum
1 parent 31c0a08 commit 5cd6610

3 files changed

Lines changed: 226 additions & 210 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 6 additions & 210 deletions
Original file line numberDiff line numberDiff line change
@@ -127,212 +127,6 @@ module ClArray =
127127

128128
outputArray
129129

130-
let private update (opAdd: Expr<'a -> 'a -> 'a>) (clContext: ClContext) workGroupSize =
131-
132-
let update =
133-
<@ fun (ndRange: Range1D) (inputArrayLength: int) (bunchLength: int) (resultBuffer: ClArray<'a>) (verticesBuffer: ClArray<'a>) (mirror: ClCell<bool>) ->
134-
135-
let mirror = mirror.Value
136-
137-
let mutable i = ndRange.GlobalID0 + bunchLength
138-
let gid = i
139-
140-
if mirror then
141-
i <- inputArrayLength - 1 - i
142-
143-
if gid < inputArrayLength then
144-
resultBuffer.[i] <- (%opAdd) verticesBuffer.[gid / bunchLength] resultBuffer.[i] @>
145-
146-
let program = clContext.Compile(update)
147-
148-
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) (inputArrayLength: int) (vertices: ClArray<'a>) (bunchLength: int) (mirror: bool) ->
149-
150-
let kernel = program.GetKernel()
151-
152-
let ndRange =
153-
Range1D.CreateValid(inputArrayLength - bunchLength, workGroupSize)
154-
155-
let mirror = clContext.CreateClCell mirror
156-
157-
processor.Post(
158-
Msg.MsgSetArguments
159-
(fun () -> kernel.KernelFunc ndRange inputArrayLength bunchLength inputArray vertices mirror)
160-
)
161-
162-
processor.Post(Msg.CreateRunMsg<_, _> kernel)
163-
processor.Post(Msg.CreateFreeMsg(mirror))
164-
165-
let private scanGeneral
166-
beforeLocalSumClear
167-
writeData
168-
(opAdd: Expr<'a -> 'a -> 'a>)
169-
(clContext: ClContext)
170-
workGroupSize
171-
=
172-
173-
let subSum = SubSum.treeSum opAdd
174-
175-
let scan =
176-
<@ fun (ndRange: Range1D) inputArrayLength verticesLength (resultBuffer: ClArray<'a>) (verticesBuffer: ClArray<'a>) (totalSumBuffer: ClCell<'a>) (zero: ClCell<'a>) (mirror: ClCell<bool>) ->
177-
178-
let mirror = mirror.Value
179-
180-
let resultLocalBuffer = localArray<'a> workGroupSize
181-
let mutable i = ndRange.GlobalID0
182-
let gid = i
183-
184-
if mirror then
185-
i <- inputArrayLength - 1 - i
186-
187-
let localID = ndRange.LocalID0
188-
189-
let zero = zero.Value
190-
191-
if gid < inputArrayLength then
192-
resultLocalBuffer.[localID] <- resultBuffer.[i]
193-
else
194-
resultLocalBuffer.[localID] <- zero
195-
196-
barrierLocal ()
197-
198-
(%subSum) workGroupSize localID resultLocalBuffer
199-
200-
if localID = workGroupSize - 1 then
201-
if verticesLength <= 1 && localID = gid then
202-
totalSumBuffer.Value <- resultLocalBuffer.[localID]
203-
204-
verticesBuffer.[gid / workGroupSize] <- resultLocalBuffer.[localID]
205-
(%beforeLocalSumClear) resultBuffer resultLocalBuffer.[localID] inputArrayLength gid i
206-
resultLocalBuffer.[localID] <- zero
207-
208-
let mutable step = workGroupSize
209-
210-
while step > 1 do
211-
barrierLocal ()
212-
213-
if localID < workGroupSize / step then
214-
let i = step * (localID + 1) - 1
215-
let j = i - (step >>> 1)
216-
217-
let tmp = resultLocalBuffer.[i]
218-
let buff = (%opAdd) tmp resultLocalBuffer.[j]
219-
resultLocalBuffer.[i] <- buff
220-
resultLocalBuffer.[j] <- tmp
221-
222-
step <- step >>> 1
223-
224-
barrierLocal ()
225-
226-
(%writeData) resultBuffer resultLocalBuffer inputArrayLength workGroupSize gid i localID @>
227-
228-
let program = clContext.Compile(scan)
229-
230-
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) (inputArrayLength: int) (vertices: ClArray<'a>) (verticesLength: int) (totalSum: ClCell<'a>) (zero: 'a) (mirror: bool) ->
231-
232-
// TODO: передавать zero как константу
233-
let zero = clContext.CreateClCell(zero)
234-
235-
let kernel = program.GetKernel()
236-
237-
let ndRange =
238-
Range1D.CreateValid(inputArrayLength, workGroupSize)
239-
240-
let mirror = clContext.CreateClCell mirror
241-
242-
processor.Post(
243-
Msg.MsgSetArguments
244-
(fun () ->
245-
kernel.KernelFunc
246-
ndRange
247-
inputArrayLength
248-
verticesLength
249-
inputArray
250-
vertices
251-
totalSum
252-
zero
253-
mirror)
254-
)
255-
256-
processor.Post(Msg.CreateRunMsg<_, _> kernel)
257-
processor.Post(Msg.CreateFreeMsg(zero))
258-
processor.Post(Msg.CreateFreeMsg(mirror))
259-
260-
let private scanExclusive<'a when 'a: struct> =
261-
scanGeneral
262-
<@ fun (a: ClArray<'a>) (b: 'a) (c: int) (d: int) (e: int) ->
263-
264-
() @>
265-
<@ fun (resultBuffer: ClArray<'a>) (resultLocalBuffer: 'a []) (inputArrayLength: int) (smth: int) (gid: int) (i: int) (localID: int) ->
266-
267-
if gid < inputArrayLength then
268-
resultBuffer.[i] <- resultLocalBuffer.[localID] @>
269-
270-
let private scanInclusive<'a when 'a: struct> =
271-
scanGeneral
272-
<@ fun (resultBuffer: ClArray<'a>) (value: 'a) (inputArrayLength: int) (gid: int) (i: int) ->
273-
274-
if gid < inputArrayLength then
275-
resultBuffer.[i] <- value @>
276-
<@ fun (resultBuffer: ClArray<'a>) (resultLocalBuffer: 'a []) (inputArrayLength: int) (workGroupSize: int) (gid: int) (i: int) (localID: int) ->
277-
278-
if gid < inputArrayLength
279-
&& localID < workGroupSize - 1 then
280-
resultBuffer.[i] <- resultLocalBuffer.[localID + 1] @>
281-
282-
let private runInplace (mirror: bool) scan (opAdd: Expr<'a -> 'a -> 'a>) (clContext: ClContext) workGroupSize =
283-
284-
let scan = scan opAdd clContext workGroupSize
285-
286-
let scanExclusive =
287-
scanExclusive opAdd clContext workGroupSize
288-
289-
let update = update opAdd clContext workGroupSize
290-
291-
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) (totalSum: ClCell<'a>) (zero: 'a) ->
292-
293-
let firstVertices =
294-
clContext.CreateClArray<'a>(
295-
(inputArray.Length - 1) / workGroupSize + 1,
296-
hostAccessMode = HostAccessMode.NotAccessible
297-
)
298-
299-
let secondVertices =
300-
clContext.CreateClArray<'a>(
301-
(firstVertices.Length - 1) / workGroupSize + 1,
302-
hostAccessMode = HostAccessMode.NotAccessible
303-
)
304-
305-
let mutable verticesArrays = firstVertices, secondVertices
306-
let swap (a, b) = (b, a)
307-
let mutable verticesLength = firstVertices.Length
308-
let mutable bunchLength = workGroupSize
309-
310-
scan processor inputArray inputArray.Length (fst verticesArrays) verticesLength totalSum zero mirror
311-
312-
while verticesLength > 1 do
313-
let fstVertices = fst verticesArrays
314-
let sndVertices = snd verticesArrays
315-
316-
scanExclusive
317-
processor
318-
fstVertices
319-
verticesLength
320-
sndVertices
321-
((verticesLength - 1) / workGroupSize + 1)
322-
totalSum
323-
zero
324-
false
325-
326-
update processor inputArray inputArray.Length fstVertices bunchLength mirror
327-
bunchLength <- bunchLength * workGroupSize
328-
verticesArrays <- swap verticesArrays
329-
verticesLength <- (verticesLength - 1) / workGroupSize + 1
330-
331-
processor.Post(Msg.CreateFreeMsg(firstVertices))
332-
processor.Post(Msg.CreateFreeMsg(secondVertices))
333-
334-
inputArray, totalSum
335-
336130
/// <summary>
337131
/// Exclude inplace prefix sum.
338132
/// </summary>
@@ -354,7 +148,7 @@ module ClArray =
354148
///<param name="totalSum">.</param>
355149
///<param name="plus">Associative binary operation.</param>
356150
///<param name="zero">Zero element for binary operation.</param>
357-
let prefixSumExcludeInplace plus = runInplace false scanExclusive plus
151+
let prefixSumExcludeInplace = PrefixSum.runExcludeInplace
358152

359153
/// <summary>
360154
/// Include inplace prefix sum.
@@ -377,7 +171,7 @@ module ClArray =
377171
///<param name="totalSum">.</param>
378172
///<param name="plus">Associative binary operation.</param>
379173
///<param name="zero">Zero element for binary operation.</param>
380-
let prefixSumIncludeInplace plus = runInplace false scanInclusive plus
174+
let prefixSumIncludeInplace = PrefixSum.runIncludeInplace
381175

382176
let prefixSumExclude plus (clContext: ClContext) workGroupSize =
383177

@@ -405,9 +199,11 @@ module ClArray =
405199

406200
runIncludeInplace processor outputArray totalSum zero
407201

408-
let prefixSumBackwardsExcludeInplace plus = runInplace true scanExclusive plus
202+
let prefixSumBackwardsExcludeInplace plus =
203+
PrefixSum.runBackwardsExcludeInplace plus
409204

410-
let prefixSumBackwardsIncludeInplace plus = runInplace true scanInclusive plus
205+
let prefixSumBackwardsIncludeInplace plus =
206+
PrefixSum.runBackwardsIncludeInplace plus
411207

412208
let getUniqueBitmap (clContext: ClContext) =
413209

0 commit comments

Comments
 (0)