Skip to content

Commit f223b97

Browse files
committed
refactor: Map.tests, float32 generator shift
1 parent c89e1b9 commit f223b97

3 files changed

Lines changed: 27 additions & 28 deletions

File tree

src/GraphBLAS-sharp.Backend/Quotes/Arithmetic.fs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ module ArithmeticOperations =
6666

6767
if res then Some true else None @>
6868

69+
let inline addLeftConst zero constant =
70+
mkUnaryOp zero <@ fun x -> constant + x @>
71+
72+
let inline addRightConst zero constant =
73+
mkUnaryOp zero <@ fun x -> x + constant @>
74+
6975
let intSum = mkNumericSum 0
7076
let byteSum = mkNumericSum 0uy
7177
let floatSum = mkNumericSum 0.0
@@ -89,6 +95,12 @@ module ArithmeticOperations =
8995

9096
if res then Some true else None @>
9197

98+
let inline mulLeftConst zero constant =
99+
mkUnaryOp zero <@ fun x -> constant * x @>
100+
101+
let inline mulRightConst zero constant =
102+
mkUnaryOp zero <@ fun x -> x * constant @>
103+
92104
let intMul = mkNumericMul 0
93105
let byteMul = mkNumericMul 0uy
94106
let floatMul = mkNumericMul 0.0

tests/GraphBLAS-sharp.Tests/Generators.fs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,15 @@ module Generators =
4747

4848
let rec normalFloat32Generator (random: System.Random) =
4949
gen {
50-
let result = random.NextSingle()
50+
let rawValue = random.NextSingle()
5151

52-
if System.Single.IsNormal result then
53-
return result
52+
if System.Single.IsNormal rawValue then
53+
let sign = float32 <| sign rawValue
54+
let processedValue = ((+) 1.0f) <| (abs <| rawValue)
55+
56+
return processedValue * sign
5457
else
55-
return! normalFloat32Generator random
58+
return 0.0f
5659
}
5760

5861
let genericSparseGenerator zero valuesGen handler =

tests/GraphBLAS-sharp.Tests/Matrix/Map.fs

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,12 @@ let testFixturesMapAdd case =
115115
let q = case.TestContext.Queue
116116
q.Error.Add(fun e -> failwithf "%A" e)
117117

118-
let addFloat64Q =
119-
ArithmeticOperations.mkUnaryOp 0.0 <@ fun x -> x + 10.0 @>
120-
121-
let addFloat32Q =
122-
ArithmeticOperations.mkUnaryOp 0.0f <@ fun x -> x + 10.0f @>
123-
124-
let addByte =
125-
ArithmeticOperations.mkUnaryOp 0uy <@ fun x -> x + 10uy @>
126-
127118
if Utils.isFloat64Available context.ClDevice then
128-
createTestMap case 0.0 ((+) 10.0) Utils.floatIsEqual addFloat64Q Matrix.map
119+
createTestMap case 0.0 ((+) 10.0) Utils.floatIsEqual (ArithmeticOperations.addLeftConst 0.0 10.0) Matrix.map
120+
121+
createTestMap case 0.0f ((+) 10.0f) Utils.float32IsEqual (ArithmeticOperations.addLeftConst 0.0f 10.0f) Matrix.map
129122

130-
createTestMap case 0.0f ((+) 10.0f) Utils.float32IsEqual addFloat32Q Matrix.map
131-
createTestMap case 0uy ((+) 10uy) (=) addByte Matrix.map ]
123+
createTestMap case 0uy ((+) 10uy) (=) (ArithmeticOperations.addLeftConst 0uy 10uy) Matrix.map ]
132124

133125
let addTests =
134126
operationGPUTests "Backend.Matrix.map add tests" testFixturesMapAdd
@@ -138,20 +130,12 @@ let testFixturesMapMul case =
138130
let q = case.TestContext.Queue
139131
q.Error.Add(fun e -> failwithf "%A" e)
140132

141-
let mulFloat64Q =
142-
ArithmeticOperations.mkUnaryOp 0.0 <@ fun x -> x * 10.0 @>
143-
144-
let mulFloat32Q =
145-
ArithmeticOperations.mkUnaryOp 0.0f <@ fun x -> x * 10.0f @>
146-
147-
let mulByte =
148-
ArithmeticOperations.mkUnaryOp 0uy <@ fun x -> x * 10uy @>
149-
150133
if Utils.isFloat64Available context.ClDevice then
151-
createTestMap case 0.0 ((*) 10.0) Utils.floatIsEqual mulFloat64Q Matrix.map
134+
createTestMap case 0.0 ((*) 10.0) Utils.floatIsEqual (ArithmeticOperations.mulLeftConst 0.0 10.0) Matrix.map
135+
136+
createTestMap case 0.0f ((*) 10.0f) Utils.float32IsEqual (ArithmeticOperations.mulLeftConst 0.0f 10.0f) Matrix.map
152137

153-
createTestMap case 0.0f ((*) 10.0f) Utils.float32IsEqual mulFloat32Q Matrix.map
154-
createTestMap case 0uy ((*) 10uy) (=) mulByte Matrix.map ]
138+
createTestMap case 0uy ((*) 10uy) (=) (ArithmeticOperations.mulLeftConst 0uy 10uy) Matrix.map ]
155139

156140
let mulTests =
157141
operationGPUTests "Backend.Matrix.map mul tests" testFixturesMapMul

0 commit comments

Comments
 (0)