Skip to content

Commit 471cfdb

Browse files
committed
refactor: DenseVector.preparePositions
1 parent 0ab45fa commit 471cfdb

3 files changed

Lines changed: 105 additions & 37 deletions

File tree

src/GraphBLAS-sharp.Backend/Vector/SparseVector/ElementwiseConstructor.fs

Lines changed: 92 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,78 @@ module ElementwiseConstructor =
119119
firstResultValues.[i] <- firstValuesBuffer.[beginIdx + boundaryX]
120120
isLeftBitMap.[i] <- 1 @>
121121

122+
let private opWriteBothFill (opAdd: Expr<'a option -> 'b option -> 'a -> 'a option>) =
123+
<@
124+
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: 'a) ->
125+
(%opAdd) (Some leftValues.[gid]) (Some rightValues.[gid + 1]) value
126+
@>
127+
128+
let private opWriteLeftFill (opAdd: Expr<'a option -> 'b option -> 'a -> 'a option>) =
129+
<@
130+
fun gid (leftValues: ClArray<'a>) (value: 'a) ->
131+
(%opAdd) (Some leftValues.[gid]) None value
132+
@>
133+
134+
let private opWriteRightFill (opAdd: Expr<'a option -> 'b option -> 'a -> 'a option>) =
135+
<@
136+
fun gid (rightValues: ClArray<'b>) (value: 'a) ->
137+
(%opAdd) None (Some rightValues.[gid + 1]) value
138+
@>
139+
140+
let private opWriteAtLeastOneBothFill (opAdd: Expr<AtLeastOne<'a,'b> -> 'a -> 'a option>) =
141+
<@
142+
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: 'a) ->
143+
(%opAdd) (Both(leftValues.[gid], rightValues.[gid + 1])) value
144+
@>
145+
146+
let private opWriteAtLeastOneLeftFill (opAdd: Expr<AtLeastOne<'a,'b> -> 'a -> 'a option>) =
147+
<@
148+
fun gid (leftValues: ClArray<'a>) (value: 'a) ->
149+
(%opAdd) (Left(leftValues.[gid])) value
150+
@>
151+
152+
let private opWriteAtLeastOneRightFill (opAdd: Expr<AtLeastOne<'a,'b> -> 'a -> 'a option>) =
153+
<@
154+
fun gid (rightValues: ClArray<'b>) (value: 'a) ->
155+
(%opAdd) (Right(rightValues.[gid])) value
156+
@>
157+
158+
let private opWriteBoth (opAdd: Expr<'a option -> 'b option -> 'c option>) =
159+
<@
160+
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) ->
161+
(%opAdd) (Some leftValues.[gid]) (Some rightValues.[gid + 1])
162+
@>
163+
164+
let private opWriteLeft (opAdd: Expr<'a option -> 'b option -> 'c option>) =
165+
<@
166+
fun gid (leftValues: ClArray<'a>)->
167+
(%opAdd) (Some leftValues.[gid]) None
168+
@>
169+
170+
let private opWriteRight (opAdd: Expr<'a option -> 'b option -> 'c option>) =
171+
<@
172+
fun gid (rightValues: ClArray<'b>) ->
173+
(%opAdd) None (Some rightValues.[gid + 1])
174+
@>
175+
176+
let private opWriteAtLeastOneBoth (opAdd: Expr<AtLeastOne<'a,'b> -> 'c option>) =
177+
<@
178+
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) ->
179+
(%opAdd) (Both(leftValues.[gid], rightValues.[gid + 1]))
180+
@>
181+
182+
let opWriteAtLeastOneLeft (opAdd: Expr<AtLeastOne<'a,'b> -> 'c option>) =
183+
<@
184+
fun gid (leftValues: ClArray<'a>) ->
185+
(%opAdd) (Left(leftValues.[gid]))
186+
@>
187+
188+
let opWriteAtLeastOneRight (opAdd: Expr<AtLeastOne<'a,'b> -> 'a option>) =
189+
<@
190+
fun gid (rightValues: ClArray<'b>) ->
191+
(%opAdd) (Right(rightValues.[gid]))
192+
@>
193+
122194
let private both<'c> =
123195
<@ fun index (result: 'c option) (rawPositionsBuffer: ClArray<int>) (allValuesBuffer: ClArray<'c>) ->
124196
rawPositionsBuffer.[index] <- 0
@@ -144,64 +216,56 @@ module ElementwiseConstructor =
144216
rawPositionsBuffer.[index] <- 1
145217
| None -> rawPositionsBuffer.[index] <- 0 @>
146218

147-
let preparePositionsAtLeastOne opAdd =
148-
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (isLeft: ClArray<int>) (allValues: ClArray<'c>) (positions: ClArray<int>) ->
149-
150-
let gid = ndRange.GlobalID0
151-
152-
if gid < length - 1
153-
&& allIndices.[gid] = allIndices.[gid + 1] then
154-
let result = (%opAdd) (Both(leftValues.[gid], rightValues.[gid + 1]))
219+
let private preparePositionsGeneral
220+
(bothWrite: Expr<(int -> ClArray<'a> -> ClArray<'b> -> 'c option)>)
221+
leftWrite
222+
rightWrite
223+
=
155224

156-
(%both) gid result positions allValues
157-
elif (gid < length
158-
&& gid > 0
159-
&& allIndices.[gid - 1] <> allIndices.[gid])
160-
|| gid = 0 then
161-
162-
let leftResult = (%opAdd) (Left(leftValues.[gid]))
163-
let rightResult = (%opAdd) (Right(rightValues.[gid]))
164-
165-
(%leftRight) gid leftResult rightResult isLeft allValues positions @>
166-
167-
let preparePositions (opAdd: Expr<'a option -> 'b option -> 'c option>) =
168225
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (isLeft: ClArray<int>) (allValues: ClArray<'c>) (positions: ClArray<int>) ->
169226

170227
let gid = ndRange.GlobalID0
171228

172229
if gid < length - 1
173230
&& allIndices.[gid] = allIndices.[gid + 1] then
174-
let result = (%opAdd) (Some leftValues.[gid]) (Some rightValues.[gid + 1])
231+
let (result: 'c option) = (%bothWrite) gid leftValues rightValues
175232

176233
(%both) gid result positions allValues
177234
elif (gid < length
178235
&& gid > 0
179236
&& allIndices.[gid - 1] <> allIndices.[gid])
180237
|| gid = 0 then
181238

182-
let leftResult = (%opAdd) (Some leftValues.[gid]) None
183-
let rightResult = (%opAdd) None (Some rightValues.[gid])
239+
let leftResult = (%leftWrite) gid leftValues
240+
let rightResult = (%rightWrite) gid rightValues
184241

185242
(%leftRight) gid leftResult rightResult isLeft allValues positions @>
186243

187-
let preparePositionsFillSubVectorAtLeasOne opAdd =
188-
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: ClCell<'a>) (isLeft: ClArray<int>) (allValues: ClArray<'c>) (positions: ClArray<int>) ->
244+
let private prepareFillVectorGeneral bothWrite leftWrite rightWrite =
245+
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: ClCell<'a>) (isLeft: ClArray<int>) (allValues: ClArray<'a>) (positions: ClArray<int>) ->
189246

190247
let gid = ndRange.GlobalID0
191248

192249
let value = value.Value
193250

194251
if gid < length - 1
195252
&& allIndices.[gid] = allIndices.[gid + 1] then
196-
let result = (%opAdd) (Both(leftValues.[gid], rightValues.[gid + 1])) value
253+
let (result: 'a option) = (%bothWrite) gid leftValues rightValues value
197254

198255
(%both) gid result positions allValues
199256
elif (gid < length
200257
&& gid > 0
201258
&& allIndices.[gid - 1] <> allIndices.[gid])
202259
|| gid = 0 then
203-
let leftResult = (%opAdd) (Left(leftValues.[gid])) value
204-
let rightResult = (%opAdd) (Right(rightValues.[gid])) value
260+
let leftResult = (%leftWrite) gid leftValues value
261+
let rightResult = (%rightWrite) gid rightValues value
205262

206263
(%leftRight) gid leftResult rightResult isLeft allValues positions @>
207264

265+
let preparePositions opAdd = preparePositionsGeneral (opWriteBoth opAdd) (opWriteLeft opAdd) (opWriteRight opAdd)
266+
267+
let preparePositionsAtLeastOne opAdd = preparePositionsGeneral (opWriteAtLeastOneBoth opAdd) (opWriteAtLeastOneLeft opAdd) (opWriteAtLeastOneRight opAdd)
268+
269+
let prepareFillVector opAdd = prepareFillVectorGeneral (opWriteBothFill opAdd) (opWriteLeftFill opAdd) (opWriteRightFill opAdd)
270+
271+
let prepareFillVectorAtLeastOne opAdd = prepareFillVectorGeneral (opWriteAtLeastOneBothFill opAdd) (opWriteAtLeastOneLeftFill opAdd) (opWriteAtLeastOneRightFill opAdd)

src/GraphBLAS-sharp.Backend/Vector/SparseVector/SparseVector.fs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ module SparseVector =
211211
Indices = resultIndices
212212
Size = max leftVector.Size rightVector.Size }
213213

214-
let elementWiseAtLeasOne (clContext: ClContext) opAdd (workGroupSize: int) =
214+
let elementWiseAtLeasOne (clContext: ClContext) (opAdd: Expr<(AtLeastOne<'a,'b> -> 'c option)>) (workGroupSize: int) =
215215
elementWiseGeneral clContext (ElementwiseConstructor.preparePositionsAtLeastOne opAdd) workGroupSize
216216

217217
let elementWise (clContext: ClContext) opAdd (workGroupSize: int) =
@@ -301,7 +301,10 @@ module SparseVector =
301301
Size = max leftVector.Size rightVector.Size }
302302

303303
let fillSubVectorAtLeasOne (clContext: ClContext) opAdd (workGroupSize: int) =
304-
fillSubVectorGeneral clContext (ElementwiseConstructor.preparePositionsFillSubVectorAtLeasOne opAdd) workGroupSize
304+
fillSubVectorGeneral clContext (ElementwiseConstructor.prepareFillVectorAtLeastOne opAdd) workGroupSize
305+
306+
let fillSubVector (clContext: ClContext) opAdd (workGroupSize: int) =
307+
fillSubVectorGeneral clContext (ElementwiseConstructor.prepareFillVector opAdd) workGroupSize
305308

306309
let toDense (clContext: ClContext) (workGroupSize: int) =
307310

src/GraphBLAS-sharp.Backend/Vector/Vector.fs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ open Microsoft.FSharp.Control
66
open Microsoft.FSharp.Quotations
77
open GraphBLAS.FSharp.Backend.Common
88
open GraphBLAS.FSharp.Backend.DenseVector
9+
open GraphBLAS.FSharp.Backend.SparseVector
910

1011
module Vector =
1112
let zeroCreate (clContext: ClContext) (workGroupSize: int) =
@@ -95,7 +96,7 @@ module Vector =
9596

9697
let elementWiseAtLeastOne (clContext: ClContext) (opAdd: Expr<AtLeastOne<'a, 'b> -> 'c option>) workGroupSize =
9798
let addCoo =
98-
SparseVector.elementWiseAtLeastOne clContext opAdd workGroupSize
99+
SparseVector.elementWiseAtLeasOne clContext opAdd workGroupSize //TODO()
99100

100101
let addDense =
101102
DenseVector.elementWiseAtLeastOne clContext opAdd workGroupSize
@@ -114,12 +115,12 @@ module Vector =
114115
| ClVectorDense leftVector, ClVectorDense rightVector -> addDense processor leftVector rightVector
115116
| _ -> failwith "Vector formats are not matching."
116117

117-
let fillSubVector (clContext: ClContext) (workGroupSize: int) =
118+
let fillSubVector (clContext: ClContext) mask (workGroupSize: int) =
118119
let cooFillVector =
119-
SparseVector.fillSubVector clContext workGroupSize
120+
SparseVector.fillSubVector clContext mask workGroupSize
120121

121122
let denseFillVector =
122-
DenseVector.fillSubVector clContext StandardOperations.mask workGroupSize
123+
DenseVector.fillSubVector clContext mask workGroupSize
123124

124125
let toCooVector =
125126
DenseVector.toSparse clContext workGroupSize
@@ -131,17 +132,17 @@ module Vector =
131132
match vector, maskVector with
132133
| ClVectorSparse vector, ClVectorSparse mask ->
133134
ClVectorSparse
134-
<| cooFillVector value processor vector mask
135+
<| cooFillVector processor vector mask value
135136
| ClVectorSparse vector, ClVectorDense mask ->
136137
let mask = toCooMask processor mask
137138

138139
ClVectorSparse
139-
<| cooFillVector value processor vector mask
140+
<| cooFillVector processor vector mask value
140141
| ClVectorDense vector, ClVectorSparse mask ->
141142
let vector = toCooVector processor vector
142143

143144
ClVectorSparse
144-
<| cooFillVector value processor vector mask
145+
<| cooFillVector processor vector mask value
145146
| ClVectorDense vector, ClVectorDense mask ->
146147
ClVectorDense
147148
<| denseFillVector processor vector mask value

0 commit comments

Comments
 (0)