@@ -2,6 +2,7 @@ module GraphBLAS.FSharp.Tests.Backend.Common.Reduce.ByKey
22
33open Expecto
44open GraphBLAS.FSharp .Backend .Common
5+ open GraphBLAS.FSharp .Test
56open GraphBLAS.FSharp .Tests
67open GraphBLAS.FSharp .Backend .Objects .ClContext
78open Brahma.FSharp
@@ -185,3 +186,156 @@ let sequentialSegmentTests =
185186 createTestSequentialSegments< bool> (=) (&&) <@ (&&) @> ]
186187
187188 testList " Sequential segments" [ addTests; mulTests ]
189+
190+ let checkResult2D isEqual firstActualKeys secondActualKeys actualValues firstKeys secondKeys values reduceOp =
191+
192+ let expectedFirstKeys , expectedSecondKeys , expectedValues =
193+ HostPrimitives.reduceByKey2D firstKeys secondKeys values reduceOp
194+
195+ " First keys must be the same"
196+ |> Utils.compareArrays (=) firstActualKeys expectedFirstKeys
197+
198+ " Second keys must be the same"
199+ |> Utils.compareArrays (=) secondActualKeys expectedSecondKeys
200+
201+ " Values must the same"
202+ |> Utils.compareArrays isEqual actualValues expectedValues
203+
204+ let makeTest2D isEqual reduce reduceOp ( array : ( int * int * 'a ) []) =
205+ let firstKeys , secondKeys , values =
206+ array
207+ |> Array.sortBy ( fun ( fst , snd , _ ) -> fst, snd)
208+ |> Array.unzip3
209+
210+ if firstKeys.Length > 0 then
211+ let clFirstKeys =
212+ context.CreateClArrayWithSpecificAllocationMode( DeviceOnly, firstKeys)
213+
214+ let clSecondKeys =
215+ context.CreateClArrayWithSpecificAllocationMode( DeviceOnly, secondKeys)
216+
217+ let clValues =
218+ context.CreateClArrayWithSpecificAllocationMode( DeviceOnly, values)
219+
220+ let resultLength = Array.length <| Array.distinctBy ( fun ( fst , snd , _ ) -> ( fst, snd)) array
221+
222+ let clFirstActualKeys , clSecondActualKeys , clActualValues : ClArray < int > * ClArray < int > * ClArray < 'a > =
223+ reduce processor HostInterop resultLength clFirstKeys clSecondKeys clValues
224+
225+ clValues.Free processor
226+ clFirstKeys.Free processor
227+ clSecondKeys.Free processor
228+
229+ let actualValues = clActualValues.ToHostAndFree processor
230+ let firstActualKeys = clFirstActualKeys.ToHostAndFree processor
231+ let secondActualKeys = clSecondActualKeys.ToHostAndFree processor
232+
233+ checkResult2D isEqual firstActualKeys secondActualKeys actualValues firstKeys secondKeys values reduceOp
234+
235+ let createTestSequential2D < 'a > ( isEqual : 'a -> 'a -> bool ) reduceOp reduceOpQ =
236+
237+ let reduce =
238+ Reduce.ByKey2D.sequential context Utils.defaultWorkGroupSize reduceOpQ
239+
240+ makeTest2D isEqual reduce reduceOp
241+ |> testPropertyWithConfig { config with arbitrary = [ typeof< Generators.ArrayOfDistinctKeys> ]; endSize = 10 } $" test on {typeof<'a>}"
242+
243+ let sequential2DTest =
244+ let addTests =
245+ testList
246+ " add tests"
247+ [ createTestSequential2D< int> (=) (+) <@ (+) @>
248+ createTestSequential2D< byte> (=) (+) <@ (+) @>
249+
250+ if Utils.isFloat64Available context.ClDevice then
251+ createTestSequential2D< float> Utils.floatIsEqual (+) <@ (+) @>
252+
253+ createTestSequential2D< float32> Utils.float32IsEqual (+) <@ (+) @>
254+ createTestSequential2D< bool> (=) (||) <@ (||) @> ]
255+
256+ let mulTests =
257+ testList
258+ " mul tests"
259+ [ createTestSequential2D< int> (=) (*) <@ (*) @>
260+ createTestSequential2D< byte> (=) (*) <@ (*) @>
261+
262+ if Utils.isFloat64Available context.ClDevice then
263+ createTestSequential2D< float> Utils.floatIsEqual (*) <@ (*) @>
264+
265+ createTestSequential2D< float32> Utils.float32IsEqual (*) <@ (*) @>
266+ createTestSequential2D< bool> (=) (&&) <@ (&&) @> ]
267+
268+ testList " Sequential 2D" [ addTests; mulTests ]
269+
270+ let makeTestSequentialSegments2D isEqual reduce reduceOp ( array : ( int * int * 'a ) []) =
271+
272+ let firstKeys , secondKeys , values =
273+ array
274+ |> Array.sortBy ( fun ( fst , snd , _ ) -> fst, snd)
275+ |> Array.unzip3
276+
277+ if firstKeys.Length > 0 then
278+ let offsets =
279+ array
280+ |> Array.map ( fun ( fst , snd , _ ) -> fst, snd)
281+ |> HostPrimitives.getUniqueBitmapFirstOccurrence
282+ |> HostPrimitives.getBitPositions
283+
284+ let resultLength = offsets.Length
285+
286+ let firstKeys , secondKeys , values = Array.unzip3 array
287+
288+ let clOffsets =
289+ context.CreateClArrayWithSpecificAllocationMode( HostInterop, offsets)
290+
291+ let clFirstKeys =
292+ context.CreateClArrayWithSpecificAllocationMode( DeviceOnly, firstKeys)
293+
294+ let clSecondKeys =
295+ context.CreateClArrayWithSpecificAllocationMode( DeviceOnly, secondKeys)
296+
297+ let clValues =
298+ context.CreateClArrayWithSpecificAllocationMode( DeviceOnly, values)
299+
300+ let clFirstActualKeys , clSecondActualKeys , clReducedValues : ClArray < int > * ClArray < int > * ClArray < 'a > =
301+ reduce processor DeviceOnly resultLength clOffsets clFirstKeys clSecondKeys clValues
302+
303+ let reducedFirsKeys = clFirstActualKeys.ToHostAndFree processor
304+ let reducesSecondKeys = clSecondActualKeys.ToHostAndFree processor
305+ let reducedValues = clReducedValues.ToHostAndFree processor
306+
307+ checkResult2D isEqual reducedFirsKeys reducesSecondKeys reducedValues firstKeys secondKeys values reduceOp
308+
309+ let createTestSequentialSegments2D < 'a > ( isEqual : 'a -> 'a -> bool ) reduceOp reduceOpQ =
310+ let reduce =
311+ Reduce.ByKey2D.segmentSequential context Utils.defaultWorkGroupSize reduceOpQ
312+
313+ makeTestSequentialSegments2D isEqual reduce reduceOp
314+ |> testPropertyWithConfig { config with arbitrary = [ typeof< Generators.ArrayOfDistinctKeys> ] } $" test on {typeof<'a>}"
315+
316+ let sequentialSegmentTests2D =
317+ let addTests =
318+ testList
319+ " add tests"
320+ [ createTestSequentialSegments2D< int> (=) (+) <@ (+) @>
321+ createTestSequentialSegments2D< byte> (=) (+) <@ (+) @>
322+
323+ if Utils.isFloat64Available context.ClDevice then
324+ createTestSequentialSegments2D< float> Utils.floatIsEqual (+) <@ (+) @>
325+
326+ createTestSequentialSegments2D< float32> Utils.float32IsEqual (+) <@ (+) @>
327+ createTestSequentialSegments2D< bool> (=) (||) <@ (||) @> ]
328+
329+ let mulTests =
330+ testList
331+ " mul tests"
332+ [ createTestSequentialSegments2D< int> (=) (*) <@ (*) @>
333+ createTestSequentialSegments2D< byte> (=) (*) <@ (*) @>
334+
335+ if Utils.isFloat64Available context.ClDevice then
336+ createTestSequentialSegments2D< float> Utils.floatIsEqual (*) <@ (*) @>
337+
338+ createTestSequentialSegments2D< float32> Utils.float32IsEqual (*) <@ (*) @>
339+ createTestSequentialSegments2D< bool> (=) (&&) <@ (&&) @> ]
340+
341+ testList " Sequential segments 2D" [ addTests; mulTests ]
0 commit comments