Skip to content

Commit 7b2e627

Browse files
committed
refactor: types in SparseVector.ElementWiseConstructor
1 parent 471cfdb commit 7b2e627

5 files changed

Lines changed: 197 additions & 183 deletions

File tree

src/GraphBLAS-sharp.Backend/Vector/SparseVector/ElementwiseConstructor.fs

Lines changed: 93 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -119,77 +119,81 @@ module ElementwiseConstructor =
119119
firstResultValues.[i] <- firstValuesBuffer.[beginIdx + boundaryX]
120120
isLeftBitMap.[i] <- 1 @>
121121

122-
let private opWriteBothFill (opAdd: Expr<'a option -> 'b option -> 'a -> 'a option>) =
123-
<@
124-
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: 'a) ->
125-
(%opAdd) (Some leftValues.[gid]) (Some rightValues.[gid + 1]) value
126-
@>
127-
128-
let private opWriteLeftFill (opAdd: Expr<'a option -> 'b option -> 'a -> 'a option>) =
129-
<@
130-
fun gid (leftValues: ClArray<'a>) (value: 'a) ->
131-
(%opAdd) (Some leftValues.[gid]) None value
132-
@>
133-
134-
let private opWriteRightFill (opAdd: Expr<'a option -> 'b option -> 'a -> 'a option>) =
135-
<@
136-
fun gid (rightValues: ClArray<'b>) (value: 'a) ->
137-
(%opAdd) None (Some rightValues.[gid + 1]) value
138-
@>
139-
140-
let private opWriteAtLeastOneBothFill (opAdd: Expr<AtLeastOne<'a,'b> -> 'a -> 'a option>) =
141-
<@
142-
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: 'a) ->
143-
(%opAdd) (Both(leftValues.[gid], rightValues.[gid + 1])) value
144-
@>
145-
146-
let private opWriteAtLeastOneLeftFill (opAdd: Expr<AtLeastOne<'a,'b> -> 'a -> 'a option>) =
147-
<@
148-
fun gid (leftValues: ClArray<'a>) (value: 'a) ->
149-
(%opAdd) (Left(leftValues.[gid])) value
150-
@>
151-
152-
let private opWriteAtLeastOneRightFill (opAdd: Expr<AtLeastOne<'a,'b> -> 'a -> 'a option>) =
153-
<@
154-
fun gid (rightValues: ClArray<'b>) (value: 'a) ->
155-
(%opAdd) (Right(rightValues.[gid])) value
156-
@>
157-
158-
let private opWriteBoth (opAdd: Expr<'a option -> 'b option -> 'c option>) =
159-
<@
160-
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) ->
161-
(%opAdd) (Some leftValues.[gid]) (Some rightValues.[gid + 1])
162-
@>
163-
164-
let private opWriteLeft (opAdd: Expr<'a option -> 'b option -> 'c option>) =
165-
<@
166-
fun gid (leftValues: ClArray<'a>)->
167-
(%opAdd) (Some leftValues.[gid]) None
168-
@>
169-
170-
let private opWriteRight (opAdd: Expr<'a option -> 'b option -> 'c option>) =
171-
<@
172-
fun gid (rightValues: ClArray<'b>) ->
173-
(%opAdd) None (Some rightValues.[gid + 1])
174-
@>
175-
176-
let private opWriteAtLeastOneBoth (opAdd: Expr<AtLeastOne<'a,'b> -> 'c option>) =
177-
<@
178-
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) ->
179-
(%opAdd) (Both(leftValues.[gid], rightValues.[gid + 1]))
180-
@>
181-
182-
let opWriteAtLeastOneLeft (opAdd: Expr<AtLeastOne<'a,'b> -> 'c option>) =
183-
<@
184-
fun gid (leftValues: ClArray<'a>) ->
185-
(%opAdd) (Left(leftValues.[gid]))
186-
@>
187-
188-
let opWriteAtLeastOneRight (opAdd: Expr<AtLeastOne<'a,'b> -> 'a option>) =
189-
<@
190-
fun gid (rightValues: ClArray<'b>) ->
191-
(%opAdd) (Right(rightValues.[gid]))
192-
@>
122+
module FillSubVectorRead =
123+
let both (opAdd: Expr<'a option -> 'b option -> 'a -> 'a option>) =
124+
<@
125+
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: 'a) ->
126+
(%opAdd) (Some leftValues.[gid]) (Some rightValues.[gid + 1]) value
127+
@>
128+
129+
let left (opAdd: Expr<'a option -> 'b option -> 'a -> 'a option>) =
130+
<@
131+
fun gid (leftValues: ClArray<'a>) (value: 'a) ->
132+
(%opAdd) (Some leftValues.[gid]) None value
133+
@>
134+
135+
let right (opAdd: Expr<'a option -> 'b option -> 'a -> 'a option>) =
136+
<@
137+
fun gid (rightValues: ClArray<'b>) (value: 'a) ->
138+
(%opAdd) None (Some rightValues.[gid + 1]) value
139+
@>
140+
141+
module FillSubVectorAtLeasOneRead =
142+
let both (opAdd: Expr<AtLeastOne<'a,'b> -> 'a -> 'a option>) =
143+
<@
144+
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: 'a) ->
145+
(%opAdd) (Both(leftValues.[gid], rightValues.[gid + 1])) value
146+
@>
147+
148+
let left (opAdd: Expr<AtLeastOne<'a,'b> -> 'a -> 'a option>) =
149+
<@
150+
fun gid (leftValues: ClArray<'a>) (value: 'a) ->
151+
(%opAdd) (Left(leftValues.[gid])) value
152+
@>
153+
154+
let right (opAdd: Expr<AtLeastOne<'a,'b> -> 'a -> 'a option>) =
155+
<@
156+
fun gid (rightValues: ClArray<'b>) (value: 'a) ->
157+
(%opAdd) (Right(rightValues.[gid])) value
158+
@>
159+
160+
module ElementWiseRead =
161+
let both (opAdd: Expr<'a option -> 'b option -> 'c option>) =
162+
<@
163+
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) ->
164+
(%opAdd) (Some leftValues.[gid]) (Some rightValues.[gid + 1])
165+
@>
166+
167+
let left (opAdd: Expr<'a option -> 'b option -> 'c option>) =
168+
<@
169+
fun gid (leftValues: ClArray<'a>)->
170+
(%opAdd) (Some leftValues.[gid]) None
171+
@>
172+
173+
let right (opAdd: Expr<'a option -> 'b option -> 'c option>) =
174+
<@
175+
fun gid (rightValues: ClArray<'b>) ->
176+
(%opAdd) None (Some rightValues.[gid + 1])
177+
@>
178+
179+
module ElementWiseAtLeasOneRead =
180+
let both (opAdd: Expr<AtLeastOne<'a,'b> -> 'c option>) =
181+
<@
182+
fun gid (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) ->
183+
(%opAdd) (Both(leftValues.[gid], rightValues.[gid + 1]))
184+
@>
185+
186+
let left (opAdd: Expr<AtLeastOne<'a,'b> -> 'c option>) =
187+
<@
188+
fun gid (leftValues: ClArray<'a>) ->
189+
(%opAdd) (Left(leftValues.[gid]))
190+
@>
191+
192+
let right (opAdd: Expr<AtLeastOne<'a,'b> -> 'c option>) =
193+
<@
194+
fun gid (rightValues: ClArray<'b>) ->
195+
(%opAdd) (Right(rightValues.[gid]))
196+
@>
193197

194198
let private both<'c> =
195199
<@ fun index (result: 'c option) (rawPositionsBuffer: ClArray<int>) (allValuesBuffer: ClArray<'c>) ->
@@ -217,9 +221,9 @@ module ElementwiseConstructor =
217221
| None -> rawPositionsBuffer.[index] <- 0 @>
218222

219223
let private preparePositionsGeneral
220-
(bothWrite: Expr<(int -> ClArray<'a> -> ClArray<'b> -> 'c option)>)
221-
leftWrite
222-
rightWrite
224+
bothRead
225+
leftRead
226+
rightRead
223227
=
224228

225229
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (isLeft: ClArray<int>) (allValues: ClArray<'c>) (positions: ClArray<int>) ->
@@ -228,20 +232,20 @@ module ElementwiseConstructor =
228232

229233
if gid < length - 1
230234
&& allIndices.[gid] = allIndices.[gid + 1] then
231-
let (result: 'c option) = (%bothWrite) gid leftValues rightValues
235+
let (result: 'c option) = (%bothRead) gid leftValues rightValues
232236

233237
(%both) gid result positions allValues
234238
elif (gid < length
235239
&& gid > 0
236240
&& allIndices.[gid - 1] <> allIndices.[gid])
237241
|| gid = 0 then
238242

239-
let leftResult = (%leftWrite) gid leftValues
240-
let rightResult = (%rightWrite) gid rightValues
243+
let leftResult = (%leftRead) gid leftValues
244+
let rightResult = (%rightRead) gid rightValues
241245

242246
(%leftRight) gid leftResult rightResult isLeft allValues positions @>
243247

244-
let private prepareFillVectorGeneral bothWrite leftWrite rightWrite =
248+
let private prepareFillVectorGeneral bothRead leftRead rightRead =
245249
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: ClCell<'a>) (isLeft: ClArray<int>) (allValues: ClArray<'a>) (positions: ClArray<int>) ->
246250

247251
let gid = ndRange.GlobalID0
@@ -250,22 +254,26 @@ module ElementwiseConstructor =
250254

251255
if gid < length - 1
252256
&& allIndices.[gid] = allIndices.[gid + 1] then
253-
let (result: 'a option) = (%bothWrite) gid leftValues rightValues value
257+
let (result: 'a option) = (%bothRead) gid leftValues rightValues value
254258

255259
(%both) gid result positions allValues
256260
elif (gid < length
257261
&& gid > 0
258262
&& allIndices.[gid - 1] <> allIndices.[gid])
259263
|| gid = 0 then
260-
let leftResult = (%leftWrite) gid leftValues value
261-
let rightResult = (%rightWrite) gid rightValues value
264+
let leftResult = (%leftRead) gid leftValues value
265+
let rightResult = (%rightRead) gid rightValues value
262266

263267
(%leftRight) gid leftResult rightResult isLeft allValues positions @>
264268

265-
let preparePositions opAdd = preparePositionsGeneral (opWriteBoth opAdd) (opWriteLeft opAdd) (opWriteRight opAdd)
269+
let preparePositions opAdd =
270+
preparePositionsGeneral (ElementWiseRead.both opAdd) (ElementWiseRead.left opAdd) (ElementWiseRead.right opAdd)
266271

267-
let preparePositionsAtLeastOne opAdd = preparePositionsGeneral (opWriteAtLeastOneBoth opAdd) (opWriteAtLeastOneLeft opAdd) (opWriteAtLeastOneRight opAdd)
272+
let preparePositionsAtLeastOne opAdd =
273+
preparePositionsGeneral (ElementWiseAtLeasOneRead.both opAdd) (ElementWiseAtLeasOneRead.left opAdd) (ElementWiseAtLeasOneRead.right opAdd)
268274

269-
let prepareFillVector opAdd = prepareFillVectorGeneral (opWriteBothFill opAdd) (opWriteLeftFill opAdd) (opWriteRightFill opAdd)
275+
let prepareFillVector opAdd =
276+
prepareFillVectorGeneral (FillSubVectorRead.both opAdd) (FillSubVectorRead.left opAdd) (FillSubVectorRead.right opAdd)
270277

271-
let prepareFillVectorAtLeastOne opAdd = prepareFillVectorGeneral (opWriteAtLeastOneBothFill opAdd) (opWriteAtLeastOneLeftFill opAdd) (opWriteAtLeastOneRightFill opAdd)
278+
let prepareFillVectorAtLeastOne opAdd =
279+
prepareFillVectorGeneral (FillSubVectorAtLeasOneRead.both opAdd) (FillSubVectorAtLeasOneRead.left opAdd) (FillSubVectorAtLeasOneRead.right opAdd)

src/GraphBLAS-sharp.Backend/Vector/SparseVector/SparseVector.fs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ module SparseVector =
9292
let length = allIndices.Length
9393

9494
let allValues =
95-
clContext.CreateClArray<'a>(
95+
clContext.CreateClArray<'c>(
9696
length,
9797
hostAccessMode = HostAccessMode.NotAccessible,
9898
deviceAccessMode = DeviceAccessMode.ReadWrite,
@@ -183,7 +183,7 @@ module SparseVector =
183183
let merge = merge clContext workGroupSize
184184

185185
let prepare =
186-
preparePositions clContext preparePositionsKernel workGroupSize
186+
preparePositions<'a, 'b , 'c> clContext preparePositionsKernel workGroupSize
187187

188188
let setPositions = setPositions clContext workGroupSize
189189

@@ -211,10 +211,10 @@ module SparseVector =
211211
Indices = resultIndices
212212
Size = max leftVector.Size rightVector.Size }
213213

214-
let elementWiseAtLeasOne (clContext: ClContext) (opAdd: Expr<(AtLeastOne<'a,'b> -> 'c option)>) (workGroupSize: int) =
214+
let elementWiseAtLeasOne (clContext: ClContext) (opAdd: Expr<AtLeastOne<'a,'b> -> 'c option>) (workGroupSize: int) =
215215
elementWiseGeneral clContext (ElementwiseConstructor.preparePositionsAtLeastOne opAdd) workGroupSize
216216

217-
let elementWise (clContext: ClContext) opAdd (workGroupSize: int) =
217+
let elementWise (clContext: ClContext) (opAdd: Expr<'a option ->'b option -> 'c option>) (workGroupSize: int) =
218218
elementWiseGeneral clContext (ElementwiseConstructor.preparePositions opAdd) workGroupSize
219219

220220
let private preparePositionsFillSubVector<'a, 'b, 'c when 'a: struct and 'b: struct and 'c: struct>

src/GraphBLAS-sharp.Backend/Vector/Vector.fs

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -95,66 +95,69 @@ module Vector =
9595
ClVectorDense <| toDense processor vector
9696

9797
let elementWiseAtLeastOne (clContext: ClContext) (opAdd: Expr<AtLeastOne<'a, 'b> -> 'c option>) workGroupSize =
98-
let addCoo =
99-
SparseVector.elementWiseAtLeasOne clContext opAdd workGroupSize //TODO()
98+
let addSparse =
99+
SparseVector.elementWiseAtLeasOne clContext opAdd workGroupSize
100100

101101
let addDense =
102102
DenseVector.elementWiseAtLeastOne clContext opAdd workGroupSize
103103

104104
fun (processor: MailboxProcessor<_>) (leftVector: ClVector<'a>) (rightVector: ClVector<'b>) ->
105105
match leftVector, rightVector with
106-
| ClVectorSparse left, ClVectorSparse right -> ClVectorSparse <| addCoo processor left right
106+
| ClVectorSparse left, ClVectorSparse right -> ClVectorSparse <| addSparse processor left right
107107
| ClVectorDense left, ClVectorDense right -> ClVectorDense <| addDense processor left right
108108
| _ -> failwith "Vector formats are not matching."
109109

110110
let elementWise (clContext: ClContext) (opAdd: Expr<'a option -> 'b option -> 'c option>) (workGroupSize: int) =
111111
let addDense = DenseVector.elementWise clContext opAdd workGroupSize
112112

113+
let addSparse = SparseVector.elementWise clContext opAdd workGroupSize
114+
113115
fun (processor: MailboxProcessor<_>) (leftVector: ClVector<'a>) (rightVector: ClVector<'b>) ->
114116
match leftVector, rightVector with
115-
| ClVectorDense leftVector, ClVectorDense rightVector -> addDense processor leftVector rightVector
117+
| ClVectorDense leftVector, ClVectorDense rightVector -> ClVectorDense <| addDense processor leftVector rightVector
118+
| ClVectorSparse left, ClVectorSparse right -> ClVectorSparse <| addSparse processor left right
116119
| _ -> failwith "Vector formats are not matching."
117120

118-
let fillSubVector (clContext: ClContext) mask (workGroupSize: int) =
119-
let cooFillVector =
120-
SparseVector.fillSubVector clContext mask workGroupSize
121+
let fillSubVector (clContext: ClContext) maskOp (workGroupSize: int) =
122+
let sparseFillVector =
123+
SparseVector.fillSubVector clContext maskOp workGroupSize
121124

122125
let denseFillVector =
123-
DenseVector.fillSubVector clContext mask workGroupSize
126+
DenseVector.fillSubVector clContext maskOp workGroupSize
124127

125-
let toCooVector =
128+
let toSparseVector =
126129
DenseVector.toSparse clContext workGroupSize
127130

128-
let toCooMask =
131+
let toSparseMask =
129132
DenseVector.toSparse clContext workGroupSize
130133

131134
fun (processor: MailboxProcessor<_>) (vector: ClVector<'a>) (maskVector: ClVector<'b>) (value: ClCell<'a>) ->
132135
match vector, maskVector with
133136
| ClVectorSparse vector, ClVectorSparse mask ->
134137
ClVectorSparse
135-
<| cooFillVector processor vector mask value
138+
<| sparseFillVector processor vector mask value
136139
| ClVectorSparse vector, ClVectorDense mask ->
137-
let mask = toCooMask processor mask
140+
let mask = toSparseMask processor mask
138141

139142
ClVectorSparse
140-
<| cooFillVector processor vector mask value
143+
<| sparseFillVector processor vector mask value
141144
| ClVectorDense vector, ClVectorSparse mask ->
142-
let vector = toCooVector processor vector
145+
let vector = toSparseVector processor vector
143146

144147
ClVectorSparse
145-
<| cooFillVector processor vector mask value
148+
<| sparseFillVector processor vector mask value
146149
| ClVectorDense vector, ClVectorDense mask ->
147150
ClVectorDense
148151
<| denseFillVector processor vector mask value
149152

150153
let reduce (clContext: ClContext) (workGroupSize: int) (opAdd: Expr<'a -> 'a -> 'a>) =
151-
let cooReduce =
154+
let sparseReduce =
152155
SparseVector.reduce clContext workGroupSize opAdd
153156

154157
let denseReduce =
155158
DenseVector.reduce clContext workGroupSize opAdd
156159

157160
fun (processor: MailboxProcessor<_>) (vector: ClVector<'a>) ->
158161
match vector with
159-
| ClVectorSparse vector -> cooReduce processor vector
162+
| ClVectorSparse vector -> sparseReduce processor vector
160163
| ClVectorDense vector -> denseReduce processor vector

0 commit comments

Comments
 (0)