Skip to content

Commit 344c49e

Browse files
committed
Frontier and levels buffers reused with new methods
1 parent 1f7838e commit 344c49e

4 files changed

Lines changed: 103 additions & 67 deletions

File tree

src/GraphBLAS-sharp.Backend/Algorithms/BFS.fs

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,65 +11,62 @@ module BFS =
1111
(clContext: ClContext)
1212
(add: Expr<int option -> int option -> int option>)
1313
(mul: Expr<'a option -> 'b option -> int option>)
14-
(addNumeric: Expr<int -> int -> int>)
1514
workGroupSize
1615
=
1716

18-
let spMV = SpMV.run clContext add mul workGroupSize
17+
let spMVTo =
18+
SpMV.runTo clContext add mul workGroupSize
1919

2020
let zeroCreate =
2121
ClArray.zeroCreate clContext workGroupSize
2222

2323
let ofList = Vector.ofList clContext Dense
2424

25-
let maskComplemented =
26-
DenseVector.DenseVector.elementWise clContext StandardOperations.complementedMaskOp workGroupSize
25+
let maskComplementedTo =
26+
DenseVector.DenseVector.elementWiseTo clContext StandardOperations.complementedMaskOp workGroupSize
2727

28-
let fillSubVector =
29-
Vector.standardFillSubVector<int, int> clContext workGroupSize
28+
let fillSubVectorTo =
29+
DenseVector.DenseVector.standardFillSubVectorTo<int, int> clContext workGroupSize
3030

3131
let containsNonZero =
3232
DenseVector.DenseVector.containsNonZero clContext workGroupSize
3333

3434
fun (queue: MailboxProcessor<Msg>) (matrix: CSRMatrix<'a>) (source: int) ->
3535
let vertexCount = matrix.RowCount
3636

37-
let mutable levels: ClVector<int> =
38-
zeroCreate queue vertexCount |> ClVectorDense
37+
let levels = zeroCreate queue vertexCount
3938

40-
let mutable frontier = ofList vertexCount [ source, 1 ]
39+
let frontier = ofList vertexCount [ source, 1 ]
4140

42-
let mutable level = 0
43-
let mutable stop = false
41+
match frontier with
42+
| ClVectorDense front ->
4443

45-
while not stop do
46-
level <- level + 1
44+
let mutable level = 0
45+
let mutable stop = false
4746

48-
let newLevels =
49-
fillSubVector queue levels frontier (clContext.CreateClCell level)
47+
while not stop do
48+
level <- level + 1
5049

51-
levels.Dispose queue
52-
53-
match frontier, newLevels with
54-
| ClVectorDense f, ClVectorDense nl ->
55-
let newFrontierUnmasked = spMV queue matrix f
50+
//Assigning new level values
51+
fillSubVectorTo queue levels front (clContext.CreateClCell level) levels
52+
|> ignore
5653

57-
let newFrontier =
58-
maskComplemented queue newFrontierUnmasked nl
54+
//Getting new frontier
55+
spMVTo queue matrix front front |> ignore
5956

60-
newFrontierUnmasked.Dispose queue
61-
frontier.Dispose queue
57+
maskComplementedTo queue front levels front
58+
|> ignore
6259

60+
//Checking if front is empty
6361
let frontNotEmpty = Array.zeroCreate 1
64-
let sum = containsNonZero queue newFrontier
62+
let sum = containsNonZero queue front
6563

6664
queue.PostAndReply(fun ch -> Msg.CreateToHostMsg(sum, frontNotEmpty, ch))
6765
|> ignore
6866

69-
frontier <- newFrontier |> ClVectorDense
70-
levels <- newLevels
71-
7267
stop <- not frontNotEmpty.[0]
73-
| _ -> failwith "Not implemented"
7468

75-
levels
69+
front.Dispose queue
70+
71+
levels
72+
| _ -> failwith "Not implemented"

src/GraphBLAS-sharp.Backend/Vector/DenseVector/DenseVector.fs

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ module DenseVector =
3636

3737
result
3838

39-
let elementWise<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct>
39+
let elementWiseTo<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct>
4040
(clContext: ClContext)
4141
(opAdd: Expr<'a option -> 'b option -> 'c option>)
4242
(workGroupSize: int)
@@ -52,14 +52,7 @@ module DenseVector =
5252

5353
let kernel = clContext.Compile(elementWise)
5454

55-
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) ->
56-
let resultVector =
57-
clContext.CreateClArray(
58-
leftVector.Length,
59-
hostAccessMode = HostAccessMode.NotAccessible,
60-
deviceAccessMode = DeviceAccessMode.ReadWrite,
61-
allocationMode = AllocationMode.Default
62-
)
55+
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (resultVector: ClArray<'c option>) ->
6356

6457
let ndRange =
6558
Range1D.CreateValid(leftVector.Length, workGroupSize)
@@ -75,10 +68,30 @@ module DenseVector =
7568

7669
resultVector
7770

71+
let elementWise<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct>
72+
(clContext: ClContext)
73+
(opAdd: Expr<'a option -> 'b option -> 'c option>)
74+
(workGroupSize: int)
75+
=
76+
77+
let elementWiseTo =
78+
elementWiseTo clContext opAdd workGroupSize
79+
80+
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) ->
81+
let resultVector =
82+
clContext.CreateClArray(
83+
leftVector.Length,
84+
hostAccessMode = HostAccessMode.NotAccessible,
85+
deviceAccessMode = DeviceAccessMode.ReadWrite,
86+
allocationMode = AllocationMode.Default
87+
)
88+
89+
elementWiseTo processor leftVector rightVector resultVector
90+
7891
let elementWiseAtLeastOne clContext op workGroupSize =
7992
elementWise clContext (StandardOperations.atLeastOneToOption op) workGroupSize
8093

81-
let fillSubVector<'a, 'b when 'a: struct and 'b: struct>
94+
let fillSubVectorTo<'a, 'b when 'a: struct and 'b: struct>
8295
(clContext: ClContext)
8396
(maskOp: Expr<'a option -> 'b option -> 'a -> 'a option>)
8497
(workGroupSize: int)
@@ -94,14 +107,7 @@ module DenseVector =
94107

95108
let kernel = clContext.Compile(fillSubVectorKernel)
96109

97-
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (maskVector: ClArray<'b option>) (value: ClCell<'a>) ->
98-
let resultVector =
99-
clContext.CreateClArray<'a option>(
100-
leftVector.Length,
101-
hostAccessMode = HostAccessMode.NotAccessible,
102-
deviceAccessMode = DeviceAccessMode.ReadWrite,
103-
allocationMode = AllocationMode.Default
104-
)
110+
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (maskVector: ClArray<'b option>) (value: ClCell<'a>) (resultVector: ClArray<'a option>) ->
105111

106112
let ndRange =
107113
Range1D.CreateValid(leftVector.Length, workGroupSize)
@@ -117,6 +123,32 @@ module DenseVector =
117123

118124
resultVector
119125

126+
let fillSubVector<'a, 'b when 'a: struct and 'b: struct>
127+
(clContext: ClContext)
128+
(maskOp: Expr<'a option -> 'b option -> 'a -> 'a option>)
129+
(workGroupSize: int)
130+
=
131+
132+
let fillSubVectorTo =
133+
fillSubVectorTo clContext maskOp workGroupSize
134+
135+
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (maskVector: ClArray<'b option>) (value: ClCell<'a>) ->
136+
let resultVector =
137+
clContext.CreateClArray<'a option>(
138+
leftVector.Length,
139+
hostAccessMode = HostAccessMode.NotAccessible,
140+
deviceAccessMode = DeviceAccessMode.ReadWrite,
141+
allocationMode = AllocationMode.Default
142+
)
143+
144+
fillSubVectorTo processor leftVector maskVector value resultVector
145+
146+
let standardFillSubVectorTo<'a, 'b when 'a: struct and 'b: struct> (clContext: ClContext) (workGroupSize: int) =
147+
fillSubVectorTo<'a, 'b>
148+
clContext
149+
(StandardOperations.fillSubToOption StandardOperations.fillSubOp<'a>)
150+
workGroupSize
151+
120152
let private getBitmap<'a when 'a: struct> (clContext: ClContext) (workGroupSize: int) =
121153

122154
let getPositions =

src/GraphBLAS-sharp.Backend/Vector/SpMV.fs

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ open GraphBLAS.FSharp.Backend.Common
77
open Microsoft.FSharp.Quotations
88

99
module SpMV =
10-
let run
10+
let runTo
1111
(clContext: ClContext)
1212
(add: Expr<'c option -> 'c option -> 'c option>)
1313
(mul: Expr<'a option -> 'b option -> 'c option>)
@@ -95,7 +95,7 @@ module SpMV =
9595
let multiplyValues = clContext.Compile multiplyValues
9696
let reduceValuesByRows = clContext.Compile reduceValuesByRows
9797

98-
fun (queue: MailboxProcessor<_>) (matrix: CSRMatrix<'a>) (vector: ClArray<'b option>) ->
98+
fun (queue: MailboxProcessor<_>) (matrix: CSRMatrix<'a>) (vector: ClArray<'b option>) (result: ClArray<'b option>) ->
9999

100100
let matrixLength = matrix.Values.Length
101101

@@ -129,14 +129,6 @@ module SpMV =
129129

130130
queue.Post(Msg.CreateRunMsg<_, _>(multiplyValues))
131131

132-
let outputArray =
133-
clContext.CreateClArray<'c option>(
134-
matrix.RowCount,
135-
deviceAccessMode = DeviceAccessMode.ReadWrite,
136-
hostAccessMode = HostAccessMode.NotAccessible,
137-
allocationMode = AllocationMode.Default
138-
)
139-
140132
let reduceValuesByRows = reduceValuesByRows.GetKernel()
141133

142134
queue.Post(
@@ -147,11 +139,31 @@ module SpMV =
147139
matrix.RowCount
148140
intermediateArray
149141
matrix.RowPointers
150-
outputArray)
142+
result)
151143
)
152144

153145
queue.Post(Msg.CreateRunMsg<_, _>(reduceValuesByRows))
154146

155147
queue.Post(Msg.CreateFreeMsg intermediateArray)
156148

157-
outputArray
149+
result
150+
151+
let run
152+
(clContext: ClContext)
153+
(add: Expr<'c option -> 'c option -> 'c option>)
154+
(mul: Expr<'a option -> 'b option -> 'c option>)
155+
workGroupSize
156+
=
157+
let runTo = runTo clContext add mul workGroupSize
158+
159+
fun (queue: MailboxProcessor<_>) (matrix: CSRMatrix<'a>) (vector: ClArray<'b option>) ->
160+
161+
let result =
162+
clContext.CreateClArray<'b option>(
163+
matrix.RowCount,
164+
deviceAccessMode = DeviceAccessMode.ReadWrite,
165+
hostAccessMode = HostAccessMode.NotAccessible,
166+
allocationMode = AllocationMode.Default
167+
)
168+
169+
runTo queue matrix vector result

tests/GraphBLAS-sharp.Tests/Algorithms/BFS.fs

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,7 @@ let testFixtures (testContext: TestContext) =
1818
sprintf "Test on %A" testContext.ClContext
1919

2020
let bfs =
21-
Algorithms.BFS.singleSource
22-
context
23-
StandardOperations.intSum
24-
StandardOperations.intMul
25-
<@ (+) @>
26-
workGroupSize
21+
Algorithms.BFS.singleSource context StandardOperations.intSum StandardOperations.intMul workGroupSize
2722

2823
testPropertyWithConfig config testName
2924
<| fun (matrix: int [,]) ->
@@ -47,7 +42,7 @@ let testFixtures (testContext: TestContext) =
4742

4843
match matrix with
4944
| MatrixCSR mtx ->
50-
let res = bfs queue mtx source
45+
let res = bfs queue mtx source |> ClVectorDense
5146

5247
let resHost = res.ToHost queue
5348

0 commit comments

Comments
 (0)