|
1 | 1 | module Backend.EwiseAdd |
2 | 2 |
|
3 | | -open FsCheck |
4 | 3 | open Expecto |
5 | 4 | open Expecto.Logging |
6 | 5 | open Expecto.Logging.Message |
7 | 6 | open Brahma.FSharp.OpenCL |
8 | 7 | open GraphBLAS.FSharp.Backend |
9 | 8 | open GraphBLAS.FSharp |
10 | | -open GraphBLAS.FSharp.Tests.Generators |
11 | 9 | open GraphBLAS.FSharp.Tests.Utils |
| 10 | +open OpenCL.Net |
12 | 11 |
|
13 | 12 | let logger = Log.create "EwiseAdd.Tests" |
14 | 13 |
|
15 | | -let context = |
16 | | - let deviceType = ClDeviceType.Default |
17 | | - let platformName = ClPlatform.Any |
18 | | - ClContext(platformName, deviceType) |
19 | | - |
20 | | -let getMatricesToAdd generator size isZero mFormat = |
21 | | - let gen = generator |> Arb.toGen |
22 | | - let m1, m2 = (Gen.sample (abs size) 1 gen).[0] |
23 | | - |
24 | | - let mtx1, mtx2 = |
25 | | - createMatrixFromArray2D mFormat m1 isZero, createMatrixFromArray2D mFormat m2 isZero |
26 | | - |
27 | | - mtx1, mtx2, m1, m2 |
28 | | - |
29 | | -let checkResult op zero (baseMtx1: 'a [,]) (baseMtx2: 'a [,]) (actual: Matrix<'a>) = |
| 14 | +let checkResult isEqual op zero (baseMtx1: 'a [,]) (baseMtx2: 'a [,]) (actual: Matrix<'a>) = |
30 | 15 | let rows = Array2D.length1 baseMtx1 |
31 | 16 | let columns = Array2D.length2 baseMtx1 |
32 | 17 | Expect.equal columns actual.ColumnCount "The number of columns should be the same." |
33 | 18 | Expect.equal rows actual.RowCount "The number of rows should be the same." |
34 | 19 |
|
35 | | - let expected = Array2D.create rows columns zero |
| 20 | + let expected2D = Array2D.create rows columns zero |
36 | 21 |
|
37 | 22 | for i in 0 .. rows - 1 do |
38 | 23 | for j in 0 .. columns - 1 do |
39 | | - expected.[i, j] <- op baseMtx1.[i, j] baseMtx2.[i, j] |
40 | | - |
41 | | - let actual2D = |
42 | | - Array2D.create actual.RowCount actual.ColumnCount zero |
43 | | - |
44 | | - let actual2D = |
45 | | - match actual with |
46 | | - | MatrixCOO actual -> |
47 | | - for i in 0 .. actual.Rows.Length - 1 do |
48 | | - actual2D.[actual.Rows.[i], actual.Columns.[i]] <- actual.Values.[i] |
49 | | - |
50 | | - actual2D |
51 | | - | MatrixCSR actual -> |
52 | | - let rowIndices = |
53 | | - Array.create actual.ColumnIndices.Length 0 |
| 24 | + expected2D.[i, j] <- op baseMtx1.[i, j] baseMtx2.[i, j] |
54 | 25 |
|
55 | | - for i in 0 .. actual.RowCount - 1 do |
56 | | - if i < actual.RowCount - 1 then |
57 | | - let rowStart = actual.RowPointers.[i] |
58 | | - let rowEnd = actual.RowPointers.[i + 1] |
59 | | - let rowLength = rowEnd - rowStart |
| 26 | + let actual2D = Array2D.create rows columns zero |
60 | 27 |
|
61 | | - for j in 0 .. rowLength - 1 do |
62 | | - rowIndices.[rowStart + j] <- i |
63 | | - else |
64 | | - let rowStart = actual.RowPointers.[actual.RowCount - 1] |
65 | | - let rowLength = rowIndices.Length - rowStart |
66 | | - |
67 | | - for j in 0 .. rowLength - 1 do |
68 | | - rowIndices.[rowStart + j] <- i |
69 | | - |
70 | | - for i in 0 .. rowIndices.Length - 1 do |
71 | | - actual2D.[rowIndices.[i], actual.ColumnIndices.[i]] <- actual.Values.[i] |
72 | | - |
73 | | - actual2D |
| 28 | + match actual with |
| 29 | + | MatrixCOO actual -> |
| 30 | + for i in 0 .. actual.Rows.Length - 1 do |
| 31 | + actual2D.[actual.Rows.[i], actual.Columns.[i]] <- actual.Values.[i] |
| 32 | + | _ -> failwith "Impossible case." |
74 | 33 |
|
75 | 34 | for i in 0 .. rows - 1 do |
76 | 35 | for j in 0 .. columns - 1 do |
77 | | - Expect.equal actual2D.[i, j] expected.[i, j] "Elements of matrices should be equals." |
78 | | - |
79 | | -let testCases = |
80 | | - let q = context.Provider.CommandQueue |
| 36 | + Expect.isTrue (isEqual actual2D.[i, j] expected2D.[i, j]) "Values should be the same." |
| 37 | + |
| 38 | +let correctnessGenericTest |
| 39 | + zero |
| 40 | + op |
| 41 | + (addFun: MailboxProcessor<Msg> -> Backend.Matrix<'a> -> Backend.Matrix<'a> -> Backend.Matrix<'a>) |
| 42 | + toCOOFun |
| 43 | + (isEqual: 'a -> 'a -> bool) |
| 44 | + (case: OperationCase) |
| 45 | + (leftMatrix: 'a [,], rightMatrix: 'a [,]) |
| 46 | + = |
| 47 | + let q = case.ClContext.Provider.CommandQueue |
81 | 48 | q.Error.Add(fun e -> failwithf "%A" e) |
82 | 49 |
|
83 | | - let setSizeForAddFun mAdd = |
84 | | - fun (array: array<_>) -> |
85 | | - let wgSize = |
86 | | - [| for i in 0 .. 5 -> pown 2 i |] |
87 | | - |> Array.filter (fun i -> array.Length % i = 0) |
88 | | - |> Array.max |
89 | | - |
90 | | - mAdd (if wgSize = 1 then 2 else wgSize) q |
91 | | - |
92 | | - let makeTest (context: ClContext) generator size mFormat op qOp zero = |
93 | | - let mtx1, mtx2, baseMtx1, baseMtx2 = |
94 | | - getMatricesToAdd generator size ((=) zero) mFormat |
95 | | - |
96 | | - match mtx1, mtx2 with |
97 | | - | MatrixCOO mtx1, MatrixCOO mtx2 -> |
98 | | - if mtx1.Values.Length > 0 && mtx2.Values.Length > 0 then |
99 | | - use clRows1 = context.CreateClArray mtx1.Rows |
100 | | - use clColumns1 = context.CreateClArray mtx1.Columns |
101 | | - use clValues1 = context.CreateClArray mtx1.Values |
102 | | - |
103 | | - let m1 = |
104 | | - { Context = context |
105 | | - Backend.COOMatrix.RowCount = mtx1.RowCount |
106 | | - ColumnCount = mtx1.ColumnCount |
107 | | - Rows = clRows1 |
108 | | - Columns = clColumns1 |
109 | | - Values = clValues1 } |
110 | | - |
111 | | - use clRows2 = context.CreateClArray mtx2.Rows |
112 | | - use clColumns2 = context.CreateClArray mtx2.Columns |
113 | | - use clValues2 = context.CreateClArray mtx2.Values |
114 | | - |
115 | | - let m2 = |
116 | | - { Context = context |
117 | | - Backend.COOMatrix.RowCount = mtx2.RowCount |
118 | | - ColumnCount = mtx2.ColumnCount |
119 | | - Rows = clRows2 |
120 | | - Columns = clColumns2 |
121 | | - Values = clValues2 } |
122 | | - |
123 | | - let getAddFun = |
124 | | - COOMatrix.eWiseAdd context qOp |> setSizeForAddFun |
125 | | - |
126 | | - let add = getAddFun mtx1.Values |
127 | | - |
128 | | - let actual = |
129 | | - let res: Backend.COOMatrix<'a> = add m1 m2 |
130 | | - let actualRows = Array.zeroCreate res.Rows.Length |
131 | | - let actualColumns = Array.zeroCreate res.Columns.Length |
132 | | - let actualValues = Array.zeroCreate res.Values.Length |
133 | | - |
134 | | - let _ = |
135 | | - q.Post(Msg.CreateToHostMsg(res.Rows, actualRows)) |
136 | | - |
137 | | - let _ = |
138 | | - q.Post(Msg.CreateToHostMsg(res.Columns, actualColumns)) |
139 | | - |
140 | | - let _ = |
141 | | - q.PostAndReply(fun ch -> Msg.CreateToHostMsg(res.Values, actualValues, ch)) |
142 | | - |
143 | | - q.Post(Msg.CreateFreeMsg<_>(res.Columns)) |
144 | | - q.Post(Msg.CreateFreeMsg<_>(res.Rows)) |
145 | | - q.Post(Msg.CreateFreeMsg<_>(res.Values)) |
146 | | - |
147 | | - { RowCount = res.RowCount |
148 | | - ColumnCount = res.ColumnCount |
149 | | - Rows = actualRows |
150 | | - Columns = actualColumns |
151 | | - Values = actualValues } |
152 | | - |
153 | | - logger.debug ( |
154 | | - eventX "Actual is {actual}" |
155 | | - >> setField "actual" (sprintf "%A" actual) |
156 | | - ) |
157 | | - |
158 | | - checkResult op zero baseMtx1 baseMtx2 (MatrixCOO actual) |
| 50 | + let mtx1 = |
| 51 | + createMatrixFromArray2D case.MatrixCase leftMatrix (isEqual zero) |
159 | 52 |
|
160 | | - | MatrixCSR mtx1, MatrixCSR mtx2 -> |
161 | | - if mtx1.Values.Length > 0 && mtx2.Values.Length > 0 then |
162 | | - use clRows1 = context.CreateClArray mtx1.RowPointers |
163 | | - use clColumns1 = context.CreateClArray mtx1.ColumnIndices |
164 | | - use clValues1 = context.CreateClArray mtx1.Values |
| 53 | + let mtx2 = |
| 54 | + createMatrixFromArray2D case.MatrixCase rightMatrix (isEqual zero) |
165 | 55 |
|
166 | | - let m1 = |
167 | | - { Context = context |
168 | | - Backend.CSRMatrix.RowCount = mtx1.RowCount |
169 | | - ColumnCount = mtx1.ColumnCount |
170 | | - RowPointers = clRows1 |
171 | | - Columns = clColumns1 |
172 | | - Values = clValues1 } |
| 56 | + if mtx1.NNZCount > 0 && mtx2.NNZCount > 0 then |
| 57 | + let m1 = mtx1.ToBackend case.ClContext |
| 58 | + let m2 = mtx2.ToBackend case.ClContext |
173 | 59 |
|
174 | | - use clRows2 = context.CreateClArray mtx2.RowPointers |
175 | | - use clColumns2 = context.CreateClArray mtx2.ColumnIndices |
176 | | - use clValues2 = context.CreateClArray mtx2.Values |
| 60 | + let res = addFun q m1 m2 |
177 | 61 |
|
178 | | - let m2 = |
179 | | - { Context = context |
180 | | - Backend.CSRMatrix.RowCount = mtx2.RowCount |
181 | | - ColumnCount = mtx2.ColumnCount |
182 | | - RowPointers = clRows2 |
183 | | - Columns = clColumns2 |
184 | | - Values = clValues2 } |
185 | | - |
186 | | - let getAddFun = |
187 | | - CSRMatrix.eWiseAdd context qOp |> setSizeForAddFun |
| 62 | + m1.Dispose() |
| 63 | + m2.Dispose() |
188 | 64 |
|
189 | | - let add = getAddFun mtx1.Values |
| 65 | + let cooRes = toCOOFun q res |
| 66 | + let actual = Matrix.FromBackend q cooRes |
190 | 67 |
|
191 | | - let actual = |
192 | | - let res: Backend.CSRMatrix<'a> = add m1 m2 |
193 | | - let actualRows = Array.zeroCreate res.RowPointers.Length |
194 | | - let actualColumns = Array.zeroCreate res.Columns.Length |
195 | | - let actualValues = Array.zeroCreate res.Values.Length |
196 | | - |
197 | | - let _ = |
198 | | - q.Post(Msg.CreateToHostMsg(res.RowPointers, actualRows)) |
| 68 | + cooRes.Dispose() |
| 69 | + res.Dispose() |
199 | 70 |
|
200 | | - let _ = |
201 | | - q.Post(Msg.CreateToHostMsg(res.Columns, actualColumns)) |
| 71 | + logger.debug ( |
| 72 | + eventX "Actual is {actual}" |
| 73 | + >> setField "actual" (sprintf "%A" actual) |
| 74 | + ) |
202 | 75 |
|
203 | | - let _ = |
204 | | - q.PostAndReply(fun ch -> Msg.CreateToHostMsg(res.Values, actualValues, ch)) |
| 76 | + checkResult isEqual op zero leftMatrix rightMatrix actual |
205 | 77 |
|
206 | | - q.Post(Msg.CreateFreeMsg<_>(res.Columns)) |
207 | | - q.Post(Msg.CreateFreeMsg<_>(res.RowPointers)) |
208 | | - q.Post(Msg.CreateFreeMsg<_>(res.Values)) |
| 78 | +let testFixtures case = |
| 79 | + [ let config = defaultConfig |
| 80 | + let wgSize = 128 |
| 81 | + //Test name on multiple devices can be duplicated due to the ClContext.toString |
| 82 | + let getCorrectnessTestName datatype = |
| 83 | + sprintf "Correctness on %s, %A" datatype case |
209 | 84 |
|
210 | | - { CSRMatrix.RowCount = res.RowCount |
211 | | - ColumnCount = res.ColumnCount |
212 | | - RowPointers = actualRows |
213 | | - ColumnIndices = actualColumns |
214 | | - Values = actualValues } |
| 85 | + let boolAdd = |
| 86 | + Matrix.eWiseAdd case.ClContext <@ (||) @> wgSize |
215 | 87 |
|
216 | | - logger.debug ( |
217 | | - eventX "Actual is {actual}" |
218 | | - >> setField "actual" (sprintf "%A" actual) |
219 | | - ) |
| 88 | + let boolToCOO = Matrix.toCOO case.ClContext wgSize |
220 | 89 |
|
221 | | - checkResult op zero baseMtx1 baseMtx2 (MatrixCSR(actual)) |
| 90 | + case |
| 91 | + |> correctnessGenericTest false (||) boolAdd boolToCOO (=) |
| 92 | + |> testPropertyWithConfig config (getCorrectnessTestName "bool") |
222 | 93 |
|
223 | | - | _ -> failwith "No other types of matrices tested yet." |
| 94 | + let intAdd = |
| 95 | + Matrix.eWiseAdd case.ClContext <@ (+) @> wgSize |
224 | 96 |
|
225 | | - [ testProperty "Correctness test on random int matrices COO" |
226 | | - <| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.IntType()) size COO (+) <@ (+) @> 0) |
| 97 | + let intToCOO = Matrix.toCOO case.ClContext wgSize |
227 | 98 |
|
228 | | - testProperty "Correctness test on random bool matrices COO" |
229 | | - <| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.BoolType()) size COO (||) <@ (||) @> false) |
| 99 | + case |
| 100 | + |> correctnessGenericTest 0 (+) intAdd intToCOO (=) |
| 101 | + |> testPropertyWithConfig config (getCorrectnessTestName "int") |
230 | 102 |
|
231 | | - testProperty "Correctness test on random float matrices COO" |
232 | | - <| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.FloatType()) size COO (+) <@ (+) @> 0.0) |
| 103 | + let floatAdd = |
| 104 | + Matrix.eWiseAdd case.ClContext <@ (+) @> wgSize |
233 | 105 |
|
234 | | - testProperty "Correctness test on random byte matrices COO" |
235 | | - <| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.ByteType()) size COO (+) <@ (+) @> 0uy) |
| 106 | + let floatToCOO = Matrix.toCOO case.ClContext wgSize |
236 | 107 |
|
237 | | - testProperty "Correctness test on random int matrices CSR" |
238 | | - <| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.IntType()) size CSR (+) <@ (+) @> 0) |
| 108 | + case |
| 109 | + |> correctnessGenericTest 0.0 (+) floatAdd floatToCOO (fun x y -> abs (x - y) < Accuracy.medium.absolute) |
| 110 | + |> testPropertyWithConfig config (getCorrectnessTestName "float") |
239 | 111 |
|
240 | | - testProperty "Correctness test on random bool matrices CSR" |
241 | | - <| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.BoolType()) size CSR (||) <@ (||) @> false) |
| 112 | + let byteAdd = |
| 113 | + Matrix.eWiseAdd case.ClContext <@ (+) @> wgSize |
242 | 114 |
|
243 | | - testProperty "Correctness test on random float matrices CSR" |
244 | | - <| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.FloatType()) size CSR (+) <@ (+) @> 0.0) |
| 115 | + let byteToCOO = Matrix.toCOO case.ClContext wgSize |
245 | 116 |
|
246 | | - testProperty "Correctness test on random byte matrices CSR" |
247 | | - <| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.ByteType()) size CSR (+) <@ (+) @> 0uy) ] |
| 117 | + case |
| 118 | + |> correctnessGenericTest 0uy (+) byteAdd byteToCOO (=) |
| 119 | + |> testPropertyWithConfig config (getCorrectnessTestName "byte") ] |
248 | 120 |
|
249 | 121 | let tests = |
250 | | - testCases |> testList "Backend.EwiseAdd tests" |
| 122 | + testCases |
| 123 | + |> List.filter |
| 124 | + (fun case -> |
| 125 | + let mutable e = ErrorCode.Unknown |
| 126 | + let device = case.ClContext.Device |
| 127 | + |
| 128 | + let deviceType = |
| 129 | + Cl |
| 130 | + .GetDeviceInfo(device, DeviceInfo.Type, &e) |
| 131 | + .CastTo<DeviceType>() |
| 132 | + |
| 133 | + deviceType = DeviceType.Default) |
| 134 | + |> List.collect testFixtures |
| 135 | + |> testList "Backend.Matrix.eWiseAdd tests" |
0 commit comments