Skip to content

Commit 86abbce

Browse files
committed
add: CommonQuotes, refactor: COOMatrix, Sum
1 parent 1c5663c commit 86abbce

5 files changed

Lines changed: 256 additions & 212 deletions

File tree

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
namespace GraphBLAS.FSharp.Backend
2+
3+
open Brahma.FSharp
4+
5+
module SubSum =
6+
let private treeAccess<'a> opAdd =
7+
<@
8+
fun step lid wgSize (localBuffer: 'a []) ->
9+
let i = step * (lid + 1) - 1
10+
11+
let buff =
12+
(%opAdd) localBuffer.[i - (step >>> 1)] localBuffer.[i]
13+
14+
localBuffer.[i] <- buff
15+
@>
16+
17+
let private sequentialAccess<'a> opAdd =
18+
<@
19+
fun step lid wgSize (localBuffer: 'a []) ->
20+
let firstValue = localBuffer.[lid]
21+
let secondValue = localBuffer.[lid + wgSize / step]
22+
23+
localBuffer.[lid] <- (%opAdd) firstValue secondValue
24+
@>
25+
26+
let sumGeneral<'a> memoryAccess =
27+
<@
28+
fun wgSize lid (localBuffer: 'a []) ->
29+
let mutable step = 2
30+
31+
while step <= wgSize do
32+
if lid < wgSize / step then
33+
(%memoryAccess) step lid wgSize localBuffer
34+
35+
step <- step <<< 1
36+
37+
barrierLocal ()
38+
@>
39+
40+
let sequentialSum<'a> opAdd = sumGeneral<'a> <| sequentialAccess<'a> opAdd
41+
42+
let treeSum<'a> opAdd = sumGeneral<'a> <| treeAccess opAdd
43+
44+
module PreparePositions =
45+
let both<'c> =
46+
<@
47+
fun index (result: 'c option) (rawPositionsBuffer: ClArray<int>) (allValuesBuffer: ClArray<'c>) ->
48+
rawPositionsBuffer.[index] <- 0
49+
50+
match result with
51+
| Some v ->
52+
allValuesBuffer.[index + 1] <- v
53+
rawPositionsBuffer.[index + 1] <- 1
54+
| None -> rawPositionsBuffer.[index + 1] <- 0
55+
@>
56+
57+
let leftRight<'c> =
58+
<@
59+
fun index (leftResult: 'c option) (rightResult: 'c option) (isLeftBitmap: ClArray<int>) (allValuesBuffer: ClArray<'c>) (rawPositionsBuffer: ClArray<int>) ->
60+
if isLeftBitmap.[index] = 1 then
61+
match leftResult with
62+
| Some v ->
63+
allValuesBuffer.[index] <- v
64+
rawPositionsBuffer.[index] <- 1
65+
| None -> rawPositionsBuffer.[index] <- 0
66+
else
67+
match rightResult with
68+
| Some v ->
69+
allValuesBuffer.[index] <- v
70+
rawPositionsBuffer.[index] <- 1
71+
| None -> rawPositionsBuffer.[index] <- 0
72+
@>
Lines changed: 136 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,170 @@
11
namespace GraphBLAS.FSharp.Backend.Common
22

33
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend
45
open Microsoft.FSharp.Quotations
56

6-
module internal rec Sum =
7-
let run (inputArray: 'a []) (plus: Expr<'a -> 'a -> 'a>) (zero: 'a) =
8-
if inputArray.Length = 0 then
9-
opencl {
10-
let result = [| zero |]
7+
module internal Sum =
118

12-
let bruh =
13-
<@ fun (range: Range1D) (array: 'a []) ->
14-
let mutable a = 0
15-
a <- 0 @>
9+
let private scan
10+
(clContext: ClContext)
11+
(workGroupSize: int)
12+
(opAdd: Expr<'a -> 'a -> 'a>)
13+
zero
14+
=
1615

17-
do!
18-
runCommand bruh
19-
<| fun kernelPrepare -> kernelPrepare <| Range1D(64, 64) <| result
16+
let subSum = SubSum.sequentialSum opAdd
2017

21-
return result
22-
}
23-
else
24-
runNotEmpty inputArray plus zero
18+
let scan =
19+
<@
20+
fun (ndRange: Range1D) length (inputArray: ClArray<'a>) (resultArray: ClArray<'a>) ->
2521

26-
let private runNotEmpty (inputArray: 'a []) (plus: Expr<'a -> 'a -> 'a>) (zero: 'a) =
27-
opencl {
28-
let workGroupSize = 256
22+
let gid = ndRange.GlobalID0
23+
let lid = ndRange.LocalID0
2924

30-
let firstVertices =
31-
Array.zeroCreate
32-
<| (inputArray.Length - 1) / workGroupSize + 1
25+
let localValues = localArray<'a> workGroupSize
3326

34-
let secondVertices =
35-
Array.zeroCreate
36-
<| (firstVertices.Length - 1) / workGroupSize + 1
27+
if gid < length then
28+
localValues[lid] <- inputArray[gid]
29+
else
30+
localValues[lid] <- zero
3731

38-
let mutable verticesArrays = firstVertices, secondVertices
39-
let swap (a, b) = (b, a)
32+
barrierLocal ()
4033

41-
let mutable verticesLength = firstVertices.Length
34+
(%subSum) workGroupSize lid localValues
4235

43-
do! scan inputArray inputArray.Length (fst verticesArrays) plus zero
36+
resultArray[gid / workGroupSize] <- localValues[0]
37+
@>
4438

45-
while verticesLength > workGroupSize do
46-
let fstVertices = fst verticesArrays
47-
let sndVertices = snd verticesArrays
48-
do! scan fstVertices verticesLength sndVertices plus zero
39+
let kernel = clContext.Compile(scan)
4940

50-
verticesArrays <- swap verticesArrays
51-
verticesLength <- (verticesLength - 1) / workGroupSize + 1
41+
fun (processor: MailboxProcessor<_>) (valuesArray: ClArray<'a>) valuesLength (resultArray: ClArray<'a>) ->
42+
let ndRange = Range1D.CreateValid(valuesArray.Length, workGroupSize)
5243

53-
let result = Array.create 1 zero
44+
let kernel = kernel.GetKernel()
5445

55-
let fstVertices = fst verticesArrays
56-
do! scan fstVertices verticesLength result plus zero
46+
processor.Post(
47+
Msg.MsgSetArguments(
48+
fun () ->
49+
kernel.KernelFunc
50+
ndRange
51+
valuesLength
52+
valuesArray
53+
resultArray)
54+
)
5755

58-
return result
59-
}
56+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
6057

61-
let private scan
62-
(inputArray: 'a [])
63-
(inputArrayLength: int)
64-
(vertices: 'a [])
65-
(plus: Expr<'a -> 'a -> 'a>)
66-
(zero: 'a)
58+
()
59+
60+
let private scanToCell
61+
(clContext: ClContext)
62+
(workGroupSize: int)
63+
(opAdd: Expr<'a -> 'a -> 'a>)
64+
zero
6765
=
68-
opencl {
69-
let workGroupSize = 256
7066

71-
let scan =
72-
<@ fun (ndRange: Range1D) (resultBuffer: 'a []) (verticesBuffer: 'a []) ->
67+
let subSum = SubSum.sequentialSum opAdd
68+
69+
let scan =
70+
<@
71+
fun (ndRange: Range1D) length (inputArray: ClArray<'a>) (resultCell: ClCell<'a>) ->
7372

74-
let i = ndRange.GlobalID0
75-
let localID = ndRange.LocalID0
73+
let gid = ndRange.GlobalID0
74+
let lid = ndRange.LocalID0
7675

77-
let resultLocalBuffer = localArray<'a> workGroupSize
76+
let localValues = localArray<'a> workGroupSize
7877

79-
if i < inputArrayLength then
80-
resultLocalBuffer.[localID] <- resultBuffer.[i]
78+
if gid < length then
79+
localValues[lid] <- inputArray[gid]
8180
else
82-
resultLocalBuffer.[localID] <- zero
81+
localValues[lid] <- zero
8382

84-
let mutable step = 2
83+
barrierLocal ()
8584

86-
while step <= workGroupSize do
87-
barrierLocal ()
85+
(%subSum) workGroupSize lid localValues
8886

89-
if localID < workGroupSize / step then
90-
let i = step * (localID + 1) - 1
91-
resultLocalBuffer.[i] <- (%plus) resultLocalBuffer.[i] resultLocalBuffer.[i - (step >>> 1)]
87+
resultCell.Value <- localValues[0]
88+
@>
9289

93-
step <- step <<< 1
90+
let kernel = clContext.Compile(scan)
9491

95-
barrierLocal ()
92+
fun (processor: MailboxProcessor<_>) (valuesArray: ClArray<'a>) valuesLength ->
93+
94+
let ndRange = Range1D.CreateValid(valuesArray.Length, workGroupSize)
95+
96+
let resultCell = clContext.CreateClCell zero
97+
98+
let kernel = kernel.GetKernel()
99+
100+
processor.Post(
101+
Msg.MsgSetArguments(
102+
fun () ->
103+
kernel.KernelFunc
104+
ndRange
105+
valuesLength
106+
valuesArray
107+
resultCell)
108+
)
109+
110+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
111+
112+
resultCell
96113

97-
if localID = workGroupSize - 1 then
98-
verticesBuffer.[i / workGroupSize] <- resultLocalBuffer.[localID] @>
114+
let run
115+
(clContext: ClContext)
116+
(workGroupSize: int)
117+
(opAdd: Expr<'a -> 'a -> 'a>)
118+
(zero: 'a)
119+
=
120+
121+
let scan = scan clContext workGroupSize opAdd zero
122+
let scanToCell = scanToCell clContext workGroupSize opAdd zero
123+
124+
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<'a>) ->
125+
126+
let scan = scan processor
127+
128+
let firstLength = (inputArray.Length - 1) / workGroupSize + 1
129+
130+
let firstVerticesArray =
131+
clContext.CreateClArray(
132+
firstLength,
133+
hostAccessMode = HostAccessMode.NotAccessible,
134+
deviceAccessMode = DeviceAccessMode.ReadWrite,
135+
allocationMode = AllocationMode.Default
136+
)
137+
138+
let secondLength = (firstLength - 1) / workGroupSize + 1
139+
140+
let secondVerticesArray =
141+
clContext.CreateClArray(
142+
secondLength,
143+
hostAccessMode = HostAccessMode.NotAccessible,
144+
deviceAccessMode = DeviceAccessMode.ReadWrite,
145+
allocationMode = AllocationMode.Default
146+
)
147+
148+
let mutable verticesArrays = firstVerticesArray, secondVerticesArray
149+
let swap (a, b) = (b, a)
150+
151+
scan inputArray inputArray.Length (fst verticesArrays)
152+
153+
let mutable verticesLength = firstLength
154+
155+
while verticesLength > workGroupSize do
156+
let fstVertices = fst verticesArrays
157+
let sndVertices = snd verticesArrays
158+
159+
scan fstVertices verticesLength sndVertices
160+
161+
verticesArrays <- swap verticesArrays
162+
verticesLength <- (verticesLength - 1) / workGroupSize + 1
163+
164+
let fstVertices = fst verticesArrays
165+
let result = scanToCell processor fstVertices verticesLength
99166

100-
do!
101-
runCommand scan
102-
<| fun kernelPrepare ->
103-
let ndRange =
104-
Range1D.CreateValid(inputArrayLength, workGroupSize)
167+
processor.Post(Msg.CreateFreeMsg(firstVerticesArray))
168+
processor.Post(Msg.CreateFreeMsg(secondVerticesArray))
105169

106-
kernelPrepare ndRange inputArray vertices
107-
}
170+
result

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
<ItemGroup>
1313
<Compile Include="AssemblyInfo.fs" />
14+
<Compile Include="Common/CommonQuotes.fs" />
1415
<Compile Include="Common/Utils.fs" />
1516
<Compile Include="Common/ClArray.fs" />
1617
<Compile Include="Common/Sum.fs" />

0 commit comments

Comments
 (0)