Skip to content

Commit 5479816

Browse files
committed
add: Expand stage
1 parent 9d25601 commit 5479816

3 files changed

Lines changed: 226 additions & 220 deletions

File tree

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM/Expand.fs

Lines changed: 129 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -73,99 +73,43 @@ module Expand =
7373

7474
requiredRawsLengths
7575

76-
let expandRightMatrixValuesIndices (clContext: ClContext) workGroupSize =
77-
78-
let kernel =
79-
<@ fun (ndRange: Range1D) length (globalRightMatrixValuesPositions: Indices) (requiredRightMatrixValuesPointers: Indices) (globalPositions: Indices) (result: Indices) ->
76+
let extractLeftMatrixRequiredValuesAndColumns (clContext: ClContext) workGroupSize =
8077

81-
let gid = ndRange.GlobalID0
78+
let getUniqueBitmap =
79+
ClArray.getUniqueBitmap clContext workGroupSize
8280

83-
if gid < length then
84-
// index corresponding to the position of pointers
85-
let positionIndex = globalPositions.[gid] - 1
81+
let prefixSumExclude =
82+
PrefixSum.standardExcludeInplace clContext workGroupSize
8683

87-
// the position of the beginning of a new line of pointers
88-
let sourcePosition = globalRightMatrixValuesPositions.[positionIndex]
84+
let indicesScatter =
85+
Scatter.runInplace clContext workGroupSize
8986

90-
// offset from the source pointer
91-
let offsetFromSourcePosition = gid - sourcePosition
87+
let dataScatter =
88+
Scatter.runInplace clContext workGroupSize
9289

93-
// pointer to the first element in the row of the right matrix from which
94-
// the offset will be counted to get pointers to subsequent elements in this row
95-
let sourcePointer = requiredRightMatrixValuesPointers.[positionIndex]
90+
fun (processor: MailboxProcessor<_>) (leftMatrix: ClMatrix.CSR<'a>) (globalRightMatrixRawsStartPositions: Indices) ->
9691

97-
// adding up the mix with the source pointer,
98-
// we get a pointer to a specific element in the raw
99-
result.[gid] <- sourcePointer + offsetFromSourcePosition @>
92+
let leftMatrixRequiredPositions, resultLength =
93+
let bitmap =
94+
getUniqueBitmap processor DeviceOnly globalRightMatrixRawsStartPositions
10095

101-
let kernel = clContext.Compile kernel
96+
let length = (prefixSumExclude processor bitmap).ToHostAndFree processor
10297

103-
fun (processor: MailboxProcessor<_>) (resultLength: int) (globalRightMatrixRawsStartPositions: Indices) (requiredRightMatrixValuesPointers: Indices) (globalPositions: Indices) ->
98+
bitmap, length
10499

105-
let globalRightMatrixValuesPointers =
100+
let requiredLeftMatrixValues =
106101
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, resultLength)
107102

108-
let kernel = kernel.GetKernel()
109-
110-
let ndRange =
111-
Range1D.CreateValid(resultLength, workGroupSize)
112-
113-
processor.Post(
114-
Msg.MsgSetArguments
115-
(fun () ->
116-
kernel.KernelFunc
117-
ndRange
118-
resultLength
119-
globalRightMatrixRawsStartPositions
120-
requiredRightMatrixValuesPointers
121-
globalPositions
122-
globalRightMatrixValuesPointers)
123-
)
124-
125-
processor.Post <| Msg.CreateRunMsg<_, _> kernel
126-
127-
globalRightMatrixValuesPointers
128-
129-
let getResultRowPointers (clContext: ClContext) workGroupSize =
130-
131-
let kernel =
132-
<@ fun (ndRange: Range1D) length (leftMatrixRowPointers: Indices) (globalArrayRightMatrixRawPointers: Indices) (result: Indices) ->
133-
134-
let gid = ndRange.GlobalID0
135-
136-
if gid < length then
137-
let rowPointer = leftMatrixRowPointers.[gid]
138-
let globalPointer = globalArrayRightMatrixRawPointers.[rowPointer]
139-
140-
result.[gid] <- globalPointer
141-
@>
142-
143-
let kernel = clContext.Compile kernel
144-
145-
fun (processor: MailboxProcessor<_>) (leftMatrixRowPointers: Indices) (globalArrayRightMatrixRawPointers: Indices) ->
146-
147-
let result =
148-
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, leftMatrixRowPointers.Length)
149-
150-
let kernel = kernel.GetKernel()
103+
indicesScatter processor leftMatrixRequiredPositions leftMatrix.Values requiredLeftMatrixValues
151104

152-
let ndRange =
153-
Range1D.CreateValid( leftMatrixRowPointers.Length, workGroupSize)
105+
let requiredLeftMatrixColumns =
106+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, resultLength)
154107

155-
processor.Post(
156-
Msg.MsgSetArguments
157-
(fun () ->
158-
kernel.KernelFunc
159-
ndRange
160-
leftMatrixRowPointers.Length
161-
leftMatrixRowPointers
162-
globalArrayRightMatrixRawPointers
163-
result)
164-
)
108+
dataScatter processor leftMatrixRequiredPositions leftMatrix.Columns requiredLeftMatrixColumns
165109

166-
processor.Post <| Msg.CreateRunMsg<_, _> kernel
110+
leftMatrixRequiredPositions.Free processor
167111

168-
result
112+
requiredLeftMatrixColumns, requiredLeftMatrixValues
169113

170114
let getGlobalMap (clContext: ClContext) workGroupSize =
171115

@@ -183,49 +127,56 @@ module Expand =
183127
// Insert units at the beginning of new lines (source positions)
184128
assignUnits processor globalRightMatrixValuesPositions globalPositions
185129

186-
// Apply the prefix sum,
130+
// Apply the prefix sum, SIDE EFFECT!!!
187131
// get an array where different sub-arrays of pointers to elements of the same row differ in values
188132
(prefixSum processor globalPositions).Free processor
189133

190134
globalPositions
191135

192-
let extractLeftMatrixRequiredValuesAndColumns (clContext: ClContext) workGroupSize =
193-
194-
let getUniqueBitmap =
195-
ClArray.getUniqueBitmap clContext workGroupSize
136+
let getResultRowPointers (clContext: ClContext) workGroupSize =
196137

197-
let prefixSumExclude =
198-
PrefixSum.standardExcludeInplace clContext workGroupSize
138+
let kernel =
139+
<@ fun (ndRange: Range1D) length (leftMatrixRowPointers: Indices) (globalArrayRightMatrixRawPointers: Indices) (result: Indices) ->
199140

200-
let indicesScatter =
201-
Scatter.runInplace clContext workGroupSize
141+
let gid = ndRange.GlobalID0
202142

203-
let dataScatter =
204-
Scatter.runInplace clContext workGroupSize
143+
// do not touch the last element
144+
if gid < length - 1 then
145+
let rowPointer = leftMatrixRowPointers.[gid]
146+
let globalPointer = globalArrayRightMatrixRawPointers.[rowPointer]
205147

206-
fun (processor: MailboxProcessor<_>) (leftMatrix: ClMatrix.CSR<'a>) (globalRightMatrixRawsStartPositions: Indices) ->
148+
result.[gid] <- globalPointer @>
207149

208-
let leftMatrixRequiredPositions, resultLength =
209-
let bitmap =
210-
getUniqueBitmap processor DeviceOnly globalRightMatrixRawsStartPositions
150+
let kernel = clContext.Compile kernel
211151

212-
let length = (prefixSumExclude processor bitmap).ToHostAndFree processor
152+
let createResultPointersBuffer = ClArray.create clContext workGroupSize
213153

214-
bitmap, length
154+
fun (processor: MailboxProcessor<_>) (globalLength: int) (leftMatrixRowPointers: Indices) (globalRightMatrixRowPointers: Indices) ->
215155

216-
let requiredLeftMatrixValues =
217-
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, resultLength)
156+
// The last element must be equal to the length of the global array.
157+
let result =
158+
createResultPointersBuffer processor DeviceOnly leftMatrixRowPointers.Length globalLength
218159

219-
indicesScatter processor leftMatrixRequiredPositions leftMatrix.Values requiredLeftMatrixValues
160+
let kernel = kernel.GetKernel()
220161

221-
let requiredLeftMatrixColumns =
222-
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, resultLength)
162+
// do not touch the last element
163+
let ndRange =
164+
Range1D.CreateValid(leftMatrixRowPointers.Length - 1, workGroupSize)
223165

224-
dataScatter processor leftMatrixRequiredPositions leftMatrix.Columns requiredLeftMatrixColumns
166+
processor.Post(
167+
Msg.MsgSetArguments
168+
(fun () ->
169+
kernel.KernelFunc
170+
ndRange
171+
leftMatrixRowPointers.Length
172+
leftMatrixRowPointers
173+
globalRightMatrixRowPointers
174+
result)
175+
)
225176

226-
leftMatrixRequiredPositions.Free processor
177+
processor.Post <| Msg.CreateRunMsg<_, _> kernel
227178

228-
requiredLeftMatrixColumns, requiredLeftMatrixValues
179+
result
229180

230181
let processPositions (clContext: ClContext) workGroupSize =
231182

@@ -242,6 +193,8 @@ module Expand =
242193

243194
let getGlobalPositions = getGlobalMap clContext workGroupSize
244195

196+
let getRowPointers = getResultRowPointers clContext workGroupSize
197+
245198
let getRequiredRightMatrixValuesPointers =
246199
processLeftMatrixColumnsAndRightMatrixRawPointers clContext workGroupSize requiredRawPointers
247200

@@ -250,15 +203,14 @@ module Expand =
250203
let requiredRawsLengths =
251204
getRequiredRawsLengths processor leftMatrix.Columns rightMatrix.RowPointers
252205

253-
// global expanded array length (sum of previous length)
206+
// global expanded array length (sum of previous length) SIDE EFFECT!!!
254207
let globalLength =
255208
(prefixSumExclude processor requiredRawsLengths).ToHostAndFree processor
256209

257210
// rename array after side effect of prefix sum include
258211
// positions in global array for right matrix raws with duplicates
259212
let globalRightMatrixRowsStartPositions = requiredRawsLengths
260213

261-
262214
/// Extract required left matrix columns and values by global right matrix pointers.
263215
/// Then get required right matrix rows (pointers to rows) by required left matrix columns.
264216
@@ -277,13 +229,70 @@ module Expand =
277229
let globalRightMatrixRawsPointersWithoutDuplicates =
278230
removeDuplications processor globalRightMatrixRowsStartPositions
279231

232+
// RESULT row pointers into result expanded (obtained by multiplication) array
233+
let resultRowPointers =
234+
getRowPointers processor globalLength leftMatrix.RowPointers globalRightMatrixRowsStartPositions
235+
280236
globalRightMatrixRowsStartPositions.Free processor
281237

282238
// int map to distinguish different raws in a general array. 1 for first, 2 for second and so forth...
283239
let globalMap =
284240
getGlobalPositions processor globalLength globalRightMatrixRawsPointersWithoutDuplicates
285241

286-
globalMap, globalRightMatrixRawsPointersWithoutDuplicates, requiredLeftMatrixValues, requiredRightMatrixRawPointers
242+
globalMap, globalRightMatrixRawsPointersWithoutDuplicates, requiredLeftMatrixValues, requiredRightMatrixRawPointers, resultRowPointers
243+
244+
let expandRightMatrixValuesIndices (clContext: ClContext) workGroupSize =
245+
246+
let kernel =
247+
<@ fun (ndRange: Range1D) length (globalRightMatrixValuesPositions: Indices) (requiredRightMatrixValuesPointers: Indices) (globalPositions: Indices) (result: Indices) ->
248+
249+
let gid = ndRange.GlobalID0
250+
251+
if gid < length then
252+
// index corresponding to the position of pointers
253+
let positionIndex = globalPositions.[gid] - 1
254+
255+
// the position of the beginning of a new line of pointers
256+
let sourcePosition = globalRightMatrixValuesPositions.[positionIndex]
257+
258+
// offset from the source pointer
259+
let offsetFromSourcePosition = gid - sourcePosition
260+
261+
// pointer to the first element in the row of the right matrix from which
262+
// the offset will be counted to get pointers to subsequent elements in this row
263+
let sourcePointer = requiredRightMatrixValuesPointers.[positionIndex]
264+
265+
// adding up the mix with the source pointer,
266+
// we get a pointer to a specific element in the raw
267+
result.[gid] <- sourcePointer + offsetFromSourcePosition @>
268+
269+
let kernel = clContext.Compile kernel
270+
271+
fun (processor: MailboxProcessor<_>) (globalRightMatrixRawsStartPositions: Indices) (requiredRightMatrixValuesPointers: Indices) (globalMap: Indices) ->
272+
273+
let globalRightMatrixValuesPointers =
274+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, globalMap.Length)
275+
276+
let kernel = kernel.GetKernel()
277+
278+
let ndRange =
279+
Range1D.CreateValid(globalMap.Length, workGroupSize)
280+
281+
processor.Post(
282+
Msg.MsgSetArguments
283+
(fun () ->
284+
kernel.KernelFunc
285+
ndRange
286+
globalMap.Length
287+
globalRightMatrixRawsStartPositions
288+
requiredRightMatrixValuesPointers
289+
globalMap
290+
globalRightMatrixValuesPointers)
291+
)
292+
293+
processor.Post <| Msg.CreateRunMsg<_, _> kernel
294+
295+
globalRightMatrixValuesPointers
287296

288297
let expandLeftMatrixValues (clContext: ClContext) workGroupSize =
289298

@@ -330,56 +339,51 @@ module Expand =
330339

331340
let gatherIndices = Gather.run clContext workGroupSize
332341

333-
fun (processor: MailboxProcessor<_>) (globalLength: int) (globalPositions: Indices) (rightMatrixValues: Values<'a>) (rightMatrixColumns: Indices) ->
342+
fun (processor: MailboxProcessor<_>) (globalPositions: Indices) (rightMatrix: ClMatrix.CSR<'a>) ->
334343
// gather all required right matrix values
335344
let extendedRightMatrixValues =
336-
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, globalLength)
345+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, globalPositions.Length)
337346

338-
gatherRightMatrixData processor globalPositions rightMatrixValues extendedRightMatrixValues
347+
gatherRightMatrixData processor globalPositions rightMatrix.Values extendedRightMatrixValues
339348

340349
// gather all required right matrix column indices
341350
let extendedRightMatrixColumns =
342-
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, globalLength)
351+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, globalPositions.Length)
343352

344-
gatherIndices processor globalPositions rightMatrixColumns extendedRightMatrixColumns
353+
gatherIndices processor globalPositions rightMatrix.Columns extendedRightMatrixColumns
345354

346355
extendedRightMatrixValues, extendedRightMatrixColumns
347356

348357
let run (clContext: ClContext) workGroupSize (multiplication: Expr<'a -> 'b -> 'c>) =
349358

350359
let processPositions = processPositions clContext workGroupSize
351360

352-
let getRightMatrixValuesPointers =
361+
let expandLeftMatrixValues =
362+
expandLeftMatrixValues clContext workGroupSize
363+
364+
let expandRightMatrixValuesPointers =
353365
expandRightMatrixValuesIndices clContext workGroupSize
354366

355367
let getRightMatrixColumnsAndValues =
356368
getRightMatrixColumnsAndValues clContext workGroupSize
357369

358-
let expandLeftMatrixValues =
359-
expandLeftMatrixValues clContext workGroupSize
360-
361370
let map2 = ClArray.map2 clContext workGroupSize multiplication
362371

363-
let getRawPointers = getResultRowPointers clContext workGroupSize
364-
365372
fun (processor: MailboxProcessor<_>) (leftMatrix: ClMatrix.CSR<'a>) (rightMatrix: ClMatrix.CSR<'b>) ->
366373

367-
let globalMap, globalRightMatrixRowsPointers, requiredLeftMatrixValues, requiredRightMatrixRowPointers
374+
let globalMap, globalRightMatrixRowsPointers, requiredLeftMatrixValues, requiredRightMatrixRowPointers, resultRowPointers
368375
= processPositions processor leftMatrix rightMatrix
369376

370-
// left matrix values correspondingly to right matrix values // TODO()
377+
// left matrix values correspondingly to right matrix values
371378
let extendedLeftMatrixValues =
372-
expandLeftMatrixValues processor globalMap leftMatrix.Values
373-
374-
let resultRowPointers =
375-
getRawPointers processor leftMatrix.RowPointers globalRightMatrixRowsPointers
379+
expandLeftMatrixValues processor globalMap requiredLeftMatrixValues
376380

377381
// extended pointers to all required right matrix numbers
378382
let globalRightMatrixValuesPointers =
379-
getRightMatrixValuesPointers processor globalMap.Length globalRightMatrixRowsPointers requiredRightMatrixRowPointers globalMap
383+
expandRightMatrixValuesPointers processor globalRightMatrixRowsPointers requiredRightMatrixRowPointers globalMap
380384

381385
let extendedRightMatrixValues, extendedRightMatrixColumns =
382-
getRightMatrixColumnsAndValues processor globalMap.Length globalRightMatrixValuesPointers rightMatrix.Values rightMatrix.Columns
386+
getRightMatrixColumnsAndValues processor globalRightMatrixValuesPointers rightMatrix
383387

384388
/// Multiplication
385389
let multiplicationResult =

0 commit comments

Comments
 (0)