Skip to content

Commit 654043f

Browse files
committed
Ewiseadd fix
1 parent 409cd9b commit 654043f

3 files changed

Lines changed: 43 additions & 22 deletions

File tree

src/GraphBLAS-sharp.Backend/Matrix/CSRMatrix/CSRMatrix.fs

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ module CSRMatrix =
8181

8282
let eWiseAdd (clContext: ClContext) (opAdd: Expr<'a option -> 'b option -> 'c option>) workGroupSize =
8383

84-
let toCOOInplaceLeft = toCOOInplace clContext workGroupSize
85-
let toCOOInplaceRight = toCOOInplace clContext workGroupSize
84+
let prepareRows = prepareRows clContext workGroupSize
8685

8786
let eWiseCOO =
8887
COOMatrix.eWiseAdd clContext opAdd workGroupSize
@@ -91,24 +90,32 @@ module CSRMatrix =
9190
COOMatrix.toCSRInplace clContext workGroupSize
9291

9392
fun (processor: MailboxProcessor<_>) (m1: CSRMatrix<'a>) (m2: CSRMatrix<'b>) ->
94-
95-
let m1COO = toCOOInplaceLeft processor m1
96-
let m2COO = toCOOInplaceRight processor m2
93+
let m1COO =
94+
{ Context = clContext
95+
RowCount = m1.RowCount
96+
ColumnCount = m1.ColumnCount
97+
Rows = prepareRows processor m1.RowPointers m1.Values.Length m1.RowCount
98+
Columns = m1.Columns
99+
Values = m1.Values }
100+
101+
let m2COO =
102+
{ Context = clContext
103+
RowCount = m2.RowCount
104+
ColumnCount = m2.ColumnCount
105+
Rows = prepareRows processor m2.RowPointers m2.Values.Length m2.RowCount
106+
Columns = m2.Columns
107+
Values = m2.Values }
97108

98109
let m3COO = eWiseCOO processor m1COO m2COO
99110

100111
processor.Post(Msg.CreateFreeMsg(m1COO.Rows))
101112
processor.Post(Msg.CreateFreeMsg(m2COO.Rows))
102113

103-
let m3 = toCSRInplace processor m3COO
104-
processor.Post(Msg.CreateFreeMsg(m3COO.Rows))
105-
106-
m3
114+
toCSRInplace processor m3COO
107115

108116
let eWiseAddAtLeastOne (clContext: ClContext) (opAdd: Expr<AtLeastOne<'a, 'b> -> 'c option>) workGroupSize =
109117

110-
let toCOOInplaceLeft = toCOOInplace clContext workGroupSize
111-
let toCOOInplaceRight = toCOOInplace clContext workGroupSize
118+
let prepareRows = prepareRows clContext workGroupSize
112119

113120
let eWiseCOO =
114121
COOMatrix.eWiseAddAtLeastOne clContext opAdd workGroupSize
@@ -117,19 +124,28 @@ module CSRMatrix =
117124
COOMatrix.toCSRInplace clContext workGroupSize
118125

119126
fun (processor: MailboxProcessor<_>) (m1: CSRMatrix<'a>) (m2: CSRMatrix<'b>) ->
120-
121-
let m1COO = toCOOInplaceLeft processor m1
122-
let m2COO = toCOOInplaceRight processor m2
127+
let m1COO =
128+
{ Context = clContext
129+
RowCount = m1.RowCount
130+
ColumnCount = m1.ColumnCount
131+
Rows = prepareRows processor m1.RowPointers m1.Values.Length m1.RowCount
132+
Columns = m1.Columns
133+
Values = m1.Values }
134+
135+
let m2COO =
136+
{ Context = clContext
137+
RowCount = m2.RowCount
138+
ColumnCount = m2.ColumnCount
139+
Rows = prepareRows processor m2.RowPointers m2.Values.Length m2.RowCount
140+
Columns = m2.Columns
141+
Values = m2.Values }
123142

124143
let m3COO = eWiseCOO processor m1COO m2COO
125144

126145
processor.Post(Msg.CreateFreeMsg(m1COO.Rows))
127146
processor.Post(Msg.CreateFreeMsg(m2COO.Rows))
128147

129-
let m3 = toCSRInplace processor m3COO
130-
processor.Post(Msg.CreateFreeMsg(m3COO.Rows))
131-
132-
m3
148+
toCSRInplace processor m3COO
133149

134150
let transposeInplace (clContext: ClContext) workGroupSize =
135151

tests/GraphBLAS-sharp.Tests/BackendCommonTests/MatrixEwiseAddTests.fs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ let correctnessGenericTest
8686

8787
let testFixturesEWiseAdd case =
8888
[ let config = defaultConfig
89-
let wgSize = 256
89+
let wgSize = 32
9090

9191
let getCorrectnessTestName datatype =
9292
sprintf "Correctness on %s, %A" datatype case
@@ -140,12 +140,13 @@ let tests =
140140
.CastTo<DeviceType>()
141141

142142
deviceType = DeviceType.Gpu)
143+
|> List.distinctBy (fun case -> case.ClContext.ClContext.ClDevice.DeviceType, case.MatrixCase)
143144
|> List.collect testFixturesEWiseAdd
144145
|> testList "Backend.Matrix.eWiseAdd tests"
145146

146147
let testFixturesEWiseAddAtLeastOne case =
147148
[ let config = defaultConfig
148-
let wgSize = 256
149+
let wgSize = 32
149150

150151
let getCorrectnessTestName datatype =
151152
sprintf "Correctness on %s, %A" datatype case
@@ -203,13 +204,14 @@ let tests2 =
203204
.CastTo<DeviceType>()
204205

205206
deviceType = DeviceType.Gpu)
207+
|> List.distinctBy (fun case -> case.ClContext.ClContext.ClDevice.DeviceType, case.MatrixCase)
206208
|> List.collect testFixturesEWiseAddAtLeastOne
207209
|> testList "Backend.Matrix.eWiseAddAtLeastOne tests"
208210

209211

210212
let testFixturesEWiseMulAtLeastOne case =
211213
[ let config = defaultConfig
212-
let wgSize = 256
214+
let wgSize = 32
213215

214216
let getCorrectnessTestName datatype =
215217
sprintf "Correctness on %s, %A" datatype case
@@ -267,5 +269,6 @@ let tests3 =
267269
.CastTo<DeviceType>()
268270

269271
deviceType = DeviceType.Gpu)
272+
|> List.distinctBy (fun case -> case.ClContext.ClContext.ClDevice.DeviceType, case.MatrixCase)
270273
|> List.collect testFixturesEWiseMulAtLeastOne
271274
|> testList "Backend.Matrix.eWiseMulAtLeastOne tests"

tests/GraphBLAS-sharp.Tests/Program.fs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ let allTests =
1717
Backend.RemoveDuplicates.tests
1818
Backend.Copy.tests
1919
Backend.Replicate.tests
20-
// Backend.EwiseAdd.tests
20+
Backend.EwiseAdd.tests
21+
Backend.EwiseAdd.tests2
22+
//Backend.EwiseAdd.tests3
2123
Backend.Transpose.tests
2224
//Matrix.GetTuples.tests
2325
//Matrix.Mxv.tests

0 commit comments

Comments
 (0)