Skip to content

Commit 3b5a5e6

Browse files
committed
refactor: SpraseVector.fillSubVector, tests
1 parent 21f27f5 commit 3b5a5e6

3 files changed

Lines changed: 106 additions & 126 deletions

File tree

src/GraphBLAS-sharp.Backend/Vector/SparseVector/Map2.fs

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ module Map2 =
2929
result @>
3030

3131
let preparePositionsGeneral (op: Expr<'a option -> 'b option -> 'c option>) =
32-
<@ fun (ndRange: Range1D) vectorLength leftValuesLength rightValuesLength (leftValues: ClArray<'a>) (leftIndices: ClArray<int>) (rightValues: ClArray<'b>) (rightIndices: ClArray<int>) (resultBitmap: ClArray<int>) (resultValues: ClArray<'c>) (resultIndices: ClArray<int>) ->
32+
<@ fun (ndRange: Range1D) length leftValuesLength rightValuesLength (leftValues: ClArray<'a>) (leftIndices: ClArray<int>) (rightValues: ClArray<'b>) (rightIndices: ClArray<int>) (resultBitmap: ClArray<int>) (resultValues: ClArray<'c>) (resultIndices: ClArray<int>) ->
3333

3434
let gid = ndRange.GlobalID0
3535

36-
if gid < vectorLength then
36+
if gid < length then
3737

3838
let (leftValue: 'a option) =
3939
(%binSearch) leftValuesLength gid leftIndices leftValues
@@ -49,6 +49,29 @@ module Map2 =
4949
resultBitmap.[gid] <- 1
5050
| None -> resultBitmap.[gid] <- 0 @>
5151

52+
let prepareFillGeneral op =
53+
<@ fun (ndRange: Range1D) length leftValuesLength rightValuesLength (leftValues: ClArray<'a>) (leftIndices: ClArray<int>) (rightValues: ClArray<'b>) (rightIndices: ClArray<int>) (value: ClCell<'a>) (resultBitmap: ClArray<int>) (resultValues: ClArray<'c>) (resultIndices: ClArray<int>) ->
54+
55+
let gid = ndRange.GlobalID0
56+
57+
let value = value.Value
58+
59+
if gid < length then
60+
61+
let (leftValue: 'a option) =
62+
(%binSearch) leftValuesLength gid leftIndices leftValues
63+
64+
let (rightValue: 'b option) =
65+
(%binSearch) rightValuesLength gid rightIndices rightValues
66+
67+
match (%op) leftValue rightValue value with
68+
| Some value ->
69+
resultValues.[gid] <- value
70+
resultIndices.[gid] <- gid
71+
72+
resultBitmap.[gid] <- 1
73+
| None -> resultBitmap.[gid] <- 0 @>
74+
5275
let merge workGroupSize =
5376
<@ fun (ndRange: Range1D) (firstSide: int) (secondSide: int) (sumOfSides: int) (firstIndicesBuffer: ClArray<int>) (firstValuesBuffer: ClArray<'a>) (secondIndicesBuffer: ClArray<int>) (secondValuesBuffer: ClArray<'b>) (allIndicesBuffer: ClArray<int>) (firstResultValues: ClArray<'a>) (secondResultValues: ClArray<'b>) (isLeftBitMap: ClArray<int>) ->
5477

@@ -158,31 +181,6 @@ module Map2 =
158181
firstResultValues.[i] <- firstValuesBuffer.[beginIdx + boundaryX]
159182
isLeftBitMap.[i] <- 1 @>
160183

161-
let prepareFillVector opAdd =
162-
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: ClCell<'a>) (isLeft: ClArray<int>) (allValues: ClArray<'a>) (positions: ClArray<int>) ->
163-
164-
let gid = ndRange.GlobalID0
165-
166-
let value = value.Value
167-
168-
if gid < length - 1
169-
&& allIndices.[gid] = allIndices.[gid + 1] then
170-
let result =
171-
(%opAdd) (Some leftValues.[gid]) (Some rightValues.[gid + 1]) value
172-
173-
(%PreparePositions.both) gid result positions allValues
174-
elif (gid < length
175-
&& gid > 0
176-
&& allIndices.[gid - 1] <> allIndices.[gid])
177-
|| gid = 0 then
178-
let leftResult =
179-
(%opAdd) (Some leftValues.[gid]) None value
180-
181-
let rightResult =
182-
(%opAdd) None (Some rightValues.[gid]) value
183-
184-
(%PreparePositions.leftRight) gid leftResult rightResult isLeft allValues positions @>
185-
186184
let preparePositions opAdd =
187185
<@ fun (ndRange: Range1D) length (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (isLeft: ClArray<int>) (allValues: ClArray<'c>) (positions: ClArray<int>) ->
188186

src/GraphBLAS-sharp.Backend/Vector/SparseVector/SparseVector.fs

Lines changed: 30 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -258,20 +258,21 @@ module SparseVector =
258258
=
259259

260260
let kernel =
261-
clContext.Compile(Map2.prepareFillVector op)
261+
clContext.Compile(Map2.prepareFillGeneral op)
262262

263-
fun (processor: MailboxProcessor<_>) (allIndices: ClArray<int>) (leftValues: ClArray<'a>) (rightValues: ClArray<'b>) (value: ClCell<'a>) (isLeft: ClArray<int>) ->
263+
fun (processor: MailboxProcessor<_>) (vectorLenght: int) (leftValues: ClArray<'a>) (leftIndices: ClArray<int>) (rightValues: ClArray<'b>) (rightIndices: ClArray<int>) (value: ClCell<'a>)->
264264

265-
let length = allIndices.Length
265+
let resultBitmap =
266+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, vectorLenght)
266267

267-
let allValues =
268-
clContext.CreateClArrayWithSpecificAllocationMode<'a>(DeviceOnly, length)
268+
let resultIndices =
269+
clContext.CreateClArrayWithSpecificAllocationMode<int>(DeviceOnly, vectorLenght)
269270

270-
let positions =
271-
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, length)
271+
let resultValues =
272+
clContext.CreateClArrayWithSpecificAllocationMode<'a>(DeviceOnly, vectorLenght)
272273

273274
let ndRange =
274-
Range1D.CreateValid(length, workGroupSize)
275+
Range1D.CreateValid(vectorLenght, workGroupSize)
275276

276277
let kernel = kernel.GetKernel()
277278

@@ -280,26 +281,31 @@ module SparseVector =
280281
(fun () ->
281282
kernel.KernelFunc
282283
ndRange
283-
length
284-
allIndices
284+
vectorLenght
285+
leftValues.Length
286+
rightValues.Length
285287
leftValues
288+
leftIndices
286289
rightValues
290+
rightIndices
287291
value
288-
isLeft
289-
allValues
290-
positions)
292+
resultBitmap
293+
resultValues
294+
resultIndices)
291295
)
292296

293297
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
294298

295-
allValues, positions
299+
resultBitmap, resultValues, resultIndices
296300

297301
///<param name="clContext">.</param>
298302
///<param name="op">.</param>
299303
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
300-
let fillSubVector<'a, 'b when 'a: struct and 'b: struct> (clContext: ClContext) op workGroupSize =
301-
302-
let merge = merge clContext workGroupSize
304+
let fillSubVector<'a, 'b when 'a: struct and 'b: struct>
305+
(clContext: ClContext)
306+
op
307+
workGroupSize
308+
=
303309

304310
let prepare =
305311
preparePositionsFillSubVector clContext op workGroupSize
@@ -308,27 +314,20 @@ module SparseVector =
308314

309315
fun (processor: MailboxProcessor<_>) allocationMode (leftVector: ClVector.Sparse<'a>) (rightVector: ClVector.Sparse<'b>) (value: ClCell<'a>) ->
310316

311-
let allIndices, leftValues, rightValues, isLeft =
312-
merge processor leftVector.Indices leftVector.Values rightVector.Indices rightVector.Values
313-
314-
let allValues, positions =
315-
prepare processor allIndices leftValues rightValues value isLeft
316-
317-
processor.Post(Msg.CreateFreeMsg<_>(leftValues))
318-
processor.Post(Msg.CreateFreeMsg<_>(rightValues))
319-
processor.Post(Msg.CreateFreeMsg<_>(isLeft))
317+
let bitmap, values, indices =
318+
prepare processor leftVector.Size leftVector.Values leftVector.Indices rightVector.Values rightVector.Indices value
320319

321320
let resultValues, resultIndices =
322-
setPositions processor allocationMode allValues allIndices positions
321+
setPositions processor allocationMode values indices bitmap
323322

324-
processor.Post(Msg.CreateFreeMsg<_>(allIndices))
325-
processor.Post(Msg.CreateFreeMsg<_>(allValues))
326-
processor.Post(Msg.CreateFreeMsg<_>(positions))
323+
processor.Post(Msg.CreateFreeMsg<_>(indices))
324+
processor.Post(Msg.CreateFreeMsg<_>(values))
325+
processor.Post(Msg.CreateFreeMsg<_>(bitmap))
327326

328327
{ Context = clContext
329328
Values = resultValues
330329
Indices = resultIndices
331-
Size = max leftVector.Size rightVector.Size }
330+
Size = rightVector.Size }
332331

333332
let toDense (clContext: ClContext) workGroupSize =
334333

0 commit comments

Comments
 (0)