@@ -8,6 +8,50 @@ open Brahma.FSharp.OpenCL.WorkflowBuilder.Basic
88open GlobalContext
99open TypeShape.Core
1010open GraphBLAS.FSharp .Tests
11+ open System
12+
13+ type OperationParameter =
14+ | MatrixFormatParam of MatrixBackendFormat
15+ | MaskTypeParam of MaskType
16+
17+ type OperationCase = {
18+ MatrixCase: MatrixBackendFormat
19+ MaskCase: MaskType
20+ }
21+
22+ let testCases =
23+ [
24+ Utils.listOfUnionCases< MatrixBackendFormat> |> List.map MatrixFormatParam
25+ Utils.listOfUnionCases< MaskType> |> List.map MaskTypeParam
26+ ]
27+ |> Utils.cartesian
28+ |> List.map
29+ ( fun list ->
30+ let ( MatrixFormatParam marixFormat ) = list.[ 0 ]
31+ let ( MaskTypeParam maskType ) = list.[ 1 ]
32+ {
33+ MatrixCase = marixFormat
34+ MaskCase = maskType
35+ }
36+ )
37+
38+ let createMatrix < 'a when 'a : struct and 'a : equality > matrixFormat args =
39+ match matrixFormat with
40+ | CSR ->
41+ Activator.CreateInstanceGeneric< CSRMatrix<_>>(
42+ Array.singleton typeof< 'a>, args
43+ )
44+ |> unbox< CSRMatrix< 'a>>
45+ :> Matrix< 'a>
46+ | COO ->
47+ Activator.CreateInstanceGeneric< COOMatrix<_>>(
48+ Array.singleton typeof< 'a>, args
49+ )
50+ |> unbox< COOMatrix< 'a>>
51+ :> Matrix< 'a>
52+
53+ let createCSR < 'a when 'a : struct and 'a : equality > = createMatrix< 'a> CSR
54+ let createCOO < 'a when 'a : struct and 'a : equality > = createMatrix< 'a> COO
1155
1256type PrimitiveType =
1357 | Float32
@@ -17,45 +61,73 @@ type PairOfSparseMatrices =
1761 static member Float32Type () =
1862 Arb.fromGen <| Generators.pairOfSparseMatricesGenerator
1963 Arb.generate< float32>
20- 0 f
21- ((=) 0 f )
64+ 0. f
65+ ((=) 0. f )
2266
2367 static member BoolType () =
2468 Arb.fromGen <| Generators.pairOfSparseMatricesGenerator
2569 Arb.generate< bool>
2670 false
2771 ((=) false )
2872
29- // type IPredicate =
30- // abstract Invoke : 'a -> bool
31-
32- // let equals (a: 'a) (b: 'a) = true
33-
34- // let reflexivity = {
35- // new IPredicate with
36- // member this.Invoke item = equals item item
37- // }
38-
39- // let conf = {
40- // Config.Verbose with
41- // Arbitrary = [typeof<Arbi>]
42- // }
43-
44- // let meaning randomType =
45- // match randomType with
46- // | Float32 -> typeof<float32>
47- // | Bool -> typeof<bool>
48-
49- // let check (predicate: IPredicate) (randomType: PrimitiveTypes) =
50- // let systemType = meaning randomType
51- // let shape = TypeShape.Create systemType
52- // shape.Accept {
53- // new ITypeVisitor<bool> with
54- // member this.Visit<'a>() =
55- // Check.One<'a -> bool>(conf, predicate.Invoke)
56- // true
57- // }
58-
59- // // генерится тип
60- // let checkGeneric (predicate: IPredicate) =
61- // Check.Quick<PrimitiveTypes -> bool>(check predicate)
73+ let meaning primitiveType =
74+ match primitiveType with
75+ | Float32 -> typeof< float32>
76+ | Bool -> typeof< bool>
77+
78+ let printer primitiveType =
79+ match primitiveType with
80+ | Float32 -> " float32"
81+ | Bool -> " bool"
82+
83+ let checkConcrete ( testCase : OperationCase ) ( primitiveType : PrimitiveType ) =
84+ let config = { FsCheckConfig.defaultConfig with arbitrary = [ typeof< PairOfSparseMatrices>] }
85+ let systemType = meaning primitiveType
86+ let prettyType = printer primitiveType
87+ let shape = TypeShape.Create systemType
88+ shape.Accept { new ITypeVisitor< Test> with
89+ member this.Visit < 'a >() =
90+ testPropertyWithConfig config ( sprintf " On type %s " prettyType) <|
91+ fun ( matrixA : 'a [,]) ( matrixB : 'a [,]) ->
92+ // let sparseA = createMatrix<'a> testCase.MatrixCase [|
93+ // box matrixA
94+ // box ((=) 0.)
95+ // |]
96+ // let sparseB = createMatrix<'a> testCase.MatrixCase [|
97+ // box matrixB
98+ // box ((=) 0.)
99+ // |]
100+
101+ let a = Matrix.ofArray2D matrixA ((=) 0. )
102+
103+ // let result =
104+ // opencl {
105+ // return! (sparseA + sparseB) mask stdSemiring
106+ // } |> oclContext.RunSync
107+
108+ ()
109+ // let a = LinearAlgebra.DenseMatrix.ofArray2 matrixA
110+ // let b = LinearAlgebra.DenseVector.ofArray matrixB
111+ // let c = b * a
112+ // let elementWiseDifference =
113+ // (result |> Vector.toSeq, c.AsArray() |> Seq.ofArray)
114+ // ||> Seq.zip
115+ // |> Seq.map (fun (a, b) -> a - b)
116+
117+ // Expect.all
118+ // elementWiseDifference
119+ // (fun diff -> abs diff < Accuracy.medium.absolute)
120+ // (sprintf "%A @ %A = %A\n case:\n %A" vector matrix result case)
121+ }
122+
123+ let testsI =
124+ testCases
125+ |> List.collect
126+ ( fun case ->
127+ [
128+ Utils.listOfUnionCases< PrimitiveType>
129+ |> List.map ( checkConcrete case)
130+ |> testList " Operation correctness"
131+ ]
132+ )
133+ |> testList " EWiseAdd Tests"
0 commit comments