Skip to content

Commit 03e7e95

Browse files
committed
add: init gather
1 parent 63037b6 commit 03e7e95

4 files changed

Lines changed: 98 additions & 12 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/Gather.fs

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,34 @@ namespace GraphBLAS.FSharp.Backend.Common
33
open Brahma.FSharp
44

55
module internal Gather =
6+
let runInit positionMap (clContext: ClContext) workGroupSize =
7+
8+
let gather =
9+
<@ fun (ndRange: Range1D) valuesLength (values: ClArray<'a>) (outputArray: ClArray<'a>) ->
10+
11+
let gid = ndRange.GlobalID0
12+
13+
if gid < valuesLength then
14+
let position = (%positionMap) gid
15+
16+
if position >= 0 && position < valuesLength then
17+
outputArray.[gid] <- values.[position] @>
18+
19+
let program = clContext.Compile gather
20+
21+
fun (processor: MailboxProcessor<_>) (values: ClArray<'a>) (outputArray: ClArray<'a>) ->
22+
23+
let kernel = program.GetKernel()
24+
25+
let ndRange = Range1D.CreateValid(outputArray.Length, workGroupSize)
26+
27+
processor.Post(
28+
Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange values.Length values outputArray)
29+
)
30+
31+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
32+
33+
634
/// <summary>
735
/// Creates a new array obtained from positions replaced with values from the given array at these positions (indices).
836
/// </summary>
@@ -19,13 +47,13 @@ module internal Gather =
1947
let gather =
2048
<@ fun (ndRange: Range1D) positionsLength valuesLength (positions: ClArray<int>) (values: ClArray<'a>) (outputArray: ClArray<'a>) ->
2149

22-
let i = ndRange.GlobalID0
50+
let gid = ndRange.GlobalID0
2351

24-
if i < positionsLength then
25-
let position = positions.[i]
52+
if gid < positionsLength then
53+
let position = positions.[gid]
2654

2755
if position >= 0 && position < valuesLength then
28-
outputArray.[i] <- values.[position] @>
56+
outputArray.[gid] <- values.[position] @>
2957

3058
let program = clContext.Compile gather
3159

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM/Expand.fs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ module Expand =
2727

2828
let gather = Gather.run clContext workGroupSize
2929

30+
let shiftedGather = Gather.runInit Map.inc clContext workGroupSize
31+
3032
let prefixSum = PrefixSum.standardExcludeInplace clContext workGroupSize
3133

3234
fun (processor: MailboxProcessor<_>) (leftMatrix: ClMatrix.CSR<'a>) (rightMatrix: ClMatrix.CSR<'b>) ->
@@ -130,7 +132,7 @@ module Expand =
130132
// another way to get offsets ???
131133
let offsets = removeDuplicates processor segmentsPointers
132134

133-
segmentPrefixSum processor offsets.Length BPositions APositions offsets // TODO(offsets lengths in scan)
135+
segmentPrefixSum processor offsets.Length BPositions APositions offsets
134136

135137
offsets.Free processor
136138

@@ -206,7 +208,6 @@ module Expand =
206208
let reduce (clContext: ClContext) workGroupSize opAdd =
207209

208210
let reduce = Reduce.ByKey2D.segmentSequential clContext workGroupSize opAdd
209-
//let reduce = Reduce.ByKey2D.sequential clContext workGroupSize opAdd
210211

211212
let getUniqueBitmap =
212213
ClArray.getUniqueBitmap2LastOccurrence clContext workGroupSize
@@ -240,12 +241,9 @@ module Expand =
240241
bitmap.Free processor
241242
positions.Free processor
242243

243-
let reducedColumns, reducedRows, reducedValues =
244+
let reducedColumns, reducedRows, reducedValues = // by size variance TODO()
244245
reduce processor allocationMode uniqueKeysCount offsets columns rows values
245246

246-
// let reducedColumns, reducedRows, reducedValues =
247-
// reduce processor DeviceOnly uniqueKeysCount columns rows values
248-
249247
offsets.Free processor
250248

251249
reducedValues, reducedColumns, reducedRows

tests/GraphBLAS-sharp.Tests/Common/Gather.fs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ open Expecto
66
open Microsoft.FSharp.Collections
77
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
88
open GraphBLAS.FSharp.Backend.Objects.ClContext
9+
open GraphBLAS.FSharp.Backend.Quotes
910

1011
let context = Context.defaultContext.ClContext
1112

@@ -62,3 +63,62 @@ let tests =
6263
createTest<bool> (=) Gather.run
6364
createTest<uint> (=) Gather.run ]
6465
|> testList "Gather"
66+
67+
68+
let makeTestInit isEqual testFun indexMap (array: ('a * 'a) []) =
69+
if array.Length > 0 then
70+
71+
let positions, values, target =
72+
Array.mapi (fun index (first, second) -> indexMap index, first, second) array
73+
|> Array.unzip3
74+
75+
let clValues =
76+
context.CreateClArrayWithSpecificAllocationMode(DeviceOnly, values)
77+
78+
let clTarget =
79+
context.CreateClArrayWithSpecificAllocationMode(DeviceOnly, target)
80+
81+
testFun processor clValues clTarget
82+
83+
clValues.Free processor
84+
85+
let actual = clTarget.ToHostAndFree processor
86+
87+
check isEqual actual positions values target
88+
89+
let createTestInit<'a> (isEqual: 'a -> 'a -> bool) testFun indexMapQ indexMap =
90+
91+
let testFun = testFun indexMapQ context Utils.defaultWorkGroupSize
92+
93+
makeTestInit isEqual testFun indexMap
94+
|> testPropertyWithConfig Utils.defaultConfig $"test on {typeof<'a>}"
95+
96+
let initTests =
97+
98+
let idTests =
99+
[ createTestInit<int> (=) Gather.runInit Map.id id
100+
101+
if Utils.isFloat64Available context.ClDevice then
102+
createTestInit<float> Utils.floatIsEqual Gather.runInit Map.id id
103+
104+
createTestInit<float32> Utils.float32IsEqual Gather.runInit Map.id id
105+
createTestInit<bool> (=) Gather.runInit Map.id id
106+
createTestInit<uint> (=) Gather.runInit Map.id id]
107+
|> testList "id"
108+
109+
let inc = ((+) 1)
110+
111+
let incTests =
112+
[ createTestInit<int> (=) Gather.runInit Map.inc inc
113+
114+
if Utils.isFloat64Available context.ClDevice then
115+
createTestInit<float> Utils.floatIsEqual Gather.runInit Map.inc inc
116+
117+
createTestInit<float32> Utils.float32IsEqual Gather.runInit Map.inc inc
118+
createTestInit<bool> (=) Gather.runInit Map.inc inc
119+
createTestInit<uint> (=) Gather.runInit Map.inc inc]
120+
|> testList "inc"
121+
122+
testList "init" [idTests; incTests]
123+
124+

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ open GraphBLAS.FSharp.Tests.Matrix
9494
let allTests =
9595
testList
9696
"All tests"
97-
[ SpGeMM.generalTests
98-
]
97+
[ // SpGeMM.generalTests
98+
Common.Gather.initTests ]
9999

100100
|> testSequenced
101101

0 commit comments

Comments
 (0)