Skip to content

Commit 7ae6f99

Browse files
authored
Merge pull request #44 from artemgl/transpose
Fixed tests
2 parents 077f2a2 + ca2c687 commit 7ae6f99

10 files changed

Lines changed: 246 additions & 146 deletions

File tree

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/CSRMatrix.fs

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ module CSRMatrix =
8181

8282
let eWiseAdd (clContext: ClContext) (opAdd: Expr<'a option -> 'b option -> 'c option>) workGroupSize =
8383

84-
let toCOOInplaceLeft = toCOOInplace clContext workGroupSize
85-
let toCOOInplaceRight = toCOOInplace clContext workGroupSize
84+
let prepareRows = prepareRows clContext workGroupSize
8685

8786
let eWiseCOO =
8887
COOMatrix.eWiseAdd clContext opAdd workGroupSize
@@ -91,24 +90,32 @@ module CSRMatrix =
9190
COOMatrix.toCSRInplace clContext workGroupSize
9291

9392
fun (processor: MailboxProcessor<_>) (m1: CSRMatrix<'a>) (m2: CSRMatrix<'b>) ->
94-
95-
let m1COO = toCOOInplaceLeft processor m1
96-
let m2COO = toCOOInplaceRight processor m2
93+
let m1COO =
94+
{ Context = clContext
95+
RowCount = m1.RowCount
96+
ColumnCount = m1.ColumnCount
97+
Rows = prepareRows processor m1.RowPointers m1.Values.Length m1.RowCount
98+
Columns = m1.Columns
99+
Values = m1.Values }
100+
101+
let m2COO =
102+
{ Context = clContext
103+
RowCount = m2.RowCount
104+
ColumnCount = m2.ColumnCount
105+
Rows = prepareRows processor m2.RowPointers m2.Values.Length m2.RowCount
106+
Columns = m2.Columns
107+
Values = m2.Values }
97108

98109
let m3COO = eWiseCOO processor m1COO m2COO
99110

100111
processor.Post(Msg.CreateFreeMsg(m1COO.Rows))
101112
processor.Post(Msg.CreateFreeMsg(m2COO.Rows))
102113

103-
let m3 = toCSRInplace processor m3COO
104-
processor.Post(Msg.CreateFreeMsg(m3COO.Rows))
105-
106-
m3
114+
toCSRInplace processor m3COO
107115

108116
let eWiseAddAtLeastOne (clContext: ClContext) (opAdd: Expr<AtLeastOne<'a, 'b> -> 'c option>) workGroupSize =
109117

110-
let toCOOInplaceLeft = toCOOInplace clContext workGroupSize
111-
let toCOOInplaceRight = toCOOInplace clContext workGroupSize
118+
let prepareRows = prepareRows clContext workGroupSize
112119

113120
let eWiseCOO =
114121
COOMatrix.eWiseAddAtLeastOne clContext opAdd workGroupSize
@@ -117,19 +124,28 @@ module CSRMatrix =
117124
COOMatrix.toCSRInplace clContext workGroupSize
118125

119126
fun (processor: MailboxProcessor<_>) (m1: CSRMatrix<'a>) (m2: CSRMatrix<'b>) ->
120-
121-
let m1COO = toCOOInplaceLeft processor m1
122-
let m2COO = toCOOInplaceRight processor m2
127+
let m1COO =
128+
{ Context = clContext
129+
RowCount = m1.RowCount
130+
ColumnCount = m1.ColumnCount
131+
Rows = prepareRows processor m1.RowPointers m1.Values.Length m1.RowCount
132+
Columns = m1.Columns
133+
Values = m1.Values }
134+
135+
let m2COO =
136+
{ Context = clContext
137+
RowCount = m2.RowCount
138+
ColumnCount = m2.ColumnCount
139+
Rows = prepareRows processor m2.RowPointers m2.Values.Length m2.RowCount
140+
Columns = m2.Columns
141+
Values = m2.Values }
123142

124143
let m3COO = eWiseCOO processor m1COO m2COO
125144

126145
processor.Post(Msg.CreateFreeMsg(m1COO.Rows))
127146
processor.Post(Msg.CreateFreeMsg(m2COO.Rows))
128147

129-
let m3 = toCSRInplace processor m3COO
130-
processor.Post(Msg.CreateFreeMsg(m3COO.Rows))
131-
132-
m3
148+
toCSRInplace processor m3COO
133149

134150
let transposeInplace (clContext: ClContext) workGroupSize =
135151

src/GraphBLAS-sharp.Backend/Matrix/Matrix.fs

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ module Matrix =
3636

3737
MatrixCSR res
3838

39+
/// <summary>
40+
/// Creates a new matrix, represented in CSR format, that is equal to the given one.
41+
/// </summary>
42+
///<param name="clContext">OpenCL context.</param>
43+
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
3944
let toCSR (clContext: ClContext) workGroupSize =
4045
let toCSR = COOMatrix.toCSR clContext workGroupSize
4146
let copy = copy clContext workGroupSize
@@ -45,6 +50,12 @@ module Matrix =
4550
| MatrixCOO m -> toCSR processor m |> MatrixCSR
4651
| MatrixCSR _ -> copy processor matrix
4752

53+
/// <summary>
54+
/// Returns the matrix, represented in CSR format, that is equal to the given one.
55+
/// The given matrix should neither be used afterwards nor be disposed.
56+
/// </summary>
57+
///<param name="clContext">OpenCL context.</param>
58+
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
4859
let toCSRInplace (clContext: ClContext) workGroupSize =
4960
let toCSRInplace =
5061
COOMatrix.toCSRInplace clContext workGroupSize
@@ -54,6 +65,11 @@ module Matrix =
5465
| MatrixCOO m -> toCSRInplace processor m |> MatrixCSR
5566
| MatrixCSR _ -> matrix
5667

68+
/// <summary>
69+
/// Creates a new matrix, represented in COO format, that is equal to the given one.
70+
/// </summary>
71+
///<param name="clContext">OpenCL context.</param>
72+
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
5773
let toCOO (clContext: ClContext) workGroupSize =
5874
let toCOO = CSRMatrix.toCOO clContext workGroupSize
5975
let copy = copy clContext workGroupSize
@@ -63,6 +79,12 @@ module Matrix =
6379
| MatrixCOO _ -> copy processor matrix
6480
| MatrixCSR m -> toCOO processor m |> MatrixCOO
6581

82+
/// <summary>
83+
/// Returns the matrix, represented in COO format, that is equal to the given one.
84+
/// The given matrix should neither be used afterwards nor be disposed.
85+
/// </summary>
86+
///<param name="clContext">OpenCL context.</param>
87+
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
6688
let toCOOInplace (clContext: ClContext) workGroupSize =
6789
let toCOOInplace =
6890
CSRMatrix.toCOOInplace clContext workGroupSize
@@ -98,7 +120,12 @@ module Matrix =
98120
| MatrixCSR m1, MatrixCSR m2 -> CSReWiseAdd processor m1 m2 |> MatrixCSR
99121
| _ -> failwith "Matrix formats are not matching"
100122

101-
123+
/// <summary>
124+
/// Transposes the given matrix and returns result. The storage format is preserved.
125+
/// The given matrix should neither be used afterwards nor be disposed.
126+
/// </summary>
127+
///<param name="clContext">OpenCL context.</param>
128+
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
102129
let transposeInplace (clContext: ClContext) workGroupSize =
103130
let COOtransposeInplace =
104131
COOMatrix.transposeInplace clContext workGroupSize
@@ -111,7 +138,11 @@ module Matrix =
111138
| MatrixCOO m -> COOtransposeInplace processor m |> MatrixCOO
112139
| MatrixCSR m -> CSRtransposeInplace processor m |> MatrixCSR
113140

114-
141+
/// <summary>
142+
/// Transposes the given matrix and returns result as a new matrix. The storage format is preserved.
143+
/// </summary>
144+
///<param name="clContext">OpenCL context.</param>
145+
///<param name="workGroupSize">Should be a power of 2 and greater than 1.</param>
115146
let transpose (clContext: ClContext) workGroupSize =
116147
let COOtranspose =
117148
COOMatrix.transpose clContext workGroupSize

src/GraphBLAS-sharp/Objects/Matrix.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ namespace GraphBLAS.FSharp
33
open Brahma.FSharp
44
open GraphBLAS.FSharp.Backend
55

6-
type MatrixFromat =
6+
type MatrixFormat =
77
| CSR
88
| COO
99

tests/GraphBLAS-sharp.Tests/BackendCommonTests/BitonicSortTests.fs

Lines changed: 23 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,19 @@ open Expecto.Logging.Message
66
open GraphBLAS.FSharp.Backend.Common
77
open Brahma.FSharp
88
open GraphBLAS.FSharp.Tests.Utils
9-
open OpenCL.Net
109

1110
let logger = Log.create "BitonicSort.Tests"
1211

13-
let testContext =
14-
""
15-
|> avaliableContexts
16-
|> Seq.filter
17-
(fun context ->
18-
let mutable e = ErrorCode.Unknown
19-
let device = context.ClContext.ClDevice.Device
20-
21-
let deviceType =
22-
Cl
23-
.GetDeviceInfo(device, DeviceInfo.Type, &e)
24-
.CastTo<DeviceType>()
25-
26-
deviceType = DeviceType.Gpu)
27-
|> Seq.tryHead
28-
29-
let makeTest (context: ClContext) (q: MailboxProcessor<_>) sort (filter: 'a -> bool) (array: ('n * 'n * 'a) []) =
12+
let makeTest (context: ClContext) (q: MailboxProcessor<_>) sort (array: ('n * 'n * 'a) []) =
3013
if array.Length > 0 then
3114
let projection (row: 'n) (col: 'n) (v: 'a) = row, col
3215

33-
let rows, cols, vals =
34-
array
35-
|> Array.distinctBy ((<|||) projection)
36-
|> Array.filter (fun (_, _, v) -> filter v)
37-
|> Array.unzip3
16+
logger.debug (
17+
eventX "Initial size is {size}"
18+
>> setField "size" (sprintf "%A" array.Length)
19+
)
20+
21+
let rows, cols, vals = Array.unzip3 array
3822

3923
use clRows = context.CreateClArray rows
4024
use clCols = context.CreateClArray cols
@@ -55,56 +39,46 @@ let makeTest (context: ClContext) (q: MailboxProcessor<_>) sort (filter: 'a -> b
5539

5640
rows, cols, vals
5741

58-
logger.debug (
59-
eventX "Actual are {actualRows}, {actualCols}, {actualVals}"
60-
>> setField "actualRows" (sprintf "%A" actualRows)
61-
>> setField "actualCols" (sprintf "%A" actualCols)
62-
>> setField "actualVals" (sprintf "%A" actualVals)
63-
)
64-
6542
let expectedRows, expectedCols, expectedVals =
6643
(rows, cols, vals)
6744
|||> Array.zip3
6845
|> Array.sortBy ((<|||) projection)
6946
|> Array.unzip3
7047

7148
(sprintf "Row arrays should be equal. Actual is \n%A, expected \n%A, input is \n%A" actualRows expectedRows rows)
72-
|> Expect.sequenceEqual actualRows expectedRows
49+
|> compareArrays (=) actualRows expectedRows
7350

7451
(sprintf
7552
"Column arrays should be equal. Actual is \n%A, expected \n%A, input is \n%A"
7653
actualCols
7754
expectedCols
7855
cols)
79-
|> Expect.sequenceEqual actualCols expectedCols
56+
|> compareArrays (=) actualCols expectedCols
8057

8158
(sprintf
8259
"Value arrays should be equal. Actual is \n%A, expected \n%A, input is \n%A"
8360
actualVals
8461
expectedVals
8562
vals)
86-
|> Expect.sequenceEqual actualVals expectedVals
63+
|> compareArrays (=) actualVals expectedVals
8764

88-
let testFixtures<'a when 'a: equality> config wgSize context q filter =
65+
let testFixtures<'a when 'a: equality> config wgSize context q =
8966
let sort: MailboxProcessor<_> -> ClArray<int> -> ClArray<int> -> ClArray<'a> -> unit =
9067
BitonicSort.sortKeyValuesInplace context wgSize
9168

92-
makeTest context q sort filter
69+
makeTest context q sort
9370
|> testPropertyWithConfig config (sprintf "Correctness on %A" typeof<'a>)
9471

9572
let tests =
96-
match testContext with
97-
| Some c ->
98-
let context = c.ClContext
99-
let config = defaultConfig
100-
101-
let wgSize = 128
102-
let q = c.Queue
103-
q.Error.Add(fun e -> failwithf "%A" e)
104-
105-
[ testFixtures<int> config wgSize context q (fun _ -> true)
106-
testFixtures<float> config wgSize context q (System.Double.IsNaN >> not)
107-
testFixtures<byte> config wgSize context q (fun _ -> true)
108-
testFixtures<bool> config wgSize context q (fun _ -> true) ]
109-
| _ -> []
73+
let context = defaultContext.ClContext
74+
let config = { defaultConfig with endSize = 1000000 }
75+
76+
let wgSize = 32
77+
let q = defaultContext.Queue
78+
q.Error.Add(fun e -> failwithf "%A" e)
79+
80+
[ testFixtures<int> config wgSize context q
81+
testFixtures<float> config wgSize context q
82+
testFixtures<byte> config wgSize context q
83+
testFixtures<bool> config wgSize context q ]
11084
|> testList "Backend.Common.BitonicSort tests"

tests/GraphBLAS-sharp.Tests/BackendCommonTests/ConvertTests.fs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ open OpenCL.Net
1212

1313
let logger = Log.create "Convert.Tests"
1414

15-
let context = defaultContext.ClContext
1615
let config = defaultConfig
17-
let wgSize = 128
16+
let wgSize = 32
1817

19-
let makeTestCSR q toCOO isZero (array: 'a [,]) =
18+
let makeTestCSR context q toCOO isZero (array: 'a [,]) =
2019
let mtx = createMatrixFromArray2D CSR array isZero
2120

2221
if mtx.NNZCount > 0 then
@@ -38,7 +37,7 @@ let makeTestCSR q toCOO isZero (array: 'a [,]) =
3837
"Matrices should be equal"
3938
|> Expect.equal actual expected
4039

41-
let makeTestCOO q toCSR isZero (array: 'a [,]) =
40+
let makeTestCOO context q toCSR isZero (array: 'a [,]) =
4241
let mtx = createMatrixFromArray2D COO array isZero
4342

4443
if mtx.NNZCount > 0 then
@@ -68,49 +67,50 @@ let testFixtures case =
6867
System.Double.IsNaN x
6968
|| abs x < Accuracy.medium.absolute
7069

71-
let q = defaultContext.Queue
70+
let context = case.ClContext.ClContext
71+
let q = case.ClContext.Queue
7272
q.Error.Add(fun e -> failwithf "%A" e)
7373

7474
match case.MatrixCase with
7575
| COO ->
7676
[ let toCSR = Matrix.toCSR context wgSize
7777

78-
makeTestCOO q toCSR ((=) 0)
78+
makeTestCOO context q toCSR ((=) 0)
7979
|> testPropertyWithConfig config (getCorrectnessTestName "int")
8080

8181
let toCSR = Matrix.toCSR context wgSize
8282

83-
makeTestCOO q toCSR filterFloat
83+
makeTestCOO context q toCSR filterFloat
8484
|> testPropertyWithConfig config (getCorrectnessTestName "float")
8585

8686
let toCSR = Matrix.toCSR context wgSize
8787

88-
makeTestCOO q toCSR ((=) 0uy)
88+
makeTestCOO context q toCSR ((=) 0uy)
8989
|> testPropertyWithConfig config (getCorrectnessTestName "byte")
9090

9191
let toCSR = Matrix.toCSR context wgSize
9292

93-
makeTestCOO q toCSR ((=) false)
93+
makeTestCOO context q toCSR ((=) false)
9494
|> testPropertyWithConfig config (getCorrectnessTestName "bool") ]
9595
| CSR ->
9696
[ let toCOO = Matrix.toCOO context wgSize
9797

98-
makeTestCSR q toCOO ((=) 0)
98+
makeTestCSR context q toCOO ((=) 0)
9999
|> testPropertyWithConfig config (getCorrectnessTestName "int")
100100

101101
let toCOO = Matrix.toCOO context wgSize
102102

103-
makeTestCSR q toCOO filterFloat
103+
makeTestCSR context q toCOO filterFloat
104104
|> testPropertyWithConfig config (getCorrectnessTestName "float")
105105

106106
let toCOO = Matrix.toCOO context wgSize
107107

108-
makeTestCSR q toCOO ((=) 0uy)
108+
makeTestCSR context q toCOO ((=) 0uy)
109109
|> testPropertyWithConfig config (getCorrectnessTestName "byte")
110110

111111
let toCOO = Matrix.toCOO context wgSize
112112

113-
makeTestCSR q toCOO ((=) false)
113+
makeTestCSR context q toCOO ((=) false)
114114
|> testPropertyWithConfig config (getCorrectnessTestName "bool") ]
115115

116116
let tests =
@@ -126,6 +126,6 @@ let tests =
126126
.CastTo<DeviceType>()
127127

128128
deviceType = DeviceType.Gpu)
129-
|> List.distinctBy (fun case -> case.MatrixCase)
129+
|> List.distinctBy (fun case -> case.ClContext.ClContext.ClDevice.DeviceType, case.MatrixCase)
130130
|> List.collect testFixtures
131131
|> testList "Convert tests"

0 commit comments

Comments
 (0)