Skip to content

Commit 8ec7fd7

Browse files
committed
add: Scatter.firstOccurrence
1 parent 84fb950 commit 8ec7fd7

11 files changed

Lines changed: 107 additions & 52 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ module ClArray =
184184
let removeDuplications (clContext: ClContext) workGroupSize =
185185

186186
let scatter =
187-
Scatter.runInplace clContext workGroupSize
187+
Scatter.scatterLastOccurrence clContext workGroupSize
188188

189189
let getUniqueBitmap = getUniqueBitmapLastOccurrence clContext workGroupSize
190190

@@ -349,7 +349,7 @@ module ClArray =
349349
PrefixSum.runExcludeInplace <@ (+) @> clContext workGroupSize
350350

351351
let scatter =
352-
Scatter.runInplace clContext workGroupSize
352+
Scatter.scatterLastOccurrence clContext workGroupSize
353353

354354
fun (processor: MailboxProcessor<_>) allocationMode (array: ClArray<'a>) ->
355355

src/GraphBLAS-sharp.Backend/Common/Scatter.fs

Lines changed: 62 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,41 +3,19 @@ namespace GraphBLAS.FSharp.Backend.Common
33
open Brahma.FSharp
44

55
module internal Scatter =
6-
7-
/// <summary>
8-
/// Creates a new array from the given one where it is indicated by the array of positions at which position in the new array
9-
/// should be a value from the given one.
10-
/// </summary>
11-
/// <remarks>
12-
/// Every element of the positions array must not be less than the previous one.
13-
/// If there are several elements with the same indices, the last one of them will be at the common index.
14-
/// If index is out of bounds, the value will be ignored.
15-
/// </remarks>
16-
/// <example>
17-
/// <code>
18-
/// let positions = [| 0; 0; 1; 1; 1; 2; 3; 3; 4 |]
19-
/// let values = [| 1.9; 2.8; 3.7; 4.6; 5.5; 6.4; 7.3; 8.2; 9.1 |]
20-
/// let result = run clContext 32 processor positions values result
21-
/// ...
22-
/// > val result = [| 2.8; 5.5; 6.4; 8.2; 9.1 |]
23-
/// </code>
24-
/// </example>
25-
let runInplace<'a> (clContext: ClContext) workGroupSize =
6+
let private general<'a> predicate (clContext: ClContext) workGroupSize =
267

278
let run =
289
<@ fun (ndRange: Range1D) (positions: ClArray<int>) (positionsLength: int) (values: ClArray<'a>) (result: ClArray<'a>) (resultLength: int) ->
2910

3011
let gid = ndRange.GlobalID0
3112

3213
if gid < positionsLength then
33-
let index = positions.[gid]
14+
// positions lengths == values length
15+
let predicateResult = (%predicate) gid positionsLength positions resultLength
3416

35-
if 0 <= index && index < resultLength then
36-
if gid < positionsLength - 1 then
37-
if index <> positions.[gid + 1] then
38-
result.[index] <- values.[gid]
39-
else
40-
result.[index] <- values.[gid] @>
17+
if predicateResult then
18+
result.[positions.[gid]] <- values.[gid] @>
4119

4220
let program = clContext.Compile(run)
4321

@@ -58,3 +36,60 @@ module internal Scatter =
5836
)
5937

6038
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
39+
40+
/// <summary>
41+
/// Creates a new array from the given one where it is indicated by the array of positions at which position in the new array
42+
/// should be a value from the given one.
43+
/// </summary>
44+
/// <remarks>
45+
/// Every element of the positions array must not be less than the previous one.
46+
/// If there are several elements with the same indices, the FIRST one of them will be at the common index.
47+
/// If index is out of bounds, the value will be ignored.
48+
/// </remarks>
49+
/// <example>
50+
/// <code>
51+
/// let positions = [| 0; 0; 1; 1; 1; 2; 3; 3; 4 |]
52+
/// let values = [| 1.9; 2.8; 3.7; 4.6; 5.5; 6.4; 7.3; 8.2; 9.1 |]
53+
/// let result = run clContext 32 processor positions values result
54+
/// ...
55+
/// > val result = [| 1,9; 3.7; 6.4; 7.3; 9.1 |]
56+
/// </code>
57+
/// </example>
58+
let scatterFirstOccurrence clContext =
59+
general
60+
<| <@ fun gid _ (positions: ClArray<int>) resultLength ->
61+
let currentKey = positions.[gid]
62+
// first occurrence condition
63+
(gid = 0 || positions.[gid - 1] <> positions.[gid])
64+
// result position in valid range
65+
&& (0 <= currentKey && currentKey < resultLength) @>
66+
<| clContext
67+
68+
/// <summary>
69+
/// Creates a new array from the given one where it is indicated by the array of positions at which position in the new array
70+
/// should be a value from the given one.
71+
/// </summary>
72+
/// <remarks>
73+
/// Every element of the positions array must not be less than the previous one.
74+
/// If there are several elements with the same indices, the last one of them will be at the common index.
75+
/// If index is out of bounds, the value will be ignored.
76+
/// </remarks>
77+
/// <example>
78+
/// <code>
79+
/// let positions = [| 0; 0; 1; 1; 1; 2; 3; 3; 4 |]
80+
/// let values = [| 1.9; 2.8; 3.7; 4.6; 5.5; 6.4; 7.3; 8.2; 9.1 |]
81+
/// let result = run clContext 32 processor positions values result
82+
/// ...
83+
/// > val result = [| 2.8; 5.5; 6.4; 8.2; 9.1 |]
84+
/// </code>
85+
/// </example>
86+
let scatterLastOccurrence clContext =
87+
general
88+
<| <@ fun gid positionsLength (positions: ClArray<int>) resultLength ->
89+
let index = positions.[gid]
90+
// last occurrence condition
91+
(gid = positionsLength - 1 || index <> positions.[gid + 1])
92+
// result position in valid range
93+
&& (0 <= index && index < resultLength) @>
94+
<| clContext
95+

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM/Expand.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ module Expand =
8181

8282
let init = ClArray.init clContext workGroupSize Map.id
8383

84-
let scatter = Scatter.runInplace clContext workGroupSize
84+
let scatter = Scatter.scatterLastOccurrence clContext workGroupSize
8585

8686
let zeroCreate = ClArray.zeroCreate clContext workGroupSize
8787

@@ -214,7 +214,7 @@ module Expand =
214214

215215
let init = ClArray.init clContext workGroupSize Map.id // TODO(fuse)
216216

217-
let scatter = Scatter.runInplace clContext workGroupSize
217+
let scatter = Scatter.scatterLastOccurrence clContext workGroupSize
218218

219219
fun (processor: MailboxProcessor<_>) allocationMode (values: ClArray<'a>) (columns: Indices) (rows: Indices) ->
220220

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMMMasked.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ module internal SpGEMMMasked =
151151
let calculate =
152152
calculate context workGroupSize opAdd opMul
153153

154-
let scatter = Scatter.runInplace context workGroupSize
155-
let scatterData = Scatter.runInplace context workGroupSize
154+
let scatter = Scatter.scatterLastOccurrence context workGroupSize
155+
let scatterData = Scatter.scatterLastOccurrence context workGroupSize
156156

157157
let scanInplace =
158158
PrefixSum.standardExcludeInplace context workGroupSize

src/GraphBLAS-sharp.Backend/Matrix/Common.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ module Common =
1111
let setPositions<'a when 'a: struct> (clContext: ClContext) workGroupSize =
1212

1313
let indicesScatter =
14-
Scatter.runInplace clContext workGroupSize
14+
Scatter.scatterLastOccurrence clContext workGroupSize
1515

1616
let valuesScatter =
17-
Scatter.runInplace clContext workGroupSize
17+
Scatter.scatterLastOccurrence clContext workGroupSize
1818

1919
let sum =
2020
PrefixSum.standardExcludeInplace clContext workGroupSize

src/GraphBLAS-sharp.Backend/Vector/DenseVector/DenseVector.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ module DenseVector =
9090
let toSparse<'a when 'a: struct> (clContext: ClContext) workGroupSize =
9191

9292
let scatterValues =
93-
Scatter.runInplace clContext workGroupSize
93+
Scatter.scatterLastOccurrence clContext workGroupSize
9494

9595
let scatterIndices =
96-
Scatter.runInplace clContext workGroupSize
96+
Scatter.scatterLastOccurrence clContext workGroupSize
9797

9898
let getBitmap =
9999
ClArray.map clContext workGroupSize

src/GraphBLAS-sharp.Backend/Vector/SparseVector/Common.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ module internal Common =
1313
PrefixSum.standardExcludeInplace clContext workGroupSize
1414

1515
let valuesScatter =
16-
Scatter.runInplace clContext workGroupSize
16+
Scatter.scatterLastOccurrence clContext workGroupSize
1717

1818
let indicesScatter =
19-
Scatter.runInplace clContext workGroupSize
19+
Scatter.scatterLastOccurrence clContext workGroupSize
2020

2121
fun (processor: MailboxProcessor<_>) allocationMode (allValues: ClArray<'a>) (allIndices: ClArray<int>) (positions: ClArray<int>) ->
2222

src/GraphBLAS-sharp.Backend/Vector/Vector.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ module Vector =
3535

3636
let ofList (clContext: ClContext) workGroupSize =
3737
let scatter =
38-
Scatter.runInplace clContext workGroupSize
38+
Scatter.scatterLastOccurrence clContext workGroupSize
3939

4040
let zeroCreate =
4141
ClArray.zeroCreate clContext workGroupSize

tests/GraphBLAS-sharp.Tests/Common/Scatter.fs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ let wgSize = Tests.Utils.defaultWorkGroupSize
2121

2222
let q = defaultContext.Queue
2323

24-
let makeTest scatter (array: (int * 'a) []) (result: 'a []) =
24+
let makeTest hostScatter scatter (array: (int * 'a) []) (result: 'a []) =
2525
if array.Length > 0 then
2626
let positions, values = Array.unzip array
2727

2828
let expected =
2929
Array.copy result
30-
|> HostPrimitives.scatter positions values
30+
|> hostScatter positions values
3131

3232
let actual =
3333
use clPositions = context.CreateClArray positions
@@ -41,15 +41,29 @@ let makeTest scatter (array: (int * 'a) []) (result: 'a []) =
4141
$"Arrays should be equal. Actual is \n%A{actual}, expected \n%A{expected}"
4242
|> Tests.Utils.compareArrays (=) actual expected
4343

44-
let testFixtures<'a when 'a: equality> =
45-
Scatter.runInplace<'a> context wgSize
46-
|> makeTest
44+
let testFixturesLast<'a when 'a: equality> hostScatter =
45+
Scatter.scatterLastOccurrence<'a> context wgSize
46+
|> makeTest hostScatter
47+
|> testPropertyWithConfig { config with endSize = 10 } $"Correctness on %A{typeof<'a>}"
48+
49+
let testFixturesFirst<'a when 'a: equality> hostScatter =
50+
Scatter.scatterFirstOccurrence<'a> context wgSize
51+
|> makeTest hostScatter
4752
|> testPropertyWithConfig { config with endSize = 10 } $"Correctness on %A{typeof<'a>}"
4853

4954
let tests =
5055
q.Error.Add(fun e -> failwithf $"%A{e}")
5156

52-
[ testFixtures<int>
53-
testFixtures<byte>
54-
testFixtures<bool> ]
55-
|> testList "Backend.Common.Scatter tests"
57+
let last =
58+
[ testFixturesLast<int> HostPrimitives.scatterLastOccurrence
59+
testFixturesLast<byte> HostPrimitives.scatterLastOccurrence
60+
testFixturesLast<bool> HostPrimitives.scatterLastOccurrence ]
61+
|> testList "Last Occurrence"
62+
63+
let first =
64+
[ testFixturesFirst<int> HostPrimitives.scatterFirstOccurrence
65+
testFixturesFirst<byte> HostPrimitives.scatterFirstOccurrence
66+
testFixturesFirst<bool> HostPrimitives.scatterFirstOccurrence ]
67+
|> testList "First Occurrence"
68+
69+
testList "Scatter tests" [first; last]

tests/GraphBLAS-sharp.Tests/Helpers.fs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,11 @@ module HostPrimitives =
200200
||> Array.map2 (fun (fst, snd) value -> fst, snd, value)
201201
|> Array.unzip3
202202

203-
let scatter (positions: int array) (values: 'a array) (resultValues: 'a array) =
203+
let generalScatter getBitmap (positions: int array) (values: 'a array) (resultValues: 'a array) =
204204

205205
if positions.Length <> values.Length then failwith "Lengths must be the same"
206206

207-
let bitmap = getUniqueBitmapLastOccurrence positions
207+
let bitmap = getBitmap positions
208208

209209
Array.iteri2
210210
(fun index bit key ->
@@ -215,6 +215,10 @@ module HostPrimitives =
215215

216216
resultValues
217217

218+
let scatterLastOccurrence positions = generalScatter getUniqueBitmapLastOccurrence positions
219+
220+
let scatterFirstOccurrence positions = generalScatter getUniqueBitmapFirstOccurrence positions
221+
218222
let gather (positions: int []) (values: 'a []) (result: 'a []) =
219223
if positions.Length <> result.Length then
220224
failwith "Lengths must be the same"

0 commit comments

Comments
 (0)