Skip to content

Commit 6f02570

Browse files
committed
add: init scatter
1 parent 03e7e95 commit 6f02570

4 files changed

Lines changed: 127 additions & 46 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/Scatter.fs

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ module internal Scatter =
5050
/// <code>
5151
/// let positions = [| 0; 0; 1; 1; 1; 2; 3; 3; 4 |]
5252
/// let values = [| 1.9; 2.8; 3.7; 4.6; 5.5; 6.4; 7.3; 8.2; 9.1 |]
53-
/// let result = run clContext 32 processor positions values result
53+
/// run clContext 32 processor positions values result
5454
/// ...
5555
/// > val result = [| 1,9; 3.7; 6.4; 7.3; 9.1 |]
5656
/// </code>
@@ -78,7 +78,7 @@ module internal Scatter =
7878
/// <code>
7979
/// let positions = [| 0; 0; 1; 1; 1; 2; 3; 3; 4 |]
8080
/// let values = [| 1.9; 2.8; 3.7; 4.6; 5.5; 6.4; 7.3; 8.2; 9.1 |]
81-
/// let result = run clContext 32 processor positions values result
81+
/// run clContext 32 processor positions values result
8282
/// ...
8383
/// > val result = [| 2.8; 5.5; 6.4; 8.2; 9.1 |]
8484
/// </code>
@@ -93,3 +93,50 @@ module internal Scatter =
9393
&& (0 <= index && index < resultLength) @>
9494
<| clContext
9595

96+
/// <summary>
97+
/// Writes elements from the array of values to the array at the positions indicated by the global id map.
98+
/// </summary>
99+
/// <remarks>
100+
/// If index is out of bounds, the value will be ignored.
101+
/// </remarks>
102+
/// <example>
103+
/// <code>
104+
/// let positionMap = fun x -> x + 1
105+
/// let values = [| 1.9; 2.8; 3.7; 4.6; 5.5; 6.4; 7.3; 8.2; 9.1 |]
106+
/// let result = ... // create result
107+
/// run positionMap clContext 32 processor positions values result
108+
/// ...
109+
/// > val result = [| 2.8; 3.7; 4.6; 5.5; 6.4; 7.3; 8.2; 9.1 |]
110+
/// </code>
111+
/// </example>
112+
/// <param name="positionMap">Should be injective in order to avoid race conditions.</param>
113+
let init<'a> positionMap (clContext: ClContext) workGroupSize =
114+
115+
let run =
116+
<@ fun (ndRange: Range1D) (valuesLength: int) (values: ClArray<'a>) (result: ClArray<'a>) resultLength ->
117+
118+
let gid = ndRange.GlobalID0
119+
120+
if gid < valuesLength then
121+
let position = (%positionMap) gid
122+
123+
// may be race condition
124+
if 0 <= position && position < resultLength then
125+
result.[position] <- values.[gid] @>
126+
127+
let program = clContext.Compile(run)
128+
129+
fun (processor: MailboxProcessor<_>) (values: ClArray<'a>) (result: ClArray<'a>) ->
130+
131+
let ndRange =
132+
Range1D.CreateValid(values.Length, workGroupSize)
133+
134+
let kernel = program.GetKernel()
135+
136+
processor.Post(
137+
Msg.MsgSetArguments
138+
(fun () -> kernel.KernelFunc ndRange values.Length values result result.Length)
139+
)
140+
141+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
142+

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM/Expand.fs

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,14 @@ type Values<'a> = ClArray<'a>
1717
module Expand =
1818
let getSegmentPointers (clContext: ClContext) workGroupSize =
1919

20-
let create =
21-
ClArray.init clContext workGroupSize Map.id
20+
let subtract = ClArray.map2 clContext workGroupSize Map.subtraction
2221

23-
let createShifted =
24-
ClArray.init clContext workGroupSize Map.inc
22+
let idGather = Gather.runInit Map.id clContext workGroupSize
2523

26-
let subtract = ClArray.map2 clContext workGroupSize Map.subtraction
24+
let incGather = Gather.runInit Map.inc clContext workGroupSize
2725

2826
let gather = Gather.run clContext workGroupSize
2927

30-
let shiftedGather = Gather.runInit Map.inc clContext workGroupSize
31-
3228
let prefixSum = PrefixSum.standardExcludeInplace clContext workGroupSize
3329

3430
fun (processor: MailboxProcessor<_>) (leftMatrix: ClMatrix.CSR<'a>) (rightMatrix: ClMatrix.CSR<'b>) ->
@@ -37,27 +33,17 @@ module Expand =
3733

3834
// extract first rightMatrix.RowPointers.Lengths - 1 indices from rightMatrix.RowPointers
3935
// (right matrix row pointers without last item)
40-
let positions = // TODO(fuse)
41-
create processor DeviceOnly positionsLength
42-
4336
let firstPointers =
4437
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, positionsLength)
4538

46-
gather processor positions rightMatrix.RowPointers firstPointers
47-
48-
positions.Free processor
39+
idGather processor rightMatrix.RowPointers firstPointers
4940

5041
// extract last rightMatrix.RowPointers.Lengths - 1 indices from rightMatrix.RowPointers
5142
// (right matrix row pointers without first item)
52-
let shiftedPositions = // TODO(fuse)
53-
createShifted processor DeviceOnly positionsLength
54-
5543
let lastPointers =
5644
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, positionsLength)
5745

58-
gather processor shiftedPositions rightMatrix.RowPointers lastPointers
59-
60-
shiftedPositions.Free processor
46+
incGather processor rightMatrix.RowPointers lastPointers
6147

6248
// subtract
6349
let rightMatrixRowsLengths =

tests/GraphBLAS-sharp.Tests/Common/Scatter.fs

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,21 @@ open Expecto.Logging
55
open Brahma.FSharp
66
open GraphBLAS.FSharp.Tests
77
open GraphBLAS.FSharp.Tests.Context
8-
open GraphBLAS.FSharp
8+
open GraphBLAS.FSharp.Backend.Quotes
99
open GraphBLAS.FSharp.Backend.Common
1010
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
1111

1212
let logger = Log.create "Scatter.Tests"
1313

1414
let context = defaultContext.ClContext
1515

16-
let config =
17-
{ Tests.Utils.defaultConfig with
18-
endSize = 1000000 }
16+
let config = { Utils.defaultConfig with endSize = 10000 }
1917

20-
let wgSize = Tests.Utils.defaultWorkGroupSize
18+
let wgSize = Utils.defaultWorkGroupSize
2119

2220
let q = defaultContext.Queue
2321

24-
let makeTest hostScatter scatter (array: (int * 'a) []) (result: 'a []) =
22+
let makeTest<'a when 'a: equality> hostScatter scatter (array: (int * 'a) []) (result: 'a []) =
2523
if array.Length > 0 then
2624
let positions, values = Array.unzip array
2725

@@ -30,40 +28,89 @@ let makeTest hostScatter scatter (array: (int * 'a) []) (result: 'a []) =
3028
|> hostScatter positions values
3129

3230
let actual =
33-
use clPositions = context.CreateClArray positions
31+
let clPositions = context.CreateClArray positions
3432
use clValues = context.CreateClArray values
3533
use clResult = context.CreateClArray result
3634

3735
scatter q clPositions clValues clResult
3836

3937
clResult.ToHostAndFree q
4038

41-
$"Arrays should be equal. Actual is \n%A{actual}, expected \n%A{expected}"
42-
|> Tests.Utils.compareArrays (=) actual expected
39+
$"Arrays should be equal."
40+
|> Utils.compareArrays (=) actual expected
4341

44-
let testFixturesLast<'a when 'a: equality> hostScatter =
45-
Scatter.scatterLastOccurrence<'a> context wgSize
46-
|> makeTest hostScatter
47-
|> testPropertyWithConfig { config with endSize = 10 } $"Correctness on %A{typeof<'a>}"
42+
let testFixturesLast<'a when 'a: equality> =
43+
Scatter.scatterLastOccurrence context wgSize
44+
|> makeTest<'a> HostPrimitives.scatterLastOccurrence
45+
|> testPropertyWithConfig config $"Correctness on %A{typeof<'a>}"
4846

49-
let testFixturesFirst<'a when 'a: equality> hostScatter =
50-
Scatter.scatterFirstOccurrence<'a> context wgSize
51-
|> makeTest hostScatter
52-
|> testPropertyWithConfig { config with endSize = 10 } $"Correctness on %A{typeof<'a>}"
47+
let testFixturesFirst<'a when 'a: equality> =
48+
Scatter.scatterFirstOccurrence context wgSize
49+
|> makeTest<'a> HostPrimitives.scatterFirstOccurrence
50+
|> testPropertyWithConfig config $"Correctness on %A{typeof<'a>}"
5351

5452
let tests =
5553
q.Error.Add(fun e -> failwithf $"%A{e}")
5654

5755
let last =
58-
[ testFixturesLast<int> HostPrimitives.scatterLastOccurrence
59-
testFixturesLast<byte> HostPrimitives.scatterLastOccurrence
60-
testFixturesLast<bool> HostPrimitives.scatterLastOccurrence ]
56+
[ testFixturesLast<int>
57+
testFixturesLast<byte>
58+
testFixturesLast<bool> ]
6159
|> testList "Last Occurrence"
6260

6361
let first =
64-
[ testFixturesFirst<int> HostPrimitives.scatterFirstOccurrence
65-
testFixturesFirst<byte> HostPrimitives.scatterFirstOccurrence
66-
testFixturesFirst<bool> HostPrimitives.scatterFirstOccurrence ]
62+
[ testFixturesFirst<int>
63+
testFixturesFirst<byte>
64+
testFixturesFirst<bool> ]
6765
|> testList "First Occurrence"
6866

6967
testList "Scatter tests" [first; last]
68+
69+
let makeTestInit<'a when 'a: equality> positionsMap scatter (values: 'a []) (result: 'a []) =
70+
if values.Length > 0 then
71+
72+
let positionsAndValues =
73+
Array.mapi (fun index value -> positionsMap index, value) values
74+
75+
let expected =
76+
Array.init result.Length (fun index ->
77+
match Array.tryFind (fst >> ((=) index)) positionsAndValues with
78+
| Some (_, value) -> value
79+
| None -> result.[index])
80+
81+
let actual =
82+
let values = Array.map snd positionsAndValues
83+
84+
use clValues = context.CreateClArray values
85+
use clResult = context.CreateClArray result
86+
87+
scatter q clValues clResult
88+
89+
clResult.ToHostAndFree q
90+
91+
$"Arrays should be equal."
92+
|> Utils.compareArrays (=) actual expected
93+
94+
let createInitTest<'a when 'a: equality> indexMap indexMapQ =
95+
Scatter.init<'a> indexMapQ context Utils.defaultWorkGroupSize
96+
|> makeTestInit<'a> indexMap
97+
|> testPropertyWithConfig config $"test on {typeof<'a>}"
98+
99+
let initTests =
100+
q.Error.Add(fun e -> failwithf $"%A{e}")
101+
102+
let idTest =
103+
[ createInitTest<int> id Map.id
104+
createInitTest<byte> id Map.id
105+
createInitTest<bool> id Map.id ]
106+
|> testList "id"
107+
108+
let inc = ((+) 1)
109+
110+
let incTest =
111+
[ createInitTest<int> inc Map.inc
112+
createInitTest<byte> inc Map.inc
113+
createInitTest<bool> inc Map.inc ]
114+
|> testList "increment"
115+
116+
testList "Scatter init tests" [idTest; incTest]

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,9 @@ open GraphBLAS.FSharp.Tests.Matrix
9494
let allTests =
9595
testList
9696
"All tests"
97-
[ // SpGeMM.generalTests
98-
Common.Gather.initTests ]
97+
[ // SpGeMM.getSegmentsTests
98+
// Common.Gather.initTests
99+
Common.Scatter.initTests ]
99100

100101
|> testSequenced
101102

0 commit comments

Comments
 (0)