@@ -117,3 +117,111 @@ module internal Merge =
117117
118118 return allIndices, allValues
119119 }
120+
121+ let mergeWithScalar ( leftIndices : int []) ( leftValues : 'a []) ( rightIndices : int []) ( scalar : 'a ) ( mask : Mask1D option ) : OpenCLEvaluation < int [] * 'a []> = opencl {
122+ let workGroupSize = Utils.workGroupSize
123+ let firstSide = leftValues.Length
124+ let secondSide = rightIndices.Length
125+ let sumOfSides = firstSide + secondSide
126+
127+ let merge =
128+ <@
129+ fun ( ndRange : _1D )
130+ ( firstIndicesBuffer : int [])
131+ ( firstValuesBuffer : 'a [])
132+ ( secondIndicesBuffer : int [])
133+ ( allIndicesBuffer : int [])
134+ ( allValuesBuffer : 'a []) ->
135+
136+ let i = ndRange.GlobalID0
137+
138+ let mutable beginIdxLocal = local ()
139+ let mutable endIdxLocal = local ()
140+ let localID = ndRange.LocalID0
141+ if localID < 2 then
142+ let mutable x = localID * ( workGroupSize - 1 ) + i - 1
143+ if x >= sumOfSides then x <- sumOfSides - 1
144+ let diagonalNumber = x
145+
146+ let mutable leftEdge = diagonalNumber + 1 - secondSide
147+ if leftEdge < 0 then leftEdge <- 0
148+
149+ let mutable rightEdge = firstSide - 1
150+ if rightEdge > diagonalNumber then rightEdge <- diagonalNumber
151+
152+ while leftEdge <= rightEdge do
153+ let middleIdx = ( leftEdge + rightEdge) / 2
154+ let firstIndex = firstIndicesBuffer.[ middleIdx]
155+ let secondIndex = secondIndicesBuffer.[ diagonalNumber - middleIdx]
156+ if firstIndex < secondIndex then leftEdge <- middleIdx + 1 else rightEdge <- middleIdx - 1
157+
158+ // Here localID equals either 0 or 1
159+ if localID = 0 then beginIdxLocal <- leftEdge else endIdxLocal <- leftEdge
160+ barrier ()
161+
162+ let beginIdx = beginIdxLocal
163+ let endIdx = endIdxLocal
164+ let firstLocalLength = endIdx - beginIdx
165+ let mutable x = workGroupSize - firstLocalLength
166+ if endIdx = firstSide then x <- secondSide - i + localID + beginIdx
167+ let secondLocalLength = x
168+
169+ //First indices are from 0 to firstLocalLength - 1 inclusive
170+ //Second indices are from firstLocalLength to firstLocalLength + secondLocalLength - 1 inclusive
171+ let localIndices = localArray< int> workGroupSize
172+
173+ if localID < firstLocalLength then
174+ localIndices.[ localID] <- firstIndicesBuffer.[ beginIdx + localID]
175+ if localID < secondLocalLength then
176+ localIndices.[ firstLocalLength + localID] <- secondIndicesBuffer.[ i - beginIdx]
177+ barrier ()
178+
179+ if i < sumOfSides then
180+ let mutable leftEdge = localID + 1 - secondLocalLength
181+ if leftEdge < 0 then leftEdge <- 0
182+
183+ let mutable rightEdge = firstLocalLength - 1
184+ if rightEdge > localID then rightEdge <- localID
185+
186+ while leftEdge <= rightEdge do
187+ let middleIdx = ( leftEdge + rightEdge) / 2
188+ let firstIndex = localIndices.[ middleIdx]
189+ let secondIndex = localIndices.[ firstLocalLength + localID - middleIdx]
190+ if firstIndex < secondIndex then leftEdge <- middleIdx + 1 else rightEdge <- middleIdx - 1
191+
192+ let boundaryX = rightEdge
193+ let boundaryY = localID - leftEdge
194+
195+ // boundaryX and boundaryY can't be off the right edge of array (only off the left edge)
196+ let isValidX = boundaryX >= 0
197+ let isValidY = boundaryY >= 0
198+
199+ let mutable fstIdx = 0
200+ if isValidX then fstIdx <- localIndices.[ boundaryX]
201+
202+ let mutable sndIdx = 0
203+ if isValidY then sndIdx <- localIndices.[ firstLocalLength + boundaryY]
204+
205+ if not isValidX || isValidY && fstIdx < sndIdx then
206+ allIndicesBuffer.[ i] <- sndIdx
207+ allValuesBuffer.[ i] <- scalar
208+ else
209+ allIndicesBuffer.[ i] <- fstIdx
210+ allValuesBuffer.[ i] <- firstValuesBuffer.[ beginIdx + boundaryX]
211+ @>
212+
213+ let allIndices = Array.zeroCreate sumOfSides
214+ let allValues = Array.create sumOfSides Unchecked.defaultof< 'a>
215+
216+ do ! RunCommand merge <| fun kernelPrepare ->
217+ let ndRange = _ 1D( Utils.workSize sumOfSides, workGroupSize)
218+ kernelPrepare
219+ ndRange
220+ leftIndices
221+ leftValues
222+ rightIndices
223+ allIndices
224+ allValues
225+
226+ return allIndices, allValues
227+ }
0 commit comments