@@ -3,41 +3,19 @@ namespace GraphBLAS.FSharp.Backend.Common
33open Brahma.FSharp
44
55module internal Scatter =
6-
7- /// <summary >
8- /// Creates a new array from the given one where it is indicated by the array of positions at which position in the new array
9- /// should be a value from the given one.
10- /// </summary >
11- /// <remarks >
12- /// Every element of the positions array must not be less than the previous one.
13- /// If there are several elements with the same indices, the last one of them will be at the common index.
14- /// If index is out of bounds, the value will be ignored.
15- /// </remarks >
16- /// <example >
17- /// <code >
18- /// let positions = [ | 0; 0; 1; 1; 1; 2; 3; 3; 4 |]
19- /// let values = [ | 1.9; 2.8; 3.7; 4.6; 5.5; 6.4; 7.3; 8.2; 9.1 |]
20- /// let result = run clContext 32 processor positions values result
21- /// ...
22- /// > val result = [ | 2.8; 5.5; 6.4; 8.2; 9.1 |]
23- /// </code>
24- /// </example>
25- let runInplace < 'a > ( clContext : ClContext ) workGroupSize =
6+ let private general < 'a > predicate ( clContext : ClContext ) workGroupSize =
267
278 let run =
289 <@ fun ( ndRange : Range1D ) ( positions : ClArray < int >) ( positionsLength : int ) ( values : ClArray < 'a >) ( result : ClArray < 'a >) ( resultLength : int ) ->
2910
3011 let gid = ndRange.GlobalID0
3112
3213 if gid < positionsLength then
33- let index = positions.[ gid]
14+ // positions lengths == values length
15+ let predicateResult = (% predicate) gid positionsLength positions resultLength
3416
35- if 0 <= index && index < resultLength then
36- if gid < positionsLength - 1 then
37- if index <> positions.[ gid + 1 ] then
38- result.[ index] <- values.[ gid]
39- else
40- result.[ index] <- values.[ gid] @>
17+ if predicateResult then
18+ result.[ positions.[ gid]] <- values.[ gid] @>
4119
4220 let program = clContext.Compile( run)
4321
@@ -58,3 +36,60 @@ module internal Scatter =
5836 )
5937
6038 processor.Post( Msg.CreateRunMsg<_, _>( kernel))
39+
40+ /// <summary >
41+ /// Creates a new array from the given one where it is indicated by the array of positions at which position in the new array
42+ /// should be a value from the given one.
43+ /// </summary >
44+ /// <remarks >
45+ /// Every element of the positions array must not be less than the previous one.
46+ /// If there are several elements with the same indices, the FIRST one of them will be at the common index.
47+ /// If index is out of bounds, the value will be ignored.
48+ /// </remarks >
49+ /// <example >
50+ /// <code >
51+ /// let positions = [ | 0; 0; 1; 1; 1; 2; 3; 3; 4 |]
52+ /// let values = [ | 1.9; 2.8; 3.7; 4.6; 5.5; 6.4; 7.3; 8.2; 9.1 |]
53+ /// let result = run clContext 32 processor positions values result
54+ /// ...
55+ /// > val result = [ | 1,9; 3.7; 6.4; 7.3; 9.1 |]
56+ /// </code>
57+ /// </example>
58+ let scatterFirstOccurrence clContext =
59+ general
60+ <| <@ fun gid _ ( positions : ClArray < int >) resultLength ->
61+ let currentKey = positions.[ gid]
62+ // first occurrence condition
63+ ( gid = 0 || positions.[ gid - 1 ] <> positions.[ gid])
64+ // result position in valid range
65+ && ( 0 <= currentKey && currentKey < resultLength) @>
66+ <| clContext
67+
68+ /// <summary >
69+ /// Creates a new array from the given one where it is indicated by the array of positions at which position in the new array
70+ /// should be a value from the given one.
71+ /// </summary >
72+ /// <remarks >
73+ /// Every element of the positions array must not be less than the previous one.
74+ /// If there are several elements with the same indices, the last one of them will be at the common index.
75+ /// If index is out of bounds, the value will be ignored.
76+ /// </remarks >
77+ /// <example >
78+ /// <code >
79+ /// let positions = [ | 0; 0; 1; 1; 1; 2; 3; 3; 4 |]
80+ /// let values = [ | 1.9; 2.8; 3.7; 4.6; 5.5; 6.4; 7.3; 8.2; 9.1 |]
81+ /// let result = run clContext 32 processor positions values result
82+ /// ...
83+ /// > val result = [ | 2.8; 5.5; 6.4; 8.2; 9.1 |]
84+ /// </code>
85+ /// </example>
86+ let scatterLastOccurrence clContext =
87+ general
88+ <| <@ fun gid positionsLength ( positions : ClArray < int >) resultLength ->
89+ let index = positions.[ gid]
90+ // last occurrence condition
91+ ( gid = positionsLength - 1 || index <> positions.[ gid + 1 ])
92+ // result position in valid range
93+ && ( 0 <= index && index < resultLength) @>
94+ <| clContext
95+
0 commit comments