Skip to content

Commit 0ab45fa

Browse files
committed
refactor: SparseVector.fillSubVector, elementWise
1 parent f76e8ac commit 0ab45fa

6 files changed

Lines changed: 312 additions & 167 deletions

File tree

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333
<Compile Include="Matrix/CSRMatrix/CSRMatrix.fs" />
3434
<Compile Include="Matrix/CSRMatrix/SpMV.fs" />
3535
<Compile Include="Matrix/Matrix.fs" />
36+
<Compile Include="Vector\SparseVector\ElementwiseConstructor.fs" />
3637
<Compile Include="Vector/SparseVector/SparseVector.fs" />
37-
<Compile Include="Vector\DenseVector\ElementwiseQuotes.fs" />
38+
<Compile Include="Vector\DenseVector\ElementwiseConstructor.fs" />
3839
<Compile Include="Vector/DenseVector/DenseVector.fs" />
3940
<Compile Include="Vector/Vector.fs" />
4041
<!--Compile Include="Backend/CSRMatrix/GetTuples.fs" /-->

src/GraphBLAS-sharp.Backend/Vector/DenseVector/DenseVector.fs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
namespace GraphBLAS.FSharp.Backend
1+
namespace GraphBLAS.FSharp.Backend.DenseVector
22

33
open Brahma.FSharp
44
open GraphBLAS.FSharp.Backend
@@ -12,7 +12,7 @@ module DenseVector =
1212
(workGroupSize: int)
1313
=
1414

15-
let kernel = clContext.Compile(ElementwiseQuotes.kernel opAdd)
15+
let kernel = clContext.Compile(ElementwiseConstructor.kernel opAdd)
1616

1717
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) ->
1818

@@ -44,7 +44,7 @@ module DenseVector =
4444
(workGroupSize: int)
4545
=
4646

47-
let kernel = clContext.Compile(ElementwiseQuotes.atLeastOneKernel opAdd)
47+
let kernel = clContext.Compile(ElementwiseConstructor.atLeastOneKernel opAdd)
4848

4949
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) ->
5050

@@ -75,7 +75,7 @@ module DenseVector =
7575
(maskOp: Expr<'a option -> 'b option -> 'a -> 'a option>)
7676
(workGroupSize: int) =
7777

78-
let kernel = clContext.Compile(ElementwiseQuotes.fillSubVector maskOp)
78+
let kernel = clContext.Compile(ElementwiseConstructor.fillSubVector maskOp)
7979

8080
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (maskVector: ClArray<'b option>) (value: ClCell<'a>) ->
8181
let resultArray =

src/GraphBLAS-sharp.Backend/Vector/DenseVector/ElementwiseQuotes.fs renamed to src/GraphBLAS-sharp.Backend/Vector/DenseVector/ElementwiseConstructor.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
namespace GraphBLAS.FSharp.Backend
1+
namespace GraphBLAS.FSharp.Backend.DenseVector
22

33
open Brahma.FSharp
44
open GraphBLAS.FSharp.Backend.Common
55

6-
module ElementwiseQuotes =
6+
module ElementwiseConstructor =
77
let private elementWiseGeneralKernel writeOp =
88
<@ fun (ndRange: Range1D) resultLength (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (resultVector: ClArray<'c option>) ->
99

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
namespace GraphBLAS.FSharp.Backend.SparseVector
2+
3+
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Backend.Common
5+
open Microsoft.FSharp.Quotations
6+
7+
module ElementwiseConstructor =
8+
let merge workGroupSize =
9+
<@ fun (ndRange: Range1D) (firstSide: int) (secondSide: int) (sumOfSides: int) (firstIndicesBuffer: ClArray<int>) (firstValuesBuffer: ClArray<'a>) (secondIndicesBuffer: ClArray<int>) (secondValuesBuffer: ClArray<'b>) (allIndicesBuffer: ClArray<int>) (firstResultValues: ClArray<'a>) (secondResultValues: ClArray<'b>) (isLeftBitMap: ClArray<int>) ->
10+
11+
let i = ndRange.GlobalID0
12+
13+
let mutable beginIdxLocal = local ()
14+
let mutable endIdxLocal = local ()
15+
let localID = ndRange.LocalID0
16+
17+
if localID < 2 then
18+
let mutable x = localID * (workGroupSize - 1) + i - 1
19+
20+
if x >= sumOfSides then
21+
x <- sumOfSides - 1
22+
23+
let diagonalNumber = x
24+
25+
let mutable leftEdge = diagonalNumber + 1 - secondSide
26+
if leftEdge < 0 then leftEdge <- 0
27+
28+
let mutable rightEdge = firstSide - 1
29+
30+
if rightEdge > diagonalNumber then
31+
rightEdge <- diagonalNumber
32+
33+
while leftEdge <= rightEdge do
34+
let middleIdx = (leftEdge + rightEdge) / 2
35+
let firstIndex = firstIndicesBuffer.[middleIdx]
36+
37+
let secondIndex =
38+
secondIndicesBuffer.[diagonalNumber - middleIdx]
39+
40+
if firstIndex <= secondIndex then
41+
leftEdge <- middleIdx + 1
42+
else
43+
rightEdge <- middleIdx - 1
44+
45+
// Here localID equals either 0 or 1
46+
if localID = 0 then
47+
beginIdxLocal <- leftEdge
48+
else
49+
endIdxLocal <- leftEdge
50+
51+
barrierLocal ()
52+
53+
let beginIdx = beginIdxLocal
54+
let endIdx = endIdxLocal
55+
let firstLocalLength = endIdx - beginIdx
56+
let mutable x = workGroupSize - firstLocalLength
57+
58+
if endIdx = firstSide then
59+
x <- secondSide - i + localID + beginIdx
60+
61+
let secondLocalLength = x
62+
63+
//First indices are from 0 to firstLocalLength - 1 inclusive
64+
//Second indices are from firstLocalLength to firstLocalLength + secondLocalLength - 1 inclusive
65+
let localIndices = localArray<int> workGroupSize
66+
67+
if localID < firstLocalLength then
68+
localIndices.[localID] <- firstIndicesBuffer.[beginIdx + localID]
69+
70+
if localID < secondLocalLength then
71+
localIndices.[firstLocalLength + localID] <- secondIndicesBuffer.[i - beginIdx]
72+
73+
barrierLocal ()
74+
75+
if i < sumOfSides then
76+
let mutable leftEdge = localID + 1 - secondLocalLength
77+
if leftEdge < 0 then leftEdge <- 0
78+
79+
let mutable rightEdge = firstLocalLength - 1
80+
81+
if rightEdge > localID then
82+
rightEdge <- localID
83+
84+
while leftEdge <= rightEdge do
85+
let middleIdx = (leftEdge + rightEdge) / 2
86+
let firstIndex = localIndices.[middleIdx]
87+
88+
let secondIndex =
89+
localIndices.[firstLocalLength + localID - middleIdx]
90+
91+
if firstIndex <= secondIndex then
92+
leftEdge <- middleIdx + 1
93+
else
94+
rightEdge <- middleIdx - 1
95+
96+
let boundaryX = rightEdge
97+
let boundaryY = localID - leftEdge
98+
99+
// boundaryX and boundaryY can't be off the right edge of array (only off the left edge)
100+
let isValidX = boundaryX >= 0
101+
let isValidY = boundaryY >= 0
102+
103+
let mutable fstIdx = 0
104+
105+
if isValidX then
106+
fstIdx <- localIndices.[boundaryX]
107+
108+
let mutable sndIdx = 0
109+
110+
if isValidY then
111+
sndIdx <- localIndices.[firstLocalLength + boundaryY]
112+
113+
if not isValidX || isValidY && fstIdx <= sndIdx then
114+
allIndicesBuffer.[i] <- sndIdx
115+
secondResultValues.[i] <- secondValuesBuffer.[i - localID - beginIdx + boundaryY]
116+
isLeftBitMap.[i] <- 0
117+
else
118+
allIndicesBuffer.[i] <- fstIdx
119+
firstResultValues.[i] <- firstValuesBuffer.[beginIdx + boundaryX]
120+
isLeftBitMap.[i] <- 1 @>
121+
122+
let private both<'c> =
123+
<@ fun index (result: 'c option) (rawPositionsBuffer: ClArray<int>) (allValuesBuffer: ClArray<'c>) ->
124+
rawPositionsBuffer.[index] <- 0
125+
126+
match result with
127+
| Some v ->
128+
allValuesBuffer.[index + 1] <- v
129+
rawPositionsBuffer.[index + 1] <- 1
130+
| None -> rawPositionsBuffer.[index + 1] <- 0 @>
131+
132+
let private leftRight<'c> =
133+
<@ fun index (leftResult: 'c option) (rightResult: 'c option) (isLeftBitmap: ClArray<int>) (allValuesBuffer: ClArray<'c>) (rawPositionsBuffer: ClArray<int>) ->
134+
if isLeftBitmap.[index] = 1 then
135+
match leftResult with
136+
| Some v ->
137+
allValuesBuffer.[index] <- v
138+
rawPositionsBuffer.[index] <- 1
139+
| None -> rawPositionsBuffer.[index] <- 0
140+
else
141+
match rightResult with
142+
| Some v ->
143+
allValuesBuffer.[index] <- v
144+
rawPositionsBuffer.[index] <- 1
145+
| None -> rawPositionsBuffer.[index] <- 0 @>
146+
147+
let preparePositionsAtLeastOne opAdd =
148+
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (isLeft: ClArray<int>) (allValues: ClArray<'c>) (positions: ClArray<int>) ->
149+
150+
let gid = ndRange.GlobalID0
151+
152+
if gid < length - 1
153+
&& allIndices.[gid] = allIndices.[gid + 1] then
154+
let result = (%opAdd) (Both(leftValues.[gid], rightValues.[gid + 1]))
155+
156+
(%both) gid result positions allValues
157+
elif (gid < length
158+
&& gid > 0
159+
&& allIndices.[gid - 1] <> allIndices.[gid])
160+
|| gid = 0 then
161+
162+
let leftResult = (%opAdd) (Left(leftValues.[gid]))
163+
let rightResult = (%opAdd) (Right(rightValues.[gid]))
164+
165+
(%leftRight) gid leftResult rightResult isLeft allValues positions @>
166+
167+
let preparePositions (opAdd: Expr<'a option -> 'b option -> 'c option>) =
168+
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (isLeft: ClArray<int>) (allValues: ClArray<'c>) (positions: ClArray<int>) ->
169+
170+
let gid = ndRange.GlobalID0
171+
172+
if gid < length - 1
173+
&& allIndices.[gid] = allIndices.[gid + 1] then
174+
let result = (%opAdd) (Some leftValues.[gid]) (Some rightValues.[gid + 1])
175+
176+
(%both) gid result positions allValues
177+
elif (gid < length
178+
&& gid > 0
179+
&& allIndices.[gid - 1] <> allIndices.[gid])
180+
|| gid = 0 then
181+
182+
let leftResult = (%opAdd) (Some leftValues.[gid]) None
183+
let rightResult = (%opAdd) None (Some rightValues.[gid])
184+
185+
(%leftRight) gid leftResult rightResult isLeft allValues positions @>
186+
187+
let preparePositionsFillSubVectorAtLeasOne opAdd =
188+
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: ClCell<'a>) (isLeft: ClArray<int>) (allValues: ClArray<'c>) (positions: ClArray<int>) ->
189+
190+
let gid = ndRange.GlobalID0
191+
192+
let value = value.Value
193+
194+
if gid < length - 1
195+
&& allIndices.[gid] = allIndices.[gid + 1] then
196+
let result = (%opAdd) (Both(leftValues.[gid], rightValues.[gid + 1])) value
197+
198+
(%both) gid result positions allValues
199+
elif (gid < length
200+
&& gid > 0
201+
&& allIndices.[gid - 1] <> allIndices.[gid])
202+
|| gid = 0 then
203+
let leftResult = (%opAdd) (Left(leftValues.[gid])) value
204+
let rightResult = (%opAdd) (Right(rightValues.[gid])) value
205+
206+
(%leftRight) gid leftResult rightResult isLeft allValues positions @>
207+

0 commit comments

Comments
 (0)