@@ -5,23 +5,21 @@ open Expecto.Logging
55open Brahma.FSharp
66open GraphBLAS.FSharp .Tests
77open GraphBLAS.FSharp .Tests .Context
8- open GraphBLAS.FSharp
8+ open GraphBLAS.FSharp . Backend . Quotes
99open GraphBLAS.FSharp .Backend .Common
1010open GraphBLAS.FSharp .Backend .Objects .ArraysExtensions
1111
1212let logger = Log.create " Scatter.Tests"
1313
1414let context = defaultContext.ClContext
1515
16- let config =
17- { Tests.Utils.defaultConfig with
18- endSize = 1000000 }
16+ let config = { Utils.defaultConfig with endSize = 10000 }
1917
20- let wgSize = Tests. Utils.defaultWorkGroupSize
18+ let wgSize = Utils.defaultWorkGroupSize
2119
2220let q = defaultContext.Queue
2321
24- let makeTest hostScatter scatter ( array : ( int * 'a ) []) ( result : 'a []) =
22+ let makeTest < 'a when 'a : equality > hostScatter scatter ( array : ( int * 'a ) []) ( result : 'a []) =
2523 if array.Length > 0 then
2624 let positions , values = Array.unzip array
2725
@@ -30,40 +28,89 @@ let makeTest hostScatter scatter (array: (int * 'a) []) (result: 'a []) =
3028 |> hostScatter positions values
3129
3230 let actual =
33- use clPositions = context.CreateClArray positions
31+ let clPositions = context.CreateClArray positions
3432 use clValues = context.CreateClArray values
3533 use clResult = context.CreateClArray result
3634
3735 scatter q clPositions clValues clResult
3836
3937 clResult.ToHostAndFree q
4038
41- $" Arrays should be equal. Actual is \n %A {actual}, expected \n %A {expected} "
42- |> Tests. Utils.compareArrays (=) actual expected
39+ $" Arrays should be equal."
40+ |> Utils.compareArrays (=) actual expected
4341
44- let testFixturesLast < 'a when 'a : equality > hostScatter =
45- Scatter.scatterLastOccurrence< 'a > context wgSize
46- |> makeTest hostScatter
47- |> testPropertyWithConfig { config with endSize = 10 } $" Correctness on %A {typeof<'a>}"
42+ let testFixturesLast < 'a when 'a : equality > =
43+ Scatter.scatterLastOccurrence context wgSize
44+ |> makeTest< 'a > HostPrimitives.scatterLastOccurrence
45+ |> testPropertyWithConfig config $" Correctness on %A {typeof<'a>}"
4846
49- let testFixturesFirst < 'a when 'a : equality > hostScatter =
50- Scatter.scatterFirstOccurrence< 'a > context wgSize
51- |> makeTest hostScatter
52- |> testPropertyWithConfig { config with endSize = 10 } $" Correctness on %A {typeof<'a>}"
47+ let testFixturesFirst < 'a when 'a : equality > =
48+ Scatter.scatterFirstOccurrence context wgSize
49+ |> makeTest< 'a > HostPrimitives.scatterFirstOccurrence
50+ |> testPropertyWithConfig config $" Correctness on %A {typeof<'a>}"
5351
5452let tests =
5553 q.Error.Add( fun e -> failwithf $" %A {e}" )
5654
5755 let last =
58- [ testFixturesLast< int> HostPrimitives.scatterLastOccurrence
59- testFixturesLast< byte> HostPrimitives.scatterLastOccurrence
60- testFixturesLast< bool> HostPrimitives.scatterLastOccurrence ]
56+ [ testFixturesLast< int>
57+ testFixturesLast< byte>
58+ testFixturesLast< bool> ]
6159 |> testList " Last Occurrence"
6260
6361 let first =
64- [ testFixturesFirst< int> HostPrimitives.scatterFirstOccurrence
65- testFixturesFirst< byte> HostPrimitives.scatterFirstOccurrence
66- testFixturesFirst< bool> HostPrimitives.scatterFirstOccurrence ]
62+ [ testFixturesFirst< int>
63+ testFixturesFirst< byte>
64+ testFixturesFirst< bool> ]
6765 |> testList " First Occurrence"
6866
6967 testList " Scatter tests" [ first; last]
68+
69+ let makeTestInit < 'a when 'a : equality > positionsMap scatter ( values : 'a []) ( result : 'a []) =
70+ if values.Length > 0 then
71+
72+ let positionsAndValues =
73+ Array.mapi ( fun index value -> positionsMap index, value) values
74+
75+ let expected =
76+ Array.init result.Length ( fun index ->
77+ match Array.tryFind ( fst >> ((=) index)) positionsAndValues with
78+ | Some (_, value) -> value
79+ | None -> result.[ index])
80+
81+ let actual =
82+ let values = Array.map snd positionsAndValues
83+
84+ use clValues = context.CreateClArray values
85+ use clResult = context.CreateClArray result
86+
87+ scatter q clValues clResult
88+
89+ clResult.ToHostAndFree q
90+
91+ $" Arrays should be equal."
92+ |> Utils.compareArrays (=) actual expected
93+
94+ let createInitTest < 'a when 'a : equality > indexMap indexMapQ =
95+ Scatter.init< 'a> indexMapQ context Utils.defaultWorkGroupSize
96+ |> makeTestInit< 'a> indexMap
97+ |> testPropertyWithConfig config $" test on {typeof<'a>}"
98+
99+ let initTests =
100+ q.Error.Add( fun e -> failwithf $" %A {e}" )
101+
102+ let idTest =
103+ [ createInitTest< int> id Map.id
104+ createInitTest< byte> id Map.id
105+ createInitTest< bool> id Map.id ]
106+ |> testList " id"
107+
108+ let inc = ((+) 1 )
109+
110+ let incTest =
111+ [ createInitTest< int> inc Map.inc
112+ createInitTest< byte> inc Map.inc
113+ createInitTest< bool> inc Map.inc ]
114+ |> testList " increment"
115+
116+ testList " Scatter init tests" [ idTest; incTest]
0 commit comments