1- namespace GraphBLAS.FSharp.Backend
1+ namespace GraphBLAS.FSharp.Backend
22
33open Brahma.FSharp
44open GraphBLAS.FSharp .Backend
@@ -14,27 +14,18 @@ module Vector =
1414 workGroupSize
1515 =
1616 //Until LocalMemSize added to ClDevice as member
17- let error = ref Unchecked.defaultof < ClErrorCode >
17+ let localMemorySize = Utils.getLocalMemorySize clContext
1818
19- let localMemorySize =
20- Cl
21- .GetDeviceInfo( clContext.ClDevice.Device, OpenCL.Net.DeviceInfo.LocalMemSize, error)
22- .CastTo< int>()
23-
24- let localArraySize1 = workGroupSize + 1
19+ let localPointersArraySize = workGroupSize + 1
2520
2621 let localMemoryLeft =
27- localMemorySize - localArraySize1 * sizeof< int>
28-
29- let optionTypeClSizeInBytes =
30- 4 + sizeof< 'c>
31- |> Utils.ceilToMultiple ( max sizeof< 'c> sizeof< int>)
22+ localMemorySize
23+ - localPointersArraySize * sizeof< int>
3224
33- let localArraySize2 =
34- localMemoryLeft / optionTypeClSizeInBytes
35- |> Utils.floorToMultiple workGroupSize
25+ let localValuesArraySize =
26+ Utils.getClArrayOfOptionTypeSize localMemoryLeft
3627
37- let kernel1 =
28+ let multiplyValues =
3829 <@ fun ( ndRange : Range1D ) matrixLength ( matrixColumns : ClArray < int >) ( matrixValues : ClArray < 'a >) ( vectorValues : ClArray < 'b option >) ( intermediateArray : ClArray < 'c option >) ->
3930
4031 let i = ndRange.GlobalID0
@@ -44,7 +35,7 @@ module Vector =
4435 if i < matrixLength then
4536 intermediateArray.[ i] <- (% mul) ( Some value) vectorValues.[ column] @>
4637
47- let kernel2 =
38+ let reduceValuesByRows =
4839 <@ fun ( ndRange : Range1D ) ( numberOfRows : int ) ( intermediateArray : ClArray < 'c option >) ( matrixPtr : ClArray < int >) ( outputVector : ClArray < 'c option >) ->
4940
5041 let gid = ndRange.GlobalID0
@@ -54,18 +45,20 @@ module Vector =
5445 let threadsPerBlock =
5546 min ( numberOfRows - gid + lid) workGroupSize //If number of rows left is lesser than number of threads in a block
5647
57- let localPtr = localArray< int> localArraySize1
48+ let localPtr = localArray< int> localPointersArraySize
5849 localPtr.[ lid] <- matrixPtr.[ gid]
5950
6051 if lid = 0 then
6152 localPtr.[ threadsPerBlock] <- matrixPtr.[ gid + threadsPerBlock]
6253
6354 barrierLocal ()
6455
65- let localValues = localArray< 'c option> localArraySize2
56+ let localValues =
57+ localArray< 'c option> localValuesArraySize
58+
6659 let workEnd = localPtr.[ threadsPerBlock]
6760 let mutable blockLowerBound = localPtr.[ 0 ]
68- let numberOfBlocksFitting = localArraySize2 / threadsPerBlock
61+ let numberOfBlocksFitting = localValuesArraySize / threadsPerBlock
6962 let workPerIteration = threadsPerBlock * numberOfBlocksFitting
7063
7164 let mutable sum : 'c option = None
@@ -90,18 +83,17 @@ module Vector =
9083 let rowEnd =
9184 min ( localPtr.[ lid + 1 ] - blockLowerBound) workPerIteration
9285
93- for jj in rowStart .. rowEnd - 1 do
94- match (% add) sum localValues.[ jj] with
95- | Some v -> sum <- Some v
96- | None -> sum <- None
86+ for j in rowStart .. rowEnd - 1 do
87+ let newSum = (% add) sum localValues.[ j] //For some reason sum <- (%add) ... causes Brahma exception
88+ sum <- newSum
9789
9890 blockLowerBound <- blockLowerBound + workPerIteration
9991
10092 if gid < numberOfRows then
10193 outputVector.[ gid] <- sum @>
10294
103- let kernel1 = clContext.Compile kernel1
104- let kernel2 = clContext.Compile kernel2
95+ let multiplyValues = clContext.Compile multiplyValues
96+ let reduceValuesByRows = clContext.Compile reduceValuesByRows
10597
10698 fun ( queue : MailboxProcessor < _ >) ( matrix : CSRMatrix < 'a >) ( vector : ClArray < 'b option >) ->
10799
@@ -121,15 +113,21 @@ module Vector =
121113 allocationMode = AllocationMode.Default
122114 )
123115
124- let kernel1 = kernel1 .GetKernel()
116+ let multiplyValues = multiplyValues .GetKernel()
125117
126118 queue.Post(
127119 Msg.MsgSetArguments
128120 ( fun () ->
129- kernel1.KernelFunc ndRange1 matrixLength matrix.Columns matrix.Values vector intermediateArray)
121+ multiplyValues.KernelFunc
122+ ndRange1
123+ matrixLength
124+ matrix.Columns
125+ matrix.Values
126+ vector
127+ intermediateArray)
130128 )
131129
132- queue.Post( Msg.CreateRunMsg<_, _>( kernel1 ))
130+ queue.Post( Msg.CreateRunMsg<_, _>( multiplyValues ))
133131
134132 let outputArray =
135133 clContext.CreateClArray< 'c option>(
@@ -139,15 +137,20 @@ module Vector =
139137 allocationMode = AllocationMode.Default
140138 )
141139
142- let kernel2 = kernel2 .GetKernel()
140+ let reduceValuesByRows = reduceValuesByRows .GetKernel()
143141
144142 queue.Post(
145143 Msg.MsgSetArguments
146144 ( fun () ->
147- kernel2.KernelFunc ndRange2 matrix.RowCount intermediateArray matrix.RowPointers outputArray)
145+ reduceValuesByRows.KernelFunc
146+ ndRange2
147+ matrix.RowCount
148+ intermediateArray
149+ matrix.RowPointers
150+ outputArray)
148151 )
149152
150- queue.Post( Msg.CreateRunMsg<_, _>( kernel2 ))
153+ queue.Post( Msg.CreateRunMsg<_, _>( reduceValuesByRows ))
151154
152155 queue.Post( Msg.CreateFreeMsg intermediateArray)
153156
0 commit comments