Skip to content

Commit 481f047

Browse files
committed
refactor: DenseVector.toSparse
1 parent a87785a commit 481f047

3 files changed

Lines changed: 38 additions & 74 deletions

File tree

src/GraphBLAS-sharp.Backend/GraphBLAS-sharp.Backend.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
<Compile Include="Quotes/SubSum.fs" />
2525
<Compile Include="Quotes/PreparePositions.fs" />
2626
<Compile Include="Quotes/Predicates.fs" />
27+
<Compile Include="Quotes\Map.fs" />
2728
<Compile Include="Common/Scatter.fs" />
2829
<Compile Include="Common/Utils.fs" />
2930
<Compile Include="Common/Sum.fs" />
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module GraphBLAS.FSharp.Backend.Quotes
2+
3+
module Map =
4+
let optionToValueOrZero<'a> =
5+
<@ fun (item: 'a option) ->
6+
match item with
7+
| Some value -> value
8+
| None -> Unchecked.defaultof<'a> @>
9+
10+
let option onSome onNone =
11+
<@ function
12+
| Some _ -> onSome
13+
| None -> onNone @>

src/GraphBLAS-sharp.Backend/Vector/DenseVector/DenseVector.fs

Lines changed: 24 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -86,65 +86,26 @@ module DenseVector =
8686

8787
resultVector
8888

89-
let private getBitmap<'a when 'a: struct> (clContext: ClContext) workGroupSize =
90-
91-
let getPositions =
92-
<@ fun (ndRange: Range1D) length (vector: ClArray<'a option>) (positions: ClArray<int>) ->
93-
94-
let gid = ndRange.GlobalID0
95-
96-
if gid < length then
97-
match vector.[gid] with
98-
| Some _ -> positions.[gid] <- 1
99-
| None -> positions.[gid] <- 0 @>
100-
101-
let kernel = clContext.Compile(getPositions)
102-
103-
fun (processor: MailboxProcessor<_>) allocationMode (vector: ClArray<'a option>) ->
104-
let positions =
105-
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, vector.Length)
106-
107-
let ndRange =
108-
Range1D.CreateValid(vector.Length, workGroupSize)
109-
110-
let kernel = kernel.GetKernel()
111-
112-
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange vector.Length vector positions))
113-
114-
processor.Post(Msg.CreateRunMsg(kernel))
115-
116-
positions
117-
118-
let private getValuesAndIndices<'a when 'a: struct> (clContext: ClContext) workGroupSize =
119-
120-
let getValuesAndIndices =
121-
<@ fun (ndRange: Range1D) length (denseVector: ClArray<'a option>) (positions: ClArray<int>) (resultValues: ClArray<'a>) (resultIndices: ClArray<int>) ->
122-
123-
let gid = ndRange.GlobalID0
124-
125-
if gid = length - 1
126-
|| gid < length
127-
&& positions.[gid] <> positions.[gid + 1] then
128-
let index = positions.[gid]
89+
let toSparse<'a when 'a: struct> (clContext: ClContext) workGroupSize =
12990

130-
match denseVector.[gid] with
131-
| Some value ->
132-
resultValues.[index] <- value
133-
resultIndices.[index] <- gid
134-
| None -> () @>
91+
let scatterValues = Scatter.runInplace clContext workGroupSize
13592

136-
let kernel = clContext.Compile(getValuesAndIndices)
93+
let scatterIndices = Scatter.runInplace clContext workGroupSize
13794

138-
let getPositions = getBitmap clContext workGroupSize
95+
let getBitmap = ClArray.map clContext workGroupSize <| Map.option 1 0
13996

14097
let prefixSum =
14198
PrefixSum.standardExcludeInplace clContext workGroupSize
14299

100+
let allIndices = ClArray.init clContext workGroupSize <@ id @>
101+
102+
let allValues = ClArray.map clContext workGroupSize Map.optionToValueOrZero
103+
143104
let resultLength = Array.zeroCreate<int> 1
144105

145106
fun (processor: MailboxProcessor<_>) allocationMode (vector: ClArray<'a option>) ->
146107

147-
let positions = getPositions processor DeviceOnly vector
108+
let positions = getBitmap processor DeviceOnly vector
148109

149110
let resultLengthGpu = clContext.CreateClCell 0
150111

@@ -159,60 +120,49 @@ module DenseVector =
159120

160121
res.[0]
161122

162-
let resultValues =
163-
clContext.CreateClArrayWithSpecificAllocationMode<'a>(allocationMode, resultLength)
164-
123+
// compute result indices
165124
let resultIndices =
166125
clContext.CreateClArrayWithSpecificAllocationMode<int>(allocationMode, resultLength)
167126

168-
let ndRange =
169-
Range1D.CreateValid(vector.Length, workGroupSize)
127+
let allIndices = allIndices processor DeviceOnly vector.Length
170128

171-
let kernel = kernel.GetKernel()
129+
scatterIndices processor positions allIndices resultIndices
172130

173-
processor.Post(
174-
Msg.MsgSetArguments
175-
(fun () -> kernel.KernelFunc ndRange vector.Length vector positions resultValues resultIndices)
176-
)
177-
178-
processor.Post(Msg.CreateRunMsg(kernel))
131+
processor.Post <| Msg.CreateFreeMsg<_>(allIndices)
179132

180-
processor.Post(Msg.CreateFreeMsg<_>(positions))
133+
// compute result values
134+
let allValues = allValues processor DeviceOnly vector
181135

182-
resultValues, resultIndices
183-
184-
let toSparse<'a when 'a: struct> (clContext: ClContext) workGroupSize =
136+
let resultValues =
137+
clContext.CreateClArrayWithSpecificAllocationMode<'a>(allocationMode, resultLength)
185138

186-
let getValuesAndIndices =
187-
getValuesAndIndices clContext workGroupSize
139+
scatterValues processor positions allValues resultValues
188140

189-
fun (processor: MailboxProcessor<_>) allocationMode (vector: ClArray<'a option>) ->
141+
processor.Post <| Msg.CreateFreeMsg<_>(allValues)
190142

191-
let values, indices =
192-
getValuesAndIndices processor allocationMode vector
143+
processor.Post <| Msg.CreateFreeMsg<_>(positions)
193144

194145
{ Context = clContext
195-
Indices = indices
196-
Values = values
146+
Indices = resultIndices
147+
Values = resultValues
197148
Size = vector.Length }
198149

199150
let reduce<'a when 'a: struct> (clContext: ClContext) workGroupSize (opAdd: Expr<'a -> 'a -> 'a>) =
200151

201152
let getValuesAndIndices =
202-
getValuesAndIndices clContext workGroupSize
153+
ClArray.map clContext workGroupSize Map.optionToValueOrZero
203154

204155
let reduce =
205156
Reduce.reduce clContext workGroupSize opAdd
206157

207158
fun (processor: MailboxProcessor<_>) (vector: ClArray<'a option>) ->
208159

209160
try
210-
let values, indices =
161+
let values =
211162
getValuesAndIndices processor DeviceOnly vector
212163

213164
let result = reduce processor values
214165

215-
processor.Post(Msg.CreateFreeMsg<_>(indices))
216166
processor.Post(Msg.CreateFreeMsg<_>(values))
217167

218168
result

0 commit comments

Comments
 (0)