Skip to content

Commit 415172d

Browse files
committed
add: ClArray.map* test
1 parent 1f1aba6 commit 415172d

12 files changed

Lines changed: 195 additions & 40 deletions

File tree

src/GraphBLAS-sharp.Backend/Algorithms/BFS.fs

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ open GraphBLAS.FSharp.Backend.Vector
1010
open GraphBLAS.FSharp.Backend.Vector.Dense
1111
open GraphBLAS.FSharp.Backend.Objects.ClContext
1212
open GraphBLAS.FSharp.Backend.Objects.ArraysExtensions
13+
open GraphBLAS.FSharp.Backend.Objects.ClCell
1314

1415
module BFS =
1516
let singleSource
@@ -62,13 +63,9 @@ module BFS =
6263
maskComplementedTo queue front levels front
6364

6465
//Checking if front is empty
65-
let frontNotEmpty = Array.zeroCreate 1
66-
let sum = containsNonZero queue front
67-
68-
queue.PostAndReply(fun ch -> Msg.CreateToHostMsg(sum, frontNotEmpty, ch))
69-
|> ignore
70-
71-
stop <- not frontNotEmpty.[0]
66+
stop <-
67+
not
68+
<| (containsNonZero queue front).ToHostAndFree queue
7269

7370
front.Dispose queue
7471

src/GraphBLAS-sharp.Backend/Common/ClArray.fs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ module ClArray =
177177

178178
let outputArray = copy processor allocationMode inputArray
179179

180-
let totalSum = runExcludeInplace processor outputArray zero
180+
let totalSum =
181+
runExcludeInplace processor outputArray zero
181182

182183
outputArray, totalSum
183184

@@ -192,7 +193,8 @@ module ClArray =
192193

193194
let outputArray = copy processor allocationMode inputArray
194195

195-
let totalSum = runIncludeInplace processor outputArray zero
196+
let totalSum =
197+
runIncludeInplace processor outputArray zero
196198

197199
outputArray, totalSum
198200

@@ -339,16 +341,21 @@ module ClArray =
339341

340342
fun (processor: MailboxProcessor<_>) (leftArray: ClArray<'a>) (rightArray: ClArray<'b>) (resultArray: ClArray<'c>) ->
341343

342-
let ndRange = Range1D.CreateValid(resultArray.Length, workGroupSize)
344+
let ndRange =
345+
Range1D.CreateValid(resultArray.Length, workGroupSize)
343346

344347
let kernel = kernel.GetKernel()
345348

346-
processor.Post(Msg.MsgSetArguments(fun () -> kernel.KernelFunc ndRange resultArray.Length leftArray rightArray resultArray))
349+
processor.Post(
350+
Msg.MsgSetArguments
351+
(fun () -> kernel.KernelFunc ndRange resultArray.Length leftArray rightArray resultArray)
352+
)
347353

348354
processor.Post(Msg.CreateRunMsg<_, _>(kernel))
349355

350356
let map2<'a, 'b, 'c> (clContext: ClContext) workGroupSize map =
351-
let map2 = map2Inplace<'a, 'b, 'c> clContext workGroupSize map
357+
let map2 =
358+
map2Inplace<'a, 'b, 'c> clContext workGroupSize map
352359

353360
fun (processor: MailboxProcessor<_>) allocationMode (leftArray: ClArray<'a>) (rightArray: ClArray<'b>) ->
354361

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/SpGEMM.fs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,7 @@ module internal SpGEMM =
164164
calculate queue matrixLeft matrixRight mask
165165

166166
let resultNNZ =
167-
(scanInplace queue positions)
168-
.ToHostAndFree(queue)
167+
(scanInplace queue positions).ToHostAndFree(queue)
169168

170169
let resultRows = context.CreateClArray<int> resultNNZ
171170
let resultCols = context.CreateClArray<int> resultNNZ

src/GraphBLAS-sharp.Backend/Matrix/Common.fs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ module Common =
2323
fun (processor: MailboxProcessor<_>) allocationMode (allRows: ClArray<int>) (allColumns: ClArray<int>) (allValues: ClArray<'a>) (positions: ClArray<int>) ->
2424

2525
let resultLength =
26-
(sum processor positions)
27-
.ToHostAndFree(processor)
26+
(sum processor positions).ToHostAndFree(processor)
2827

2928
let resultRows =
3029
clContext.CreateClArrayWithSpecificAllocationMode<int>(allocationMode, resultLength)
Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
module GraphBLAS.FSharp.Backend.Quotes
1+
namespace GraphBLAS.FSharp.Backend.Quotes
22

33
module Map =
4-
let optionToValueOrZero<'a> =
4+
let optionToValueOrZero<'a> zero =
55
<@ fun (item: 'a option) ->
66
match item with
77
| Some value -> value
8-
| None -> Unchecked.defaultof<'a> @>
8+
| None -> zero @>
99

1010
let option onSome onNone =
1111
<@ function
12-
| Some _ -> onSome
13-
| None -> onNone @>
12+
| Some _ -> onSome
13+
| None -> onNone @>
14+
15+
let id<'a> = <@ fun (item: 'a) -> item @>

src/GraphBLAS-sharp.Backend/Vector/DenseVector/DenseVector.fs

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ module DenseVector =
1616
workGroupSize
1717
=
1818

19-
let map2InPlace = ClArray.map2Inplace clContext workGroupSize opAdd
19+
let map2InPlace =
20+
ClArray.map2Inplace clContext workGroupSize opAdd
2021

2122
fun (processor: MailboxProcessor<_>) (leftVector: ClArray<'a option>) (rightVector: ClArray<'b option>) (resultVector: ClArray<'c option>) ->
2223

@@ -89,18 +90,25 @@ module DenseVector =
8990

9091
let toSparse<'a when 'a: struct> (clContext: ClContext) workGroupSize =
9192

92-
let scatterValues = Scatter.runInplace clContext workGroupSize
93+
let scatterValues =
94+
Scatter.runInplace clContext workGroupSize
9395

94-
let scatterIndices = Scatter.runInplace clContext workGroupSize
96+
let scatterIndices =
97+
Scatter.runInplace clContext workGroupSize
9598

96-
let getBitmap = ClArray.map clContext workGroupSize <| Map.option 1 0
99+
let getBitmap =
100+
ClArray.map clContext workGroupSize
101+
<| Map.option 1 0
97102

98103
let prefixSum =
99104
PrefixSum.standardExcludeInplace clContext workGroupSize
100105

101-
let allIndices = ClArray.init clContext workGroupSize <@ id @>
106+
let allIndices =
107+
ClArray.init clContext workGroupSize Map.id
102108

103-
let allValues = ClArray.map clContext workGroupSize Map.optionToValueOrZero
109+
let allValues =
110+
ClArray.map clContext workGroupSize
111+
<| Map.optionToValueOrZero Unchecked.defaultof<'a>
104112

105113
fun (processor: MailboxProcessor<_>) allocationMode (vector: ClArray<'a option>) ->
106114

@@ -114,7 +122,8 @@ module DenseVector =
114122
let resultIndices =
115123
clContext.CreateClArrayWithSpecificAllocationMode<int>(allocationMode, resultLength)
116124

117-
let allIndices = allIndices processor DeviceOnly vector.Length
125+
let allIndices =
126+
allIndices processor DeviceOnly vector.Length
118127

119128
scatterIndices processor positions allIndices resultIndices
120129

@@ -139,17 +148,17 @@ module DenseVector =
139148

140149
let reduce<'a when 'a: struct> (clContext: ClContext) workGroupSize (opAdd: Expr<'a -> 'a -> 'a>) =
141150

142-
let getValuesAndIndices =
143-
ClArray.map clContext workGroupSize Map.optionToValueOrZero
151+
let map =
152+
ClArray.map clContext workGroupSize
153+
<| Map.optionToValueOrZero Unchecked.defaultof<'a>
144154

145155
let reduce =
146156
Reduce.reduce clContext workGroupSize opAdd
147157

148158
fun (processor: MailboxProcessor<_>) (vector: ClArray<'a option>) ->
149159

150160
try
151-
let values =
152-
getValuesAndIndices processor DeviceOnly vector
161+
let values = map processor DeviceOnly vector
153162

154163
let result = reduce processor values
155164

src/GraphBLAS-sharp.Backend/Vector/SparseVector/SparseVector.fs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ module SparseVector =
2727
fun (processor: MailboxProcessor<_>) allocationMode (allValues: ClArray<'a>) (allIndices: ClArray<int>) (positions: ClArray<int>) ->
2828

2929
let resultLength =
30-
(sum processor positions)
31-
.ToHostAndFree(processor)
30+
(sum processor positions).ToHostAndFree(processor)
3231

3332
let resultValues =
3433
clContext.CreateClArrayWithSpecificAllocationMode<'a>(allocationMode, resultLength)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
module GraphBLAS.FSharp.Tests.Backend.Common.Map
2+
3+
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Tests
5+
open GraphBLAS.FSharp.Tests.Context
6+
open GraphBLAS.FSharp.Backend.Common
7+
open GraphBLAS.FSharp.Backend.Quotes
8+
open Expecto
9+
open GraphBLAS.FSharp.Backend.Objects.ClContext
10+
11+
let context = defaultContext.Queue
12+
13+
let wgSize = Utils.defaultWorkGroupSize
14+
15+
let config = Utils.defaultConfig
16+
17+
let mapOptionToValue zero =
18+
function
19+
| Some value -> value
20+
| None -> zero
21+
22+
let makeTest (testContext: TestContext) mapFun zero isEqual (array: 'a option []) =
23+
if array.Length > 0 then
24+
let context = testContext.ClContext
25+
let q = testContext.Queue
26+
27+
let clArray =
28+
context.CreateClArrayWithSpecificAllocationMode(DeviceOnly, array)
29+
30+
let (actualDevice: ClArray<_>) = mapFun q HostInterop clArray
31+
32+
let actualHost = Array.zeroCreate actualDevice.Length
33+
34+
q.PostAndReply(fun ch -> Msg.CreateToHostMsg(actualDevice, actualHost, ch))
35+
|> ignore
36+
37+
let expected = Array.map (mapOptionToValue zero) array
38+
39+
"Arrays must be the same"
40+
|> Utils.compareArrays isEqual actualHost expected
41+
42+
let createTest<'a when 'a: equality> (testContext: TestContext) (zero: 'a) isEqual =
43+
44+
let context = testContext.ClContext
45+
46+
let map =
47+
ClArray.map context wgSize
48+
<| Map.optionToValueOrZero zero
49+
50+
makeTest testContext map zero isEqual
51+
|> testPropertyWithConfig config $"Correctness on {typeof<'a>}"
52+
53+
let testFixtures (testContext: TestContext) =
54+
[ createTest<int> testContext 0 (=)
55+
createTest<bool> testContext false (=)
56+
57+
if Utils.isFloat64Available testContext.ClContext.ClDevice then
58+
createTest<float> testContext 0.0 Utils.floatIsEqual
59+
60+
createTest<float32> testContext 0.0f Utils.float32IsEqual
61+
createTest<byte> testContext 0uy (=) ]
62+
63+
let tests =
64+
TestCases.gpuTests "ClArray.map tests" testFixtures
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
module GraphBLAS.FSharp.Tests.Backend.Common.Map2
2+
3+
open Brahma.FSharp
4+
open GraphBLAS.FSharp.Tests
5+
open GraphBLAS.FSharp.Tests.Context
6+
open GraphBLAS.FSharp.Backend.Common
7+
open GraphBLAS.FSharp.Backend.Quotes
8+
open Expecto
9+
open GraphBLAS.FSharp.Backend.Objects.ClContext
10+
11+
let context = defaultContext.Queue
12+
13+
let wgSize = Utils.defaultWorkGroupSize
14+
15+
let config = Utils.defaultConfig
16+
17+
let makeTest<'a when 'a: equality> testContext clMapFun hostMapFun isEqual (leftArray: 'a [], rightArray: 'a []) =
18+
if leftArray.Length > 0 then
19+
let context = testContext.ClContext
20+
let q = testContext.Queue
21+
22+
let leftClArray =
23+
context.CreateClArrayWithSpecificAllocationMode(DeviceOnly, leftArray)
24+
25+
let rightClArray =
26+
context.CreateClArrayWithSpecificAllocationMode(DeviceOnly, rightArray)
27+
28+
let (actualDevice: ClArray<'a>) =
29+
clMapFun q HostInterop leftClArray rightClArray
30+
31+
let actualHost = Array.zeroCreate actualDevice.Length
32+
33+
q.PostAndReply(fun ch -> Msg.CreateToHostMsg(actualDevice, actualHost, ch))
34+
|> ignore
35+
36+
let expected =
37+
Array.map2 hostMapFun leftArray rightArray
38+
39+
"Arrays must be the same"
40+
|> Utils.compareArrays isEqual actualHost expected
41+
42+
let createTest<'a when 'a: equality> (testContext: TestContext) isEqual hostMapFun mapFunQ =
43+
44+
let context = testContext.ClContext
45+
46+
let map = ClArray.map2 context wgSize mapFunQ
47+
48+
makeTest<'a> testContext map hostMapFun isEqual
49+
|> testPropertyWithConfig config $"Correctness on {typeof<'a>}"
50+
51+
let testFixturesAdd (testContext: TestContext) =
52+
[ createTest<int> testContext (=) (+) <@ (+) @>
53+
createTest<bool> testContext (=) (||) <@ (||) @>
54+
55+
if Utils.isFloat64Available testContext.ClContext.ClDevice then
56+
createTest<float> testContext Utils.floatIsEqual (+) <@ (+) @>
57+
58+
createTest<float32> testContext Utils.float32IsEqual (+) <@ (+) @>
59+
createTest<byte> testContext (=) (+) <@ (+) @> ]
60+
61+
let addTests =
62+
TestCases.gpuTests "ClArray.map2 add tests" testFixturesAdd
63+
64+
let testFixturesMul (testContext: TestContext) =
65+
[ createTest<int> testContext (=) (*) <@ (*) @>
66+
createTest<bool> testContext (=) (&&) <@ (&&) @>
67+
68+
if Utils.isFloat64Available testContext.ClContext.ClDevice then
69+
createTest<float> testContext Utils.floatIsEqual (*) <@ (*) @>
70+
71+
createTest<float32> testContext Utils.float32IsEqual (*) <@ (*) @>
72+
createTest<byte> testContext (=) (+) <@ (+) @> ]
73+
74+
let mulTests =
75+
TestCases.gpuTests "ClArray.map2 multiplication tests" testFixturesMul

tests/GraphBLAS-sharp.Tests/Common/PrefixSum.fs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ open Brahma.FSharp
77
open GraphBLAS.FSharp.Backend.Common
88
open GraphBLAS.FSharp.Tests.Context
99
open GraphBLAS.FSharp
10+
open GraphBLAS.FSharp.Backend.Objects.ClCell
1011

1112
let logger = Log.create "ClArray.PrefixSum.Tests"
1213

@@ -28,13 +29,11 @@ let makeTest plus zero isEqual scan (array: 'a []) =
2829

2930
let actual, actualSum =
3031
use clArray = context.CreateClArray array
31-
use total = context.CreateClCell()
32-
scan q clArray total zero |> ignore
32+
let (total: ClCell<_>) = scan q clArray zero
3333

3434
let actual = Array.zeroCreate<'a> clArray.Length
35-
let actualSum = [| zero |]
36-
q.Post(Msg.CreateToHostMsg(total, actualSum))
37-
q.PostAndReply(fun ch -> Msg.CreateToHostMsg(clArray, actual, ch)), actualSum.[0]
35+
let actualSum = total.ToHostAndFree(q)
36+
q.PostAndReply(fun ch -> Msg.CreateToHostMsg(clArray, actual, ch)), actualSum
3837

3938
logger.debug (
4039
eventX "Actual is {actual}\n"

0 commit comments

Comments
 (0)