Skip to content

Commit 73d755f

Browse files
committed
refactor: deforestation in ClArray.choose
1 parent 972b392 commit 73d755f

3 files changed

Lines changed: 43 additions & 29 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,6 @@ module ClArray =
319319

320320
let result = map processor allocationMode firstBitmap secondBitmap
321321

322-
printfn $"first bitmap: %A{firstBitmap.ToHost processor}"
323-
printfn $"second bitmap: %A{secondBitmap.ToHost processor}"
324-
325322
firstBitmap.Free processor
326323
secondBitmap.Free processor
327324

@@ -333,42 +330,60 @@ module ClArray =
333330
let getUniqueBitmap2LastOccurrence clContext =
334331
getUniqueBitmap2General getUniqueBitmapLastOccurrence clContext
335332

333+
let private assignOption (clContext: ClContext) workGroupSize (op: Expr<'a -> 'b option>) =
334+
335+
let assign =
336+
<@ fun (ndRange: Range1D) length (values: ClArray<'a>) (positions: ClArray<int>) (result: ClArray<'b>) resultLength ->
337+
338+
let gid = ndRange.GlobalID0
339+
340+
if gid < length then
341+
let position = positions.[gid]
342+
let value = values.[gid]
343+
344+
// seems like scatter (option scatter) ???
345+
if 0 <= position && position < resultLength then
346+
match (%op) value with
347+
| Some value ->
348+
result.[position] <- value
349+
| None -> () @>
350+
351+
let kernel = clContext.Compile assign
352+
353+
fun (processor: MailboxProcessor<_>) (values: ClArray<'a>) (positions: ClArray<int>) (result: ClArray<'b>) ->
354+
355+
let ndRange =
356+
Range1D.CreateValid(values.Length, workGroupSize)
357+
358+
let kernel = kernel.GetKernel()
359+
360+
processor.Post(
361+
Msg.MsgSetArguments
362+
(fun () -> kernel.KernelFunc ndRange values.Length values positions result result.Length)
363+
)
364+
365+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
366+
336367
let choose<'a, 'b> (clContext: ClContext) workGroupSize (predicate: Expr<'a -> 'b option>) =
337368
let getBitmap =
338369
map<'a, int> clContext workGroupSize
339370
<| Map.chooseBitmap predicate
340371

341-
let getOptionValues =
342-
map<'a, 'b option> clContext workGroupSize predicate
343-
344-
let getValues =
345-
map<'b option, 'b> clContext workGroupSize
346-
<| Map.optionToValueOrZero Unchecked.defaultof<'b>
372+
let prefixSum = PrefixSum.standardExcludeInplace clContext workGroupSize
347373

348-
let prefixSum =
349-
PrefixSum.runExcludeInplace <@ (+) @> clContext workGroupSize
350-
351-
let scatter =
352-
Scatter.lastOccurrence clContext workGroupSize
374+
let assignValues = assignOption clContext workGroupSize predicate
353375

354-
fun (processor: MailboxProcessor<_>) allocationMode (array: ClArray<'a>) ->
376+
fun (processor: MailboxProcessor<_>) allocationMode (sourceValues: ClArray<'a>) ->
355377

356-
let positions = getBitmap processor DeviceOnly array
378+
let positions = getBitmap processor DeviceOnly sourceValues
357379

358380
let resultLength =
359-
(prefixSum processor positions 0)
381+
(prefixSum processor positions)
360382
.ToHostAndFree(processor)
361383

362-
let optionValues =
363-
getOptionValues processor DeviceOnly array
364-
365-
let values =
366-
getValues processor DeviceOnly optionValues
367-
368-
let result =
369-
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
384+
let result = clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
370385

371-
scatter processor positions values result
386+
assignValues processor sourceValues positions result
372387

373388
result
374389

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM/Expand.fs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,6 @@ module Expand =
6767

6868
let expand (clContext: ClContext) workGroupSize opMul =
6969

70-
let init = ClArray.init clContext workGroupSize Map.id
71-
7270
let idScatter = Scatter.initLastOccurrence Map.id clContext workGroupSize
7371

7472
let scatter = Scatter.lastOccurrence clContext workGroupSize

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,10 @@ open GraphBLAS.FSharp.Tests.Matrix
9494
let allTests =
9595
testList
9696
"All tests"
97-
[ SpGeMM.generalTests
97+
[ // SpGeMM.generalTests
9898
// Common.Gather.initTests
9999
//Common.Scatter.allTests ]
100+
Common.ClArray.Choose.tests
100101
]
101102

102103
|> testSequenced

0 commit comments

Comments
 (0)