11namespace GraphBLAS.FSharp.Backend.Common
22
33open Brahma.FSharp
4+ open GraphBLAS.FSharp .Backend
45open Microsoft.FSharp .Control
56open Microsoft.FSharp .Quotations
67
@@ -20,20 +21,9 @@ module Reduce =
2021
2122 barrierLocal ()
2223
23- let mutable step = 2
24-
2524 if gid < length then
26- while step <= workGroupSize do
27- if ( gid + workGroupSize / step) < length
28- && lid < workGroupSize / step then
29- let firstValue = localValues.[ lid]
30- let secondValue = localValues.[ lid + workGroupSize / step]
31-
32- localValues.[ lid] <- (% opAdd) firstValue secondValue
33-
34- step <- step <<< 1
3525
36- barrierLocal ()
26+ (% SubReduce.run opAdd ) length workGroupSize gid lid localValues
3727
3828 if lid = 0 then
3929 resultArray.[ gid / workGroupSize] <- localValues.[ 0 ] @>
@@ -53,10 +43,53 @@ module Reduce =
5343
5444 processor.Post( Msg.CreateRunMsg<_, _>( kernel))
5545
46+ let private scanToCell < 'a when 'a : struct >
47+ ( clContext : ClContext )
48+ ( workGroupSize : int )
49+ ( opAdd : Expr < 'a -> 'a -> 'a >)
50+ =
51+
52+ let scan =
53+ <@ fun ( ndRange : Range1D ) length ( inputArray : ClArray < 'a >) ( resultValue : ClCell < 'a >) ->
54+
55+ let gid = ndRange.GlobalID0
56+ let lid = ndRange.LocalID0
57+
58+ let localValues = localArray< 'a> workGroupSize
59+
60+ if gid < length then
61+ localValues.[ lid] <- inputArray.[ gid]
62+
63+ barrierLocal ()
64+
65+ if gid < length then
66+
67+ (% SubReduce.run opAdd) length workGroupSize gid lid localValues
68+
69+ if lid = 0 then
70+ resultValue.Value <- localValues.[ 0 ] @>
71+
72+ let kernel = clContext.Compile( scan)
73+
74+ fun ( processor : MailboxProcessor < _ >) ( valuesArray : ClArray < 'a >) valuesLength ( resultValue : ClCell < 'a >) ->
75+
76+ let ndRange =
77+ Range1D.CreateValid( valuesArray.Length, workGroupSize)
78+
79+ let kernel = kernel.GetKernel()
80+
81+ processor.Post(
82+ Msg.MsgSetArguments( fun () -> kernel.KernelFunc ndRange valuesLength valuesArray resultValue)
83+ )
84+
85+ processor.Post( Msg.CreateRunMsg<_, _>( kernel))
86+
5687 let run < 'a when 'a : struct > ( clContext : ClContext ) ( workGroupSize : int ) ( opAdd : Expr < 'a -> 'a -> 'a >) =
5788
5889 let scan = scan clContext workGroupSize opAdd
5990
91+ let scanToCell = scanToCell clContext workGroupSize opAdd
92+
6093 fun ( processor : MailboxProcessor < _ >) ( inputArray : ClArray < 'a >) ->
6194
6295 let scan = scan processor
@@ -101,14 +134,9 @@ module Reduce =
101134 let fstVertices = fst verticesArrays
102135
103136 let result =
104- clContext.CreateClArray(
105- 1 ,
106- hostAccessMode = HostAccessMode.NotAccessible,
107- deviceAccessMode = DeviceAccessMode.ReadWrite,
108- allocationMode = AllocationMode.Default
109- )
137+ clContext.CreateClCell Unchecked.defaultof< 'a>
110138
111- scan fstVertices verticesLength result
139+ scanToCell processor fstVertices verticesLength result
112140
113141 processor.Post( Msg.CreateFreeMsg( firstVerticesArray))
114142 processor.Post( Msg.CreateFreeMsg( secondVerticesArray))
0 commit comments