@@ -86,65 +86,26 @@ module DenseVector =
8686
8787 resultVector
8888
89- let private getBitmap < 'a when 'a : struct > ( clContext : ClContext ) workGroupSize =
90-
91- let getPositions =
92- <@ fun ( ndRange : Range1D ) length ( vector : ClArray < 'a option >) ( positions : ClArray < int >) ->
93-
94- let gid = ndRange.GlobalID0
95-
96- if gid < length then
97- match vector.[ gid] with
98- | Some _ -> positions.[ gid] <- 1
99- | None -> positions.[ gid] <- 0 @>
100-
101- let kernel = clContext.Compile( getPositions)
102-
103- fun ( processor : MailboxProcessor < _ >) allocationMode ( vector : ClArray < 'a option >) ->
104- let positions =
105- clContext.CreateClArrayWithSpecificAllocationMode( allocationMode, vector.Length)
106-
107- let ndRange =
108- Range1D.CreateValid( vector.Length, workGroupSize)
109-
110- let kernel = kernel.GetKernel()
111-
112- processor.Post( Msg.MsgSetArguments( fun () -> kernel.KernelFunc ndRange vector.Length vector positions))
113-
114- processor.Post( Msg.CreateRunMsg( kernel))
115-
116- positions
117-
118- let private getValuesAndIndices < 'a when 'a : struct > ( clContext : ClContext ) workGroupSize =
119-
120- let getValuesAndIndices =
121- <@ fun ( ndRange : Range1D ) length ( denseVector : ClArray < 'a option >) ( positions : ClArray < int >) ( resultValues : ClArray < 'a >) ( resultIndices : ClArray < int >) ->
122-
123- let gid = ndRange.GlobalID0
124-
125- if gid = length - 1
126- || gid < length
127- && positions.[ gid] <> positions.[ gid + 1 ] then
128- let index = positions.[ gid]
89+ let toSparse < 'a when 'a : struct > ( clContext : ClContext ) workGroupSize =
12990
130- match denseVector.[ gid] with
131- | Some value ->
132- resultValues.[ index] <- value
133- resultIndices.[ index] <- gid
134- | None -> () @>
91+ let scatterValues = Scatter.runInplace clContext workGroupSize
13592
136- let kernel = clContext.Compile ( getValuesAndIndices )
93+ let scatterIndices = Scatter.runInplace clContext workGroupSize
13794
138- let getPositions = getBitmap clContext workGroupSize
95+ let getBitmap = ClArray.map clContext workGroupSize <| Map.option 1 0
13996
14097 let prefixSum =
14198 PrefixSum.standardExcludeInplace clContext workGroupSize
14299
100+ let allIndices = ClArray.init clContext workGroupSize <@ id @>
101+
102+ let allValues = ClArray.map clContext workGroupSize Map.optionToValueOrZero
103+
143104 let resultLength = Array.zeroCreate< int> 1
144105
145106 fun ( processor : MailboxProcessor < _ >) allocationMode ( vector : ClArray < 'a option >) ->
146107
147- let positions = getPositions processor DeviceOnly vector
108+ let positions = getBitmap processor DeviceOnly vector
148109
149110 let resultLengthGpu = clContext.CreateClCell 0
150111
@@ -159,60 +120,49 @@ module DenseVector =
159120
160121 res.[ 0 ]
161122
162- let resultValues =
163- clContext.CreateClArrayWithSpecificAllocationMode< 'a>( allocationMode, resultLength)
164-
123+ // compute result indices
165124 let resultIndices =
166125 clContext.CreateClArrayWithSpecificAllocationMode< int>( allocationMode, resultLength)
167126
168- let ndRange =
169- Range1D.CreateValid( vector.Length, workGroupSize)
127+ let allIndices = allIndices processor DeviceOnly vector.Length
170128
171- let kernel = kernel.GetKernel ()
129+ scatterIndices processor positions allIndices resultIndices
172130
173- processor.Post(
174- Msg.MsgSetArguments
175- ( fun () -> kernel.KernelFunc ndRange vector.Length vector positions resultValues resultIndices)
176- )
177-
178- processor.Post( Msg.CreateRunMsg( kernel))
131+ processor.Post <| Msg.CreateFreeMsg<_>( allIndices)
179132
180- processor.Post( Msg.CreateFreeMsg<_>( positions))
133+ // compute result values
134+ let allValues = allValues processor DeviceOnly vector
181135
182- resultValues, resultIndices
183-
184- let toSparse < 'a when 'a : struct > ( clContext : ClContext ) workGroupSize =
136+ let resultValues =
137+ clContext.CreateClArrayWithSpecificAllocationMode< 'a>( allocationMode, resultLength)
185138
186- let getValuesAndIndices =
187- getValuesAndIndices clContext workGroupSize
139+ scatterValues processor positions allValues resultValues
188140
189- fun ( processor : MailboxProcessor < _ >) allocationMode ( vector : ClArray < 'a option >) ->
141+ processor.Post <| Msg.CreateFreeMsg <_>( allValues )
190142
191- let values , indices =
192- getValuesAndIndices processor allocationMode vector
143+ processor.Post <| Msg.CreateFreeMsg<_>( positions)
193144
194145 { Context = clContext
195- Indices = indices
196- Values = values
146+ Indices = resultIndices
147+ Values = resultValues
197148 Size = vector.Length }
198149
199150 let reduce < 'a when 'a : struct > ( clContext : ClContext ) workGroupSize ( opAdd : Expr < 'a -> 'a -> 'a >) =
200151
201152 let getValuesAndIndices =
202- getValuesAndIndices clContext workGroupSize
153+ ClArray.map clContext workGroupSize Map.optionToValueOrZero
203154
204155 let reduce =
205156 Reduce.reduce clContext workGroupSize opAdd
206157
207158 fun ( processor : MailboxProcessor < _ >) ( vector : ClArray < 'a option >) ->
208159
209160 try
210- let values , indices =
161+ let values =
211162 getValuesAndIndices processor DeviceOnly vector
212163
213164 let result = reduce processor values
214165
215- processor.Post( Msg.CreateFreeMsg<_>( indices))
216166 processor.Post( Msg.CreateFreeMsg<_>( values))
217167
218168 result
0 commit comments