Skip to content

Commit 29c564c

Browse files
committed
add: scatter init value
1 parent 6f02570 commit 29c564c

10 files changed

Lines changed: 132 additions & 96 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ module ClArray =
184184
let removeDuplications (clContext: ClContext) workGroupSize =
185185

186186
let scatter =
187-
Scatter.scatterLastOccurrence clContext workGroupSize
187+
Scatter.lastOccurrence clContext workGroupSize
188188

189189
let getUniqueBitmap = getUniqueBitmapLastOccurrence clContext workGroupSize
190190

@@ -349,7 +349,7 @@ module ClArray =
349349
PrefixSum.runExcludeInplace <@ (+) @> clContext workGroupSize
350350

351351
let scatter =
352-
Scatter.scatterLastOccurrence clContext workGroupSize
352+
Scatter.lastOccurrence clContext workGroupSize
353353

354354
fun (processor: MailboxProcessor<_>) allocationMode (array: ClArray<'a>) ->
355355

src/GraphBLAS-sharp.Backend/Common/Scatter.fs

Lines changed: 78 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@ namespace GraphBLAS.FSharp.Backend.Common
33
open Brahma.FSharp
44

55
module internal Scatter =
6+
let firstOccurencePredicate () =
7+
<@ fun gid _ (positions: ClArray<int>) ->
8+
// first occurrence condition
9+
(gid = 0 || positions.[gid - 1] <> positions.[gid]) @>
10+
11+
let lastOccurrencePredicate () =
12+
<@ fun gid positionsLength (positions: ClArray<int>) ->
13+
// last occurrence condition
14+
(gid = positionsLength - 1 || positions.[gid] <> positions.[gid + 1]) @>
15+
16+
617
let private general<'a> predicate (clContext: ClContext) workGroupSize =
718

819
let run =
@@ -12,9 +23,12 @@ module internal Scatter =
1223

1324
if gid < positionsLength then
1425
// positions lengths == values length
15-
let predicateResult = (%predicate) gid positionsLength positions resultLength
26+
let predicateResult = (%predicate) gid positionsLength positions
27+
let position = positions.[gid]
28+
29+
if predicateResult
30+
&& 0 <= position && position < resultLength then
1631

17-
if predicateResult then
1832
result.[positions.[gid]] <- values.[gid] @>
1933

2034
let program = clContext.Compile(run)
@@ -55,14 +69,9 @@ module internal Scatter =
5569
/// > val result = [| 1,9; 3.7; 6.4; 7.3; 9.1 |]
5670
/// </code>
5771
/// </example>
58-
let scatterFirstOccurrence clContext =
72+
let firstOccurrence clContext =
5973
general
60-
<| <@ fun gid _ (positions: ClArray<int>) resultLength ->
61-
let currentKey = positions.[gid]
62-
// first occurrence condition
63-
(gid = 0 || positions.[gid - 1] <> positions.[gid])
64-
// result position in valid range
65-
&& (0 <= currentKey && currentKey < resultLength) @>
74+
<| firstOccurencePredicate ()
6675
<| clContext
6776

6877
/// <summary>
@@ -71,7 +80,7 @@ module internal Scatter =
7180
/// </summary>
7281
/// <remarks>
7382
/// Every element of the positions array must not be less than the previous one.
74-
/// If there are several elements with the same indices, the last one of them will be at the common index.
83+
/// If there are several elements with the same indices, the LAST one of them will be at the common index.
7584
/// If index is out of bounds, the value will be ignored.
7685
/// </remarks>
7786
/// <example>
@@ -83,60 +92,85 @@ module internal Scatter =
8392
/// > val result = [| 2.8; 5.5; 6.4; 8.2; 9.1 |]
8493
/// </code>
8594
/// </example>
86-
let scatterLastOccurrence clContext =
95+
let lastOccurrence clContext =
8796
general
88-
<| <@ fun gid positionsLength (positions: ClArray<int>) resultLength ->
89-
let index = positions.[gid]
90-
// last occurrence condition
91-
(gid = positionsLength - 1 || index <> positions.[gid + 1])
92-
// result position in valid range
93-
&& (0 <= index && index < resultLength) @>
97+
<| lastOccurrencePredicate ()
9498
<| clContext
9599

96-
/// <summary>
97-
/// Writes elements from the array of values to the array at the positions indicated by the global id map.
98-
/// </summary>
99-
/// <remarks>
100-
/// If index is out of bounds, the value will be ignored.
101-
/// </remarks>
102-
/// <example>
103-
/// <code>
104-
/// let positionMap = fun x -> x + 1
105-
/// let values = [| 1.9; 2.8; 3.7; 4.6; 5.5; 6.4; 7.3; 8.2; 9.1 |]
106-
/// let result = ... // create result
107-
/// run positionMap clContext 32 processor positions values result
108-
/// ...
109-
/// > val result = [| 2.8; 3.7; 4.6; 5.5; 6.4; 7.3; 8.2; 9.1 |]
110-
/// </code>
111-
/// </example>
112-
/// <param name="positionMap">Should be injective in order to avoid race conditions.</param>
113-
let init<'a> positionMap (clContext: ClContext) workGroupSize =
100+
let private generalInit<'a> predicate valueMap (clContext: ClContext) workGroupSize =
114101

115102
let run =
116-
<@ fun (ndRange: Range1D) (valuesLength: int) (values: ClArray<'a>) (result: ClArray<'a>) resultLength ->
103+
<@ fun (ndRange: Range1D) (positions: ClArray<int>) (positionsLength: int) (result: ClArray<'a>) (resultLength: int) ->
117104

118105
let gid = ndRange.GlobalID0
119106

120-
if gid < valuesLength then
121-
let position = (%positionMap) gid
107+
if gid < positionsLength then
108+
// positions lengths == values length
109+
let predicateResult = (%predicate) gid positionsLength positions
110+
111+
let position = positions.[gid]
112+
113+
if predicateResult
114+
&& 0 <= position && position < resultLength then
122115

123-
// may be race condition
124-
if 0 <= position && position < resultLength then
125-
result.[position] <- values.[gid] @>
116+
result.[positions.[gid]] <- (%valueMap) gid @>
126117

127118
let program = clContext.Compile(run)
128119

129-
fun (processor: MailboxProcessor<_>) (values: ClArray<'a>) (result: ClArray<'a>) ->
120+
fun (processor: MailboxProcessor<_>) (positions: ClArray<int>) (result: ClArray<'a>) ->
121+
122+
let positionsLength = positions.Length
130123

131124
let ndRange =
132-
Range1D.CreateValid(values.Length, workGroupSize)
125+
Range1D.CreateValid(positionsLength, workGroupSize)
133126

134127
let kernel = program.GetKernel()
135128

136129
processor.Post(
137130
Msg.MsgSetArguments
138-
(fun () -> kernel.KernelFunc ndRange values.Length values result result.Length)
131+
(fun () -> kernel.KernelFunc ndRange positions positionsLength result result.Length)
139132
)
140133

141134
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
142135

136+
/// <summary>
137+
/// Creates a new array from the given one where it is indicated by the array of positions at which position in the new array
138+
/// should be a values obtained by applying the mapping to the global id.
139+
/// </summary>
140+
/// <remarks>
141+
/// Every element of the positions array must not be less than the previous one.
142+
/// If there are several elements with the same indices, the FIRST one of them will be at the common index.
143+
/// If index is out of bounds, the value will be ignored.
144+
/// </remarks>
145+
/// <example>
146+
/// <code>
147+
/// let positions = [| 0; 0; 1; 1; 1; 2; 3; 3; 4 |]
148+
/// let valueMap = id
149+
/// run clContext 32 processor positions values result
150+
/// ...
151+
/// > val result = [| 0; 2; 5; 6; 8 |]
152+
/// </code>
153+
/// </example>
154+
/// <param name="valueMap">Maps global id to a value</param>
155+
let initFirsOccurrence<'a> valueMap = generalInit<'a> <| firstOccurencePredicate () <| valueMap
156+
157+
/// <summary>
158+
/// Creates a new array from the given one where it is indicated by the array of positions at which position in the new array
159+
/// should be a values obtained by applying the mapping to the global id.
160+
/// </summary>
161+
/// <remarks>
162+
/// Every element of the positions array must not be less than the previous one.
163+
/// If there are several elements with the same indices, the LAST one of them will be at the common index.
164+
/// If index is out of bounds, the value will be ignored.
165+
/// </remarks>
166+
/// <example>
167+
/// <code>
168+
/// let positions = [| 0; 0; 1; 1; 1; 2; 3; 3; 4 |]
169+
/// let valueMap = id
170+
/// run clContext 32 processor positions values result
171+
/// ...
172+
/// > val result = [| 1; 4; 5; 7; 8 |]
173+
/// </code>
174+
/// </example>
175+
/// <param name="valueMap">Maps global id to a value</param>
176+
let initLastOccurrence<'a> valueMap = generalInit<'a> <| lastOccurrencePredicate () <| valueMap

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM/Expand.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ module Expand =
6969

7070
let init = ClArray.init clContext workGroupSize Map.id
7171

72-
let scatter = Scatter.scatterLastOccurrence clContext workGroupSize
72+
let scatter = Scatter.lastOccurrence clContext workGroupSize
7373

7474
let zeroCreate = ClArray.zeroCreate clContext workGroupSize
7575

@@ -202,7 +202,7 @@ module Expand =
202202

203203
let init = ClArray.init clContext workGroupSize Map.id // TODO(fuse)
204204

205-
let scatter = Scatter.scatterFirstOccurrence clContext workGroupSize
205+
let scatter = Scatter.firstOccurrence clContext workGroupSize
206206

207207
fun (processor: MailboxProcessor<_>) allocationMode (values: ClArray<'a>) (columns: Indices) (rows: Indices) ->
208208

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMMMasked.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ module internal SpGEMMMasked =
151151
let calculate =
152152
calculate context workGroupSize opAdd opMul
153153

154-
let scatter = Scatter.scatterLastOccurrence context workGroupSize
155-
let scatterData = Scatter.scatterLastOccurrence context workGroupSize
154+
let scatter = Scatter.lastOccurrence context workGroupSize
155+
let scatterData = Scatter.lastOccurrence context workGroupSize
156156

157157
let scanInplace =
158158
PrefixSum.standardExcludeInplace context workGroupSize

src/GraphBLAS-sharp.Backend/Matrix/Common.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ module Common =
1111
let setPositions<'a when 'a: struct> (clContext: ClContext) workGroupSize =
1212

1313
let indicesScatter =
14-
Scatter.scatterLastOccurrence clContext workGroupSize
14+
Scatter.lastOccurrence clContext workGroupSize
1515

1616
let valuesScatter =
17-
Scatter.scatterLastOccurrence clContext workGroupSize
17+
Scatter.lastOccurrence clContext workGroupSize
1818

1919
let sum =
2020
PrefixSum.standardExcludeInplace clContext workGroupSize

src/GraphBLAS-sharp.Backend/Vector/DenseVector/DenseVector.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,10 @@ module DenseVector =
9090
let toSparse<'a when 'a: struct> (clContext: ClContext) workGroupSize =
9191

9292
let scatterValues =
93-
Scatter.scatterLastOccurrence clContext workGroupSize
93+
Scatter.lastOccurrence clContext workGroupSize
9494

9595
let scatterIndices =
96-
Scatter.scatterLastOccurrence clContext workGroupSize
96+
Scatter.lastOccurrence clContext workGroupSize
9797

9898
let getBitmap =
9999
ClArray.map clContext workGroupSize

src/GraphBLAS-sharp.Backend/Vector/SparseVector/Common.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ module internal Common =
1313
PrefixSum.standardExcludeInplace clContext workGroupSize
1414

1515
let valuesScatter =
16-
Scatter.scatterLastOccurrence clContext workGroupSize
16+
Scatter.lastOccurrence clContext workGroupSize
1717

1818
let indicesScatter =
19-
Scatter.scatterLastOccurrence clContext workGroupSize
19+
Scatter.lastOccurrence clContext workGroupSize
2020

2121
fun (processor: MailboxProcessor<_>) allocationMode (allValues: ClArray<'a>) (allIndices: ClArray<int>) (positions: ClArray<int>) ->
2222

src/GraphBLAS-sharp.Backend/Vector/Vector.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ module Vector =
3535

3636
let ofList (clContext: ClContext) workGroupSize =
3737
let scatter =
38-
Scatter.scatterLastOccurrence clContext workGroupSize
38+
Scatter.lastOccurrence clContext workGroupSize
3939

4040
let zeroCreate =
4141
ClArray.zeroCreate clContext workGroupSize

0 commit comments

Comments
 (0)