Skip to content

Commit 0c4add8

Browse files
committed
Methods to compute local array size and SpMV vars names
1 parent ef998aa commit 0c4add8

2 files changed

Lines changed: 51 additions & 33 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/Utils.fs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,18 @@ module internal Utils =
2525
let floorToMultiple multiple x = x / multiple * multiple
2626

2727
let ceilToMultiple multiple x = ((x - 1) / multiple + 1) * multiple
28+
29+
let getLocalMemorySize (clContext: ClContext) =
30+
let error = ref Unchecked.defaultof<ClErrorCode>
31+
32+
Cl
33+
.GetDeviceInfo(clContext.ClDevice.Device, OpenCL.Net.DeviceInfo.LocalMemSize, error)
34+
.CastTo<int>()
35+
36+
let getClArrayOfValueTypeSize<'a when 'a: struct> localMemorySize = localMemorySize / sizeof<'a>
37+
38+
//Option type in C is represented as structure with additional integer field
39+
let getClArrayOfOptionTypeSize<'a> localMemorySize =
40+
localMemorySize
41+
/ (sizeof<int> + sizeof<'a>
42+
|> ceilToMultiple (max sizeof<'a> sizeof<int>))

src/GraphBLAS-sharp.Backend/Vector/SpMV.fs

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
namespace GraphBLAS.FSharp.Backend
1+
namespace GraphBLAS.FSharp.Backend
22

33
open Brahma.FSharp
44
open GraphBLAS.FSharp.Backend
@@ -14,27 +14,18 @@ module Vector =
1414
workGroupSize
1515
=
1616
//Until LocalMemSize added to ClDevice as member
17-
let error = ref Unchecked.defaultof<ClErrorCode>
17+
let localMemorySize = Utils.getLocalMemorySize clContext
1818

19-
let localMemorySize =
20-
Cl
21-
.GetDeviceInfo(clContext.ClDevice.Device, OpenCL.Net.DeviceInfo.LocalMemSize, error)
22-
.CastTo<int>()
23-
24-
let localArraySize1 = workGroupSize + 1
19+
let localPointersArraySize = workGroupSize + 1
2520

2621
let localMemoryLeft =
27-
localMemorySize - localArraySize1 * sizeof<int>
28-
29-
let optionTypeClSizeInBytes =
30-
4 + sizeof<'c>
31-
|> Utils.ceilToMultiple (max sizeof<'c> sizeof<int>)
22+
localMemorySize
23+
- localPointersArraySize * sizeof<int>
3224

33-
let localArraySize2 =
34-
localMemoryLeft / optionTypeClSizeInBytes
35-
|> Utils.floorToMultiple workGroupSize
25+
let localValuesArraySize =
26+
Utils.getClArrayOfOptionTypeSize localMemoryLeft
3627

37-
let kernel1 =
28+
let multiplyValues =
3829
<@ fun (ndRange: Range1D) matrixLength (matrixColumns: ClArray<int>) (matrixValues: ClArray<'a>) (vectorValues: ClArray<'b option>) (intermediateArray: ClArray<'c option>) ->
3930

4031
let i = ndRange.GlobalID0
@@ -44,7 +35,7 @@ module Vector =
4435
if i < matrixLength then
4536
intermediateArray.[i] <- (%mul) (Some value) vectorValues.[column] @>
4637

47-
let kernel2 =
38+
let reduceValuesByRows =
4839
<@ fun (ndRange: Range1D) (numberOfRows: int) (intermediateArray: ClArray<'c option>) (matrixPtr: ClArray<int>) (outputVector: ClArray<'c option>) ->
4940

5041
let gid = ndRange.GlobalID0
@@ -54,18 +45,20 @@ module Vector =
5445
let threadsPerBlock =
5546
min (numberOfRows - gid + lid) workGroupSize //If number of rows left is lesser than number of threads in a block
5647

57-
let localPtr = localArray<int> localArraySize1
48+
let localPtr = localArray<int> localPointersArraySize
5849
localPtr.[lid] <- matrixPtr.[gid]
5950

6051
if lid = 0 then
6152
localPtr.[threadsPerBlock] <- matrixPtr.[gid + threadsPerBlock]
6253

6354
barrierLocal ()
6455

65-
let localValues = localArray<'c option> localArraySize2
56+
let localValues =
57+
localArray<'c option> localValuesArraySize
58+
6659
let workEnd = localPtr.[threadsPerBlock]
6760
let mutable blockLowerBound = localPtr.[0]
68-
let numberOfBlocksFitting = localArraySize2 / threadsPerBlock
61+
let numberOfBlocksFitting = localValuesArraySize / threadsPerBlock
6962
let workPerIteration = threadsPerBlock * numberOfBlocksFitting
7063

7164
let mutable sum: 'c option = None
@@ -90,18 +83,17 @@ module Vector =
9083
let rowEnd =
9184
min (localPtr.[lid + 1] - blockLowerBound) workPerIteration
9285

93-
for jj in rowStart .. rowEnd - 1 do
94-
match (%add) sum localValues.[jj] with
95-
| Some v -> sum <- Some v
96-
| None -> sum <- None
86+
for j in rowStart .. rowEnd - 1 do
87+
let newSum = (%add) sum localValues.[j] //For some reason sum <- (%add) ... causes Brahma exception
88+
sum <- newSum
9789

9890
blockLowerBound <- blockLowerBound + workPerIteration
9991

10092
if gid < numberOfRows then
10193
outputVector.[gid] <- sum @>
10294

103-
let kernel1 = clContext.Compile kernel1
104-
let kernel2 = clContext.Compile kernel2
95+
let multiplyValues = clContext.Compile multiplyValues
96+
let reduceValuesByRows = clContext.Compile reduceValuesByRows
10597

10698
fun (queue: MailboxProcessor<_>) (matrix: CSRMatrix<'a>) (vector: ClArray<'b option>) ->
10799

@@ -121,15 +113,21 @@ module Vector =
121113
allocationMode = AllocationMode.Default
122114
)
123115

124-
let kernel1 = kernel1.GetKernel()
116+
let multiplyValues = multiplyValues.GetKernel()
125117

126118
queue.Post(
127119
Msg.MsgSetArguments
128120
(fun () ->
129-
kernel1.KernelFunc ndRange1 matrixLength matrix.Columns matrix.Values vector intermediateArray)
121+
multiplyValues.KernelFunc
122+
ndRange1
123+
matrixLength
124+
matrix.Columns
125+
matrix.Values
126+
vector
127+
intermediateArray)
130128
)
131129

132-
queue.Post(Msg.CreateRunMsg<_, _>(kernel1))
130+
queue.Post(Msg.CreateRunMsg<_, _>(multiplyValues))
133131

134132
let outputArray =
135133
clContext.CreateClArray<'c option>(
@@ -139,15 +137,20 @@ module Vector =
139137
allocationMode = AllocationMode.Default
140138
)
141139

142-
let kernel2 = kernel2.GetKernel()
140+
let reduceValuesByRows = reduceValuesByRows.GetKernel()
143141

144142
queue.Post(
145143
Msg.MsgSetArguments
146144
(fun () ->
147-
kernel2.KernelFunc ndRange2 matrix.RowCount intermediateArray matrix.RowPointers outputArray)
145+
reduceValuesByRows.KernelFunc
146+
ndRange2
147+
matrix.RowCount
148+
intermediateArray
149+
matrix.RowPointers
150+
outputArray)
148151
)
149152

150-
queue.Post(Msg.CreateRunMsg<_, _>(kernel2))
153+
queue.Post(Msg.CreateRunMsg<_, _>(reduceValuesByRows))
151154

152155
queue.Post(Msg.CreateFreeMsg intermediateArray)
153156

0 commit comments

Comments
 (0)