Skip to content

Commit 6f08dad

Browse files
committed
SSSP dense
1 parent b494dab commit 6f08dad

8 files changed

Lines changed: 248 additions & 3 deletions

File tree

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
namespace GraphBLAS.FSharp.Backend.Algorithms
2+
3+
open GraphBLAS.FSharp.Backend
4+
open Brahma.FSharp
5+
open FSharp.Quotations
6+
open GraphBLAS.FSharp.Backend.Objects
7+
open GraphBLAS.FSharp.Backend.Common
8+
open GraphBLAS.FSharp.Backend.Quotes
9+
open GraphBLAS.FSharp.Backend.Vector
10+
open GraphBLAS.FSharp.Backend.Vector.Dense
11+
open GraphBLAS.FSharp.Backend.Objects.ClContext
12+
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
13+
open GraphBLAS.FSharp.Backend.Objects.ClCell
14+
15+
module SSSP =
16+
let run (clContext: ClContext) workGroupSize =
17+
18+
let less = ArithmeticOperations.less<int>
19+
let min = ArithmeticOperations.min<int>
20+
let plus = ArithmeticOperations.intSumAsMul
21+
22+
let spMVTo =
23+
SpMV.runTo min plus clContext workGroupSize
24+
25+
let create = ClArray.create clContext workGroupSize
26+
27+
let createMask = ClArray.create clContext workGroupSize
28+
29+
let ofList = Vector.ofList clContext workGroupSize
30+
31+
let eWiseMulLess =
32+
ClArray.map2InPlace less clContext workGroupSize
33+
34+
let eWiseAddMin =
35+
ClArray.map2InPlace min clContext workGroupSize
36+
37+
let fillSubVectorTo =
38+
Vector.assignByMaskInPlace (Convert.assignToOption Mask.assignComplemented) clContext workGroupSize
39+
40+
let containsNonZero =
41+
ClArray.exists Predicates.isSome clContext workGroupSize
42+
43+
fun (queue: MailboxProcessor<Msg>) (matrix: ClMatrix.CSR<int>) (source: int) ->
44+
let vertexCount = matrix.RowCount
45+
46+
//None is System.Int32.MaxValue
47+
let distance =
48+
ofList queue DeviceOnly Dense vertexCount [ source, 0 ]
49+
50+
let mutable f1 =
51+
ofList queue DeviceOnly Dense vertexCount [ source, 0 ]
52+
53+
let mutable f2 =
54+
create queue DeviceOnly vertexCount None
55+
|> ClVector.Dense
56+
57+
let m =
58+
createMask queue DeviceOnly vertexCount None
59+
|> ClVector.Dense
60+
61+
let mutable stop = false
62+
63+
while not stop do
64+
match f1, f2, distance, m with
65+
| ClVector.Dense front1, ClVector.Dense front2, ClVector.Dense dist, ClVector.Dense mask ->
66+
//Getting new frontier
67+
spMVTo queue matrix front1 front2
68+
69+
//Checking which distances were updated
70+
eWiseMulLess queue front2 dist mask
71+
//Updating
72+
eWiseAddMin queue dist front2 dist
73+
74+
//Filtering unproductive vertices
75+
fillSubVectorTo queue front2 mask (clContext.CreateClCell 0) front2
76+
77+
//Swap fronts
78+
let temp = f1
79+
f1 <- f2
80+
f2 <- temp
81+
82+
//Checking if no distances were updated
83+
stop <-
84+
not
85+
<| (containsNonZero queue mask).ToHostAndFree(queue)
86+
87+
| _ -> failwith "not implemented"
88+
89+
f1.Dispose queue
90+
f2.Dispose queue
91+
m.Dispose queue
92+
93+
match distance with
94+
| ClVector.Dense dist -> dist
95+
| _ -> failwith "not implemented"

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<?xml version="1.0" encoding="utf-8"?>
1+
<?xml version="1.0" encoding="utf-8"?>
22
<Project Sdk="Microsoft.NET.Sdk">
33

44
<PropertyGroup>
@@ -63,6 +63,7 @@
6363
<Compile Include="Matrix/Matrix.fs" />
6464

6565
<Compile Include="Algorithms/BFS.fs" />
66+
<Compile Include="Algorithms/SSSP.fs" />
6667
</ItemGroup>
6768
<Import Project="..\..\.paket\Paket.Restore.targets" />
6869
</Project>

src/GraphBLAS-sharp.Backend/Quotes/Arithmetic.fs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@ module ArithmeticOperations =
3636

3737
if res = zero then None else Some res @>
3838

39+
let inline mkNumericSumAsMul zero =
40+
<@ fun (x: 't option) (y: 't option) ->
41+
let mutable res = zero
42+
43+
match x, y with
44+
| Some f, Some s -> res <- f + s
45+
| _ -> ()
46+
47+
if res = zero then None else Some res @>
48+
3949
let inline mkNumericMul zero =
4050
<@ fun (x: 't option) (y: 't option) ->
4151
let mutable res = zero
@@ -173,6 +183,8 @@ module ArithmeticOperations =
173183
let floatMulAtLeastOne = mkNumericMulAtLeastOne 0.0
174184
let float32MulAtLeastOne = mkNumericMulAtLeastOne 0f
175185

186+
let intSumAsMul = mkNumericSumAsMul System.Int32.MaxValue
187+
176188
let notOption =
177189
<@ fun x ->
178190
match x with
@@ -216,3 +228,20 @@ module ArithmeticOperations =
216228
let floatMul = createPair 0.0 (*) <@ (*) @>
217229

218230
let float32Mul = createPair 0.0f (*) <@ (*) @>
231+
232+
// other
233+
let less<'a when 'a: comparison> =
234+
<@ fun (x: 'a option) (y: 'a option) ->
235+
match x, y with
236+
| Some x, Some y -> if (x < y) then Some 1 else None
237+
| Some x, None -> Some 1
238+
| _ -> None @>
239+
240+
//TODO: noneValue
241+
let min<'a when 'a: comparison> =
242+
<@ fun (x: 'a option) (y: 'a option) ->
243+
match x, y with
244+
| Some x, Some y -> Some(min x y)
245+
| Some x, None -> Some x
246+
| None, Some y -> Some y
247+
| _ -> None @>

src/GraphBLAS-sharp.Backend/Quotes/Mask.fs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ module Mask =
77
| _, None -> left
88
| _ -> right @>
99

10+
let assignComplemented<'a when 'a: struct> =
11+
<@ fun (left: 'a option) (right: 'a option) ->
12+
match left, right with
13+
| _, None -> right
14+
| _ -> left @>
15+
1016
let op<'a, 'b when 'a: struct and 'b: struct> =
1117
<@ fun (left: 'a option) (right: 'b option) ->
1218
match right with
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
module GraphBLAS.FSharp.Tests.Backend.Algorithms.SSSP
2+
3+
open Expecto
4+
open GraphBLAS.FSharp.Backend
5+
open GraphBLAS.FSharp.Backend.Common
6+
open GraphBLAS.FSharp.Backend.Quotes
7+
open GraphBLAS.FSharp.Tests
8+
open GraphBLAS.FSharp.Tests.Context
9+
open GraphBLAS.FSharp.Tests.Backend.QuickGraph.Algorithms
10+
open GraphBLAS.FSharp.Tests.Backend.QuickGraph.CreateGraph
11+
open GraphBLAS.FSharp.Backend.Objects
12+
open GraphBLAS.FSharp.Objects.ClVectorExtensions
13+
open GraphBLAS.FSharp.Objects
14+
15+
let testFixtures (testContext: TestContext) =
16+
[ let config = Utils.undirectedAlgoConfig
17+
let context = testContext.ClContext
18+
let queue = testContext.Queue
19+
let workGroupSize = Utils.defaultWorkGroupSize
20+
21+
let testName =
22+
sprintf "Test on %A" testContext.ClContext
23+
24+
let ssspDense =
25+
Algorithms.SSSP.run context workGroupSize
26+
27+
testPropertyWithConfig config testName
28+
<| fun (matrix: int [,]) ->
29+
30+
let matrix = Array2D.map (fun x -> abs x) matrix
31+
32+
let graph = undirectedFromArray2D matrix 0
33+
34+
let largestComponent =
35+
ConnectedComponents.largestComponent graph
36+
37+
if largestComponent.Length > 0 then
38+
let source = largestComponent.[0]
39+
40+
let expected =
41+
SSSP.runUndirected matrix (directedFromArray2D matrix 0) source
42+
|> Array.map
43+
(fun x ->
44+
match x with
45+
| Some x -> Some(int x)
46+
| None -> None)
47+
48+
let matrixHost =
49+
Utils.createMatrixFromArray2D CSR matrix ((=) 0)
50+
51+
let matrix = matrixHost.ToDevice context
52+
53+
match matrix with
54+
| ClMatrix.CSR mtx ->
55+
let resDense =
56+
ssspDense queue mtx source |> ClVector.Dense
57+
58+
let resHost = resDense.ToHost queue
59+
60+
(mtx :> IDeviceMemObject).Dispose queue
61+
resDense.Dispose queue
62+
63+
match resHost with
64+
| Vector.Dense resHost ->
65+
let actual = resHost
66+
67+
Expect.sequenceEqual actual expected "Sequences must be equal"
68+
| _ -> failwith "Not implemented"
69+
| _ -> failwith "Not implemented" ]
70+
71+
let tests =
72+
TestCases.gpuTests "SSSP tests" testFixtures
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
namespace GraphBLAS.FSharp.Tests.Backend.QuickGraph.Algorithms
2+
3+
open QuikGraph
4+
open QuikGraph.Algorithms.ShortestPath
5+
open QuikGraph.Algorithms.Observers
6+
7+
module SSSP =
8+
let runUndirected (matrix: int [,]) (graph: AdjacencyGraph<int, Edge<int>>) source =
9+
let weight =
10+
fun (e: Edge<int>) -> float matrix.[e.Source, e.Target]
11+
12+
let dijkstra =
13+
DijkstraShortestPathAlgorithm<int, Edge<int>>(graph, weight)
14+
15+
// Attach a distance observer to give us the shortest path distances
16+
let distObserver =
17+
VertexDistanceRecorderObserver<int, Edge<int>>(weight)
18+
19+
distObserver.Attach(dijkstra) |> ignore
20+
21+
// Attach a Vertex Predecessor Recorder Observer to give us the paths
22+
let predecessorObserver =
23+
VertexPredecessorRecorderObserver<int, Edge<int>>()
24+
25+
predecessorObserver.Attach(dijkstra) |> ignore
26+
27+
// Run the algorithm with A set to be the source
28+
dijkstra.Compute(source)
29+
30+
let res: array<float option> =
31+
Array.zeroCreate (Array2D.length1 matrix)
32+
33+
for kvp in distObserver.Distances do
34+
res.[kvp.Key] <- Some kvp.Value
35+
36+
res.[source] <- Some 0.0
37+
res

tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<?xml version="1.0" encoding="utf-8"?>
1+
<?xml version="1.0" encoding="utf-8"?>
22
<Project Sdk="Microsoft.NET.Sdk">
33
<PropertyGroup>
44
<OutputType>Exe</OutputType>
@@ -15,8 +15,10 @@
1515
<Compile Include="Helpers.fs" />
1616
<Compile Include="Backend/QuickGraph/Algorithms/BFS.fs" />
1717
<Compile Include="Backend/QuickGraph/Algorithms/ConnectedComponents.fs" />
18+
<Compile Include="Backend/QuickGraph/Algorithms/SSSP.fs" />
1819
<Compile Include="Backend/QuickGraph/CreateGraph.fs" />
1920
<Compile Include="Backend/Algorithms/BFS.fs" />
21+
<Compile Include="Backend/Algorithms/SSSP.fs" />
2022
<Compile Include="Backend/Common/ClArray/Blit.fs" />
2123
<Compile Include="Backend/Common/ClArray/Choose.fs" />
2224
<Compile Include="Backend/Common/ClArray/ChunkBySize.fs" />

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ let vectorTests =
8787
|> testSequenced
8888

8989
let algorithmsTests =
90-
testList "Algorithms tests" [ Algorithms.BFS.tests ]
90+
testList
91+
"Algorithms tests"
92+
[ Algorithms.BFS.tests
93+
Algorithms.SSSP.tests ]
9194
|> testSequenced
9295

9396
let deviceTests =

0 commit comments

Comments
 (0)