@@ -119,6 +119,78 @@ module ElementwiseConstructor =
119119 firstResultValues.[ i] <- firstValuesBuffer.[ beginIdx + boundaryX]
120120 isLeftBitMap.[ i] <- 1 @>
121121
122+ let private opWriteBothFill ( opAdd : Expr < 'a option -> 'b option -> 'a -> 'a option >) =
123+ <@
124+ fun gid ( leftValues : ClArray < 'a >) ( rightValues : ClArray < 'b >) ( value : 'a ) ->
125+ (% opAdd) ( Some leftValues.[ gid]) ( Some rightValues.[ gid + 1 ]) value
126+ @>
127+
128+ let private opWriteLeftFill ( opAdd : Expr < 'a option -> 'b option -> 'a -> 'a option >) =
129+ <@
130+ fun gid ( leftValues : ClArray < 'a >) ( value : 'a ) ->
131+ (% opAdd) ( Some leftValues.[ gid]) None value
132+ @>
133+
134+ let private opWriteRightFill ( opAdd : Expr < 'a option -> 'b option -> 'a -> 'a option >) =
135+ <@
136+ fun gid ( rightValues : ClArray < 'b >) ( value : 'a ) ->
137+ (% opAdd) None ( Some rightValues.[ gid + 1 ]) value
138+ @>
139+
140+ let private opWriteAtLeastOneBothFill ( opAdd : Expr < AtLeastOne < 'a , 'b > -> 'a -> 'a option >) =
141+ <@
142+ fun gid ( leftValues : ClArray < 'a >) ( rightValues : ClArray < 'b >) ( value : 'a ) ->
143+ (% opAdd) ( Both( leftValues.[ gid], rightValues.[ gid + 1 ])) value
144+ @>
145+
146+ let private opWriteAtLeastOneLeftFill ( opAdd : Expr < AtLeastOne < 'a , 'b > -> 'a -> 'a option >) =
147+ <@
148+ fun gid ( leftValues : ClArray < 'a >) ( value : 'a ) ->
149+ (% opAdd) ( Left( leftValues.[ gid])) value
150+ @>
151+
152+ let private opWriteAtLeastOneRightFill ( opAdd : Expr < AtLeastOne < 'a , 'b > -> 'a -> 'a option >) =
153+ <@
154+ fun gid ( rightValues : ClArray < 'b >) ( value : 'a ) ->
155+ (% opAdd) ( Right( rightValues.[ gid])) value
156+ @>
157+
158+ let private opWriteBoth ( opAdd : Expr < 'a option -> 'b option -> 'c option >) =
159+ <@
160+ fun gid ( leftValues : ClArray < 'a >) ( rightValues : ClArray < 'b >) ->
161+ (% opAdd) ( Some leftValues.[ gid]) ( Some rightValues.[ gid + 1 ])
162+ @>
163+
164+ let private opWriteLeft ( opAdd : Expr < 'a option -> 'b option -> 'c option >) =
165+ <@
166+ fun gid ( leftValues : ClArray < 'a >) ->
167+ (% opAdd) ( Some leftValues.[ gid]) None
168+ @>
169+
170+ let private opWriteRight ( opAdd : Expr < 'a option -> 'b option -> 'c option >) =
171+ <@
172+ fun gid ( rightValues : ClArray < 'b >) ->
173+ (% opAdd) None ( Some rightValues.[ gid + 1 ])
174+ @>
175+
176+ let private opWriteAtLeastOneBoth ( opAdd : Expr < AtLeastOne < 'a , 'b > -> 'c option >) =
177+ <@
178+ fun gid ( leftValues : ClArray < 'a >) ( rightValues : ClArray < 'b >) ->
179+ (% opAdd) ( Both( leftValues.[ gid], rightValues.[ gid + 1 ]))
180+ @>
181+
182+ let opWriteAtLeastOneLeft ( opAdd : Expr < AtLeastOne < 'a , 'b > -> 'c option >) =
183+ <@
184+ fun gid ( leftValues : ClArray < 'a >) ->
185+ (% opAdd) ( Left( leftValues.[ gid]))
186+ @>
187+
188+ let opWriteAtLeastOneRight ( opAdd : Expr < AtLeastOne < 'a , 'b > -> 'a option >) =
189+ <@
190+ fun gid ( rightValues : ClArray < 'b >) ->
191+ (% opAdd) ( Right( rightValues.[ gid]))
192+ @>
193+
122194 let private both < 'c > =
123195 <@ fun index ( result : 'c option ) ( rawPositionsBuffer : ClArray < int >) ( allValuesBuffer : ClArray < 'c >) ->
124196 rawPositionsBuffer.[ index] <- 0
@@ -144,64 +216,56 @@ module ElementwiseConstructor =
144216 rawPositionsBuffer.[ index] <- 1
145217 | None -> rawPositionsBuffer.[ index] <- 0 @>
146218
147- let preparePositionsAtLeastOne opAdd =
148- <@ fun ( ndRange : Range1D ) length ( allIndices : ClArray < int >) ( leftValues : ClArray < 'a >) ( rightValues : ClArray < 'b >) ( isLeft : ClArray < int >) ( allValues : ClArray < 'c >) ( positions : ClArray < int >) ->
149-
150- let gid = ndRange.GlobalID0
151-
152- if gid < length - 1
153- && allIndices.[ gid] = allIndices.[ gid + 1 ] then
154- let result = (% opAdd) ( Both( leftValues.[ gid], rightValues.[ gid + 1 ]))
219+ let private preparePositionsGeneral
220+ ( bothWrite : Expr <( int -> ClArray < 'a > -> ClArray < 'b > -> 'c option )>)
221+ leftWrite
222+ rightWrite
223+ =
155224
156- (% both) gid result positions allValues
157- elif ( gid < length
158- && gid > 0
159- && allIndices.[ gid - 1 ] <> allIndices.[ gid])
160- || gid = 0 then
161-
162- let leftResult = (% opAdd) ( Left( leftValues.[ gid]))
163- let rightResult = (% opAdd) ( Right( rightValues.[ gid]))
164-
165- (% leftRight) gid leftResult rightResult isLeft allValues positions @>
166-
167- let preparePositions ( opAdd : Expr < 'a option -> 'b option -> 'c option >) =
168225 <@ fun ( ndRange : Range1D ) length ( allIndices : ClArray < int >) ( leftValues : ClArray < 'a >) ( rightValues : ClArray < 'b >) ( isLeft : ClArray < int >) ( allValues : ClArray < 'c >) ( positions : ClArray < int >) ->
169226
170227 let gid = ndRange.GlobalID0
171228
172229 if gid < length - 1
173230 && allIndices.[ gid] = allIndices.[ gid + 1 ] then
174- let result = (% opAdd ) ( Some leftValues.[ gid ]) ( Some rightValues.[ gid + 1 ])
231+ let ( result : 'c option ) = (% bothWrite ) gid leftValues rightValues
175232
176233 (% both) gid result positions allValues
177234 elif ( gid < length
178235 && gid > 0
179236 && allIndices.[ gid - 1 ] <> allIndices.[ gid])
180237 || gid = 0 then
181238
182- let leftResult = (% opAdd ) ( Some leftValues .[ gid]) None
183- let rightResult = (% opAdd ) None ( Some rightValues.[ gid ])
239+ let leftResult = (% leftWrite ) gid leftValues
240+ let rightResult = (% rightWrite ) gid rightValues
184241
185242 (% leftRight) gid leftResult rightResult isLeft allValues positions @>
186243
187- let preparePositionsFillSubVectorAtLeasOne opAdd =
188- <@ fun ( ndRange : Range1D ) length ( allIndices : ClArray < int >) ( leftValues : ClArray < 'a >) ( rightValues : ClArray < 'b >) ( value : ClCell < 'a >) ( isLeft : ClArray < int >) ( allValues : ClArray < 'c >) ( positions : ClArray < int >) ->
244+ let private prepareFillVectorGeneral bothWrite leftWrite rightWrite =
245+ <@ fun ( ndRange : Range1D ) length ( allIndices : ClArray < int >) ( leftValues : ClArray < 'a >) ( rightValues : ClArray < 'b >) ( value : ClCell < 'a >) ( isLeft : ClArray < int >) ( allValues : ClArray < 'a >) ( positions : ClArray < int >) ->
189246
190247 let gid = ndRange.GlobalID0
191248
192249 let value = value.Value
193250
194251 if gid < length - 1
195252 && allIndices.[ gid] = allIndices.[ gid + 1 ] then
196- let result = (% opAdd ) ( Both ( leftValues .[ gid], rightValues .[ gid + 1 ])) value
253+ let ( result : 'a option ) = (% bothWrite ) gid leftValues rightValues value
197254
198255 (% both) gid result positions allValues
199256 elif ( gid < length
200257 && gid > 0
201258 && allIndices.[ gid - 1 ] <> allIndices.[ gid])
202259 || gid = 0 then
203- let leftResult = (% opAdd ) ( Left ( leftValues .[ gid])) value
204- let rightResult = (% opAdd ) ( Right ( rightValues .[ gid])) value
260+ let leftResult = (% leftWrite ) gid leftValues value
261+ let rightResult = (% rightWrite ) gid rightValues value
205262
206263 (% leftRight) gid leftResult rightResult isLeft allValues positions @>
207264
265+ let preparePositions opAdd = preparePositionsGeneral ( opWriteBoth opAdd) ( opWriteLeft opAdd) ( opWriteRight opAdd)
266+
267+ let preparePositionsAtLeastOne opAdd = preparePositionsGeneral ( opWriteAtLeastOneBoth opAdd) ( opWriteAtLeastOneLeft opAdd) ( opWriteAtLeastOneRight opAdd)
268+
269+ let prepareFillVector opAdd = prepareFillVectorGeneral ( opWriteBothFill opAdd) ( opWriteLeftFill opAdd) ( opWriteRightFill opAdd)
270+
271+ let prepareFillVectorAtLeastOne opAdd = prepareFillVectorGeneral ( opWriteAtLeastOneBothFill opAdd) ( opWriteAtLeastOneLeftFill opAdd) ( opWriteAtLeastOneRightFill opAdd)
0 commit comments