Skip to content

Commit 84fb950

Browse files
committed
wip: getUniqueBitmap{first/last} occurrence
1 parent 7e09219 commit 84fb950

7 files changed

Lines changed: 162 additions & 35 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,20 @@ module ClArray =
130130

131131
outputArray
132132

133-
let getUniqueBitmap (clContext: ClContext) workGroupSize =
133+
let private getUniqueBitmapGeneral predicate (clContext: ClContext) workGroupSize =
134134

135135
let getUniqueBitmap =
136136
<@ fun (ndRange: Range1D) (inputArray: ClArray<'a>) inputLength (isUniqueBitmap: ClArray<int>) ->
137137

138-
let i = ndRange.GlobalID0
138+
let gid = ndRange.GlobalID0
139139

140-
if i < inputLength - 1
141-
&& inputArray.[i] = inputArray.[i + 1] then
142-
isUniqueBitmap.[i] <- 0
143-
else
144-
isUniqueBitmap.[i] <- 1 @>
140+
if gid < inputLength then
141+
let isUnique = (%predicate) gid inputLength inputArray // brahma error
142+
143+
if isUnique then
144+
isUniqueBitmap.[gid] <- 1
145+
else
146+
isUniqueBitmap.[gid] <- 0 @>
145147

146148
let kernel = clContext.Compile(getUniqueBitmap)
147149

@@ -163,6 +165,18 @@ module ClArray =
163165

164166
bitmap
165167

168+
let getUniqueBitmapFirstOccurrence clContext =
169+
getUniqueBitmapGeneral
170+
<| <@ fun (gid: int) (_: int) (inputArray: ClArray<'a>) ->
171+
gid = 0 || inputArray.[gid - 1] <> inputArray.[gid] @>
172+
<| clContext
173+
174+
let getUniqueBitmapLastOccurrence clContext =
175+
getUniqueBitmapGeneral
176+
<| <@ fun (gid: int) (length: int) (inputArray: ClArray<'a>) ->
177+
gid = length - 1 || inputArray.[gid] <> inputArray.[gid + 1] @>
178+
<| clContext
179+
166180
///<description>Remove duplicates form the given array.</description>
167181
///<param name="clContext">Computational context</param>
168182
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
@@ -172,7 +186,7 @@ module ClArray =
172186
let scatter =
173187
Scatter.runInplace clContext workGroupSize
174188

175-
let getUniqueBitmap = getUniqueBitmap clContext workGroupSize
189+
let getUniqueBitmap = getUniqueBitmapLastOccurrence clContext workGroupSize
176190

177191
let prefixSumExclude =
178192
PrefixSum.runExcludeInplace <@ (+) @> clContext workGroupSize
@@ -292,24 +306,33 @@ module ClArray =
292306

293307
resultArray
294308

295-
let getUniqueBitmap2<'a when 'a: equality> (clContext: ClContext) workGroupSize =
309+
let getUniqueBitmap2General<'a when 'a: equality> getUniqueBitmap (clContext: ClContext) workGroupSize =
296310

297-
let map = map2 clContext workGroupSize <@ fun x y -> if x = 1 && y = 1 then 1 else 0 @>
311+
let map = map2 clContext workGroupSize <@ fun x y -> x ||| y @>
298312

299-
let getUniqueBitmap = getUniqueBitmap clContext workGroupSize
313+
let firstGetBitmap = getUniqueBitmap clContext workGroupSize
300314

301315
fun (processor: MailboxProcessor<_>) allocationMode (firstArray: ClArray<'a>) (secondArray: ClArray<'a>) ->
302-
let firstBitmap = getUniqueBitmap processor DeviceOnly firstArray
316+
let firstBitmap = firstGetBitmap processor DeviceOnly firstArray
303317

304-
let secondBitmap = getUniqueBitmap processor DeviceOnly secondArray
318+
let secondBitmap = firstGetBitmap processor DeviceOnly secondArray
305319

306320
let result = map processor allocationMode firstBitmap secondBitmap
307321

322+
printfn $"first bitmap: %A{firstBitmap.ToHost processor}"
323+
printfn $"second bitmap: %A{secondBitmap.ToHost processor}"
324+
308325
firstBitmap.Free processor
309326
secondBitmap.Free processor
310327

311328
result
312329

330+
let getUniqueBitmap2FirstOccurrence clContext =
331+
getUniqueBitmap2General getUniqueBitmapFirstOccurrence clContext
332+
333+
let getUniqueBitmap2LastOccurrence clContext =
334+
getUniqueBitmap2General getUniqueBitmapLastOccurrence clContext
335+
313336
let choose<'a, 'b> (clContext: ClContext) workGroupSize (predicate: Expr<'a -> 'b option>) =
314337
let getBitmap =
315338
map<'a, int> clContext workGroupSize

src/GraphBLAS-sharp.Backend/Common/Sort/Radix.fs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,25 +264,28 @@ module Radix =
264264
let scatterByKey =
265265
scatterByKey clContext workGroupSize mask
266266

267-
fun (processor: MailboxProcessor<_>) (keys: Indices) (values: ClArray<'a>) ->
267+
fun (processor: MailboxProcessor<_>) allocationMode (keys: Indices) (values: ClArray<'a>) ->
268268
if values.Length <> keys.Length then
269269
failwith "Mismatch of key lengths and value. Lengths must be the same"
270270

271271
if values.Length <= 1 then
272-
values
272+
dataCopy processor allocationMode values
273273
else
274274
let firstKeys = copy processor DeviceOnly keys
275275

276276
let secondKeys =
277277
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, keys.Length)
278278

279-
let secondValues = dataCopy processor DeviceOnly values
279+
let firstValues = dataCopy processor DeviceOnly values
280+
281+
let secondValues =
282+
clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, values.Length)
280283

281284
let workGroupCount =
282285
clContext.CreateClCell((keys.Length - 1) / workGroupSize + 1)
283286

284287
let mutable keysPair = (firstKeys, secondKeys)
285-
let mutable valuesPair = (values, secondValues)
288+
let mutable valuesPair = (firstValues, secondValues)
286289

287290
let swap (x, y) = y, x
288291
// compute bound of iterations

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM/Expand.fs

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,16 +184,16 @@ module Expand =
184184

185185
fun (processor: MailboxProcessor<_>) (values: ClArray<'a>) (columns: Indices) (rows: Indices) ->
186186
// sort by columns
187-
let valuesSortedByColumns = sortByKeyValues processor columns values
187+
let valuesSortedByColumns = sortByKeyValues processor DeviceOnly columns values
188188

189-
let rowsSortedByColumns = sortByKeyIndices processor columns rows
189+
let rowsSortedByColumns = sortByKeyIndices processor DeviceOnly columns rows
190190

191191
let sortedColumns = sortKeys processor columns
192192

193193
// sort by rows
194-
let valuesSortedByRows = sortByKeyValues processor rows valuesSortedByColumns
194+
let valuesSortedByRows = sortByKeyValues processor DeviceOnly rows valuesSortedByColumns
195195

196-
let columnsSortedByRows = sortByKeyIndices processor rows sortedColumns
196+
let columnsSortedByRows = sortByKeyIndices processor DeviceOnly rows sortedColumns
197197

198198
let sortedRows = sortKeys processor rowsSortedByColumns
199199

@@ -208,21 +208,36 @@ module Expand =
208208
let reduce = Reduce.ByKey2D.segmentSequential clContext workGroupSize opAdd
209209

210210
let getUniqueBitmap =
211-
ClArray.getUniqueBitmap2 clContext workGroupSize
211+
ClArray.getUniqueBitmap2FirstOccurrence clContext workGroupSize
212212

213213
let prefixSum = PrefixSum.standardExcludeInplace clContext workGroupSize
214214

215-
let removeDuplicates = ClArray.removeDuplications clContext workGroupSize
215+
let init = ClArray.init clContext workGroupSize Map.id // TODO(fuse)
216+
217+
let scatter = Scatter.runInplace clContext workGroupSize
216218

217219
fun (processor: MailboxProcessor<_>) allocationMode (values: ClArray<'a>) (columns: Indices) (rows: Indices) ->
218220

219221
let bitmap = getUniqueBitmap processor DeviceOnly columns rows
220222

223+
printfn $"key bitmap: %A{bitmap.ToHost processor}"
224+
221225
let uniqueKeysCount = (prefixSum processor bitmap).ToHostAndFree processor
222226

223-
let offsets = removeDuplicates processor bitmap
227+
printfn $"key bitmap after prefix sum: %A{bitmap.ToHost processor}"
228+
229+
let positions = init processor DeviceOnly bitmap.Length
230+
231+
printfn $"positions: %A{positions.ToHost processor}"
232+
233+
let offsets = clContext.CreateClArrayWithSpecificAllocationMode(DeviceOnly, uniqueKeysCount)
234+
235+
scatter processor bitmap positions offsets
236+
237+
printfn $"offsets: %A{offsets.ToHost processor}"
224238

225239
bitmap.Free processor
240+
positions.Free processor
226241

227242
let reducedColumns, reducedRows, reducedValues =
228243
reduce processor allocationMode uniqueKeysCount offsets columns rows values
@@ -231,7 +246,7 @@ module Expand =
231246

232247
reducedValues, reducedColumns, reducedRows
233248

234-
let run (clContext: ClContext) workGroupSize opMul opAdd =
249+
let run (clContext: ClContext) workGroupSize opAdd opMul =
235250

236251
let getSegmentPointers = getSegmentPointers clContext workGroupSize
237252

@@ -248,18 +263,31 @@ module Expand =
248263
let values, columns, rows =
249264
expand processor length segmentPointers leftMatrix rightMatrix
250265

266+
printfn $"expanded values: %A{values.ToHost processor}"
267+
printfn $"expanded columns: %A{columns.ToHost processor}"
268+
printfn $"expanded rows: %A{rows.ToHost processor}"
269+
251270
let sortedValues, sortedColumns, sortedRows =
252271
sort processor values columns rows
253272

273+
printfn $"sorted values: %A{sortedValues.ToHost processor}"
274+
printfn $"sorted columns: %A{sortedColumns.ToHost processor}"
275+
printfn $"sorted rows: %A{sortedRows.ToHost processor}"
276+
254277
values.Free processor
255278
columns.Free processor
256279
rows.Free processor
257280

258281
let reducedValues, reducedColumns, reducedRows =
259282
reduce processor allocationMode sortedValues sortedColumns sortedRows
260283

284+
printfn $"reduced values: %A{reducedValues.ToHost processor}"
285+
printfn $"reduced columns: %A{reducedColumns.ToHost processor}"
286+
printfn $"reduced rows: %A{reducedRows.ToHost processor}"
287+
261288
sortedValues.Free processor
262289
sortedColumns.Free processor
263290
sortedRows.Free processor
264291

265292
reducedValues, reducedColumns, reducedRows
293+

tests/GraphBLAS-sharp.Tests/Common/Sort/Radix.fs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ open GraphBLAS.FSharp.Backend.Common.Sort
55
open GraphBLAS.FSharp.Tests
66
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
77
open Brahma.FSharp
8+
open GraphBLAS.FSharp.Backend.Objects.ClContext
89

910
module Radix =
1011
let config =
@@ -18,15 +19,12 @@ module Radix =
1819
let context = Context.defaultContext.ClContext
1920

2021
let checkResultByKeys (inputArray: (int * 'a) []) (actualValues: 'a []) =
21-
let expectedValues =
22-
Array.sortBy fst inputArray |> Array.map snd
22+
let expectedValues = Seq.sortBy fst inputArray |> Seq.map snd
2323

2424
"Values must be the same"
2525
|> Expect.sequenceEqual expectedValues actualValues
2626

2727
let makeTestByKeys<'a when 'a: equality> sortFun (array: (int * 'a) []) =
28-
// since Array.sort not stable
29-
let array = Array.distinctBy fst array
3028

3129
if array.Length > 0 then
3230
let keys = Array.map fst array
@@ -35,7 +33,7 @@ module Radix =
3533
let clKeys = keys.ToDevice context
3634
let clValues = values.ToDevice context
3735

38-
let clActualValues: ClArray<'a> = sortFun processor clKeys clValues
36+
let clActualValues: ClArray<'a> = sortFun processor HostInterop clKeys clValues
3937

4038
let actualValues = clActualValues.ToHostAndFree processor
4139

@@ -48,7 +46,7 @@ module Radix =
4846
makeTestByKeys<'a> sort
4947
|> testPropertyWithConfig config $"test on {typeof<'a>}"
5048

51-
let testFixturesByKeys =
49+
let testByKeys =
5250
[ createTestByKeys<int>
5351
createTestByKeys<uint>
5452

@@ -57,9 +55,7 @@ module Radix =
5755

5856
createTestByKeys<float32>
5957
createTestByKeys<bool> ]
60-
61-
let testsByKeys =
62-
testList "Radix sort by keys" testFixturesByKeys
58+
|> testList "Radix sort by keys"
6359

6460
let makeTestKeysOnly sort (keys: uint []) =
6561
if keys.Length > 0 then

tests/GraphBLAS-sharp.Tests/Helpers.fs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,18 @@ module HostPrimitives =
225225

226226
result
227227

228+
let array2DMultiplication mul add leftArray rightArray =
229+
if Array2D.length2 leftArray <> Array2D.length1 rightArray then
230+
failwith "Incompatible matrices"
231+
232+
Array2D.init
233+
<| Array2D.length1 leftArray
234+
<| Array2D.length2 rightArray
235+
<| fun i j ->
236+
(leftArray.[i, *], rightArray.[*, j])
237+
||> Array.map2 mul
238+
|> Array.reduce add
239+
228240
module Context =
229241
type TestContext =
230242
{ ClContext: ClContext

tests/GraphBLAS-sharp.Tests/Matrix/SpGeMM.fs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ open GraphBLAS.FSharp.Tests.Backend
1212
open GraphBLAS.FSharp.Objects
1313
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
1414
open Brahma.FSharp
15+
open GraphBLAS.FSharp.Backend.Objects.ClContext
1516

1617
let context = Context.defaultContext.ClContext
1718

@@ -165,3 +166,67 @@ let expandTests =
165166
createExpandTest (=) false (&&) <@ (&&) @> Expand.expand
166167
createExpandTest (=) 0uy (*) <@ (*) @> Expand.expand ]
167168
|> testList "Expand.expand"
169+
170+
let checkGeneralResult zero isEqual actualValues actualColumns actualRows mul add (leftArray: 'a [,]) (rightArray: 'a [,]) =
171+
172+
let expected =
173+
HostPrimitives.array2DMultiplication mul add leftArray rightArray
174+
|> fun array -> Utils.createMatrixFromArray2D COO array (isEqual zero)
175+
|> function Matrix.COO matrix -> matrix | _ -> failwith "format miss"
176+
177+
printfn $"leftMatrix \n %A{leftArray}"
178+
printfn $"rightMatrix \n %A{rightArray}"
179+
180+
printfn $"actual values: %A{actualValues}"
181+
printfn $"expected values: %A{expected.Values}"
182+
183+
printfn $"actualColumns: %A{actualColumns}"
184+
printfn $"expectedColumns: %A{expected.Columns}"
185+
186+
printfn $"actualRows: %A{actualRows}"
187+
printfn $"expectedRows: %A{expected.Rows}"
188+
189+
"Values must be the same"
190+
|> Utils.compareArrays isEqual actualValues expected.Values
191+
192+
"Columns must be the same"
193+
|> Utils.compareArrays (=) actualColumns expected.Columns
194+
195+
"Rows must be the same"
196+
|> Utils.compareArrays (=) actualRows expected.Rows
197+
198+
let makeGeneralTest zero isEqual opMul opAdd testFun (leftArray: 'a [,], rightArray: 'a [,]) =
199+
200+
let leftMatrix = createCSRMatrix leftArray <| isEqual zero
201+
202+
let rightMatrix = createCSRMatrix rightArray <| isEqual zero
203+
204+
if leftMatrix.NNZ > 0
205+
&& rightMatrix.NNZ > 0 then
206+
207+
let clLeftMatrix = leftMatrix.ToDevice context
208+
let clRightMatrix = rightMatrix.ToDevice context
209+
210+
let (clActualValues: ClArray<'a>), (clActualColumns: ClArray<int>), (clActualRows: ClArray<int>) =
211+
testFun processor HostInterop clLeftMatrix clRightMatrix
212+
213+
clLeftMatrix.Dispose processor
214+
clRightMatrix.Dispose processor
215+
216+
let actualValues = clActualValues.ToHostAndFree processor
217+
let actualColumns = clActualColumns.ToHostAndFree processor
218+
let actualRows = clActualRows.ToHostAndFree processor
219+
220+
checkGeneralResult zero isEqual actualValues actualColumns actualRows opMul opAdd leftArray rightArray
221+
222+
let createGeneralTest (zero: 'a) isEqual opAdd opAddQ opMul opMulQ testFun =
223+
224+
let testFun = testFun context Utils.defaultWorkGroupSize opAddQ opMulQ
225+
226+
makeGeneralTest zero isEqual opMul opAdd testFun
227+
|> testPropertyWithConfig { config with endSize = 10 } $"test on %A{typeof<'a>}"
228+
229+
let generalTests =
230+
[ createGeneralTest 0 (=) (+) <@ (+) @> (*) <@ (*) @> Expand.run ]
231+
|> testList "general"
232+

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ open GraphBLAS.FSharp.Tests.Matrix
9494
let allTests =
9595
testList
9696
"All tests"
97-
[ SpGeMM.expandTests ]
97+
[ SpGeMM.generalTests ]
9898

9999
|> testSequenced
100100

0 commit comments

Comments
 (0)