@@ -3,6 +3,17 @@ namespace GraphBLAS.FSharp.Backend.Common
33open Brahma.FSharp
44
55module internal Scatter =
6+ let firstOccurencePredicate () =
7+ <@ fun gid _ ( positions : ClArray < int >) ->
8+ // first occurrence condition
9+ ( gid = 0 || positions.[ gid - 1 ] <> positions.[ gid]) @>
10+
11+ let lastOccurrencePredicate () =
12+ <@ fun gid positionsLength ( positions : ClArray < int >) ->
13+ // last occurrence condition
14+ ( gid = positionsLength - 1 || positions.[ gid] <> positions.[ gid + 1 ]) @>
15+
16+
617 let private general < 'a > predicate ( clContext : ClContext ) workGroupSize =
718
819 let run =
@@ -12,9 +23,12 @@ module internal Scatter =
1223
1324 if gid < positionsLength then
1425 // positions lengths == values length
15- let predicateResult = (% predicate) gid positionsLength positions resultLength
26+ let predicateResult = (% predicate) gid positionsLength positions
27+ let position = positions.[ gid]
28+
29+ if predicateResult
30+ && 0 <= position && position < resultLength then
1631
17- if predicateResult then
1832 result.[ positions.[ gid]] <- values.[ gid] @>
1933
2034 let program = clContext.Compile( run)
@@ -55,14 +69,9 @@ module internal Scatter =
5569 /// > val result = [ | 1,9; 3.7; 6.4; 7.3; 9.1 |]
5670 /// </code>
5771 /// </example>
58- let scatterFirstOccurrence clContext =
72+ let firstOccurrence clContext =
5973 general
60- <| <@ fun gid _ ( positions : ClArray < int >) resultLength ->
61- let currentKey = positions.[ gid]
62- // first occurrence condition
63- ( gid = 0 || positions.[ gid - 1 ] <> positions.[ gid])
64- // result position in valid range
65- && ( 0 <= currentKey && currentKey < resultLength) @>
74+ <| firstOccurencePredicate ()
6675 <| clContext
6776
6877 /// <summary >
@@ -71,7 +80,7 @@ module internal Scatter =
7180 /// </summary >
7281 /// <remarks >
7382 /// Every element of the positions array must not be less than the previous one.
74- /// If there are several elements with the same indices, the last one of them will be at the common index.
83+ /// If there are several elements with the same indices, the LAST one of them will be at the common index.
7584 /// If index is out of bounds, the value will be ignored.
7685 /// </remarks >
7786 /// <example >
@@ -83,60 +92,85 @@ module internal Scatter =
8392 /// > val result = [ | 2.8; 5.5; 6.4; 8.2; 9.1 |]
8493 /// </code>
8594 /// </example>
86- let scatterLastOccurrence clContext =
95+ let lastOccurrence clContext =
8796 general
88- <| <@ fun gid positionsLength ( positions : ClArray < int >) resultLength ->
89- let index = positions.[ gid]
90- // last occurrence condition
91- ( gid = positionsLength - 1 || index <> positions.[ gid + 1 ])
92- // result position in valid range
93- && ( 0 <= index && index < resultLength) @>
97+ <| lastOccurrencePredicate ()
9498 <| clContext
9599
96- /// <summary >
97- /// Writes elements from the array of values to the array at the positions indicated by the global id map.
98- /// </summary >
99- /// <remarks >
100- /// If index is out of bounds, the value will be ignored.
101- /// </remarks >
102- /// <example >
103- /// <code >
104- /// let positionMap = fun x -> x + 1
105- /// let values = [ | 1.9; 2.8; 3.7; 4.6; 5.5; 6.4; 7.3; 8.2; 9.1 |]
106- /// let result = ... // create result
107- /// run positionMap clContext 32 processor positions values result
108- /// ...
109- /// > val result = [ | 2.8; 3.7; 4.6; 5.5; 6.4; 7.3; 8.2; 9.1 |]
110- /// </code>
111- /// </example>
112- /// <param name="positionMap"> Should be injective in order to avoid race conditions.</param >
113- let init < 'a > positionMap ( clContext : ClContext ) workGroupSize =
100+ let private generalInit < 'a > predicate valueMap ( clContext : ClContext ) workGroupSize =
114101
115102 let run =
116- <@ fun ( ndRange : Range1D ) ( valuesLength : int) ( values : ClArray < 'a > ) ( result : ClArray < 'a >) resultLength ->
103+ <@ fun ( ndRange : Range1D ) ( positions : ClArray < int > ) ( positionsLength : int ) ( result : ClArray < 'a >) ( resultLength : int ) ->
117104
118105 let gid = ndRange.GlobalID0
119106
120- if gid < valuesLength then
121- let position = (% positionMap) gid
107+ if gid < positionsLength then
108+ // positions lengths == values length
109+ let predicateResult = (% predicate) gid positionsLength positions
110+
111+ let position = positions.[ gid]
112+
113+ if predicateResult
114+ && 0 <= position && position < resultLength then
122115
123- // may be race condition
124- if 0 <= position && position < resultLength then
125- result.[ position] <- values.[ gid] @>
116+ result.[ positions.[ gid]] <- (% valueMap) gid @>
126117
127118 let program = clContext.Compile( run)
128119
129- fun ( processor : MailboxProcessor < _ >) ( values : ClArray < 'a >) ( result : ClArray < 'a >) ->
120+ fun ( processor : MailboxProcessor < _ >) ( positions : ClArray < int >) ( result : ClArray < 'a >) ->
121+
122+ let positionsLength = positions.Length
130123
131124 let ndRange =
132- Range1D.CreateValid( values.Length , workGroupSize)
125+ Range1D.CreateValid( positionsLength , workGroupSize)
133126
134127 let kernel = program.GetKernel()
135128
136129 processor.Post(
137130 Msg.MsgSetArguments
138- ( fun () -> kernel.KernelFunc ndRange values.Length values result result.Length)
131+ ( fun () -> kernel.KernelFunc ndRange positions positionsLength result result.Length)
139132 )
140133
141134 processor.Post( Msg.CreateRunMsg<_, _>( kernel))
142135
136+ /// <summary >
137+ /// Creates a new array from the given one where it is indicated by the array of positions at which position in the new array
138+ /// should be a values obtained by applying the mapping to the global id.
139+ /// </summary >
140+ /// <remarks >
141+ /// Every element of the positions array must not be less than the previous one.
142+ /// If there are several elements with the same indices, the FIRST one of them will be at the common index.
143+ /// If index is out of bounds, the value will be ignored.
144+ /// </remarks >
145+ /// <example >
146+ /// <code >
147+ /// let positions = [ | 0; 0; 1; 1; 1; 2; 3; 3; 4 |]
148+ /// let valueMap = id
149+ /// run clContext 32 processor positions values result
150+ /// ...
151+ /// > val result = [ | 0; 2; 5; 6; 8 |]
152+ /// </code>
153+ /// </example>
154+ /// <param name="valueMap"> Maps global id to a value</param >
155+ let initFirsOccurrence < 'a > valueMap = generalInit< 'a> <| firstOccurencePredicate () <| valueMap
156+
157+ /// <summary >
158+ /// Creates a new array from the given one where it is indicated by the array of positions at which position in the new array
159+ /// should be a values obtained by applying the mapping to the global id.
160+ /// </summary >
161+ /// <remarks >
162+ /// Every element of the positions array must not be less than the previous one.
163+ /// If there are several elements with the same indices, the LAST one of them will be at the common index.
164+ /// If index is out of bounds, the value will be ignored.
165+ /// </remarks >
166+ /// <example >
167+ /// <code >
168+ /// let positions = [ | 0; 0; 1; 1; 1; 2; 3; 3; 4 |]
169+ /// let valueMap = id
170+ /// run clContext 32 processor positions values result
171+ /// ...
172+ /// > val result = [ | 1; 4; 5; 7; 8 |]
173+ /// </code>
174+ /// </example>
175+ /// <param name="valueMap"> Maps global id to a value</param >
176+ let initLastOccurrence < 'a > valueMap = generalInit< 'a> <| lastOccurrencePredicate () <| valueMap
0 commit comments