Skip to content

Commit 3abd3d4

Browse files
committed
SpMV tests
1 parent 0ea4d0f commit 3abd3d4

3 files changed

Lines changed: 139 additions & 0 deletions

File tree

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
module Backend.SpMV
2+
3+
open Expecto
4+
open Brahma.FSharp
5+
open GraphBLAS.FSharp.Backend
6+
open GraphBLAS.FSharp.Backend.ArraysExtensions
7+
open GraphBLAS.FSharp
8+
open GraphBLAS.FSharp.Tests.Utils
9+
open Microsoft.FSharp.Collections
10+
open Microsoft.FSharp.Core
11+
open OpenCL.Net
12+
open Backend.Common.StandardOperations
13+
14+
let checkResult isEqual sumOp mulOp zero (baseMtx: 'a [,]) (baseVtr: 'b []) (actual: 'c array) =
15+
let rows = Array2D.length1 baseMtx
16+
let columns = Array2D.length2 baseMtx
17+
18+
let expected = Array.create rows zero
19+
20+
for i in 0 .. rows - 1 do
21+
let mutable sum = zero
22+
23+
for v in 0 .. columns - 1 do
24+
sum <- sumOp sum (mulOp baseMtx.[i, v] baseVtr.[v])
25+
26+
expected.[i] <- sum
27+
28+
for i in 0 .. actual.Size - 1 do
29+
match actual.[i] with
30+
| Some v ->
31+
if isEqual zero v then
32+
failwith "Resulting zeroes should be implicit."
33+
| None -> ()
34+
35+
for i in 0 .. actual.Size - 1 do
36+
match actual.[i] with
37+
| Some v ->
38+
Expect.isTrue (isEqual v expected.[i]) $"Values should be the same. Actual is {v}, expected {expected.[i]}."
39+
| None ->
40+
Expect.isTrue
41+
(isEqual zero expected.[i])
42+
$"Values should be the same. Actual is {zero}, expected {expected.[i]}."
43+
44+
let correctnessGenericTest
45+
zero
46+
sumOp
47+
mulOp
48+
(spMV: MailboxProcessor<_> -> Backend.CSRMatrix<'a> -> ClArray<'b option> -> ClArray<'c option>)
49+
(isEqual: 'a -> 'a -> bool)
50+
q
51+
(testContext: TestContext)
52+
(matrix: 'a [,], vector: 'a [], mask: bool [])
53+
=
54+
55+
let mtx =
56+
createMatrixFromArray2D CSR matrix (isEqual zero)
57+
58+
let vtr =
59+
createVectorFromArray Dense vector (isEqual zero)
60+
61+
if mtx.NNZCount > 0 && vtr.Size > 0 then
62+
try
63+
let m = mtx.ToBackend testContext.ClContext
64+
65+
match vtr, m with
66+
| VectorDense vtr, Backend.MatrixCSR m ->
67+
let v = vtr.ToDevice testContext.ClContext
68+
69+
let res = spMV testContext.Queue m v
70+
71+
(Backend.MatrixCSR m).Dispose q
72+
v.Dispose q
73+
let hostRes = res.ToHost q
74+
res.Dispose q
75+
76+
checkResult isEqual sumOp mulOp zero matrix vector hostRes
77+
| _ -> failwith "Impossible"
78+
with
79+
| ex when ex.Message = "InvalidBufferSize" -> ()
80+
| ex -> raise ex
81+
82+
let testFixturesSpMV (testContext: TestContext) =
83+
[ let config = defaultConfig
84+
let wgSize = 32
85+
86+
let getCorrectnessTestName datatype = sprintf "Correctness on %s" datatype
87+
88+
let context = testContext.ClContext
89+
let q = testContext.Queue
90+
q.Error.Add(fun e -> failwithf "%A" e)
91+
92+
let boolSpMV =
93+
Vector.spMV context boolSum boolMul wgSize
94+
95+
testContext
96+
|> correctnessGenericTest false (||) (&&) boolSpMV (=) q
97+
|> testPropertyWithConfig config (getCorrectnessTestName "bool")
98+
99+
let intSpMV = Vector.spMV context intSum intMul wgSize
100+
101+
testContext
102+
|> correctnessGenericTest 0 (+) (*) intSpMV (=) q
103+
|> testPropertyWithConfig config (getCorrectnessTestName "int")
104+
105+
let floatSpMV =
106+
Vector.spMV context floatSum floatMul wgSize
107+
108+
testContext
109+
|> correctnessGenericTest 0.0 (+) (*) floatSpMV (fun x y -> abs (x - y) < Accuracy.medium.absolute) q
110+
|> testPropertyWithConfig config (getCorrectnessTestName "float")
111+
112+
let byteAdd =
113+
Vector.spMV context byteSum byteMul wgSize
114+
115+
let byteToCOO = Matrix.toCOO context wgSize
116+
117+
testContext
118+
|> correctnessGenericTest 0uy (+) (*) byteAdd (=) q
119+
|> testPropertyWithConfig config (getCorrectnessTestName "byte") ]
120+
121+
let tests =
122+
availableContexts ""
123+
|> List.ofSeq
124+
|> List.filter
125+
(fun testContext ->
126+
let mutable e = ErrorCode.Unknown
127+
let device = testContext.ClContext.ClDevice.Device
128+
129+
let deviceType =
130+
Cl
131+
.GetDeviceInfo(device, DeviceInfo.Type, &e)
132+
.CastTo<DeviceType>()
133+
134+
deviceType = DeviceType.Gpu)
135+
|> List.distinctBy (fun testContext -> testContext.ClContext.ClDevice.DeviceType)
136+
|> List.collect testFixturesSpMV
137+
|> testList "Backend.Common.SpMV tests"

tests/GraphBLAS-sharp.Tests/GraphBLAS-sharp.Tests.fsproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
<Compile Include="BackendCommonTests/MatrixElementwiseTests.fs" />
2121
<Compile Include="BackendCommonTests/ConvertTests.fs" />
2222
<Compile Include="BackendCommonTests/TransposeTests.fs" />
23+
<Compile Include="BackendCommonTests/SpMVTests.fs" />
2324
<!--Compile Include="MatrixOperationsTests/GetTuplesTests.fs" /-->
2425
<!--Compile Include="MatrixOperationsTests/MxmTests.fs" /-->
2526
<!--Compile Include="MatrixOperationsTests/MxvTests.fs" /-->

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ let allTests =
2222
Backend.Elementwise.elementwiseAddAtLeastOneToCOOTests
2323
Backend.Elementwise.elementwiseMulAtLeastOneTests
2424
Backend.Transpose.tests
25+
Backend.SpMV.tests
2526
//Matrix.GetTuples.tests
2627
//Matrix.Mxv.tests
2728
//Algo.Bfs.tests

0 commit comments

Comments
 (0)