|
| 1 | +module GraphBLAS.FSharp.Tests.Backend.Common.Map2 |
| 2 | + |
| 3 | +open Brahma.FSharp |
| 4 | +open GraphBLAS.FSharp.Tests |
| 5 | +open GraphBLAS.FSharp.Tests.Context |
| 6 | +open GraphBLAS.FSharp.Backend.Common |
| 7 | +open GraphBLAS.FSharp.Backend.Quotes |
| 8 | +open Expecto |
| 9 | +open GraphBLAS.FSharp.Backend.Objects.ClContext |
| 10 | + |
| 11 | +let context = defaultContext.Queue |
| 12 | + |
| 13 | +let wgSize = Utils.defaultWorkGroupSize |
| 14 | + |
| 15 | +let config = Utils.defaultConfig |
| 16 | + |
| 17 | +let makeTest<'a when 'a: equality> testContext clMapFun hostMapFun isEqual (leftArray: 'a [], rightArray: 'a []) = |
| 18 | + if leftArray.Length > 0 then |
| 19 | + let context = testContext.ClContext |
| 20 | + let q = testContext.Queue |
| 21 | + |
| 22 | + let leftClArray = |
| 23 | + context.CreateClArrayWithSpecificAllocationMode(DeviceOnly, leftArray) |
| 24 | + |
| 25 | + let rightClArray = |
| 26 | + context.CreateClArrayWithSpecificAllocationMode(DeviceOnly, rightArray) |
| 27 | + |
| 28 | + let (actualDevice: ClArray<'a>) = |
| 29 | + clMapFun q HostInterop leftClArray rightClArray |
| 30 | + |
| 31 | + let actualHost = Array.zeroCreate actualDevice.Length |
| 32 | + |
| 33 | + q.PostAndReply(fun ch -> Msg.CreateToHostMsg(actualDevice, actualHost, ch)) |
| 34 | + |> ignore |
| 35 | + |
| 36 | + let expected = |
| 37 | + Array.map2 hostMapFun leftArray rightArray |
| 38 | + |
| 39 | + "Arrays must be the same" |
| 40 | + |> Utils.compareArrays isEqual actualHost expected |
| 41 | + |
| 42 | +let createTest<'a when 'a: equality> (testContext: TestContext) isEqual hostMapFun mapFunQ = |
| 43 | + |
| 44 | + let context = testContext.ClContext |
| 45 | + |
| 46 | + let map = ClArray.map2 context wgSize mapFunQ |
| 47 | + |
| 48 | + makeTest<'a> testContext map hostMapFun isEqual |
| 49 | + |> testPropertyWithConfig config $"Correctness on {typeof<'a>}" |
| 50 | + |
| 51 | +let testFixturesAdd (testContext: TestContext) = |
| 52 | + [ createTest<int> testContext (=) (+) <@ (+) @> |
| 53 | + createTest<bool> testContext (=) (||) <@ (||) @> |
| 54 | + |
| 55 | + if Utils.isFloat64Available testContext.ClContext.ClDevice then |
| 56 | + createTest<float> testContext Utils.floatIsEqual (+) <@ (+) @> |
| 57 | + |
| 58 | + createTest<float32> testContext Utils.float32IsEqual (+) <@ (+) @> |
| 59 | + createTest<byte> testContext (=) (+) <@ (+) @> ] |
| 60 | + |
| 61 | +let addTests = |
| 62 | + TestCases.gpuTests "ClArray.map2 add tests" testFixturesAdd |
| 63 | + |
| 64 | +let testFixturesMul (testContext: TestContext) = |
| 65 | + [ createTest<int> testContext (=) (*) <@ (*) @> |
| 66 | + createTest<bool> testContext (=) (&&) <@ (&&) @> |
| 67 | + |
| 68 | + if Utils.isFloat64Available testContext.ClContext.ClDevice then |
| 69 | + createTest<float> testContext Utils.floatIsEqual (*) <@ (*) @> |
| 70 | + |
| 71 | + createTest<float32> testContext Utils.float32IsEqual (*) <@ (*) @> |
| 72 | + createTest<byte> testContext (=) (+) <@ (+) @> ] |
| 73 | + |
| 74 | +let mulTests = |
| 75 | + TestCases.gpuTests "ClArray.map2 multiplication tests" testFixturesMul |
0 commit comments