Skip to content

Commit 4422318

Browse files
authored
Merge pull request #49 from kirillgarbar/spmv
SpMV
2 parents 1c5663c + 58e8a38 commit 4422318

18 files changed

Lines changed: 405 additions & 200 deletions

src/GraphBLAS-sharp.Backend/Common/Utils.fs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,22 @@ module internal Utils =
2121
>> fun x -> x ||| (x >>> 8)
2222
>> fun x -> x ||| (x >>> 16)
2323
>> fun x -> x + 1
24+
25+
let floorToMultiple multiple x = x / multiple * multiple
26+
27+
let ceilToMultiple multiple x = ((x - 1) / multiple + 1) * multiple
28+
29+
let getLocalMemorySize (clContext: ClContext) =
30+
let error = ref Unchecked.defaultof<ClErrorCode>
31+
32+
Cl
33+
.GetDeviceInfo(clContext.ClDevice.Device, OpenCL.Net.DeviceInfo.LocalMemSize, error)
34+
.CastTo<int>()
35+
36+
let getClArrayOfValueTypeSize<'a when 'a: struct> localMemorySize = localMemorySize / sizeof<'a>
37+
38+
//Option type in C is represented as structure with additional integer field
39+
let getClArrayOfOptionTypeSize<'a> localMemorySize =
40+
localMemorySize
41+
/ (sizeof<int> + sizeof<'a>
42+
|> ceilToMultiple (max sizeof<'a> sizeof<int>))

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<?xml version="1.0" encoding="utf-8"?>
1+
<?xml version="1.0" encoding="utf-8"?>
22
<Project Sdk="Microsoft.NET.Sdk">
33

44
<PropertyGroup>
@@ -30,9 +30,8 @@
3030
<Compile Include="Matrix/CSRMatrix/SpGEMM.fs" />
3131
<Compile Include="Matrix/CSRMatrix/CSRMatrix.fs" />
3232
<Compile Include="Matrix/CSRMatrix/CSRMatrix.fs" />
33-
<Compile Include="Matrix/CSRMatrix/SpMV.fs" />
3433
<Compile Include="Matrix/Matrix.fs" />
35-
<Folder Include="Vector" />
34+
<Compile Include="Vector/SpMV.fs" />
3635
<!--Compile Include="Backend/CSRMatrix/GetTuples.fs" /-->
3736
<!--Compile Include="Backend/CSRMatrix/SpMSpV.fs" /-->
3837
<!--Compile Include="Backend/CSRMatrix/Transpose.fs" /-->

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpMV.fs

Lines changed: 0 additions & 69 deletions
This file was deleted.
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
namespace GraphBLAS.FSharp.Backend
2+
3+
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend
5+
open GraphBLAS.FSharp.Backend.ArraysExtensions
6+
open GraphBLAS.FSharp.Backend.Common
7+
open Microsoft.FSharp.Quotations
8+
9+
module Vector =
10+
let spMV
11+
(clContext: ClContext)
12+
(add: Expr<'c option -> 'c option -> 'c option>)
13+
(mul: Expr<'a option -> 'b option -> 'c option>)
14+
workGroupSize
15+
=
16+
//Until LocalMemSize added to ClDevice as member
17+
let localMemorySize = Utils.getLocalMemorySize clContext
18+
19+
let localPointersArraySize = workGroupSize + 1
20+
21+
let localMemoryLeft =
22+
localMemorySize
23+
- localPointersArraySize * sizeof<int>
24+
25+
let localValuesArraySize =
26+
Utils.getClArrayOfOptionTypeSize localMemoryLeft
27+
28+
let multiplyValues =
29+
<@ fun (ndRange: Range1D) matrixLength (matrixColumns: ClArray<int>) (matrixValues: ClArray<'a>) (vectorValues: ClArray<'b option>) (intermediateArray: ClArray<'c option>) ->
30+
31+
let i = ndRange.GlobalID0
32+
let value = matrixValues.[i]
33+
let column = matrixColumns.[i]
34+
35+
if i < matrixLength then
36+
intermediateArray.[i] <- (%mul) (Some value) vectorValues.[column] @>
37+
38+
let reduceValuesByRows =
39+
<@ fun (ndRange: Range1D) (numberOfRows: int) (intermediateArray: ClArray<'c option>) (matrixPtr: ClArray<int>) (outputVector: ClArray<'c option>) ->
40+
41+
let gid = ndRange.GlobalID0
42+
let lid = ndRange.LocalID0
43+
44+
if gid <= numberOfRows then
45+
let threadsPerBlock =
46+
min (numberOfRows - gid + lid) workGroupSize //If number of rows left is lesser than number of threads in a block
47+
48+
let localPtr = localArray<int> localPointersArraySize
49+
localPtr.[lid] <- matrixPtr.[gid]
50+
51+
if lid = 0 then
52+
localPtr.[threadsPerBlock] <- matrixPtr.[gid + threadsPerBlock]
53+
54+
barrierLocal ()
55+
56+
let localValues =
57+
localArray<'c option> localValuesArraySize
58+
59+
let workEnd = localPtr.[threadsPerBlock]
60+
let mutable blockLowerBound = localPtr.[0]
61+
let numberOfBlocksFitting = localValuesArraySize / threadsPerBlock
62+
let workPerIteration = threadsPerBlock * numberOfBlocksFitting
63+
64+
let mutable sum: 'c option = None
65+
66+
while blockLowerBound < workEnd do
67+
let mutable index = blockLowerBound + lid
68+
69+
barrierLocal ()
70+
//Loading values to the local memory
71+
for block in 0 .. numberOfBlocksFitting - 1 do
72+
if index < workEnd then
73+
localValues.[lid + block * threadsPerBlock] <- intermediateArray.[index]
74+
index <- index + threadsPerBlock
75+
76+
barrierLocal ()
77+
//Reduction
78+
//Check if any part of the row is loaded into local memory on this iteration
79+
if (localPtr.[lid + 1] > blockLowerBound
80+
&& localPtr.[lid] < blockLowerBound + workPerIteration) then
81+
let rowStart = max (localPtr.[lid] - blockLowerBound) 0
82+
83+
let rowEnd =
84+
min (localPtr.[lid + 1] - blockLowerBound) workPerIteration
85+
86+
for j in rowStart .. rowEnd - 1 do
87+
let newSum = (%add) sum localValues.[j] //For some reason sum <- (%add) ... causes Brahma exception
88+
sum <- newSum
89+
90+
blockLowerBound <- blockLowerBound + workPerIteration
91+
92+
if gid < numberOfRows then
93+
outputVector.[gid] <- sum @>
94+
95+
let multiplyValues = clContext.Compile multiplyValues
96+
let reduceValuesByRows = clContext.Compile reduceValuesByRows
97+
98+
fun (queue: MailboxProcessor<_>) (matrix: CSRMatrix<'a>) (vector: ClArray<'b option>) ->
99+
100+
let matrixLength = matrix.Values.Length
101+
102+
let ndRange1 =
103+
Range1D.CreateValid(matrixLength, workGroupSize)
104+
105+
let ndRange2 =
106+
Range1D.CreateValid(matrix.RowCount, workGroupSize)
107+
108+
let intermediateArray =
109+
clContext.CreateClArray<'c option>(
110+
matrixLength,
111+
deviceAccessMode = DeviceAccessMode.ReadWrite,
112+
hostAccessMode = HostAccessMode.NotAccessible,
113+
allocationMode = AllocationMode.Default
114+
)
115+
116+
let multiplyValues = multiplyValues.GetKernel()
117+
118+
queue.Post(
119+
Msg.MsgSetArguments
120+
(fun () ->
121+
multiplyValues.KernelFunc
122+
ndRange1
123+
matrixLength
124+
matrix.Columns
125+
matrix.Values
126+
vector
127+
intermediateArray)
128+
)
129+
130+
queue.Post(Msg.CreateRunMsg<_, _>(multiplyValues))
131+
132+
let outputArray =
133+
clContext.CreateClArray<'c option>(
134+
matrix.RowCount,
135+
deviceAccessMode = DeviceAccessMode.ReadWrite,
136+
hostAccessMode = HostAccessMode.NotAccessible,
137+
allocationMode = AllocationMode.Default
138+
)
139+
140+
let reduceValuesByRows = reduceValuesByRows.GetKernel()
141+
142+
queue.Post(
143+
Msg.MsgSetArguments
144+
(fun () ->
145+
reduceValuesByRows.KernelFunc
146+
ndRange2
147+
matrix.RowCount
148+
intermediateArray
149+
matrix.RowPointers
150+
outputArray)
151+
)
152+
153+
queue.Post(Msg.CreateRunMsg<_, _>(reduceValuesByRows))
154+
155+
queue.Post(Msg.CreateFreeMsg intermediateArray)
156+
157+
outputArray

tests/GraphBLAS-sharp.Tests/BackendCommonTests/BitonicSortTests.fs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ open Expecto.Logging.Message
66
open GraphBLAS.FSharp.Backend.Common
77
open Brahma.FSharp
88
open GraphBLAS.FSharp.Tests.Utils
9+
open GraphBLAS.FSharp.Tests.Context
910

1011
let logger = Log.create "BitonicSort.Tests"
1112

tests/GraphBLAS-sharp.Tests/BackendCommonTests/ConvertTests.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ open Expecto
44
open Expecto.Logging
55
open Expecto.Logging.Message
66
open GraphBLAS.FSharp.Tests.Utils
7-
7+
open GraphBLAS.FSharp.Tests.Context
88
open GraphBLAS.FSharp.Backend
99
open GraphBLAS.FSharp
1010

tests/GraphBLAS-sharp.Tests/BackendCommonTests/CopyTests.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ open GraphBLAS.FSharp.Tests
99

1010
let logger = Log.create "Copy.Tests"
1111

12-
let context = Utils.defaultContext.ClContext
12+
let context = Context.defaultContext.ClContext
1313

1414
let testCases =
15-
let q = Utils.defaultContext.Queue
15+
let q = Context.defaultContext.Queue
1616
q.Error.Add(fun e -> failwithf "%A" e)
1717

1818
let getCopyFun copy =

0 commit comments

Comments
 (0)