|
| 1 | +module Backend.Sum |
| 2 | + |
| 3 | +open Expecto |
| 4 | +open Expecto.Logging |
| 5 | +open Expecto.Logging.Message |
| 6 | +open Brahma.FSharp |
| 7 | +open GraphBLAS.FSharp.Backend.Common |
| 8 | +open GraphBLAS.FSharp.Tests.Utils |
| 9 | +open FSharp.Quotations |
| 10 | + |
| 11 | +let logger = Log.create "Sum.Test" |
| 12 | + |
| 13 | +let context = defaultContext.ClContext |
| 14 | + |
| 15 | +let makeTest (q: MailboxProcessor<_>) sum plus zero isEqual (filter: 'a [] -> 'a []) (array: 'a []) = |
| 16 | + if array.Length > 0 then |
| 17 | + let array = filter array |
| 18 | + |
| 19 | + logger.debug ( |
| 20 | + eventX "Filtered array is {array}\n" |
| 21 | + >> setField "array" (sprintf "%A" array) |
| 22 | + ) |
| 23 | + |
| 24 | + let actualSum = |
| 25 | + use clArray = context.CreateClArray array |
| 26 | + use total = sum q clArray |
| 27 | + |
| 28 | + let actualSum = [| zero |] |
| 29 | + q.PostAndReply(fun ch -> Msg.CreateToHostMsg(total, actualSum, ch)).[0] |
| 30 | + |
| 31 | + logger.debug ( |
| 32 | + eventX "Actual is {actual}\n" |
| 33 | + >> setField "actual" (sprintf "%A" actualSum) |
| 34 | + ) |
| 35 | + |
| 36 | + let expectedSum = |
| 37 | + array |
| 38 | + |> Array.fold plus zero |
| 39 | + |
| 40 | + logger.debug ( |
| 41 | + eventX "Expected is {expected}\n" |
| 42 | + >> setField "expected" (sprintf "%A" expectedSum) |
| 43 | + ) |
| 44 | + |
| 45 | + "Total sums should be equal" |
| 46 | + |> Expect.equal actualSum expectedSum |
| 47 | + |
| 48 | +let testFixtures config wgSize q plus (plusQ: Expr<'a -> 'a -> 'a>) zero isEqual filter name = |
| 49 | + let sum = |
| 50 | + Sum.run context wgSize plusQ zero |
| 51 | + |
| 52 | + makeTest q sum plus zero isEqual filter |
| 53 | + |> testPropertyWithConfig config (sprintf "Correctness on %s" name) |
| 54 | + |
| 55 | +let tests = |
| 56 | + let config = defaultConfig |
| 57 | + |
| 58 | + let wgSize = 128 |
| 59 | + let q = defaultContext.Queue |
| 60 | + q.Error.Add(fun e -> failwithf "%A" e) |
| 61 | + |
| 62 | + let filterFloats = |
| 63 | + Array.filter (System.Double.IsNaN >> not) |
| 64 | + |
| 65 | + [ testFixtures config wgSize q (+) <@ (+) @> 0 (=) id "int add" |
| 66 | + testFixtures config wgSize q (+) <@ (+) @> 0uy (=) id "byte add" |
| 67 | + testFixtures config wgSize q max <@ max @> 0 (=) id "int max" |
| 68 | + testFixtures config wgSize q max <@ max @> 0.0 (=) filterFloats "float max" |
| 69 | + testFixtures config wgSize q max <@ max @> 0uy (=) id "byte max" |
| 70 | + testFixtures config wgSize q min <@ min @> System.Int32.MaxValue (=) id "int min" |
| 71 | + testFixtures config wgSize q min <@ min @> System.Double.MaxValue (=) filterFloats "float min" |
| 72 | + testFixtures config wgSize q min <@ min @> System.Byte.MaxValue (=) id "byte min" |
| 73 | + testFixtures config wgSize q (||) <@ (||) @> false (=) id "bool logic-or" |
| 74 | + testFixtures config wgSize q (&&) <@ (&&) @> true (=) id "bool logic-and" ] |
| 75 | + |> testList "Backend.Common.Sum tests" |
| 76 | + |
0 commit comments