@@ -3,6 +3,8 @@ namespace GraphBLAS.FSharp.Backend.Common
33open Brahma.FSharp
44open FSharp.Quotations
55open GraphBLAS.FSharp .Backend .Quotes
6+ open GraphBLAS.FSharp .Backend .Objects .ArraysExtensions
7+ open GraphBLAS.FSharp .Backend .Objects .ClCell
68
79module PrefixSum =
810 let private update ( opAdd : Expr < 'a -> 'a -> 'a >) ( clContext : ClContext ) workGroupSize =
@@ -38,7 +40,7 @@ module PrefixSum =
3840 )
3941
4042 processor.Post( Msg.CreateRunMsg<_, _> kernel)
41- processor.Post ( Msg.CreateFreeMsg ( mirror))
43+ mirror.Free processor
4244
4345 let private scanGeneral
4446 beforeLocalSumClear
@@ -48,10 +50,8 @@ module PrefixSum =
4850 workGroupSize
4951 =
5052
51- let subSum = SubSum.treeSum opAdd
52-
5353 let scan =
54- <@ fun ( ndRange : Range1D ) inputArrayLength verticesLength ( resultBuffer : ClArray < 'a >) ( verticesBuffer : ClArray < 'a >) ( totalSumBuffer : ClCell < 'a >) ( zero : ClCell < 'a >) ( mirror : ClCell < bool >) ->
54+ <@ fun ( ndRange : Range1D ) inputArrayLength verticesLength ( inputArray : ClArray < 'a >) ( verticesBuffer : ClArray < 'a >) ( totalSumBuffer : ClCell < 'a >) ( zero : ClCell < 'a >) ( mirror : ClCell < bool >) ->
5555
5656 let mirror = mirror.Value
5757
@@ -62,46 +62,34 @@ module PrefixSum =
6262 if mirror then
6363 i <- inputArrayLength - 1 - i
6464
65- let localID = ndRange.LocalID0
65+ let lid = ndRange.LocalID0
6666
6767 let zero = zero.Value
6868
6969 if gid < inputArrayLength then
70- resultLocalBuffer.[ localID ] <- resultBuffer .[ i]
70+ resultLocalBuffer.[ lid ] <- inputArray .[ i]
7171 else
72- resultLocalBuffer.[ localID ] <- zero
72+ resultLocalBuffer.[ lid ] <- zero
7373
7474 barrierLocal ()
7575
76- (% subSum) workGroupSize localID resultLocalBuffer
77-
78- if localID = workGroupSize - 1 then
79- if verticesLength <= 1 && localID = gid then
80- totalSumBuffer.Value <- resultLocalBuffer.[ localID]
81-
82- verticesBuffer.[ gid / workGroupSize] <- resultLocalBuffer.[ localID]
83- (% beforeLocalSumClear) resultBuffer resultLocalBuffer.[ localID] inputArrayLength gid i
84- resultLocalBuffer.[ localID] <- zero
85-
86- let mutable step = workGroupSize
76+ // Local tree reduce
77+ (% SubSum.upSweep opAdd) workGroupSize lid resultLocalBuffer
8778
88- while step > 1 do
89- barrierLocal ()
79+ if lid = workGroupSize - 1 then
80+ // if last iteration
81+ if verticesLength <= 1 && lid = gid then
82+ totalSumBuffer.Value <- resultLocalBuffer.[ lid]
9083
91- if localID < workGroupSize / step then
92- let i = step * ( localID + 1 ) - 1
93- let j = i - ( step >>> 1 )
84+ verticesBuffer .[ gid / workGroupSize] <- resultLocalBuffer .[ lid ]
85+ (% beforeLocalSumClear ) inputArray resultLocalBuffer .[ lid ] inputArrayLength gid i
86+ resultLocalBuffer .[ lid ] <- zero
9487
95- let tmp = resultLocalBuffer.[ i]
96- let buff = (% opAdd) tmp resultLocalBuffer.[ j]
97- resultLocalBuffer.[ i] <- buff
98- resultLocalBuffer.[ j] <- tmp
99-
100- step <- step >>> 1
88+ (% SubSum.downSweep opAdd) workGroupSize lid resultLocalBuffer
10189
10290 barrierLocal ()
10391
104- (% writeData) resultBuffer resultLocalBuffer inputArrayLength workGroupSize gid i localID @>
92+ (% writeData) inputArray resultLocalBuffer inputArrayLength workGroupSize gid i lid @>
10593
10694 let program = clContext.Compile( scan)
10795
@@ -132,13 +120,14 @@ module PrefixSum =
132120 )
133121
134122 processor.Post( Msg.CreateRunMsg<_, _> kernel)
135- processor.Post( Msg.CreateFreeMsg( zero))
136- processor.Post( Msg.CreateFreeMsg( mirror))
123+
124+ zero.Free processor
125+ mirror.Free processor
137126
138127 let private scanExclusive < 'a when 'a : struct > =
139128 scanGeneral
140129 <@ fun ( _ : ClArray < 'a >) ( _ : 'a ) ( _ : int ) ( _ : int ) ( _ : int ) -> () @>
141- <@ fun ( resultBuffer : ClArray < 'a >) ( resultLocalBuffer : 'a []) ( inputArrayLength : int ) ( smth : int ) ( gid : int ) ( i : int ) ( localID : int ) ->
130+ <@ fun ( resultBuffer : ClArray < 'a >) ( resultLocalBuffer : 'a []) ( inputArrayLength : int ) ( _ : int ) ( gid : int ) ( i : int ) ( localID : int ) ->
142131
143132 if gid < inputArrayLength then
144133 resultBuffer.[ i] <- resultLocalBuffer.[ localID] @>
@@ -206,8 +195,8 @@ module PrefixSum =
206195 verticesArrays <- swap verticesArrays
207196 verticesLength <- ( verticesLength - 1 ) / workGroupSize + 1
208197
209- processor.Post ( Msg.CreateFreeMsg ( firstVertices))
210- processor.Post ( Msg.CreateFreeMsg ( secondVertices))
198+ firstVertices.Free processor
199+ secondVertices.Free processor
211200
212201 totalSum
213202
@@ -226,7 +215,7 @@ module PrefixSum =
226215 /// <code >
227216 /// let arr = [ | 1; 1; 1; 1 |]
228217 /// let sum = [ | 0 |]
229- /// runExcludeInplace clContext workGroupSize processor arr sum <@ (+) @> 0
218+ /// runExcludeInplace clContext workGroupSize processor arr sum (+) 0
230219 /// |> ignore
231220 /// ...
232221 /// > val arr = [ | 0; 1; 2; 3 |]
@@ -252,7 +241,7 @@ module PrefixSum =
252241 /// <code >
253242 /// let arr = [ | 1; 1; 1; 1 |]
254243 /// let sum = [ | 0 |]
255- /// runExcludeInplace clContext workGroupSize processor arr sum <@ (+) @> 0
244+ /// runExcludeInplace clContext workGroupSize processor arr sum (+) 0
256245 /// |> ignore
257246 /// ...
258247 /// > val arr = [ | 1; 2; 3; 4 |]
@@ -271,7 +260,6 @@ module PrefixSum =
271260
272261 scan processor inputArray 0
273262
274-
275263 module ByKey =
276264 let private sequentialSegments opWrite ( clContext : ClContext ) workGroupSize opAdd zero =
277265
0 commit comments