@@ -3,6 +3,7 @@ namespace GraphBLAS.FSharp.Backend
33open Brahma.FSharp
44open GraphBLAS.FSharp .Backend
55open GraphBLAS.FSharp .Backend .Common
6+ open GraphBLAS.FSharp .Backend .Elementwise
67open Microsoft.FSharp .Quotations
78
89module CSRMatrix =
@@ -82,13 +83,14 @@ module CSRMatrix =
8283 Columns = matrix.Columns
8384 Values = matrix.Values }
8485
85- let eWiseAdd ( clContext : ClContext ) ( opAdd : Expr < 'a option -> 'b option -> 'c option >) workGroupSize =
86+ ///<remarks >Old version</remarks >
87+ let elementwiseWithCOO ( clContext : ClContext ) ( opAdd : Expr < 'a option -> 'b option -> 'c option >) workGroupSize =
8688
8789 let prepareRows =
8890 expandRowPointers clContext workGroupSize
8991
9092 let eWiseCOO =
91- COOMatrix.eWiseAdd clContext opAdd workGroupSize
93+ COOMatrix.elementwise clContext opAdd workGroupSize
9294
9395 let toCSRInplace =
9496 COOMatrix.toCSRInplace clContext workGroupSize
@@ -117,13 +119,18 @@ module CSRMatrix =
117119
118120 toCSRInplace processor m3COO
119121
120- let eWiseAddAtLeastOne ( clContext : ClContext ) ( opAdd : Expr < AtLeastOne < 'a , 'b > -> 'c option >) workGroupSize =
122+ ///<remarks >Old version</remarks >
123+ let elementwiseAtLeastOneWithCOO
124+ ( clContext : ClContext )
125+ ( opAdd : Expr < AtLeastOne < 'a , 'b > -> 'c option >)
126+ workGroupSize
127+ =
121128
122129 let prepareRows =
123130 expandRowPointers clContext workGroupSize
124131
125132 let eWiseCOO =
126- COOMatrix.eWiseAddAtLeastOne clContext opAdd workGroupSize
133+ COOMatrix.elementwiseAtLeastOne clContext opAdd workGroupSize
127134
128135 let toCSRInplace =
129136 COOMatrix.toCSRInplace clContext workGroupSize
@@ -183,6 +190,142 @@ module CSRMatrix =
183190 let transposedCoo = transposeInplace queue coo
184191 toCSRInplace queue transposedCoo
185192
193+ let elementwiseToCOO < 'a , 'b , 'c when 'a : struct and 'b : struct and 'c : struct and 'c : equality >
194+ ( clContext : ClContext )
195+ ( opAdd : Expr < 'a option -> 'b option -> 'c option >)
196+ workGroupSize
197+ =
198+
199+ let merge = merge clContext workGroupSize
200+
201+ let preparePositions =
202+ preparePositions clContext opAdd Utils.defaultWorkGroupSize
203+
204+ let setPositions =
205+ setPositions< 'c> clContext Utils.defaultWorkGroupSize
206+
207+ fun ( queue : MailboxProcessor < _ >) ( matrixLeft : CSRMatrix < 'a >) ( matrixRight : CSRMatrix < 'b >) ->
208+
209+ let allRows , allColumns , leftMergedValues , rightMergedValues , isRowEnd , isLeft =
210+ merge
211+ queue
212+ matrixLeft.RowPointers
213+ matrixLeft.Columns
214+ matrixLeft.Values
215+ matrixRight.RowPointers
216+ matrixRight.Columns
217+ matrixRight.Values
218+
219+ let positions , allValues =
220+ preparePositions queue allColumns leftMergedValues rightMergedValues isRowEnd isLeft
221+
222+ queue.Post( Msg.CreateFreeMsg<_>( leftMergedValues))
223+ queue.Post( Msg.CreateFreeMsg<_>( rightMergedValues))
224+
225+ let resultRows , resultColumns , resultValues , positions , positionsSum =
226+ setPositions queue allRows allColumns allValues positions
227+
228+ queue.Post( Msg.CreateFreeMsg<_>( allRows))
229+ queue.Post( Msg.CreateFreeMsg<_>( isLeft))
230+ queue.Post( Msg.CreateFreeMsg<_>( isRowEnd))
231+ queue.Post( Msg.CreateFreeMsg<_>( positions))
232+ queue.Post( Msg.CreateFreeMsg<_>( allColumns))
233+ queue.Post( Msg.CreateFreeMsg<_>( allValues))
234+
235+ { Context = clContext
236+ RowCount = matrixLeft.RowCount
237+ ColumnCount = matrixLeft.ColumnCount
238+ Rows = resultRows
239+ Columns = resultColumns
240+ Values = resultValues }
241+
242+ let elementwise < 'a , 'b , 'c when 'a : struct and 'b : struct and 'c : struct and 'c : equality >
243+ ( clContext : ClContext )
244+ ( opAdd : Expr < 'a option -> 'b option -> 'c option >)
245+ workGroupSize
246+ =
247+
248+ let elementwiseToCOO =
249+ elementwiseToCOO clContext opAdd workGroupSize
250+
251+ let toCSRInplace =
252+ COOMatrix.toCSRInplace clContext Utils.defaultWorkGroupSize
253+
254+ fun ( queue : MailboxProcessor < _ >) ( matrixLeft : CSRMatrix < 'a >) ( matrixRight : CSRMatrix < 'b >) ->
255+
256+ let cooRes =
257+ elementwiseToCOO queue matrixLeft matrixRight
258+
259+ toCSRInplace queue cooRes
260+
261+ let elementwiseAtLeastOneToCOO < 'a , 'b , 'c when 'a : struct and 'b : struct and 'c : struct and 'c : equality >
262+ ( clContext : ClContext )
263+ ( opAdd : Expr < AtLeastOne < 'a , 'b > -> 'c option >)
264+ workGroupSize
265+ =
266+
267+ let merge = merge clContext workGroupSize
268+
269+ let preparePositions =
270+ preparePositionsAtLeastOne clContext opAdd Utils.defaultWorkGroupSize
271+
272+ let setPositions =
273+ setPositions< 'c> clContext Utils.defaultWorkGroupSize
274+
275+ fun ( queue : MailboxProcessor < _ >) ( matrixLeft : CSRMatrix < 'a >) ( matrixRight : CSRMatrix < 'b >) ->
276+
277+ let allRows , allColumns , leftMergedValues , rightMergedValues , isRowEnd , isLeft =
278+ merge
279+ queue
280+ matrixLeft.RowPointers
281+ matrixLeft.Columns
282+ matrixLeft.Values
283+ matrixRight.RowPointers
284+ matrixRight.Columns
285+ matrixRight.Values
286+
287+ let positions , allValues =
288+ preparePositions queue allColumns leftMergedValues rightMergedValues isRowEnd isLeft
289+
290+ queue.Post( Msg.CreateFreeMsg<_>( leftMergedValues))
291+ queue.Post( Msg.CreateFreeMsg<_>( rightMergedValues))
292+
293+ let resultRows , resultColumns , resultValues , positions , positionsSum =
294+ setPositions queue allRows allColumns allValues positions
295+
296+ queue.Post( Msg.CreateFreeMsg<_>( allRows))
297+ queue.Post( Msg.CreateFreeMsg<_>( isLeft))
298+ queue.Post( Msg.CreateFreeMsg<_>( isRowEnd))
299+ queue.Post( Msg.CreateFreeMsg<_>( positions))
300+ queue.Post( Msg.CreateFreeMsg<_>( allColumns))
301+ queue.Post( Msg.CreateFreeMsg<_>( allValues))
302+
303+ { Context = clContext
304+ RowCount = matrixLeft.RowCount
305+ ColumnCount = matrixLeft.ColumnCount
306+ Rows = resultRows
307+ Columns = resultColumns
308+ Values = resultValues }
309+
310+ let elementwiseAtLeastOne < 'a , 'b , 'c when 'a : struct and 'b : struct and 'c : struct and 'c : equality >
311+ ( clContext : ClContext )
312+ ( opAdd : Expr < AtLeastOne < 'a , 'b > -> 'c option >)
313+ workGroupSize
314+ =
315+
316+ let elementwiseAtLeastOneToCOO =
317+ elementwiseAtLeastOneToCOO clContext opAdd workGroupSize
318+
319+ let toCSRInplace =
320+ COOMatrix.toCSRInplace clContext Utils.defaultWorkGroupSize
321+
322+ fun ( queue : MailboxProcessor < _ >) ( matrixLeft : CSRMatrix < 'a >) ( matrixRight : CSRMatrix < 'b >) ->
323+
324+ let cooRes =
325+ elementwiseAtLeastOneToCOO queue matrixLeft matrixRight
326+
327+ toCSRInplace queue cooRes
328+
186329 let spgemmCSC
187330 ( clContext : ClContext )
188331 workGroupSize
0 commit comments