Skip to content

Commit 253e52b

Browse files
committed
add: ClArray.choose
1 parent 7d84c83 commit 253e52b

6 files changed

Lines changed: 98 additions & 2 deletions

File tree

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ open Brahma.FSharp
44
open Microsoft.FSharp.Quotations
55
open GraphBLAS.FSharp.Backend.Objects.ClContext
66
open GraphBLAS.FSharp.Backend.Objects.ClCell
7+
open GraphBLAS.FSharp.Backend.Quotes
78

89
module ClArray =
910
let init (clContext: ClContext) workGroupSize (initializer: Expr<int -> 'a>) =
@@ -365,3 +366,35 @@ module ClArray =
365366
map2 processor leftArray rightArray resultArray
366367

367368
resultArray
369+
370+
let choose<'a, 'b> (clContext: ClContext) workGroupSize (predicate: Expr<'a -> 'b option>) =
371+
let getBitmap = map<'a, int> clContext workGroupSize <| Map.chooseBitmap predicate
372+
373+
let getOptionValues = map<'a, 'b option> clContext workGroupSize predicate
374+
375+
let getValues = map<'b option, 'b> clContext workGroupSize <| Map.optionToValueOrZero Unchecked.defaultof<'b>
376+
377+
let prefixSum = prefixSumExcludeInplace <@ (+) @> clContext workGroupSize
378+
379+
let scatter = Scatter.runInplace clContext workGroupSize
380+
381+
fun (processor: MailboxProcessor<_>) allocationMode (array: ClArray<'a>) ->
382+
383+
let positions = getBitmap processor DeviceOnly array
384+
385+
let resultLength =
386+
(prefixSum processor positions 0)
387+
.ToHostAndFree(processor)
388+
389+
let optionValues = getOptionValues processor DeviceOnly array
390+
391+
let values = getValues processor DeviceOnly optionValues
392+
393+
let result =
394+
clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
395+
396+
scatter processor positions values result
397+
398+
result
399+
400+

src/GraphBLAS-sharp.Backend/Quotes/Map.fs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
namespace GraphBLAS.FSharp.Backend.Quotes
22

3+
open FSharp.Quotations
4+
35
module Map =
46
let optionToValueOrZero<'a> zero =
57
<@ fun (item: 'a option) ->
@@ -13,3 +15,10 @@ module Map =
1315
| None -> onNone @>
1416

1517
let id<'a> = <@ fun (item: 'a) -> item @>
18+
19+
let chooseBitmap<'a, 'b> (map: Expr<'a -> 'b option>) =
20+
<@ fun (item: 'a) ->
21+
match (%map) item with
22+
| Some _ -> 1
23+
| None -> 0 @>
24+
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
module GraphBLAS.FSharp.Tests.Backend.Common.Choose
2+
3+
open GraphBLAS.FSharp.Backend.Common
4+
open Expecto
5+
open GraphBLAS.FSharp.Tests
6+
open GraphBLAS.FSharp.Tests.Context
7+
open GraphBLAS.FSharp.Backend.Objects.ClContext
8+
open Brahma.FSharp
9+
open GraphBLAS.FSharp.Backend.Quotes
10+
11+
let workGroupSize = Utils.defaultWorkGroupSize
12+
13+
let config = Utils.defaultConfig
14+
15+
let makeTest<'a, 'b> testContext choose mapFun isEqual (array: 'a []) =
16+
if array.Length > 0 then
17+
let context = testContext.ClContext
18+
let q = testContext.Queue
19+
20+
let clArray = context.CreateClArrayWithSpecificAllocationMode(DeviceOnly, array)
21+
22+
let (clResult: ClArray<'b>) = choose q HostInterop clArray
23+
24+
let hostResult = Array.zeroCreate clResult.Length
25+
q.PostAndReply(fun ch -> Msg.CreateToHostMsg(clResult, hostResult, ch)) |> ignore
26+
27+
let expectedResult = Array.choose mapFun array
28+
29+
"Result should be the same"
30+
|> Utils.compareArrays isEqual hostResult expectedResult
31+
32+
let createTest<'a, 'b> testContext mapFun mapFunQ isEqual =
33+
let context = testContext.ClContext
34+
35+
let choose = ClArray.choose context workGroupSize mapFunQ
36+
37+
makeTest<'a, 'b> testContext choose mapFun isEqual
38+
|> testPropertyWithConfig config $"Correctness on %A{typeof<'a>} -> %A{typeof<'b>}"
39+
40+
let testFixtures testContext =
41+
let device = testContext.ClContext.ClDevice
42+
43+
[ createTest<int option, int> testContext id Map.id (=)
44+
createTest<byte option, byte> testContext id Map.id (=)
45+
createTest<bool option, bool> testContext id Map.id (=)
46+
47+
if Utils.isFloat64Available device then
48+
createTest<float option, float> testContext id Map.id Utils.floatIsEqual
49+
50+
createTest<float32 option, float32> testContext id Map.id Utils.float32IsEqual ]
51+
52+
let tests = TestCases.gpuTests "ClArray.choose id tests" testFixtures

tests/GraphBLAS-sharp.Tests/Common/Reduce.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ let testFixtures plus plusQ zero name =
5757
let reduce = Reduce.reduce context wgSize plusQ
5858

5959
makeTest reduce plus zero
60-
|> testPropertyWithConfig config (sprintf "Correctness on %s" name)
60+
|> testPropertyWithConfig config $"Correctness on %s{name}"
6161

6262
let tests =
6363
q.Error.Add(fun e -> failwithf "%A" e)

tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
<Compile Include="Common/Exists.fs" />
2828
<Compile Include="Common/Map.fs" />
2929
<Compile Include="Common/Map2.fs" />
30+
<Compile Include="Common\Choose.fs" />
3031
<!--Compile Include="MatrixOperationsTests/GetTuplesTests.fs" /-->
3132
<!--Compile Include="MatrixOperationsTests/MxvTests.fs" /-->
3233
<!--Compile Include="MatrixOperationsTests/VxmTests.fs" /-->

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ let commonTests =
2727
Common.Exists.tests
2828
Common.Map.tests
2929
Common.Map2.addTests
30-
Common.Map2.mulTests ]
30+
Common.Map2.mulTests
31+
Common.Choose.tests ]
3132
|> testSequenced
3233

3334
let vectorTests =

0 commit comments

Comments
 (0)