Skip to content

Commit ea88cc6

Browse files
committed
add: atLeastOneToNormalForm fun, refactor
1 parent 375918c commit ea88cc6

4 files changed

Lines changed: 12 additions & 278 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/StandardOperations.fs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,11 @@ module StandardOperations =
101101
let byteMulAtLeastOne = mkNumericMulAtLeastOne 0uy
102102
let floatMulAtLeastOne = mkNumericMulAtLeastOne 0.0
103103
let float32MulAtLeastOne = mkNumericMulAtLeastOne 0f
104+
105+
let atLeastOneToNormalForm op =
106+
<@ fun (leftItem: 'a option) (rightItem: 'b option) ->
107+
match leftItem, rightItem with
108+
| Some left, Some right -> (%op) (Both(left, right))
109+
| None, Some right -> (%op) (Right right)
110+
| Some left, None -> (%op) (Left left)
111+
| None, None -> None @>

src/GraphBLAS-sharp.Backend/Matrix/COOMatrix/COOMatrix.fs

Lines changed: 1 addition & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -498,84 +498,6 @@ module COOMatrix =
498498
Columns = matrix.Columns
499499
Values = matrix.Values }
500500

501-
let private preparePositionsAtLeastOne<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
502-
(clContext: ClContext)
503-
(opAdd: Expr<AtLeastOne<'a, 'b> -> 'c option>)
504-
workGroupSize
505-
=
506-
507-
let preparePositions =
508-
<@ fun (ndRange: Range1D) length (allRowsBuffer: ClArray<int>) (allColumnsBuffer: ClArray<int>) (leftValuesBuffer: ClArray<'a>) (rightValuesBuffer: ClArray<'b>) (allValuesBuffer: ClArray<'c>) (rawPositionsBuffer: ClArray<int>) (isLeftBitmap: ClArray<int>) ->
509-
510-
let i = ndRange.GlobalID0
511-
512-
if (i < length - 1
513-
&& allRowsBuffer.[i] = allRowsBuffer.[i + 1]
514-
&& allColumnsBuffer.[i] = allColumnsBuffer.[i + 1]) then
515-
516-
let result =
517-
(%opAdd) (Both(leftValuesBuffer.[i + 1], rightValuesBuffer.[i]))
518-
519-
(%PreparePositions.both) i result rawPositionsBuffer allValuesBuffer
520-
elif (i > 0
521-
&& i < length
522-
&& (allRowsBuffer.[i] <> allRowsBuffer.[i - 1]
523-
|| allColumnsBuffer.[i] <> allColumnsBuffer.[i - 1]))
524-
|| i = 0 then
525-
526-
let leftResult = (%opAdd) (Left leftValuesBuffer.[i])
527-
let rightResult = (%opAdd) (Right rightValuesBuffer.[i])
528-
529-
(%PreparePositions.leftRight)
530-
i
531-
leftResult
532-
rightResult
533-
isLeftBitmap
534-
allValuesBuffer
535-
rawPositionsBuffer @>
536-
537-
let kernel = clContext.Compile(preparePositions)
538-
539-
fun (processor: MailboxProcessor<_>) (allRows: ClArray<int>) (allColumns: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (isLeft: ClArray<int>) ->
540-
let length = leftValues.Length
541-
542-
let ndRange =
543-
Range1D.CreateValid(length, workGroupSize)
544-
545-
let rawPositionsGpu =
546-
clContext.CreateClArray<int>(
547-
length,
548-
hostAccessMode = HostAccessMode.NotAccessible,
549-
allocationMode = AllocationMode.Default
550-
)
551-
552-
let allValues =
553-
clContext.CreateClArray<'c>(
554-
length,
555-
hostAccessMode = HostAccessMode.NotAccessible,
556-
allocationMode = AllocationMode.Default
557-
)
558-
559-
let kernel = kernel.GetKernel()
560-
561-
processor.Post(
562-
Msg.MsgSetArguments
563-
(fun () ->
564-
kernel.KernelFunc
565-
ndRange
566-
length
567-
allRows
568-
allColumns
569-
leftValues
570-
rightValues
571-
allValues
572-
rawPositionsGpu
573-
isLeft)
574-
)
575-
576-
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
577-
rawPositionsGpu, allValues
578-
579501
///<param name="clContext">.</param>
580502
///<param name="opAdd">.</param>
581503
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
@@ -585,46 +507,7 @@ module COOMatrix =
585507
workGroupSize
586508
=
587509

588-
let merge = merge clContext workGroupSize
589-
590-
let preparePositions =
591-
preparePositionsAtLeastOne clContext opAdd workGroupSize
592-
593-
let setPositions = setPositions<'c> clContext workGroupSize
594-
595-
fun (queue: MailboxProcessor<_>) (matrixLeft: COOMatrix<'a>) (matrixRight: COOMatrix<'b>) ->
596-
597-
let allRows, allColumns, leftMergedValues, rightMergedValues, isLeft =
598-
merge
599-
queue
600-
matrixLeft.Rows
601-
matrixLeft.Columns
602-
matrixLeft.Values
603-
matrixRight.Rows
604-
matrixRight.Columns
605-
matrixRight.Values
606-
607-
let rawPositions, allValues =
608-
preparePositions queue allRows allColumns leftMergedValues rightMergedValues isLeft
609-
610-
queue.Post(Msg.CreateFreeMsg<_>(leftMergedValues))
611-
queue.Post(Msg.CreateFreeMsg<_>(rightMergedValues))
612-
613-
let resultRows, resultColumns, resultValues, resultLength =
614-
setPositions queue allRows allColumns allValues rawPositions
615-
616-
queue.Post(Msg.CreateFreeMsg<_>(isLeft))
617-
queue.Post(Msg.CreateFreeMsg<_>(rawPositions))
618-
queue.Post(Msg.CreateFreeMsg<_>(allRows))
619-
queue.Post(Msg.CreateFreeMsg<_>(allColumns))
620-
queue.Post(Msg.CreateFreeMsg<_>(allValues))
621-
622-
{ Context = clContext
623-
RowCount = matrixLeft.RowCount
624-
ColumnCount = matrixLeft.ColumnCount
625-
Rows = resultRows
626-
Columns = resultColumns
627-
Values = resultValues }
510+
elementwise clContext (StandardOperations.atLeastOneToNormalForm opAdd) workGroupSize
628511

629512
let transposeInplace (clContext: ClContext) workGroupSize =
630513

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/CSRMatrix.fs

Lines changed: 3 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -126,38 +126,7 @@ module CSRMatrix =
126126
workGroupSize
127127
=
128128

129-
let prepareRows =
130-
expandRowPointers clContext workGroupSize
131-
132-
let eWiseCOO =
133-
COOMatrix.elementwiseAtLeastOne clContext opAdd workGroupSize
134-
135-
let toCSRInplace =
136-
COOMatrix.toCSRInplace clContext workGroupSize
137-
138-
fun (processor: MailboxProcessor<_>) (m1: CSRMatrix<'a>) (m2: CSRMatrix<'b>) ->
139-
let m1COO =
140-
{ Context = clContext
141-
RowCount = m1.RowCount
142-
ColumnCount = m1.ColumnCount
143-
Rows = prepareRows processor m1.RowPointers m1.Values.Length m1.RowCount
144-
Columns = m1.Columns
145-
Values = m1.Values }
146-
147-
let m2COO =
148-
{ Context = clContext
149-
RowCount = m2.RowCount
150-
ColumnCount = m2.ColumnCount
151-
Rows = prepareRows processor m2.RowPointers m2.Values.Length m2.RowCount
152-
Columns = m2.Columns
153-
Values = m2.Values }
154-
155-
let m3COO = eWiseCOO processor m1COO m2COO
156-
157-
processor.Post(Msg.CreateFreeMsg(m1COO.Rows))
158-
processor.Post(Msg.CreateFreeMsg(m2COO.Rows))
159-
160-
toCSRInplace processor m3COO
129+
elementwiseWithCOO clContext (StandardOperations.atLeastOneToNormalForm opAdd) workGroupSize
161130

162131
let transposeInplace (clContext: ClContext) workGroupSize =
163132

@@ -174,7 +143,6 @@ module CSRMatrix =
174143
let transposedCoo = transposeInplace queue coo
175144
toCSRInplace queue transposedCoo
176145

177-
178146
let transpose (clContext: ClContext) workGroupSize =
179147

180148
let toCOO = toCOO clContext workGroupSize
@@ -264,67 +232,15 @@ module CSRMatrix =
264232
workGroupSize
265233
=
266234

267-
let merge = merge clContext workGroupSize
268-
269-
let preparePositions =
270-
preparePositionsAtLeastOne clContext opAdd Utils.defaultWorkGroupSize
271-
272-
let setPositions =
273-
setPositions<'c> clContext Utils.defaultWorkGroupSize
274-
275-
fun (queue: MailboxProcessor<_>) (matrixLeft: CSRMatrix<'a>) (matrixRight: CSRMatrix<'b>) ->
276-
277-
let allRows, allColumns, leftMergedValues, rightMergedValues, isRowEnd, isLeft =
278-
merge
279-
queue
280-
matrixLeft.RowPointers
281-
matrixLeft.Columns
282-
matrixLeft.Values
283-
matrixRight.RowPointers
284-
matrixRight.Columns
285-
matrixRight.Values
286-
287-
let positions, allValues =
288-
preparePositions queue allColumns leftMergedValues rightMergedValues isRowEnd isLeft
289-
290-
queue.Post(Msg.CreateFreeMsg<_>(leftMergedValues))
291-
queue.Post(Msg.CreateFreeMsg<_>(rightMergedValues))
292-
293-
let resultRows, resultColumns, resultValues, positions, positionsSum =
294-
setPositions queue allRows allColumns allValues positions
295-
296-
queue.Post(Msg.CreateFreeMsg<_>(allRows))
297-
queue.Post(Msg.CreateFreeMsg<_>(isLeft))
298-
queue.Post(Msg.CreateFreeMsg<_>(isRowEnd))
299-
queue.Post(Msg.CreateFreeMsg<_>(positions))
300-
queue.Post(Msg.CreateFreeMsg<_>(allColumns))
301-
queue.Post(Msg.CreateFreeMsg<_>(allValues))
302-
303-
{ Context = clContext
304-
RowCount = matrixLeft.RowCount
305-
ColumnCount = matrixLeft.ColumnCount
306-
Rows = resultRows
307-
Columns = resultColumns
308-
Values = resultValues }
235+
elementwiseToCOO clContext (StandardOperations.atLeastOneToNormalForm opAdd) workGroupSize
309236

310237
let elementwiseAtLeastOne<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
311238
(clContext: ClContext)
312239
(opAdd: Expr<AtLeastOne<'a, 'b> -> 'c option>)
313240
workGroupSize
314241
=
315242

316-
let elementwiseAtLeastOneToCOO =
317-
elementwiseAtLeastOneToCOO clContext opAdd workGroupSize
318-
319-
let toCSRInplace =
320-
COOMatrix.toCSRInplace clContext Utils.defaultWorkGroupSize
321-
322-
fun (queue: MailboxProcessor<_>) (matrixLeft: CSRMatrix<'a>) (matrixRight: CSRMatrix<'b>) ->
323-
324-
let cooRes =
325-
elementwiseAtLeastOneToCOO queue matrixLeft matrixRight
326-
327-
toCSRInplace queue cooRes
243+
elementwise clContext (StandardOperations.atLeastOneToNormalForm opAdd) workGroupSize
328244

329245
let spgemmCSC
330246
(clContext: ClContext)

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/Elementwise.fs

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -82,79 +82,6 @@ module internal Elementwise =
8282
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
8383
rowPositions, allValues
8484

85-
let preparePositionsAtLeastOne<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct and 'c: equality>
86-
(clContext: ClContext)
87-
(opAdd: Expr<AtLeastOne<'a, 'b> -> 'c option>)
88-
workGroupSize
89-
=
90-
91-
let preparePositions =
92-
<@ fun (ndRange: Range1D) length (allColumns: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (allValues: ClArray<'c>) (rawPositions: ClArray<int>) (isEndOfRowBitmap: ClArray<int>) (isLeftBitmap: ClArray<int>) ->
93-
94-
let i = ndRange.GlobalID0
95-
96-
if (i < length - 1
97-
&& allColumns.[i] = allColumns.[i + 1]
98-
&& isEndOfRowBitmap.[i] = 0) then
99-
100-
let result =
101-
(%opAdd) (Both(leftValues.[i + 1], rightValues.[i]))
102-
103-
(%PreparePositions.both) i result rawPositions allValues
104-
elif i = 0
105-
|| (i < length
106-
&& (allColumns.[i] <> allColumns.[i - 1]
107-
|| isEndOfRowBitmap.[i - 1] = 1)) then
108-
109-
let leftResult = (%opAdd) (Left leftValues.[i])
110-
let rightResult = (%opAdd) (Right rightValues.[i])
111-
112-
(%PreparePositions.leftRight) i leftResult rightResult isLeftBitmap allValues rawPositions @>
113-
114-
let kernel = clContext.Compile(preparePositions)
115-
116-
fun (processor: MailboxProcessor<_>) (allColumns: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (isEndOfRow: ClArray<int>) (isLeft: ClArray<int>) ->
117-
let length = leftValues.Length
118-
119-
let ndRange =
120-
Range1D.CreateValid(length, workGroupSize)
121-
122-
let rowPositions =
123-
clContext.CreateClArray<int>(
124-
length,
125-
deviceAccessMode = DeviceAccessMode.ReadWrite,
126-
hostAccessMode = HostAccessMode.NotAccessible,
127-
allocationMode = AllocationMode.Default
128-
)
129-
130-
let allValues =
131-
clContext.CreateClArray<'c>(
132-
length,
133-
deviceAccessMode = DeviceAccessMode.ReadWrite,
134-
hostAccessMode = HostAccessMode.NotAccessible,
135-
allocationMode = AllocationMode.Default
136-
)
137-
138-
let kernel = kernel.GetKernel()
139-
140-
processor.Post(
141-
Msg.MsgSetArguments
142-
(fun () ->
143-
kernel.KernelFunc
144-
ndRange
145-
length
146-
allColumns
147-
leftValues
148-
rightValues
149-
allValues
150-
rowPositions
151-
isEndOfRow
152-
isLeft)
153-
)
154-
155-
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
156-
rowPositions, allValues
157-
15885
let setPositions<'a when 'a: struct> (clContext: ClContext) workGroupSize =
15986

16087
let sum =

0 commit comments

Comments
 (0)