Skip to content

Commit 504b252

Browse files
authored
Fix integer mod/rem simplification (#115)
1 parent 025675a commit 504b252

3 files changed

Lines changed: 29 additions & 23 deletions

File tree

ksmt-core/src/main/kotlin/io/ksmt/expr/rewrite/simplify/ArithSimplification.kt

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ import io.ksmt.sort.KRealSort
1515
import io.ksmt.utils.ArithUtils.RealValue
1616
import io.ksmt.utils.ArithUtils.bigIntegerValue
1717
import io.ksmt.utils.ArithUtils.compareTo
18-
import io.ksmt.utils.ArithUtils.modWithNegativeNumbers
1918
import io.ksmt.utils.ArithUtils.numericValue
2019
import io.ksmt.utils.ArithUtils.toRealValue
2120
import io.ksmt.utils.uncheckedCast
@@ -231,12 +230,21 @@ fun KContext.simplifyIntMod(lhs: KExpr<KIntSort>, rhs: KExpr<KIntSort>): KExpr<K
231230
}
232231

233232
if (rValue != BigInteger.ZERO && lhs is KIntNumExpr) {
234-
return mkIntNum(modWithNegativeNumbers(lhs.bigIntegerValue, rValue))
233+
return mkIntNum(evalIntMod(lhs.bigIntegerValue, rValue))
235234
}
236235
}
237236
return mkIntModNoSimplify(lhs, rhs)
238237
}
239238

239+
/**
240+
* Eval integer mod wrt Int theory rules.
241+
* */
242+
private fun evalIntMod(a: BigInteger, b: BigInteger): BigInteger {
243+
val remainder = a.rem(b)
244+
if (remainder >= BigInteger.ZERO) return remainder
245+
return if (b >= BigInteger.ZERO) remainder + b else remainder - b
246+
}
247+
240248
fun KContext.simplifyIntRem(lhs: KExpr<KIntSort>, rhs: KExpr<KIntSort>): KExpr<KIntSort> {
241249
if (rhs is KIntNumExpr) {
242250
val rValue = rhs.bigIntegerValue
@@ -246,12 +254,20 @@ fun KContext.simplifyIntRem(lhs: KExpr<KIntSort>, rhs: KExpr<KIntSort>): KExpr<K
246254
}
247255

248256
if (rValue != BigInteger.ZERO && lhs is KIntNumExpr) {
249-
return mkIntNum(lhs.bigIntegerValue.rem(rValue))
257+
return mkIntNum(evalIntRem(lhs.bigIntegerValue, rValue))
250258
}
251259
}
252260
return mkIntRemNoSimplify(lhs, rhs)
253261
}
254262

263+
/**
264+
* Eval integer rem wrt Int theory rules.
265+
* */
266+
private fun evalIntRem(a: BigInteger, b: BigInteger): BigInteger {
267+
val mod = evalIntMod(a, b)
268+
return if (b >= BigInteger.ZERO) mod else -mod
269+
}
270+
255271
fun KContext.simplifyIntToReal(arg: KExpr<KIntSort>): KExpr<KRealSort> {
256272
if (arg is KIntNumExpr) {
257273
return mkRealNum(arg)

ksmt-core/src/main/kotlin/io/ksmt/utils/ArithUtils.kt

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,24 +44,6 @@ object ArithUtils {
4444
else -> decl.value.toBigInteger()
4545
}
4646

47-
/**
48-
* BigInteger doesn't support mod operation with negative modulus.
49-
* We use the mod operation with absolute values and then manually
50-
* recover the result depending on the arguments signs.
51-
* */
52-
fun modWithNegativeNumbers(a: BigInteger, b: BigInteger): BigInteger {
53-
val aAbs = a.abs()
54-
val bAbs = b.abs()
55-
val u = aAbs.mod(bAbs)
56-
return when {
57-
u == BigInteger.ZERO -> BigInteger.ZERO
58-
a >= BigInteger.ZERO && b >= BigInteger.ZERO -> u
59-
a < BigInteger.ZERO && b >= BigInteger.ZERO -> -u + b
60-
a >= BigInteger.ZERO && b < BigInteger.ZERO -> u + b
61-
else -> -u
62-
}
63-
}
64-
6547
class RealValue private constructor(
6648
numerator: BigInteger,
6749
denominator: BigInteger

ksmt-core/src/test/kotlin/io/ksmt/ArithSimplifyTest.kt

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,22 @@ class ArithSimplifyTest : ExpressionSimplifyTest() {
8080
@Test
8181
fun testIntMod() {
8282
testOperation(isInt = true, KContext::mkIntMod, KContext::mkIntModNoSimplify) {
83-
listOf(mkIntNum(0), mkIntNum(1), mkIntNum(-1)).uncheckedCast()
83+
listOf(
84+
mkIntNum(0), mkIntNum(1), mkIntNum(-1),
85+
// Values with non-trivial remainder
86+
mkIntNum(47), mkIntNum(-47), mkIntNum(13), mkIntNum(-13),
87+
).uncheckedCast()
8488
}
8589
}
8690

8791
@Test
8892
fun testIntRem() {
8993
testOperation(isInt = true, KContext::mkIntRem, KContext::mkIntRemNoSimplify) {
90-
listOf(mkIntNum(0), mkIntNum(1), mkIntNum(-1)).uncheckedCast()
94+
listOf(
95+
mkIntNum(0), mkIntNum(1), mkIntNum(-1),
96+
// Values with non-trivial remainder
97+
mkIntNum(47), mkIntNum(-47), mkIntNum(13), mkIntNum(-13),
98+
).uncheckedCast()
9199
}
92100
}
93101

0 commit comments

Comments
 (0)