@@ -73,99 +73,43 @@ module Expand =
7373
7474 requiredRawsLengths
7575
76- let expandRightMatrixValuesIndices ( clContext : ClContext ) workGroupSize =
77-
78- let kernel =
79- <@ fun ( ndRange : Range1D ) length ( globalRightMatrixValuesPositions : Indices ) ( requiredRightMatrixValuesPointers : Indices ) ( globalPositions : Indices ) ( result : Indices ) ->
76+ let extractLeftMatrixRequiredValuesAndColumns ( clContext : ClContext ) workGroupSize =
8077
81- let gid = ndRange.GlobalID0
78+ let getUniqueBitmap =
79+ ClArray.getUniqueBitmap clContext workGroupSize
8280
83- if gid < length then
84- // index corresponding to the position of pointers
85- let positionIndex = globalPositions.[ gid] - 1
81+ let prefixSumExclude =
82+ PrefixSum.standardExcludeInplace clContext workGroupSize
8683
87- // the position of the beginning of a new line of pointers
88- let sourcePosition = globalRightMatrixValuesPositions .[ positionIndex ]
84+ let indicesScatter =
85+ Scatter.runInplace clContext workGroupSize
8986
90- // offset from the source pointer
91- let offsetFromSourcePosition = gid - sourcePosition
87+ let dataScatter =
88+ Scatter.runInplace clContext workGroupSize
9289
93- // pointer to the first element in the row of the right matrix from which
94- // the offset will be counted to get pointers to subsequent elements in this row
95- let sourcePointer = requiredRightMatrixValuesPointers.[ positionIndex]
90+ fun ( processor : MailboxProcessor < _ >) ( leftMatrix : ClMatrix.CSR < 'a >) ( globalRightMatrixRawsStartPositions : Indices ) ->
9691
97- // adding up the mix with the source pointer,
98- // we get a pointer to a specific element in the raw
99- result .[ gid ] <- sourcePointer + offsetFromSourcePosition @>
92+ let leftMatrixRequiredPositions , resultLength =
93+ let bitmap =
94+ getUniqueBitmap processor DeviceOnly globalRightMatrixRawsStartPositions
10095
101- let kernel = clContext.Compile kernel
96+ let length = ( prefixSumExclude processor bitmap ) .ToHostAndFree processor
10297
103- fun ( processor : MailboxProcessor < _ >) ( resultLength : int ) ( globalRightMatrixRawsStartPositions : Indices ) ( requiredRightMatrixValuesPointers : Indices ) ( globalPositions : Indices ) ->
98+ bitmap , length
10499
105- let globalRightMatrixValuesPointers =
100+ let requiredLeftMatrixValues =
106101 clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, resultLength)
107102
108- let kernel = kernel.GetKernel()
109-
110- let ndRange =
111- Range1D.CreateValid( resultLength, workGroupSize)
112-
113- processor.Post(
114- Msg.MsgSetArguments
115- ( fun () ->
116- kernel.KernelFunc
117- ndRange
118- resultLength
119- globalRightMatrixRawsStartPositions
120- requiredRightMatrixValuesPointers
121- globalPositions
122- globalRightMatrixValuesPointers)
123- )
124-
125- processor.Post <| Msg.CreateRunMsg<_, _> kernel
126-
127- globalRightMatrixValuesPointers
128-
129- let getResultRowPointers ( clContext : ClContext ) workGroupSize =
130-
131- let kernel =
132- <@ fun ( ndRange : Range1D ) length ( leftMatrixRowPointers : Indices ) ( globalArrayRightMatrixRawPointers : Indices ) ( result : Indices ) ->
133-
134- let gid = ndRange.GlobalID0
135-
136- if gid < length then
137- let rowPointer = leftMatrixRowPointers.[ gid]
138- let globalPointer = globalArrayRightMatrixRawPointers.[ rowPointer]
139-
140- result.[ gid] <- globalPointer
141- @>
142-
143- let kernel = clContext.Compile kernel
144-
145- fun ( processor : MailboxProcessor < _ >) ( leftMatrixRowPointers : Indices ) ( globalArrayRightMatrixRawPointers : Indices ) ->
146-
147- let result =
148- clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, leftMatrixRowPointers.Length)
149-
150- let kernel = kernel.GetKernel()
103+ indicesScatter processor leftMatrixRequiredPositions leftMatrix.Values requiredLeftMatrixValues
151104
152- let ndRange =
153- Range1D.CreateValid ( leftMatrixRowPointers.Length , workGroupSize )
105+ let requiredLeftMatrixColumns =
106+ clContext.CreateClArrayWithSpecificAllocationMode ( DeviceOnly , resultLength )
154107
155- processor.Post(
156- Msg.MsgSetArguments
157- ( fun () ->
158- kernel.KernelFunc
159- ndRange
160- leftMatrixRowPointers.Length
161- leftMatrixRowPointers
162- globalArrayRightMatrixRawPointers
163- result)
164- )
108+ dataScatter processor leftMatrixRequiredPositions leftMatrix.Columns requiredLeftMatrixColumns
165109
166- processor.Post <| Msg.CreateRunMsg <_, _> kernel
110+ leftMatrixRequiredPositions.Free processor
167111
168- result
112+ requiredLeftMatrixColumns , requiredLeftMatrixValues
169113
170114 let getGlobalMap ( clContext : ClContext ) workGroupSize =
171115
@@ -183,49 +127,56 @@ module Expand =
183127 // Insert units at the beginning of new lines (source positions)
184128 assignUnits processor globalRightMatrixValuesPositions globalPositions
185129
186- // Apply the prefix sum,
130+ // Apply the prefix sum, SIDE EFFECT!!!
187131 // get an array where different sub-arrays of pointers to elements of the same row differ in values
188132 ( prefixSum processor globalPositions) .Free processor
189133
190134 globalPositions
191135
192- let extractLeftMatrixRequiredValuesAndColumns ( clContext : ClContext ) workGroupSize =
193-
194- let getUniqueBitmap =
195- ClArray.getUniqueBitmap clContext workGroupSize
136+ let getResultRowPointers ( clContext : ClContext ) workGroupSize =
196137
197- let prefixSumExclude =
198- PrefixSum.standardExcludeInplace clContext workGroupSize
138+ let kernel =
139+ <@ fun ( ndRange : Range1D ) length ( leftMatrixRowPointers : Indices ) ( globalArrayRightMatrixRawPointers : Indices ) ( result : Indices ) ->
199140
200- let indicesScatter =
201- Scatter.runInplace clContext workGroupSize
141+ let gid = ndRange.GlobalID0
202142
203- let dataScatter =
204- Scatter.runInplace clContext workGroupSize
143+ // do not touch the last element
144+ if gid < length - 1 then
145+ let rowPointer = leftMatrixRowPointers.[ gid]
146+ let globalPointer = globalArrayRightMatrixRawPointers.[ rowPointer]
205147
206- fun ( processor : MailboxProcessor < _ >) ( leftMatrix : ClMatrix.CSR < 'a >) ( globalRightMatrixRawsStartPositions : Indices ) - >
148+ result .[ gid ] <- globalPointer @ >
207149
208- let leftMatrixRequiredPositions , resultLength =
209- let bitmap =
210- getUniqueBitmap processor DeviceOnly globalRightMatrixRawsStartPositions
150+ let kernel = clContext.Compile kernel
211151
212- let length = ( prefixSumExclude processor bitmap ) .ToHostAndFree processor
152+ let createResultPointersBuffer = ClArray.create clContext workGroupSize
213153
214- bitmap , length
154+ fun ( processor : MailboxProcessor < _ >) ( globalLength : int ) ( leftMatrixRowPointers : Indices ) ( globalRightMatrixRowPointers : Indices ) ->
215155
216- let requiredLeftMatrixValues =
217- clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, resultLength)
156+ // The last element must be equal to the length of the global array.
157+ let result =
158+ createResultPointersBuffer processor DeviceOnly leftMatrixRowPointers.Length globalLength
218159
219- indicesScatter processor leftMatrixRequiredPositions leftMatrix.Values requiredLeftMatrixValues
160+ let kernel = kernel.GetKernel ()
220161
221- let requiredLeftMatrixColumns =
222- clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, resultLength)
162+ // do not touch the last element
163+ let ndRange =
164+ Range1D.CreateValid( leftMatrixRowPointers.Length - 1 , workGroupSize)
223165
224- dataScatter processor leftMatrixRequiredPositions leftMatrix.Columns requiredLeftMatrixColumns
166+ processor.Post(
167+ Msg.MsgSetArguments
168+ ( fun () ->
169+ kernel.KernelFunc
170+ ndRange
171+ leftMatrixRowPointers.Length
172+ leftMatrixRowPointers
173+ globalRightMatrixRowPointers
174+ result)
175+ )
225176
226- leftMatrixRequiredPositions.Free processor
177+ processor.Post <| Msg.CreateRunMsg <_, _> kernel
227178
228- requiredLeftMatrixColumns , requiredLeftMatrixValues
179+ result
229180
230181 let processPositions ( clContext : ClContext ) workGroupSize =
231182
@@ -242,6 +193,8 @@ module Expand =
242193
243194 let getGlobalPositions = getGlobalMap clContext workGroupSize
244195
196+ let getRowPointers = getResultRowPointers clContext workGroupSize
197+
245198 let getRequiredRightMatrixValuesPointers =
246199 processLeftMatrixColumnsAndRightMatrixRawPointers clContext workGroupSize requiredRawPointers
247200
@@ -250,15 +203,14 @@ module Expand =
250203 let requiredRawsLengths =
251204 getRequiredRawsLengths processor leftMatrix.Columns rightMatrix.RowPointers
252205
253- // global expanded array length (sum of previous length)
206+ // global expanded array length (sum of previous length) SIDE EFFECT!!!
254207 let globalLength =
255208 ( prefixSumExclude processor requiredRawsLengths) .ToHostAndFree processor
256209
257210 // rename array after side effect of prefix sum include
258211 // positions in global array for right matrix raws with duplicates
259212 let globalRightMatrixRowsStartPositions = requiredRawsLengths
260213
261-
262214 /// Extract required left matrix columns and values by global right matrix pointers.
263215 /// Then get required right matrix rows (pointers to rows) by required left matrix columns.
264216
@@ -277,13 +229,70 @@ module Expand =
277229 let globalRightMatrixRawsPointersWithoutDuplicates =
278230 removeDuplications processor globalRightMatrixRowsStartPositions
279231
232+ // RESULT row pointers into result expanded (obtained by multiplication) array
233+ let resultRowPointers =
234+ getRowPointers processor globalLength leftMatrix.RowPointers globalRightMatrixRowsStartPositions
235+
280236 globalRightMatrixRowsStartPositions.Free processor
281237
282238 // int map to distinguish different raws in a general array. 1 for first, 2 for second and so forth...
283239 let globalMap =
284240 getGlobalPositions processor globalLength globalRightMatrixRawsPointersWithoutDuplicates
285241
286- globalMap, globalRightMatrixRawsPointersWithoutDuplicates, requiredLeftMatrixValues, requiredRightMatrixRawPointers
242+ globalMap, globalRightMatrixRawsPointersWithoutDuplicates, requiredLeftMatrixValues, requiredRightMatrixRawPointers, resultRowPointers
243+
244+ let expandRightMatrixValuesIndices ( clContext : ClContext ) workGroupSize =
245+
246+ let kernel =
247+ <@ fun ( ndRange : Range1D ) length ( globalRightMatrixValuesPositions : Indices ) ( requiredRightMatrixValuesPointers : Indices ) ( globalPositions : Indices ) ( result : Indices ) ->
248+
249+ let gid = ndRange.GlobalID0
250+
251+ if gid < length then
252+ // index corresponding to the position of pointers
253+ let positionIndex = globalPositions.[ gid] - 1
254+
255+ // the position of the beginning of a new line of pointers
256+ let sourcePosition = globalRightMatrixValuesPositions.[ positionIndex]
257+
258+ // offset from the source pointer
259+ let offsetFromSourcePosition = gid - sourcePosition
260+
261+ // pointer to the first element in the row of the right matrix from which
262+ // the offset will be counted to get pointers to subsequent elements in this row
263+ let sourcePointer = requiredRightMatrixValuesPointers.[ positionIndex]
264+
265+ // adding up the mix with the source pointer,
266+ // we get a pointer to a specific element in the raw
267+ result.[ gid] <- sourcePointer + offsetFromSourcePosition @>
268+
269+ let kernel = clContext.Compile kernel
270+
271+ fun ( processor : MailboxProcessor < _ >) ( globalRightMatrixRawsStartPositions : Indices ) ( requiredRightMatrixValuesPointers : Indices ) ( globalMap : Indices ) ->
272+
273+ let globalRightMatrixValuesPointers =
274+ clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, globalMap.Length)
275+
276+ let kernel = kernel.GetKernel()
277+
278+ let ndRange =
279+ Range1D.CreateValid( globalMap.Length, workGroupSize)
280+
281+ processor.Post(
282+ Msg.MsgSetArguments
283+ ( fun () ->
284+ kernel.KernelFunc
285+ ndRange
286+ globalMap.Length
287+ globalRightMatrixRawsStartPositions
288+ requiredRightMatrixValuesPointers
289+ globalMap
290+ globalRightMatrixValuesPointers)
291+ )
292+
293+ processor.Post <| Msg.CreateRunMsg<_, _> kernel
294+
295+ globalRightMatrixValuesPointers
287296
288297 let expandLeftMatrixValues ( clContext : ClContext ) workGroupSize =
289298
@@ -330,56 +339,51 @@ module Expand =
330339
331340 let gatherIndices = Gather.run clContext workGroupSize
332341
333- fun ( processor : MailboxProcessor < _ >) ( globalLength : int ) ( globalPositions : Indices ) ( rightMatrixValues : Values <'a >) ( rightMatrixColumns : Indices ) ->
342+ fun ( processor : MailboxProcessor < _ >) ( globalPositions : Indices ) ( rightMatrix : ClMatrix.CSR <'a >) ->
334343 // gather all required right matrix values
335344 let extendedRightMatrixValues =
336- clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, globalLength )
345+ clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, globalPositions.Length )
337346
338- gatherRightMatrixData processor globalPositions rightMatrixValues extendedRightMatrixValues
347+ gatherRightMatrixData processor globalPositions rightMatrix.Values extendedRightMatrixValues
339348
340349 // gather all required right matrix column indices
341350 let extendedRightMatrixColumns =
342- clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, globalLength )
351+ clContext.CreateClArrayWithSpecificAllocationMode( DeviceOnly, globalPositions.Length )
343352
344- gatherIndices processor globalPositions rightMatrixColumns extendedRightMatrixColumns
353+ gatherIndices processor globalPositions rightMatrix.Columns extendedRightMatrixColumns
345354
346355 extendedRightMatrixValues, extendedRightMatrixColumns
347356
348357 let run ( clContext : ClContext ) workGroupSize ( multiplication : Expr < 'a -> 'b -> 'c >) =
349358
350359 let processPositions = processPositions clContext workGroupSize
351360
352- let getRightMatrixValuesPointers =
361+ let expandLeftMatrixValues =
362+ expandLeftMatrixValues clContext workGroupSize
363+
364+ let expandRightMatrixValuesPointers =
353365 expandRightMatrixValuesIndices clContext workGroupSize
354366
355367 let getRightMatrixColumnsAndValues =
356368 getRightMatrixColumnsAndValues clContext workGroupSize
357369
358- let expandLeftMatrixValues =
359- expandLeftMatrixValues clContext workGroupSize
360-
361370 let map2 = ClArray.map2 clContext workGroupSize multiplication
362371
363- let getRawPointers = getResultRowPointers clContext workGroupSize
364-
365372 fun ( processor : MailboxProcessor < _ >) ( leftMatrix : ClMatrix.CSR < 'a >) ( rightMatrix : ClMatrix.CSR < 'b >) ->
366373
367- let globalMap , globalRightMatrixRowsPointers , requiredLeftMatrixValues , requiredRightMatrixRowPointers
374+ let globalMap , globalRightMatrixRowsPointers , requiredLeftMatrixValues , requiredRightMatrixRowPointers , resultRowPointers
368375 = processPositions processor leftMatrix rightMatrix
369376
370- // left matrix values correspondingly to right matrix values // TODO()
377+ // left matrix values correspondingly to right matrix values
371378 let extendedLeftMatrixValues =
372- expandLeftMatrixValues processor globalMap leftMatrix.Values
373-
374- let resultRowPointers =
375- getRawPointers processor leftMatrix.RowPointers globalRightMatrixRowsPointers
379+ expandLeftMatrixValues processor globalMap requiredLeftMatrixValues
376380
377381 // extended pointers to all required right matrix numbers
378382 let globalRightMatrixValuesPointers =
379- getRightMatrixValuesPointers processor globalMap.Length globalRightMatrixRowsPointers requiredRightMatrixRowPointers globalMap
383+ expandRightMatrixValuesPointers processor globalRightMatrixRowsPointers requiredRightMatrixRowPointers globalMap
380384
381385 let extendedRightMatrixValues , extendedRightMatrixColumns =
382- getRightMatrixColumnsAndValues processor globalMap.Length globalRightMatrixValuesPointers rightMatrix.Values rightMatrix.Columns
386+ getRightMatrixColumnsAndValues processor globalRightMatrixValuesPointers rightMatrix
383387
384388 /// Multiplication
385389 let multiplicationResult =
0 commit comments