Skip to content

Commit 5fe00e7

Browse files
committed
Refactor after merge
1 parent f0b17c7 commit 5fe00e7

11 files changed

Lines changed: 305 additions & 242 deletions

File tree

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
namespace GraphBLAS.FSharp.Backend.COOMatrix
22

33
open Brahma.FSharp.OpenCL.WorkflowBuilder.Basic
4-
open Brahma.FSharp.OpenCL.WorkflowBuilder.Evaluation
54
open GraphBLAS.FSharp
65
open GraphBLAS.FSharp.Backend.Common
6+
open GraphBLAS.FSharp.Backend.COOMatrix.Utilities
77

88
module internal EWiseAdd =
9-
let cooNotEmpty (matrixLeft: COOFormat<'a>) (matrixRight: COOFormat<'a>) (mask: Mask2D option) (semiring: Semiring<'a>) : OpenCLEvaluation<COOFormat<'a>> = opencl {
10-
let! allRows, allColumns, allValues = Merge.runForMatrix matrixLeft matrixRight mask
9+
let private runNonEmpty (matrixLeft: COOMatrix<'a>) (matrixRight: COOMatrix<'a>) (mask: Mask2D option) (semiring: ISemiring<'a>) = opencl {
10+
let! allRows, allColumns, allValues = merge matrixLeft matrixRight mask
1111

12-
let (BinaryOp append) = semiring.PlusMonoid.Append
13-
let! rawPositions = PreparePositions.runForMatrix allRows allColumns allValues append
12+
let (ClosedBinaryOp plus) = semiring.Plus
13+
let! rawPositions = preparePositions allRows allColumns allValues plus
1414

15-
let! resultRows, resultColumns, resultValues = SetPositions.runForMatrix allRows allColumns allValues rawPositions
15+
let! resultRows, resultColumns, resultValues = setPositions allRows allColumns allValues rawPositions
1616

1717
return {
1818
RowCount = matrixLeft.RowCount
@@ -23,7 +23,7 @@ module internal EWiseAdd =
2323
}
2424
}
2525

26-
let coo (matrixLeft: COOFormat<'a>) (matrixRight: COOFormat<'a>) (mask: Mask2D option) (semiring: Semiring<'a>) : OpenCLEvaluation<COOFormat<'a>> =
26+
let run (matrixLeft: COOMatrix<'a>) (matrixRight: COOMatrix<'a>) (mask: Mask2D option) (semiring: ISemiring<'a>) =
2727
if matrixLeft.Values.Length = 0 then
2828
opencl {
2929
let! resultRows = Copy.run matrixRight.Rows
@@ -38,6 +38,7 @@ module internal EWiseAdd =
3838
Values = resultValues
3939
}
4040
}
41+
4142
elif matrixRight.Values.Length = 0 then
4243
opencl {
4344
let! resultRows = Copy.run matrixLeft.Rows
@@ -52,30 +53,6 @@ module internal EWiseAdd =
5253
Values = resultValues
5354
}
5455
}
55-
else cooNotEmpty matrixLeft matrixRight mask semiring
56-
57-
let sparseNotEmpty (leftIndices: int[]) (leftValues: 'a[]) (rightIndices: int[]) (rightValues: 'a[]) (mask: Mask1D option) (semiring: Semiring<'a>) : OpenCLEvaluation<int[] * 'a[]> = opencl {
58-
let! allIndices, allValues = Merge.runForVector leftIndices leftValues rightIndices rightValues mask
59-
60-
let (BinaryOp append) = semiring.PlusMonoid.Append
61-
let! rawPositions = PreparePositions.runForVector allIndices allValues append
6256

63-
return! SetPositions.runForVector allIndices allValues rawPositions
64-
}
65-
66-
let sparse (leftIndices: int[]) (leftValues: 'a[]) (rightIndices: int[]) (rightValues: 'a[]) (mask: Mask1D option) (semiring: Semiring<'a>) : OpenCLEvaluation<int[] * 'a[]> =
67-
if leftValues.Length = 0 then
68-
opencl {
69-
let! resultIndices = Copy.run rightIndices
70-
let! resultValues = Copy.run rightValues
71-
72-
return resultIndices, resultValues
73-
}
74-
elif rightIndices.Length = 0 then
75-
opencl {
76-
let! resultIndices = Copy.run leftIndices
77-
let! resultValues = Copy.run leftValues
78-
79-
return resultIndices, resultValues
80-
}
81-
else sparseNotEmpty leftIndices leftValues rightIndices rightValues mask semiring
57+
else
58+
runNonEmpty matrixLeft matrixRight mask semiring
Lines changed: 3 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
namespace GraphBLAS.FSharp.Backend
1+
namespace GraphBLAS.FSharp.Backend.COOMatrix.Utilities
22

33
open Brahma.OpenCL
44
open Brahma.FSharp.OpenCL.WorkflowBuilder.Basic
55
open Brahma.FSharp.OpenCL.WorkflowBuilder.Evaluation
66
open GraphBLAS.FSharp
77
open GraphBLAS.FSharp.Backend.Common
88

9+
[<AutoOpen>]
910
module internal Merge =
10-
let runForMatrix (matrixLeft: COOFormat<'a>) (matrixRight: COOFormat<'a>) (mask: Mask2D option) : OpenCLEvaluation<int[] * int[] * 'a[]> = opencl {
11+
let merge (matrixLeft: COOMatrix<'a>) (matrixRight: COOMatrix<'a>) (mask: Mask2D option) : OpenCLEvaluation<int[] * int[] * 'a[]> = opencl {
1112
let workGroupSize = Utils.workGroupSize
1213
let firstSide = matrixLeft.Values.Length
1314
let secondSide = matrixRight.Values.Length
@@ -125,113 +126,3 @@ module internal Merge =
125126

126127
return allRows, allColumns, allValues
127128
}
128-
129-
let runForVector (leftIndices: int[]) (leftValues: 'a[]) (rightIndices: int[]) (rightValues: 'a[]) (mask: Mask1D option) : OpenCLEvaluation<int[] * 'a[]> = opencl {
130-
let workGroupSize = Utils.workGroupSize
131-
let firstSide = leftValues.Length
132-
let secondSide = rightValues.Length
133-
let sumOfSides = firstSide + secondSide
134-
135-
let merge =
136-
<@
137-
fun (ndRange: _1D)
138-
(firstIndicesBuffer: int[])
139-
(firstValuesBuffer: 'a[])
140-
(secondIndicesBuffer: int[])
141-
(secondValuesBuffer: 'a[])
142-
(allIndicesBuffer: int[])
143-
(allValuesBuffer: 'a[]) ->
144-
145-
let i = ndRange.GlobalID0
146-
147-
let mutable beginIdxLocal = local ()
148-
let mutable endIdxLocal = local ()
149-
let localID = ndRange.LocalID0
150-
if localID < 2 then
151-
let mutable x = localID * (workGroupSize - 1) + i - 1
152-
if x >= sumOfSides then x <- sumOfSides - 1
153-
let diagonalNumber = x
154-
155-
let mutable leftEdge = diagonalNumber + 1 - secondSide
156-
if leftEdge < 0 then leftEdge <- 0
157-
158-
let mutable rightEdge = firstSide - 1
159-
if rightEdge > diagonalNumber then rightEdge <- diagonalNumber
160-
161-
while leftEdge <= rightEdge do
162-
let middleIdx = (leftEdge + rightEdge) / 2
163-
let firstIndex = firstIndicesBuffer.[middleIdx]
164-
let secondIndex = secondIndicesBuffer.[diagonalNumber - middleIdx]
165-
if firstIndex < secondIndex then leftEdge <- middleIdx + 1 else rightEdge <- middleIdx - 1
166-
167-
// Here localID equals either 0 or 1
168-
if localID = 0 then beginIdxLocal <- leftEdge else endIdxLocal <- leftEdge
169-
barrier ()
170-
171-
let beginIdx = beginIdxLocal
172-
let endIdx = endIdxLocal
173-
let firstLocalLength = endIdx - beginIdx
174-
let mutable x = workGroupSize - firstLocalLength
175-
if endIdx = firstSide then x <- secondSide - i + localID + beginIdx
176-
let secondLocalLength = x
177-
178-
//First indices are from 0 to firstLocalLength - 1 inclusive
179-
//Second indices are from firstLocalLength to firstLocalLength + secondLocalLength - 1 inclusive
180-
let localIndices = localArray<int> workGroupSize
181-
182-
if localID < firstLocalLength then
183-
localIndices.[localID] <- firstIndicesBuffer.[beginIdx + localID]
184-
if localID < secondLocalLength then
185-
localIndices.[firstLocalLength + localID] <- secondIndicesBuffer.[i - beginIdx]
186-
barrier ()
187-
188-
if i < sumOfSides then
189-
let mutable leftEdge = localID + 1 - secondLocalLength
190-
if leftEdge < 0 then leftEdge <- 0
191-
192-
let mutable rightEdge = firstLocalLength - 1
193-
if rightEdge > localID then rightEdge <- localID
194-
195-
while leftEdge <= rightEdge do
196-
let middleIdx = (leftEdge + rightEdge) / 2
197-
let firstIndex = localIndices.[middleIdx]
198-
let secondIndex = localIndices.[firstLocalLength + localID - middleIdx]
199-
if firstIndex < secondIndex then leftEdge <- middleIdx + 1 else rightEdge <- middleIdx - 1
200-
201-
let boundaryX = rightEdge
202-
let boundaryY = localID - leftEdge
203-
204-
// boundaryX and boundaryY can't be off the right edge of array (only off the left edge)
205-
let isValidX = boundaryX >= 0
206-
let isValidY = boundaryY >= 0
207-
208-
let mutable fstIdx = 0
209-
if isValidX then fstIdx <- localIndices.[boundaryX]
210-
211-
let mutable sndIdx = 0
212-
if isValidY then sndIdx <- localIndices.[firstLocalLength + boundaryY]
213-
214-
if not isValidX || isValidY && fstIdx < sndIdx then
215-
allIndicesBuffer.[i] <- sndIdx
216-
allValuesBuffer.[i] <- secondValuesBuffer.[i - localID - beginIdx + boundaryY]
217-
else
218-
allIndicesBuffer.[i] <- fstIdx
219-
allValuesBuffer.[i] <- firstValuesBuffer.[beginIdx + boundaryX]
220-
@>
221-
222-
let allIndices = Array.zeroCreate sumOfSides
223-
let allValues = Array.create sumOfSides Unchecked.defaultof<'a>
224-
225-
do! RunCommand merge <| fun kernelPrepare ->
226-
let ndRange = _1D(Utils.workSize sumOfSides, workGroupSize)
227-
kernelPrepare
228-
ndRange
229-
leftIndices
230-
leftValues
231-
rightIndices
232-
rightValues
233-
allIndices
234-
allValues
235-
236-
return allIndices, allValues
237-
}

src/GraphBLAS-sharp/Backend/PreparePositions.fs renamed to src/GraphBLAS-sharp/Backend/COOMatrix/Utilities/PreparePositions.fs

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
namespace GraphBLAS.FSharp.Backend
1+
namespace GraphBLAS.FSharp.Backend.COOMatrix.Utilities
22

33
open Brahma.OpenCL
44
open Brahma.FSharp.OpenCL.WorkflowBuilder.Basic
55
open Brahma.FSharp.OpenCL.WorkflowBuilder.Evaluation
66
open GraphBLAS.FSharp.Backend.Common
77
open Microsoft.FSharp.Quotations
88

9+
[<AutoOpen>]
910
module internal PreparePositions =
10-
let runForMatrix (allRows: int[]) (allColumns: int[]) (allValues: 'a[]) (plus: Expr<'a -> 'a -> 'a>) : OpenCLEvaluation<int[]> = opencl {
11+
let preparePositions (allRows: int[]) (allColumns: int[]) (allValues: 'a[]) (plus: Expr<'a -> 'a -> 'a>) : OpenCLEvaluation<int[]> = opencl {
1112
let length = allValues.Length
1213

1314
let preparePositions =
@@ -44,39 +45,3 @@ module internal PreparePositions =
4445

4546
return rawPositions
4647
}
47-
48-
let runForVector (allIndices: int[]) (allValues: 'a[]) (plus: Expr<'a -> 'a -> 'a>) : OpenCLEvaluation<int[]> = opencl {
49-
let length = allValues.Length
50-
51-
let preparePositions =
52-
<@
53-
fun (ndRange: _1D)
54-
(allIndicesBuffer: int[])
55-
(allValuesBuffer: 'a[])
56-
(rawPositionsBuffer: int[]) ->
57-
58-
let i = ndRange.GlobalID0
59-
60-
if i < length - 1 && allIndicesBuffer.[i] = allIndicesBuffer.[i + 1] then
61-
rawPositionsBuffer.[i] <- 0
62-
63-
//Do not drop explicit zeroes
64-
allValuesBuffer.[i + 1] <- (%plus) allValuesBuffer.[i] allValuesBuffer.[i + 1]
65-
66-
//Drop explicit zeroes
67-
// let localResultBuffer = (%plus) allValuesBuffer.[i] allValuesBuffer.[i + 1]
68-
// if localResultBuffer = zero then rawPositionsBuffer.[i + 1] <- 0 else allValuesBuffer.[i + 1] <- localResultBuffer
69-
@>
70-
71-
let rawPositions = Array.create length 1
72-
73-
do! RunCommand preparePositions <| fun kernelPrepare ->
74-
let ndRange = _1D(Utils.workSize (length - 1), Utils.workGroupSize)
75-
kernelPrepare
76-
ndRange
77-
allIndices
78-
allValues
79-
rawPositions
80-
81-
return rawPositions
82-
}
Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
namespace GraphBLAS.FSharp.Backend
1+
namespace GraphBLAS.FSharp.Backend.COOMatrix.Utilities
22

33
open Brahma.OpenCL
44
open Brahma.FSharp.OpenCL.WorkflowBuilder.Basic
55
open Brahma.FSharp.OpenCL.WorkflowBuilder.Evaluation
66
open GraphBLAS.FSharp.Backend.Common
77

8+
[<AutoOpen>]
89
module internal SetPositions =
9-
let runForMatrix (allRows: int[]) (allColumns: int[]) (allValues: 'a[]) (positions: int[]) : OpenCLEvaluation<int[] * int[] * 'a[]> = opencl {
10+
let setPositions (allRows: int[]) (allColumns: int[]) (allValues: 'a[]) (positions: int[]) : OpenCLEvaluation<int[] * int[] * 'a[]> = opencl {
1011
let prefixSumArrayLength = positions.Length
1112

1213
let setPositions =
@@ -54,46 +55,3 @@ module internal SetPositions =
5455

5556
return resultRows, resultColumns, resultValues
5657
}
57-
58-
let runForVector (allIndices: int[]) (allValues: 'a[]) (positions: int[]) : OpenCLEvaluation<int[] * 'a[]> = opencl {
59-
let prefixSumArrayLength = positions.Length
60-
61-
let setPositions =
62-
<@
63-
fun (ndRange: _1D)
64-
(allIndicesBuffer: int[])
65-
(allValuesBuffer: 'a[])
66-
(prefixSumArrayBuffer: int[])
67-
(resultIndicesBuffer: int[])
68-
(resultValuesBuffer: 'a[]) ->
69-
70-
let i = ndRange.GlobalID0
71-
72-
if i = prefixSumArrayLength - 1 || i < prefixSumArrayLength && prefixSumArrayBuffer.[i] <> prefixSumArrayBuffer.[i + 1] then
73-
let index = prefixSumArrayBuffer.[i]
74-
75-
resultIndicesBuffer.[index] <- allIndicesBuffer.[i]
76-
resultValuesBuffer.[index] <- allValuesBuffer.[i]
77-
@>
78-
79-
let resultLength = Array.zeroCreate 1
80-
81-
do! PrefixSum.run positions resultLength
82-
let! _ = ToHost resultLength
83-
let resultLength = resultLength.[0]
84-
85-
let resultIndices = Array.zeroCreate resultLength
86-
let resultValues = Array.create resultLength Unchecked.defaultof<'a>
87-
88-
do! RunCommand setPositions <| fun kernelPrepare ->
89-
let ndRange = _1D(Utils.workSize positions.Length, Utils.workGroupSize)
90-
kernelPrepare
91-
ndRange
92-
allIndices
93-
allValues
94-
positions
95-
resultIndices
96-
resultValues
97-
98-
return resultIndices, resultValues
99-
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
namespace GraphBLAS.FSharp.Backend.COOVector
2+
3+
open Brahma.FSharp.OpenCL.WorkflowBuilder.Basic
4+
open Brahma.FSharp.OpenCL.WorkflowBuilder.Evaluation
5+
open GraphBLAS.FSharp
6+
open GraphBLAS.FSharp.Backend.Common
7+
open GraphBLAS.FSharp.Backend.COOVector.Utilities
8+
9+
module internal EWiseAdd =
10+
let private runNonEmpty (leftIndices: int[]) (leftValues: 'a[]) (rightIndices: int[]) (rightValues: 'a[]) (mask: Mask1D option) (semiring: ISemiring<'a>) : OpenCLEvaluation<int[] * 'a[]> = opencl {
11+
let! allIndices, allValues = merge leftIndices leftValues rightIndices rightValues mask
12+
13+
let (ClosedBinaryOp plus) = semiring.Plus
14+
let! rawPositions = preparePositions allIndices allValues plus
15+
16+
return! setPositions allIndices allValues rawPositions
17+
}
18+
19+
let run (leftIndices: int[]) (leftValues: 'a[]) (rightIndices: int[]) (rightValues: 'a[]) (mask: Mask1D option) (semiring: ISemiring<'a>) : OpenCLEvaluation<int[] * 'a[]> =
20+
if leftValues.Length = 0 then
21+
opencl {
22+
let! resultIndices = Copy.run rightIndices
23+
let! resultValues = Copy.run rightValues
24+
25+
return resultIndices, resultValues
26+
}
27+
28+
elif rightIndices.Length = 0 then
29+
opencl {
30+
let! resultIndices = Copy.run leftIndices
31+
let! resultValues = Copy.run leftValues
32+
33+
return resultIndices, resultValues
34+
}
35+
36+
else
37+
runNonEmpty leftIndices leftValues rightIndices rightValues mask semiring

0 commit comments

Comments
 (0)