Skip to content

Commit 0ea4d0f

Browse files
committed
SpMV
1 parent 26ad23b commit 0ea4d0f

3 files changed

Lines changed: 160 additions & 2 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/Utils.fs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,7 @@ module internal Utils =
2121
>> fun x -> x ||| (x >>> 8)
2222
>> fun x -> x ||| (x >>> 16)
2323
>> fun x -> x + 1
24+
25+
let floorToMultiple multiple x = x / multiple * multiple
26+
27+
let ceilToMultiple multiple x = ((x - 1) / multiple + 1) * multiple

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<?xml version="1.0" encoding="utf-8"?>
1+
<?xml version="1.0" encoding="utf-8"?>
22
<Project Sdk="Microsoft.NET.Sdk">
33

44
<PropertyGroup>
@@ -28,7 +28,7 @@
2828
<Compile Include="Matrix/CSRMatrix/SpGEMM.fs" />
2929
<Compile Include="Matrix/CSRMatrix/SpMV.fs" />
3030
<Compile Include="Matrix/Matrix.fs" />
31-
<Folder Include="Vector" />
31+
<Compile Include="Vector/SpMV.fs" />
3232
<!--Compile Include="Backend/CSRMatrix/GetTuples.fs" /-->
3333
<!--Compile Include="Backend/CSRMatrix/SpMSpV.fs" /-->
3434
<!--Compile Include="Backend/CSRMatrix/Transpose.fs" /-->
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
namespace GraphBLAS.FSharp.Backend
2+
3+
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend
5+
open GraphBLAS.FSharp.Backend.ArraysExtensions
6+
open GraphBLAS.FSharp.Backend.Common
7+
open Microsoft.FSharp.Quotations
8+
9+
module Vector =
10+
let spMV
11+
(clContext: ClContext)
12+
(add: Expr<'c option -> 'c option -> 'c option>)
13+
(mul: Expr<'a option -> 'b option -> 'c option>)
14+
workGroupSize
15+
=
16+
//Until LocalMemSize added to ClDevice as member
17+
let error = ref Unchecked.defaultof<ClErrorCode>
18+
19+
let localMemorySize =
20+
Cl
21+
.GetDeviceInfo(clContext.ClDevice.Device, OpenCL.Net.DeviceInfo.LocalMemSize, error)
22+
.CastTo<int>()
23+
24+
let localArraySize1 = workGroupSize + 1
25+
26+
let localMemoryLeft =
27+
localMemorySize - localArraySize1 * sizeof<int>
28+
29+
let optionTypeClSizeInBytes =
30+
4 + sizeof<'c>
31+
|> Utils.ceilToMultiple (max sizeof<'c> sizeof<int>)
32+
33+
let localArraySize2 =
34+
localMemoryLeft / optionTypeClSizeInBytes
35+
|> Utils.floorToMultiple workGroupSize
36+
37+
let kernel1 =
38+
<@ fun (ndRange: Range1D) matrixLength (matrixColumns: ClArray<int>) (matrixValues: ClArray<'a>) (vectorValues: ClArray<'b option>) (intermediateArray: ClArray<'c option>) ->
39+
40+
let i = ndRange.GlobalID0
41+
let value = matrixValues.[i]
42+
let column = matrixColumns.[i]
43+
44+
if i < matrixLength then
45+
intermediateArray.[i] <- (%mul) (Some value) vectorValues.[column] @>
46+
47+
let kernel2 =
48+
<@ fun (ndRange: Range1D) (numberOfRows: int) (intermediateArray: ClArray<'c option>) (matrixPtr: ClArray<int>) (outputVector: ClArray<'c option>) ->
49+
50+
let gid = ndRange.GlobalID0
51+
let lid = ndRange.LocalID0
52+
53+
if gid <= numberOfRows then
54+
let threadsPerBlock =
55+
min (numberOfRows - gid + lid) workGroupSize //If number of rows left is lesser than number of threads in a block
56+
57+
let localPtr = localArray<int> localArraySize1
58+
localPtr.[lid] <- matrixPtr.[gid]
59+
60+
if lid = 0 then
61+
localPtr.[threadsPerBlock] <- matrixPtr.[gid + threadsPerBlock]
62+
63+
barrierLocal ()
64+
65+
let localValues = localArray<'c option> localArraySize2
66+
let workEnd = localPtr.[threadsPerBlock]
67+
let mutable blockLowerBound = localPtr.[0]
68+
let numberOfBlocksFitting = localArraySize2 / threadsPerBlock
69+
let workPerIteration = threadsPerBlock * numberOfBlocksFitting
70+
71+
let mutable sum: 'c option = None
72+
73+
while blockLowerBound < workEnd do
74+
let mutable index = blockLowerBound + lid
75+
76+
barrierLocal ()
77+
//Loading values to the local memory
78+
for block in 0 .. numberOfBlocksFitting - 1 do
79+
if index < workEnd then
80+
localValues.[lid + block * threadsPerBlock] <- intermediateArray.[index]
81+
index <- index + threadsPerBlock
82+
83+
barrierLocal ()
84+
//Reduction
85+
//Check if any part of the row is loaded into local memory on this iteration
86+
if (localPtr.[lid + 1] > blockLowerBound
87+
&& localPtr.[lid] < blockLowerBound + workPerIteration) then
88+
let rowStart = max (localPtr.[lid] - blockLowerBound) 0
89+
90+
let rowEnd =
91+
min (localPtr.[lid + 1] - blockLowerBound) workPerIteration
92+
93+
for jj in rowStart .. rowEnd - 1 do
94+
match (%add) sum localValues.[jj] with
95+
| Some v -> sum <- Some v
96+
| None -> sum <- None
97+
98+
blockLowerBound <- blockLowerBound + workPerIteration
99+
100+
if gid < numberOfRows then
101+
outputVector.[gid] <- sum @>
102+
103+
let kernel1 = clContext.Compile kernel1
104+
let kernel2 = clContext.Compile kernel2
105+
106+
fun (queue: MailboxProcessor<_>) (matrix: CSRMatrix<'a>) (vector: ClArray<'b option>) ->
107+
108+
let matrixLength = matrix.Values.Length
109+
110+
let ndRange1 =
111+
Range1D.CreateValid(matrixLength, workGroupSize)
112+
113+
let ndRange2 =
114+
Range1D.CreateValid(matrix.RowCount, workGroupSize)
115+
116+
let intermediateArray =
117+
clContext.CreateClArray<'c option>(
118+
matrixLength,
119+
deviceAccessMode = DeviceAccessMode.ReadWrite,
120+
hostAccessMode = HostAccessMode.NotAccessible,
121+
allocationMode = AllocationMode.Default
122+
)
123+
124+
let kernel1 = kernel1.GetKernel()
125+
126+
queue.Post(
127+
Msg.MsgSetArguments
128+
(fun () ->
129+
kernel1.KernelFunc ndRange1 matrixLength matrix.Columns matrix.Values vector intermediateArray)
130+
)
131+
132+
queue.Post(Msg.CreateRunMsg<_, _>(kernel1))
133+
134+
let outputArray =
135+
clContext.CreateClArray<'c option>(
136+
matrix.RowCount,
137+
deviceAccessMode = DeviceAccessMode.ReadWrite,
138+
hostAccessMode = HostAccessMode.NotAccessible,
139+
allocationMode = AllocationMode.Default
140+
)
141+
142+
let kernel2 = kernel2.GetKernel()
143+
144+
queue.Post(
145+
Msg.MsgSetArguments
146+
(fun () ->
147+
kernel2.KernelFunc ndRange2 matrix.RowCount intermediateArray matrix.RowPointers outputArray)
148+
)
149+
150+
queue.Post(Msg.CreateRunMsg<_, _>(kernel2))
151+
152+
queue.Post(Msg.CreateFreeMsg intermediateArray)
153+
154+
outputArray

0 commit comments

Comments
 (0)