Skip to content

Commit f34e590

Browse files
committed
add: ClArray.choose2
1 parent 73d755f commit f34e590

13 files changed

Lines changed: 182 additions & 49 deletions

File tree

benchmarks/GraphBLAS-sharp.Benchmarks/BenchmarksBFS.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ type BFSBenchmarks<'matrixT, 'elem when 'matrixT :> IDeviceMemObject and 'elem :
103103
type BFSBenchmarksWithoutDataTransfer() =
104104

105105
inherit BFSBenchmarks<ClMatrix.CSR<int>, int>(
106-
(fun context wgSize -> BFS.singleSource context ArithmeticOperations.intSum ArithmeticOperations.intMul wgSize),
106+
(fun context wgSize -> BFS.singleSource context ArithmeticOperations.intSumOption ArithmeticOperations.intMulOption wgSize),
107107
int,
108108
(fun _ -> Utils.nextInt (System.Random())),
109109
Matrix.ToBackendCSR)

benchmarks/GraphBLAS-sharp.Benchmarks/BenchmarksEWiseAdd.fs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ module M =
195195
type EWiseAddBenchmarks4Float32COOWithoutDataTransfer() =
196196

197197
inherit EWiseAddBenchmarksWithoutDataTransfer<ClMatrix.COO<float32>,float32>(
198-
(fun context wgSize -> COO.Matrix.map2 context ArithmeticOperations.float32Sum wgSize),
198+
(fun context wgSize -> COO.Matrix.map2 context ArithmeticOperations.float32SumOption wgSize),
199199
float32,
200200
(fun _ -> Utils.nextSingle (System.Random())),
201201
Matrix.ToBackendCOO
@@ -207,7 +207,7 @@ type EWiseAddBenchmarks4Float32COOWithoutDataTransfer() =
207207
type EWiseAddBenchmarks4Float32COOWithDataTransfer() =
208208

209209
inherit EWiseAddBenchmarksWithDataTransfer<ClMatrix.COO<float32>,float32>(
210-
(fun context wgSize -> COO.Matrix.map2 context ArithmeticOperations.float32Sum wgSize),
210+
(fun context wgSize -> COO.Matrix.map2 context ArithmeticOperations.float32SumOption wgSize),
211211
float32,
212212
(fun _ -> Utils.nextSingle (System.Random())),
213213
Matrix.ToBackendCOO<float32>,
@@ -234,7 +234,7 @@ type EWiseAddBenchmarks4BoolCOOWithoutDataTransfer() =
234234
type EWiseAddBenchmarks4Float32CSRWithoutDataTransfer() =
235235

236236
inherit EWiseAddBenchmarksWithoutDataTransfer<ClMatrix.CSR<float32>,float32>(
237-
(fun context wgSize -> CSR.Matrix.map2 context ArithmeticOperations.float32Sum wgSize),
237+
(fun context wgSize -> CSR.Matrix.map2 context ArithmeticOperations.float32SumOption wgSize),
238238
float32,
239239
(fun _ -> Utils.nextSingle (System.Random())),
240240
Matrix.ToBackendCSR

benchmarks/GraphBLAS-sharp.Benchmarks/VectorEWiseAddGen.fs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,53 +159,53 @@ type VectorEWiseBenchmarksWithDataTransfer<'elem when 'elem : struct>(
159159
type VectorEWiseBenchmarks4FloatSparseWithoutDataTransfer() =
160160

161161
inherit VectorEWiseBenchmarksWithoutDataTransfer<float>(
162-
(fun context -> Vector.map2 context ArithmeticOperations.floatSum),
162+
(fun context -> Vector.map2 context ArithmeticOperations.floatSumOption),
163163
VectorGenerator.floatPair Sparse)
164164

165165
type VectorEWiseBenchmarks4Int32SparseWithoutDataTransfer() =
166166

167167
inherit VectorEWiseBenchmarksWithoutDataTransfer<int32>(
168-
(fun context -> Vector.map2 context ArithmeticOperations.intSum),
168+
(fun context -> Vector.map2 context ArithmeticOperations.intSumOption),
169169
VectorGenerator.intPair Sparse)
170170

171171
/// General
172172
173173
type VectorEWiseGeneralBenchmarks4FloatSparseWithoutDataTransfer() =
174174

175175
inherit VectorEWiseBenchmarksWithoutDataTransfer<float>(
176-
(fun context -> Vector.map2 context ArithmeticOperations.floatSum),
176+
(fun context -> Vector.map2 context ArithmeticOperations.floatSumOption),
177177
VectorGenerator.floatPair Sparse)
178178

179179
type VectorEWiseGeneralBenchmarks4Int32SparseWithoutDataTransfer() =
180180

181181
inherit VectorEWiseBenchmarksWithoutDataTransfer<int32>(
182-
(fun context -> Vector.map2 context ArithmeticOperations.intSum),
182+
(fun context -> Vector.map2 context ArithmeticOperations.intSumOption),
183183
VectorGenerator.intPair Sparse)
184184

185185
/// With data transfer
186186
187187
type VectorEWiseBenchmarks4FloatSparseWithDataTransfer() =
188188

189189
inherit VectorEWiseBenchmarksWithDataTransfer<float>(
190-
(fun context -> Vector.map2 context ArithmeticOperations.floatSum),
190+
(fun context -> Vector.map2 context ArithmeticOperations.floatSumOption),
191191
VectorGenerator.floatPair Sparse)
192192

193193
type VectorEWiseBenchmarks4Int32SparseWithDataTransfer() =
194194

195195
inherit VectorEWiseBenchmarksWithDataTransfer<int32>(
196-
(fun context -> Vector.map2 context ArithmeticOperations.intSum),
196+
(fun context -> Vector.map2 context ArithmeticOperations.intSumOption),
197197
VectorGenerator.intPair Sparse)
198198

199199
/// General with data transfer
200200
201201
type VectorEWiseGeneralBenchmarks4FloatSparseWithDataTransfer() =
202202

203203
inherit VectorEWiseBenchmarksWithDataTransfer<float>(
204-
(fun context -> Vector.map2 context ArithmeticOperations.floatSum),
204+
(fun context -> Vector.map2 context ArithmeticOperations.floatSumOption),
205205
VectorGenerator.floatPair Sparse)
206206

207207
type VectorEWiseGeneralBenchmarks4Int32SparseWithDataTransfer() =
208208

209209
inherit VectorEWiseBenchmarksWithDataTransfer<int32>(
210-
(fun context -> Vector.map2 context ArithmeticOperations.intSum),
210+
(fun context -> Vector.map2 context ArithmeticOperations.intSumOption),
211211
VectorGenerator.intPair Sparse)

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,8 @@ module ClArray =
352352

353353
fun (processor: MailboxProcessor<_>) (values: ClArray<'a>) (positions: ClArray<int>) (result: ClArray<'b>) ->
354354

355+
if values.Length <> positions.Length then failwith "lengths must be the same"
356+
355357
let ndRange =
356358
Range1D.CreateValid(values.Length, workGroupSize)
357359

@@ -387,3 +389,65 @@ module ClArray =
387389

388390
result
389391

392+
let private assignOption2 (clContext: ClContext) workGroupSize (op: Expr<'a -> 'b -> 'c option>) =
393+
394+
let assign =
395+
<@ fun (ndRange: Range1D) length (firstValues: ClArray<'a>) (secondValues: ClArray<'b>) (positions: ClArray<int>) (result: ClArray<'c>) resultLength ->
396+
397+
let gid = ndRange.GlobalID0
398+
399+
if gid < length then
400+
let position = positions.[gid]
401+
402+
let leftValue = firstValues.[gid]
403+
let rightValue = secondValues.[gid]
404+
405+
// seems like scatter2 (option scatter2) ???
406+
if 0 <= position && position < resultLength then
407+
match (%op) leftValue rightValue with
408+
| Some value ->
409+
result.[position] <- value
410+
| None -> () @>
411+
412+
let kernel = clContext.Compile assign
413+
414+
fun (processor: MailboxProcessor<_>) (firstValues: ClArray<'a>) (secondValues: ClArray<'b>) (positions: ClArray<int>) (result: ClArray<'c>) ->
415+
416+
if firstValues.Length <> secondValues.Length
417+
|| secondValues.Length <> positions.Length then
418+
failwith "lengths must be the same"
419+
420+
let ndRange =
421+
Range1D.CreateValid(firstValues.Length, workGroupSize)
422+
423+
let kernel = kernel.GetKernel()
424+
425+
processor.Post(
426+
Msg.MsgSetArguments
427+
(fun () -> kernel.KernelFunc ndRange firstValues.Length firstValues secondValues positions result result.Length)
428+
)
429+
430+
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
431+
432+
let choose2 (clContext: ClContext) workGroupSize (predicate: Expr<'a -> 'b -> 'c option>) =
433+
let getBitmap =
434+
map2<'a, 'b, int> clContext workGroupSize
435+
<| Map.chooseBitmap2 predicate
436+
437+
let prefixSum = PrefixSum.standardExcludeInplace clContext workGroupSize
438+
439+
let assignValues = assignOption2 clContext workGroupSize predicate
440+
441+
fun (processor: MailboxProcessor<_>) allocationMode (firstValues: ClArray<'a>) (secondValues: ClArray<'b>) ->
442+
443+
let positions = getBitmap processor DeviceOnly firstValues secondValues
444+
445+
let resultLength =
446+
(prefixSum processor positions)
447+
.ToHostAndFree(processor)
448+
449+
let result = clContext.CreateClArrayWithSpecificAllocationMode(allocationMode, resultLength)
450+
451+
assignValues processor firstValues secondValues positions result
452+
453+
result

src/GraphBLAS-sharp.Backend/Quotes/Arithmetic.fs

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ module ArithmeticOperations =
7272
let inline addRightConst zero constant =
7373
mkUnaryOp zero <@ fun x -> x + constant @>
7474

75-
let intSum = mkNumericSum 0
76-
let byteSum = mkNumericSum 0uy
77-
let floatSum = mkNumericSum 0.0
78-
let float32Sum = mkNumericSum 0f
75+
let intSumOption = mkNumericSum 0
76+
let byteSumOption = mkNumericSum 0uy
77+
let floatSumOption = mkNumericSum 0.0
78+
let float32SumOption = mkNumericSum 0f
7979

8080
let boolSumAtLeastOne =
8181
<@ fun (_: AtLeastOne<bool, bool>) -> Some true @>
@@ -85,7 +85,7 @@ module ArithmeticOperations =
8585
let floatSumAtLeastOne = mkNumericSumAtLeastOne 0.0
8686
let float32SumAtLeastOne = mkNumericSumAtLeastOne 0f
8787

88-
let boolMul =
88+
let boolMulOption =
8989
<@ fun (x: bool option) (y: bool option) ->
9090
let mutable res = false
9191

@@ -101,10 +101,10 @@ module ArithmeticOperations =
101101
let inline mulRightConst zero constant =
102102
mkUnaryOp zero <@ fun x -> x * constant @>
103103

104-
let intMul = mkNumericMul 0
105-
let byteMul = mkNumericMul 0uy
106-
let floatMul = mkNumericMul 0.0
107-
let float32Mul = mkNumericMul 0f
104+
let intMulOption = mkNumericMul 0
105+
let byteMulOption = mkNumericMul 0uy
106+
let floatMulOption = mkNumericMul 0.0
107+
let float32MulOption = mkNumericMul 0f
108108

109109
let boolMulAtLeastOne =
110110
<@ fun (values: AtLeastOne<bool, bool>) ->
@@ -121,8 +121,30 @@ module ArithmeticOperations =
121121
let floatMulAtLeastOne = mkNumericMulAtLeastOne 0.0
122122
let float32MulAtLeastOne = mkNumericMulAtLeastOne 0f
123123

124-
let notQ =
124+
let notOption =
125125
<@ fun x ->
126126
match x with
127127
| Some true -> None
128128
| _ -> Some true @>
129+
130+
let inline private binOpQ zero op =
131+
<@ fun (left: 'a) (right: 'a) ->
132+
let result = (%op) left right
133+
134+
if result = zero then None else Some result @>
135+
136+
let inline private binOp zero op =
137+
fun left right ->
138+
let result = op left right
139+
140+
if result = zero then None else Some result
141+
142+
let inline createPair zero op opQ = binOpQ zero opQ, binOp zero op
143+
144+
let intAdd = createPair 0 (+) <@ (+) @>
145+
146+
let boolAdd = createPair false (||) <@ (||) @>
147+
148+
let floatAdd = createPair 0.0 (+) <@ (+) @>
149+
150+
let float32Add = createPair 0.0f (+) <@ (+) @>

src/GraphBLAS-sharp.Backend/Quotes/Map.fs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ module Map =
2222
| Some _ -> 1
2323
| None -> 0 @>
2424

25+
let chooseBitmap2<'a, 'b, 'c> (map: Expr<'a -> 'b -> 'c option>) =
26+
<@ fun (leftItem: 'a) (rightItem: 'b) ->
27+
match (%map) leftItem rightItem with
28+
| Some _ -> 1
29+
| None -> 0 @>
30+
2531
let inc = <@ fun item -> item + 1 @>
2632

2733
let subtraction = <@ fun first second -> first - second @>

tests/GraphBLAS-sharp.Tests/Algorithms/BFS.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ let testFixtures (testContext: TestContext) =
2222
sprintf "Test on %A" testContext.ClContext
2323

2424
let bfs =
25-
Algorithms.BFS.singleSource context ArithmeticOperations.intSum ArithmeticOperations.intMul workGroupSize
25+
Algorithms.BFS.singleSource context ArithmeticOperations.intSumOption ArithmeticOperations.intMulOption workGroupSize
2626

2727
testPropertyWithConfig config testName
2828
<| fun (matrix: int [,]) ->

tests/GraphBLAS-sharp.Tests/Common/ClArray/Choose.fs

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,16 @@ open GraphBLAS.FSharp.Tests.Context
77
open GraphBLAS.FSharp.Backend.Objects.ClContext
88
open Brahma.FSharp
99
open GraphBLAS.FSharp.Backend.Quotes
10+
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
1011

1112
let workGroupSize = Utils.defaultWorkGroupSize
1213

1314
let config = Utils.defaultConfig
1415

16+
let context = Context.defaultContext.ClContext
17+
18+
let processor = defaultContext.Queue
19+
1520
let makeTest<'a, 'b> testContext choose mapFun isEqual (array: 'a []) =
1621
if array.Length > 0 then
1722
let context = testContext.ClContext
@@ -39,7 +44,7 @@ let createTest<'a, 'b> testContext mapFun mapFunQ isEqual =
3944
ClArray.choose context workGroupSize mapFunQ
4045

4146
makeTest<'a, 'b> testContext choose mapFun isEqual
42-
|> testPropertyWithConfig config $"Correctness on %A{typeof<'a>} -> %A{typeof<'b>}"
47+
|> testPropertyWithConfig config $"test on %A{typeof<'a>} -> %A{typeof<'b>}"
4348

4449
let testFixtures testContext =
4550
let device = testContext.ClContext.ClDevice
@@ -54,4 +59,42 @@ let testFixtures testContext =
5459
createTest<float32 option, float32> testContext id Map.id Utils.float32IsEqual ]
5560

5661
let tests =
57-
TestCases.gpuTests "ClArray.choose id tests" testFixtures
62+
TestCases.gpuTests "choose id" testFixtures
63+
64+
let makeTest2 isEqual opMap testFun (firstArray: 'a [], secondArray: 'a []) =
65+
if firstArray.Length > 0
66+
&& secondArray.Length > 0 then
67+
68+
let expected =
69+
Array.map2 opMap firstArray secondArray
70+
|> Array.choose id
71+
72+
let clFirstArray = context.CreateClArray firstArray
73+
let clSecondArray = context.CreateClArray secondArray
74+
75+
let (clActual: ClArray<_>) = testFun processor HostInterop clFirstArray clSecondArray
76+
77+
let actual = clActual.ToHostAndFree processor
78+
clFirstArray.Free processor
79+
clSecondArray.Free processor
80+
81+
"Results must be the same"
82+
|> Utils.compareArrays isEqual actual expected
83+
84+
let createTest2 (isEqual: 'a -> 'a -> bool) (opMapQ, opMap) testFun =
85+
let testFun = testFun context Utils.defaultWorkGroupSize opMapQ
86+
87+
makeTest2 isEqual opMap testFun
88+
|> testPropertyWithConfig { config with maxTest = 1000 } $"test on %A{typeof<'a>}"
89+
90+
let tests2 =
91+
[ createTest2 (=) ArithmeticOperations.intAdd ClArray.choose2
92+
93+
if Utils.isFloat64Available context.ClDevice then
94+
createTest2 (=) ArithmeticOperations.floatAdd ClArray.choose2
95+
96+
createTest2 (=) ArithmeticOperations.float32Add ClArray.choose2
97+
createTest2 (=) ArithmeticOperations.boolAdd ClArray.choose2 ]
98+
|> testList "choose2 add"
99+
100+
let allTests = testList "Choose" [ tests; tests2 ]

tests/GraphBLAS-sharp.Tests/Matrix/Map.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ let testFixturesMapNot case =
108108
[ let q = case.TestContext.Queue
109109
q.Error.Add(fun e -> failwithf "%A" e)
110110

111-
createTestMap case false true (fun _ -> not) (=) (fun _ _ -> ArithmeticOperations.notQ) ]
111+
createTestMap case false true (fun _ -> not) (=) (fun _ _ -> ArithmeticOperations.notOption) ]
112112

113113
let notTests =
114114
operationGPUTests "Backend.Matrix.map not tests" testFixturesMapNot

tests/GraphBLAS-sharp.Tests/Matrix/Map2.fs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,13 @@ let testFixturesMap2Add case =
113113
q.Error.Add(fun e -> failwithf "%A" e)
114114

115115
creatTestMap2Add case false (||) (=) ArithmeticOperations.boolSum Matrix.map2
116-
creatTestMap2Add case 0 (+) (=) ArithmeticOperations.intSum Matrix.map2
116+
creatTestMap2Add case 0 (+) (=) ArithmeticOperations.intSumOption Matrix.map2
117117

118118
if Utils.isFloat64Available context.ClDevice then
119-
creatTestMap2Add case 0.0 (+) Utils.floatIsEqual ArithmeticOperations.floatSum Matrix.map2
119+
creatTestMap2Add case 0.0 (+) Utils.floatIsEqual ArithmeticOperations.floatSumOption Matrix.map2
120120

121-
creatTestMap2Add case 0.0f (+) Utils.float32IsEqual ArithmeticOperations.float32Sum Matrix.map2
122-
creatTestMap2Add case 0uy (+) (=) ArithmeticOperations.byteSum Matrix.map2 ]
121+
creatTestMap2Add case 0.0f (+) Utils.float32IsEqual ArithmeticOperations.float32SumOption Matrix.map2
122+
creatTestMap2Add case 0uy (+) (=) ArithmeticOperations.byteSumOption Matrix.map2 ]
123123

124124
let addTests =
125125
operationGPUTests "Backend.Matrix.map2 add tests" testFixturesMap2Add

0 commit comments

Comments
 (0)