Skip to content

Commit 91a72e2

Browse files
committed
add: Gather tests
1 parent 66c2711 commit 91a72e2

17 files changed

Lines changed: 529 additions & 606 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -349,13 +349,3 @@ module ClArray =
349349

350350
result
351351

352-
let iterate (clContext: ClContext) workGroupSize iterator =
353-
354-
let create = create clContext workGroupSize iterator
355-
356-
let scatter = Scatter.runInplace clContext workGroupSize
357-
358-
fun (processor: MailboxProcessor<_>) allocationMode (inputArray: ClArray<'a>) (resultArray: ClArray<'a>) ->
359-
360-
let positions = create processor allocationMode
361-

src/GraphBLAS-sharp.Backend/Common/Gather.fs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,28 @@ module internal Gather =
1717
let run (clContext: ClContext) workGroupSize =
1818

1919
let gather =
20-
<@ fun (ndRange: Range1D) (positions: ClArray<int>) (values: ClArray<'a>) (outputArray: ClArray<'a>) (size: int) ->
20+
<@ fun (ndRange: Range1D) positionsLength valuesLength (positions: ClArray<int>) (values: ClArray<'a>) (outputArray: ClArray<'a>) ->
2121

2222
let i = ndRange.GlobalID0
2323

24-
if i < size then
24+
if i < positionsLength then
2525
let position = positions.[i]
26-
let value = values.[position]
2726

28-
outputArray.[i] <- value @>
27+
if position >= 0 && position < valuesLength then
28+
outputArray.[i] <- values.[position] @>
2929

30-
let program = clContext.Compile(gather)
30+
let program = clContext.Compile gather
3131

32-
fun (processor: MailboxProcessor<_>) (positions: ClArray<int>) (inputArray: ClArray<'a>) (outputArray: ClArray<'a>) ->
32+
fun (processor: MailboxProcessor<_>) (positions: ClArray<int>) (values: ClArray<'a>) (outputArray: ClArray<'a>) ->
3333

34-
let size = outputArray.Length
34+
if positions.Length <> outputArray.Length then failwith "Lengths must be the same"
3535

3636
let kernel = program.GetKernel()
3737

38-
let ndRange = Range1D.CreateValid(size, workGroupSize)
38+
let ndRange = Range1D.CreateValid(positions.Length, workGroupSize)
3939

4040
processor.Post(
41-
Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange positions inputArray outputArray size)
41+
Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange positions.Length values.Length positions values outputArray)
4242
)
4343

4444
processor.Post(Msg.CreateRunMsg<_, _>(kernel))

src/GraphBLAS-sharp.Backend/Common/PrefixSum.fs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,75 @@ module PrefixSum =
270270
fun (processor: MailboxProcessor<_>) (inputArray: ClArray<int>) ->
271271

272272
scan processor inputArray 0
273+
274+
275+
module ByKey =
276+
let private sequentialSegments opWrite (clContext: ClContext) workGroupSize opAdd zero =
277+
278+
let kernel =
279+
<@ fun (ndRange: Range1D) lenght uniqueKeysCount (values: ClArray<'a>) (keys: ClArray<int>) (offsets: ClArray<int>) ->
280+
let gid = ndRange.GlobalID0
281+
282+
if gid < uniqueKeysCount then
283+
let sourcePosition = offsets.[gid]
284+
let sourceKey = keys.[sourcePosition]
285+
286+
let mutable currentSum = zero
287+
let mutable previousSum = zero
288+
289+
let mutable currentPosition = sourcePosition
290+
291+
while currentPosition < lenght
292+
&& keys.[currentPosition] = sourceKey do
293+
294+
previousSum <- currentSum
295+
currentSum <- (%opAdd) currentSum values.[currentPosition]
296+
297+
values.[currentPosition] <- (%opWrite) previousSum currentSum
298+
299+
currentPosition <- currentPosition + 1 @>
300+
301+
let kernel = clContext.Compile kernel
302+
303+
fun (processor: MailboxProcessor<_>) uniqueKeysCount (values: ClArray<'a>) (keys: ClArray<int>) (offsets: ClArray<int>) ->
304+
305+
let kernel = kernel.GetKernel()
306+
307+
let ndRange =
308+
Range1D.CreateValid(values.Length, workGroupSize)
309+
310+
processor.Post(
311+
Msg.MsgSetArguments
312+
(fun () -> kernel.KernelFunc ndRange values.Length uniqueKeysCount values keys offsets)
313+
)
314+
315+
processor.Post(Msg.CreateRunMsg<_, _> kernel)
316+
317+
/// <summary>
318+
/// Exclude scan by key.
319+
/// </summary>
320+
/// <example>
321+
/// <code>
322+
/// let arr = [| 1; 1; 1; 1; 1; 1|]
323+
/// let keys = [| 1; 2; 2; 2; 3; 3 |]
324+
/// ...
325+
/// > val result = [| 0; 0; 1; 2; 0; 1 |]
326+
/// </code>
327+
/// </example>
328+
let sequentialExclude clContext =
329+
sequentialSegments (Map.fst ()) clContext
330+
331+
/// <summary>
332+
/// Include scan by key.
333+
/// </summary>
334+
/// <example>
335+
/// <code>
336+
/// let arr = [| 1; 1; 1; 1; 1; 1|]
337+
/// let keys = [| 1; 2; 2; 2; 3; 3 |]
338+
/// ...
339+
/// > val result = [| 1; 1; 2; 3; 1; 2 |]
340+
/// </code>
341+
/// </example>
342+
let sequentialInclude clContext =
343+
sequentialSegments (Map.snd ()) clContext
344+

src/GraphBLAS-sharp.Backend/Common/Scatter.fs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ module internal Scatter =
4343

4444
fun (processor: MailboxProcessor<_>) (positions: ClArray<int>) (values: ClArray<'a>) (result: ClArray<'a>) ->
4545

46+
if positions.Length <> values.Length then failwith "Lengths must be the same"
47+
4648
let positionsLength = positions.Length
4749

4850
let ndRange =

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/Matrix.fs

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,47 +12,9 @@ open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
1212
open GraphBLAS.FSharp.Backend.Objects.ClCell
1313

1414
module Matrix =
15-
let private expandRowPointers (clContext: ClContext) workGroupSize =
16-
17-
let expandRowPointers =
18-
<@ fun (ndRange: Range1D) (rowPointers: ClArray<int>) (rowCount: int) (rows: ClArray<int>) ->
19-
20-
let i = ndRange.GlobalID0
21-
22-
if i < rowCount then
23-
let rowPointer = rowPointers.[i]
24-
25-
if rowPointer <> rowPointers.[i + 1] then
26-
rows.[rowPointer] <- i @>
27-
28-
let program = clContext.Compile(expandRowPointers)
29-
30-
let create =
31-
ClArray.zeroCreate clContext workGroupSize
32-
33-
let scan =
34-
PrefixSum.runIncludeInplace <@ max @> clContext workGroupSize
35-
36-
fun (processor: MailboxProcessor<_>) allocationMode (rowPointers: ClArray<int>) nnz rowCount ->
37-
38-
let rows = create processor allocationMode nnz
39-
40-
let kernel = program.GetKernel()
41-
42-
let ndRange =
43-
Range1D.CreateValid(rowCount, workGroupSize)
44-
45-
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange rowPointers rowCount rows))
46-
processor.Post(Msg.CreateRunMsg<_, _> kernel)
47-
48-
let total = scan processor rows 0
49-
processor.Post(Msg.CreateFreeMsg(total))
50-
51-
rows
52-
5315
let toCOO (clContext: ClContext) workGroupSize =
5416
let prepare =
55-
expandRowPointers clContext workGroupSize
17+
Common.expandRowPointers clContext workGroupSize
5618

5719
let copy = ClArray.copy clContext workGroupSize
5820

@@ -77,7 +39,7 @@ module Matrix =
7739

7840
let toCOOInplace (clContext: ClContext) workGroupSize =
7941
let prepare =
80-
expandRowPointers clContext workGroupSize
42+
Common.expandRowPointers clContext workGroupSize
8143

8244
fun (processor: MailboxProcessor<_>) allocationMode (matrix: ClMatrix.CSR<'a>) ->
8345
let rows =

0 commit comments

Comments
 (0)