Skip to content

Commit f0d4850

Browse files
committed
EWiseAdd tests generalization, proper float comparison, core compilation before tests
1 parent 5ca7d0d commit f0d4850

1 file changed

Lines changed: 85 additions & 200 deletions

File tree

Lines changed: 85 additions & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -1,250 +1,135 @@
11
module Backend.EwiseAdd
22

3-
open FsCheck
43
open Expecto
54
open Expecto.Logging
65
open Expecto.Logging.Message
76
open Brahma.FSharp.OpenCL
87
open GraphBLAS.FSharp.Backend
98
open GraphBLAS.FSharp
10-
open GraphBLAS.FSharp.Tests.Generators
119
open GraphBLAS.FSharp.Tests.Utils
10+
open OpenCL.Net
1211

1312
let logger = Log.create "EwiseAdd.Tests"
1413

15-
let context =
16-
let deviceType = ClDeviceType.Default
17-
let platformName = ClPlatform.Any
18-
ClContext(platformName, deviceType)
19-
20-
let getMatricesToAdd generator size isZero mFormat =
21-
let gen = generator |> Arb.toGen
22-
let m1, m2 = (Gen.sample (abs size) 1 gen).[0]
23-
24-
let mtx1, mtx2 =
25-
createMatrixFromArray2D mFormat m1 isZero, createMatrixFromArray2D mFormat m2 isZero
26-
27-
mtx1, mtx2, m1, m2
28-
29-
let checkResult op zero (baseMtx1: 'a [,]) (baseMtx2: 'a [,]) (actual: Matrix<'a>) =
14+
let checkResult isEqual op zero (baseMtx1: 'a [,]) (baseMtx2: 'a [,]) (actual: Matrix<'a>) =
3015
let rows = Array2D.length1 baseMtx1
3116
let columns = Array2D.length2 baseMtx1
3217
Expect.equal columns actual.ColumnCount "The number of columns should be the same."
3318
Expect.equal rows actual.RowCount "The number of rows should be the same."
3419

35-
let expected = Array2D.create rows columns zero
20+
let expected2D = Array2D.create rows columns zero
3621

3722
for i in 0 .. rows - 1 do
3823
for j in 0 .. columns - 1 do
39-
expected.[i, j] <- op baseMtx1.[i, j] baseMtx2.[i, j]
40-
41-
let actual2D =
42-
Array2D.create actual.RowCount actual.ColumnCount zero
43-
44-
let actual2D =
45-
match actual with
46-
| MatrixCOO actual ->
47-
for i in 0 .. actual.Rows.Length - 1 do
48-
actual2D.[actual.Rows.[i], actual.Columns.[i]] <- actual.Values.[i]
49-
50-
actual2D
51-
| MatrixCSR actual ->
52-
let rowIndices =
53-
Array.create actual.ColumnIndices.Length 0
24+
expected2D.[i, j] <- op baseMtx1.[i, j] baseMtx2.[i, j]
5425

55-
for i in 0 .. actual.RowCount - 1 do
56-
if i < actual.RowCount - 1 then
57-
let rowStart = actual.RowPointers.[i]
58-
let rowEnd = actual.RowPointers.[i + 1]
59-
let rowLength = rowEnd - rowStart
26+
let actual2D = Array2D.create rows columns zero
6027

61-
for j in 0 .. rowLength - 1 do
62-
rowIndices.[rowStart + j] <- i
63-
else
64-
let rowStart = actual.RowPointers.[actual.RowCount - 1]
65-
let rowLength = rowIndices.Length - rowStart
66-
67-
for j in 0 .. rowLength - 1 do
68-
rowIndices.[rowStart + j] <- i
69-
70-
for i in 0 .. rowIndices.Length - 1 do
71-
actual2D.[rowIndices.[i], actual.ColumnIndices.[i]] <- actual.Values.[i]
72-
73-
actual2D
28+
match actual with
29+
| MatrixCOO actual ->
30+
for i in 0 .. actual.Rows.Length - 1 do
31+
actual2D.[actual.Rows.[i], actual.Columns.[i]] <- actual.Values.[i]
32+
| _ -> failwith "Impossible case."
7433

7534
for i in 0 .. rows - 1 do
7635
for j in 0 .. columns - 1 do
77-
Expect.equal actual2D.[i, j] expected.[i, j] "Elements of matrices should be equals."
78-
79-
let testCases =
80-
let q = context.Provider.CommandQueue
36+
Expect.isTrue (isEqual actual2D.[i, j] expected2D.[i, j]) "Values should be the same."
37+
38+
let correctnessGenericTest
39+
zero
40+
op
41+
(addFun: MailboxProcessor<Msg> -> Backend.Matrix<'a> -> Backend.Matrix<'a> -> Backend.Matrix<'a>)
42+
toCOOFun
43+
(isEqual: 'a -> 'a -> bool)
44+
(case: OperationCase)
45+
(leftMatrix: 'a [,], rightMatrix: 'a [,])
46+
=
47+
let q = case.ClContext.Provider.CommandQueue
8148
q.Error.Add(fun e -> failwithf "%A" e)
8249

83-
let setSizeForAddFun mAdd =
84-
fun (array: array<_>) ->
85-
let wgSize =
86-
[| for i in 0 .. 5 -> pown 2 i |]
87-
|> Array.filter (fun i -> array.Length % i = 0)
88-
|> Array.max
89-
90-
mAdd (if wgSize = 1 then 2 else wgSize) q
91-
92-
let makeTest (context: ClContext) generator size mFormat op qOp zero =
93-
let mtx1, mtx2, baseMtx1, baseMtx2 =
94-
getMatricesToAdd generator size ((=) zero) mFormat
95-
96-
match mtx1, mtx2 with
97-
| MatrixCOO mtx1, MatrixCOO mtx2 ->
98-
if mtx1.Values.Length > 0 && mtx2.Values.Length > 0 then
99-
use clRows1 = context.CreateClArray mtx1.Rows
100-
use clColumns1 = context.CreateClArray mtx1.Columns
101-
use clValues1 = context.CreateClArray mtx1.Values
102-
103-
let m1 =
104-
{ Context = context
105-
Backend.COOMatrix.RowCount = mtx1.RowCount
106-
ColumnCount = mtx1.ColumnCount
107-
Rows = clRows1
108-
Columns = clColumns1
109-
Values = clValues1 }
110-
111-
use clRows2 = context.CreateClArray mtx2.Rows
112-
use clColumns2 = context.CreateClArray mtx2.Columns
113-
use clValues2 = context.CreateClArray mtx2.Values
114-
115-
let m2 =
116-
{ Context = context
117-
Backend.COOMatrix.RowCount = mtx2.RowCount
118-
ColumnCount = mtx2.ColumnCount
119-
Rows = clRows2
120-
Columns = clColumns2
121-
Values = clValues2 }
122-
123-
let getAddFun =
124-
COOMatrix.eWiseAdd context qOp |> setSizeForAddFun
125-
126-
let add = getAddFun mtx1.Values
127-
128-
let actual =
129-
let res: Backend.COOMatrix<'a> = add m1 m2
130-
let actualRows = Array.zeroCreate res.Rows.Length
131-
let actualColumns = Array.zeroCreate res.Columns.Length
132-
let actualValues = Array.zeroCreate res.Values.Length
133-
134-
let _ =
135-
q.Post(Msg.CreateToHostMsg(res.Rows, actualRows))
136-
137-
let _ =
138-
q.Post(Msg.CreateToHostMsg(res.Columns, actualColumns))
139-
140-
let _ =
141-
q.PostAndReply(fun ch -> Msg.CreateToHostMsg(res.Values, actualValues, ch))
142-
143-
q.Post(Msg.CreateFreeMsg<_>(res.Columns))
144-
q.Post(Msg.CreateFreeMsg<_>(res.Rows))
145-
q.Post(Msg.CreateFreeMsg<_>(res.Values))
146-
147-
{ RowCount = res.RowCount
148-
ColumnCount = res.ColumnCount
149-
Rows = actualRows
150-
Columns = actualColumns
151-
Values = actualValues }
152-
153-
logger.debug (
154-
eventX "Actual is {actual}"
155-
>> setField "actual" (sprintf "%A" actual)
156-
)
157-
158-
checkResult op zero baseMtx1 baseMtx2 (MatrixCOO actual)
50+
let mtx1 =
51+
createMatrixFromArray2D case.MatrixCase leftMatrix (isEqual zero)
15952

160-
| MatrixCSR mtx1, MatrixCSR mtx2 ->
161-
if mtx1.Values.Length > 0 && mtx2.Values.Length > 0 then
162-
use clRows1 = context.CreateClArray mtx1.RowPointers
163-
use clColumns1 = context.CreateClArray mtx1.ColumnIndices
164-
use clValues1 = context.CreateClArray mtx1.Values
53+
let mtx2 =
54+
createMatrixFromArray2D case.MatrixCase rightMatrix (isEqual zero)
16555

166-
let m1 =
167-
{ Context = context
168-
Backend.CSRMatrix.RowCount = mtx1.RowCount
169-
ColumnCount = mtx1.ColumnCount
170-
RowPointers = clRows1
171-
Columns = clColumns1
172-
Values = clValues1 }
56+
if mtx1.NNZCount > 0 && mtx2.NNZCount > 0 then
57+
let m1 = mtx1.ToBackend case.ClContext
58+
let m2 = mtx2.ToBackend case.ClContext
17359

174-
use clRows2 = context.CreateClArray mtx2.RowPointers
175-
use clColumns2 = context.CreateClArray mtx2.ColumnIndices
176-
use clValues2 = context.CreateClArray mtx2.Values
60+
let res = addFun q m1 m2
17761

178-
let m2 =
179-
{ Context = context
180-
Backend.CSRMatrix.RowCount = mtx2.RowCount
181-
ColumnCount = mtx2.ColumnCount
182-
RowPointers = clRows2
183-
Columns = clColumns2
184-
Values = clValues2 }
185-
186-
let getAddFun =
187-
CSRMatrix.eWiseAdd context qOp |> setSizeForAddFun
62+
m1.Dispose()
63+
m2.Dispose()
18864

189-
let add = getAddFun mtx1.Values
65+
let cooRes = toCOOFun q res
66+
let actual = Matrix.FromBackend q cooRes
19067

191-
let actual =
192-
let res: Backend.CSRMatrix<'a> = add m1 m2
193-
let actualRows = Array.zeroCreate res.RowPointers.Length
194-
let actualColumns = Array.zeroCreate res.Columns.Length
195-
let actualValues = Array.zeroCreate res.Values.Length
196-
197-
let _ =
198-
q.Post(Msg.CreateToHostMsg(res.RowPointers, actualRows))
68+
cooRes.Dispose()
69+
res.Dispose()
19970

200-
let _ =
201-
q.Post(Msg.CreateToHostMsg(res.Columns, actualColumns))
71+
logger.debug (
72+
eventX "Actual is {actual}"
73+
>> setField "actual" (sprintf "%A" actual)
74+
)
20275

203-
let _ =
204-
q.PostAndReply(fun ch -> Msg.CreateToHostMsg(res.Values, actualValues, ch))
76+
checkResult isEqual op zero leftMatrix rightMatrix actual
20577

206-
q.Post(Msg.CreateFreeMsg<_>(res.Columns))
207-
q.Post(Msg.CreateFreeMsg<_>(res.RowPointers))
208-
q.Post(Msg.CreateFreeMsg<_>(res.Values))
78+
let testFixtures case =
79+
[ let config = defaultConfig
80+
let wgSize = 128
81+
//Test name on multiple devices can be duplicated due to the ClContext.toString
82+
let getCorrectnessTestName datatype =
83+
sprintf "Correctness on %s, %A" datatype case
20984

210-
{ CSRMatrix.RowCount = res.RowCount
211-
ColumnCount = res.ColumnCount
212-
RowPointers = actualRows
213-
ColumnIndices = actualColumns
214-
Values = actualValues }
85+
let boolAdd =
86+
Matrix.eWiseAdd case.ClContext <@ (||) @> wgSize
21587

216-
logger.debug (
217-
eventX "Actual is {actual}"
218-
>> setField "actual" (sprintf "%A" actual)
219-
)
88+
let boolToCOO = Matrix.toCOO case.ClContext wgSize
22089

221-
checkResult op zero baseMtx1 baseMtx2 (MatrixCSR(actual))
90+
case
91+
|> correctnessGenericTest false (||) boolAdd boolToCOO (=)
92+
|> testPropertyWithConfig config (getCorrectnessTestName "bool")
22293

223-
| _ -> failwith "No other types of matrices tested yet."
94+
let intAdd =
95+
Matrix.eWiseAdd case.ClContext <@ (+) @> wgSize
22496

225-
[ testProperty "Correctness test on random int matrices COO"
226-
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.IntType()) size COO (+) <@ (+) @> 0)
97+
let intToCOO = Matrix.toCOO case.ClContext wgSize
22798

228-
testProperty "Correctness test on random bool matrices COO"
229-
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.BoolType()) size COO (||) <@ (||) @> false)
99+
case
100+
|> correctnessGenericTest 0 (+) intAdd intToCOO (=)
101+
|> testPropertyWithConfig config (getCorrectnessTestName "int")
230102

231-
testProperty "Correctness test on random float matrices COO"
232-
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.FloatType()) size COO (+) <@ (+) @> 0.0)
103+
let floatAdd =
104+
Matrix.eWiseAdd case.ClContext <@ (+) @> wgSize
233105

234-
testProperty "Correctness test on random byte matrices COO"
235-
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.ByteType()) size COO (+) <@ (+) @> 0uy)
106+
let floatToCOO = Matrix.toCOO case.ClContext wgSize
236107

237-
testProperty "Correctness test on random int matrices CSR"
238-
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.IntType()) size CSR (+) <@ (+) @> 0)
108+
case
109+
|> correctnessGenericTest 0.0 (+) floatAdd floatToCOO (fun x y -> abs (x - y) < Accuracy.medium.absolute)
110+
|> testPropertyWithConfig config (getCorrectnessTestName "float")
239111

240-
testProperty "Correctness test on random bool matrices CSR"
241-
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.BoolType()) size CSR (||) <@ (||) @> false)
112+
let byteAdd =
113+
Matrix.eWiseAdd case.ClContext <@ (+) @> wgSize
242114

243-
testProperty "Correctness test on random float matrices CSR"
244-
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.FloatType()) size CSR (+) <@ (+) @> 0.0)
115+
let byteToCOO = Matrix.toCOO case.ClContext wgSize
245116

246-
testProperty "Correctness test on random byte matrices CSR"
247-
<| (fun size -> makeTest context (PairOfSparseMatricesOfEqualSize.ByteType()) size CSR (+) <@ (+) @> 0uy) ]
117+
case
118+
|> correctnessGenericTest 0uy (+) byteAdd byteToCOO (=)
119+
|> testPropertyWithConfig config (getCorrectnessTestName "byte") ]
248120

249121
let tests =
250-
testCases |> testList "Backend.EwiseAdd tests"
122+
testCases
123+
|> List.filter
124+
(fun case ->
125+
let mutable e = ErrorCode.Unknown
126+
let device = case.ClContext.Device
127+
128+
let deviceType =
129+
Cl
130+
.GetDeviceInfo(device, DeviceInfo.Type, &e)
131+
.CastTo<DeviceType>()
132+
133+
deviceType = DeviceType.Default)
134+
|> List.collect testFixtures
135+
|> testList "Backend.Matrix.eWiseAdd tests"

0 commit comments

Comments
 (0)