1- namespace GraphBLAS.FSharp.Backend.Matrix.CSRMatrix .SpGEMM
1+ namespace GraphBLAS.FSharp.Backend.Matrix.CSR .SpGEMM
22
33open Brahma.FSharp
44open GraphBLAS.FSharp .Backend .Common
55open GraphBLAS.FSharp .Backend .Predefined
66open GraphBLAS.FSharp .Backend .Objects .ClContext
77open GraphBLAS.FSharp .Backend .Objects
88open GraphBLAS.FSharp .Backend .Objects .ClCell
9+ open FSharp.Quotations
910
1011type Indices = ClArray< int>
1112
@@ -143,7 +144,6 @@ module Expand =
143144 )
144145
145146 processor.Post <| Msg.CreateRunMsg<_, _> kernel
146- processor.Post <| Msg.CreateFreeMsg globalPositions
147147
148148 globalRightMatrixValuesPointers
149149
@@ -157,7 +157,7 @@ module Expand =
157157 if gid < globalLength then
158158 let valuePosition = globalPositions.[ gid] - 1
159159
160- result.[ gid] <- rightMatrixValues.[ valuePosition]@>
160+ result.[ gid] <- rightMatrixValues.[ valuePosition] @>
161161
162162 let kernel = clContext.Compile kernel
163163
@@ -184,11 +184,51 @@ module Expand =
184184 )
185185
186186 processor.Post <| Msg.CreateRunMsg<_, _> kernel
187- processor.Post <| Msg.CreateFreeMsg globalPositions
188187
189188 resultLeftMatrixValues
190189
191- let run ( clContext : ClContext ) workGroupSize multiplication =
190+ let getResultRowPointers ( clContext : ClContext ) workGroupSize =
191+
192+ let kernel =
193+ <@ fun ( ndRange : Range1D ) length ( leftMatrixRowPointers : Indices ) ( globalArrayRightMatrixRawPointers : Indices ) ( result : Indices ) ->
194+
195+ let gid = ndRange.GlobalID0
196+
197+ if gid < length then
198+ let rowPointer = leftMatrixRowPointers.[ gid]
199+ let globalPointer = globalArrayRightMatrixRawPointers.[ rowPointer]
200+
201+ result.[ gid] <- globalPointer
202+ @>
203+
204+ let kernel = clContext.Compile kernel
205+
206+ fun ( processor : MailboxProcessor < _ >) ( leftMatrixRowPointers : Indices ) ( globalArrayRightMatrixRawPointers : Indices ) ->
207+
208+ let result =
209+ clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, leftMatrixRowPointers.Length)
210+
211+ let kernel = kernel.GetKernel()
212+
213+ let ndRange =
214+ Range1D.CreateValid( leftMatrixRowPointers.Length, workGroupSize)
215+
216+ processor.Post(
217+ Msg.MsgSetArguments
218+ ( fun () ->
219+ kernel.KernelFunc
220+ ndRange
221+ leftMatrixRowPointers.Length
222+ leftMatrixRowPointers
223+ globalArrayRightMatrixRawPointers
224+ result)
225+ )
226+
227+ processor.Post <| Msg.CreateRunMsg<_, _> kernel
228+
229+ result
230+
231+ let run ( clContext : ClContext ) workGroupSize ( multiplication : Expr < 'a -> 'b -> 'c >) =
192232
193233 let getRequiredRawsLengths =
194234 processLeftMatrixColumnsAndRightMatrixRawPointers clContext workGroupSize requiredRawsLengths
@@ -199,11 +239,11 @@ module Expand =
199239 let getRequiredRightMatrixValuesPointers =
200240 processLeftMatrixColumnsAndRightMatrixRawPointers clContext workGroupSize requiredRawPointers
201241
242+ let getGlobalPositions = getGlobalPositions clContext workGroupSize
243+
202244 let getRightMatrixValuesPointers =
203245 getRightMatrixPointers clContext workGroupSize
204246
205- let getGlobalPositions = getGlobalPositions clContext workGroupSize
206-
207247 let gatherRightMatrixData = Gather.run clContext workGroupSize
208248
209249 let gatherIndices = Gather.run clContext workGroupSize
@@ -213,6 +253,8 @@ module Expand =
213253
214254 let map2 = ClArray.map2 clContext workGroupSize multiplication
215255
256+ let getRawPointers = getResultRowPointers clContext workGroupSize
257+
216258 fun ( processor : MailboxProcessor < _ >) ( leftMatrix : ClMatrix.CSR < 'a >) ( rightMatrix : ClMatrix.CSR < 'b >) ->
217259
218260 let requiredRawsLengths =
@@ -252,9 +294,12 @@ module Expand =
252294
253295 // left matrix values correspondingly to right matrix values
254296 let extendedLeftMatrixValues =
255- getLeftMatrixValues processor globalLength globalPositions rightMatrix .Values
297+ getLeftMatrixValues processor globalLength globalPositions leftMatrix .Values
256298
257299 let multiplicationResult =
258300 map2 processor DeviceOnly extendedLeftMatrixValues extendedRightMatrixValues
259301
260- multiplicationResult, extendedRightMatrixColumns
302+ let rowPointers =
303+ getRawPointers processor leftMatrix.RowPointers globalRightMatrixValuesRawsStartPositions
304+
305+ multiplicationResult, extendedRightMatrixColumns, rowPointers
0 commit comments