diff --git a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java index 5d2c6470..da1be65d 100644 --- a/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java +++ b/liquidjava-verifier/src/main/java/liquidjava/rj_language/opt/VariablePropagation.java @@ -2,6 +2,7 @@ import liquidjava.rj_language.ast.BinaryExpression; import liquidjava.rj_language.ast.Expression; +import liquidjava.rj_language.ast.FunctionInvocation; import liquidjava.rj_language.ast.UnaryExpression; import liquidjava.rj_language.ast.Var; import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode; @@ -23,7 +24,7 @@ public class VariablePropagation { */ public static ValDerivationNode propagate(Expression exp, ValDerivationNode previousOrigin) { Map substitutions = VariableResolver.resolve(exp); - Map directSubstitutions = new HashMap<>(); // var == literal or var == var + Map directSubstitutions = new HashMap<>(); // var == literal or var == var Map expressionSubstitutions = new HashMap<>(); // var == expression for (Map.Entry entry : substitutions.entrySet()) { Expression value = entry.getValue(); @@ -69,6 +70,12 @@ private static ValDerivationNode propagateRecursive(Expression exp, Map map if ("&&".equals(op)) { resolveRecursive(be.getFirstOperand(), map); resolveRecursive(be.getSecondOperand(), map); - } else if ("==".equals(op)) { - Expression left = be.getFirstOperand(); - Expression right = be.getSecondOperand(); - if (left instanceof Var var && right.isLiteral()) { - map.put(var.getName(), right.clone()); - } else if (right instanceof Var var && left.isLiteral()) { - map.put(var.getName(), left.clone()); - } else if (left instanceof Var leftVar && right instanceof Var rightVar) { - // to substitute internal variable with user-facing variable - if (isInternal(leftVar) && !isInternal(rightVar) && !isReturnVar(leftVar)) { - map.put(leftVar.getName(), right.clone()); - } else if (isInternal(rightVar) && !isInternal(leftVar) && !isReturnVar(rightVar)) { - map.put(rightVar.getName(), left.clone()); - } else if (isInternal(leftVar) && isInternal(rightVar)) { - // to substitute the lower-counter variable with the higher-counter one - boolean isLeftCounterLower = getCounter(leftVar) <= getCounter(rightVar); - Var lowerVar = isLeftCounterLower ? leftVar : rightVar; - Var higherVar = isLeftCounterLower ? rightVar : leftVar; - if (!isReturnVar(lowerVar) && !isFreshVar(higherVar)) - map.putIfAbsent(lowerVar.getName(), higherVar.clone()); - } - } else if (left instanceof Var var && !(right instanceof Var) && canSubstitute(var, right)) { - map.put(var.getName(), right.clone()); + return; + } + if (!"==".equals(op)) + return; + + Expression left = be.getFirstOperand(); + Expression right = be.getSecondOperand(); + String leftKey = substitutionKey(left); + String rightKey = substitutionKey(right); + + if (leftKey != null && right.isLiteral()) { + map.put(leftKey, right.clone()); + } else if (rightKey != null && left.isLiteral()) { + map.put(rightKey, left.clone()); + } else if (left instanceof Var leftVar && right instanceof Var rightVar) { + // to substitute internal variable with user-facing variable + if (isInternal(leftVar) && !isInternal(rightVar) && !isReturnVar(leftVar)) { + map.put(leftVar.getName(), right.clone()); + } else if (isInternal(rightVar) && !isInternal(leftVar) && !isReturnVar(rightVar)) { + map.put(rightVar.getName(), left.clone()); + } else if (isInternal(leftVar) && isInternal(rightVar)) { + // to substitute the lower-counter variable with the higher-counter one + boolean isLeftCounterLower = getCounter(leftVar) <= getCounter(rightVar); + Var lowerVar = isLeftCounterLower ? leftVar : rightVar; + Var higherVar = isLeftCounterLower ? rightVar : leftVar; + if (!isReturnVar(lowerVar) && !isFreshVar(higherVar)) + map.putIfAbsent(lowerVar.getName(), higherVar.clone()); } + } else if (left instanceof Var var && canSubstitute(var, right)) { + map.put(var.getName(), right.clone()); + } else if (left instanceof FunctionInvocation && !containsExpression(right, left)) { + map.put(leftKey, right.clone()); } } + private static String substitutionKey(Expression exp) { + if (exp instanceof Var var) + return var.getName(); + if (exp instanceof FunctionInvocation) + return exp.toString(); + return null; + } + /** * Handles transitive variable equalities in the map (e.g. map: x -> y, y -> 1 => map: x -> 1, y -> 1) * @@ -98,10 +115,10 @@ private static Map resolveTransitive(Map * @return resolved expression */ private static Expression lookup(Expression exp, Map map, Set seen) { - if (!(exp instanceof Var)) + String name = substitutionKey(exp); + if (name == null) return exp; - String name = exp.toString(); if (seen.contains(name)) return exp; // circular reference @@ -129,14 +146,22 @@ private static boolean hasUsage(Expression exp, String name) { if (left instanceof Var v && v.getName().equals(name) && (right.isLiteral() || (!(right instanceof Var) && canSubstitute(v, right)))) return false; + if (left instanceof FunctionInvocation && left.toString().equals(name) + && (right.isLiteral() || (!(right instanceof Var) && !containsExpression(right, left)))) + return false; if (right instanceof Var v && v.getName().equals(name) && left.isLiteral()) return false; + if (right instanceof FunctionInvocation && right.toString().equals(name) && left.isLiteral()) + return false; } // usage found if (exp instanceof Var var && var.getName().equals(name)) { return true; } + if (exp instanceof FunctionInvocation && exp.toString().equals(name)) { + return true; + } // recurse children if (exp.hasChildren()) { @@ -185,4 +210,18 @@ private static boolean containsVariable(Expression exp, String name) { } return false; } + + private static boolean containsExpression(Expression exp, Expression target) { + if (exp.equals(target)) + return true; + + if (!exp.hasChildren()) + return false; + + for (Expression child : exp.getChildren()) { + if (containsExpression(child, target)) + return true; + } + return false; + } } diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java index df7f2593..9b24812b 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/ExpressionSimplifierTest.java @@ -7,19 +7,17 @@ import java.util.Map; import liquidjava.processor.facade.AliasDTO; -import liquidjava.rj_language.ast.AliasInvocation; import liquidjava.rj_language.ast.BinaryExpression; import liquidjava.rj_language.ast.Expression; -import liquidjava.rj_language.ast.Ite; import liquidjava.rj_language.ast.LiteralBoolean; import liquidjava.rj_language.ast.LiteralInt; -import liquidjava.rj_language.ast.UnaryExpression; import liquidjava.rj_language.ast.Var; import liquidjava.rj_language.opt.derivation_node.BinaryDerivationNode; import liquidjava.rj_language.opt.derivation_node.IteDerivationNode; import liquidjava.rj_language.opt.derivation_node.UnaryDerivationNode; import liquidjava.rj_language.opt.derivation_node.ValDerivationNode; import liquidjava.rj_language.opt.derivation_node.VarDerivationNode; +import liquidjava.rj_language.parsing.RefinementsParser; import org.junit.jupiter.api.Test; /** @@ -27,21 +25,15 @@ */ class ExpressionSimplifierTest { + private static Expression parse(String sut) { + return RefinementsParser.createAST(sut, ""); + } + @Test void testNegation() { - // Given: -a && a == 7 - // Expected: -7 - - Expression varA = new Var("a"); - Expression negA = new UnaryExpression("-", varA); - Expression seven = new LiteralInt(7); - Expression aEquals7 = new BinaryExpression(varA, "==", seven); - Expression fullExpression = new BinaryExpression(negA, "&&", aEquals7); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + Expression expression = parse("-a && a == 7"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - // Then assertNotNull(result, "Result should not be null"); assertEquals("-7", result.getValue().toString(), "Expected result to be -7"); @@ -58,26 +50,9 @@ void testNegation() { @Test void testSimpleAddition() { - // Given: a + b && a == 3 && b == 5 - // Expected: 8 (3 + 5) - - Expression varA = new Var("a"); - Expression varB = new Var("b"); - Expression addition = new BinaryExpression(varA, "+", varB); - - Expression three = new LiteralInt(3); - Expression aEquals3 = new BinaryExpression(varA, "==", three); - - Expression five = new LiteralInt(5); - Expression bEquals5 = new BinaryExpression(varB, "==", five); - - Expression conditions = new BinaryExpression(aEquals3, "&&", bEquals5); - Expression fullExpression = new BinaryExpression(addition, "&&", conditions); + Expression expression = parse("a + b && a == 3 && b == 5"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then assertNotNull(result, "Result should not be null"); assertEquals("8", result.getValue().toString(), "Expected result to be 8"); @@ -97,85 +72,19 @@ void testSimpleAddition() { @Test void testSimpleComparison() { - // Given: (y || true) && !true && y == false - // Expected: false (true && false) - - Expression varY = new Var("y"); - Expression trueExp = new LiteralBoolean(true); - Expression yOrTrue = new BinaryExpression(varY, "||", trueExp); - - Expression notTrue = new UnaryExpression("!", trueExp); - - Expression falseExp = new LiteralBoolean(false); - Expression yEqualsFalse = new BinaryExpression(varY, "==", falseExp); + Expression expression = parse("((y || true) && !true) && y == false"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - Expression firstAnd = new BinaryExpression(yOrTrue, "&&", notTrue); - Expression fullExpression = new BinaryExpression(firstAnd, "&&", yEqualsFalse); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then assertNotNull(result, "Result should not be null"); assertInstanceOf(LiteralBoolean.class, result.getValue(), "Result should be a boolean"); assertFalse((result.getValue()).isBooleanTrue(), "Expected result to be false"); - - // (y || true) && y == false => false || true = true - ValDerivationNode valFalseForY = new ValDerivationNode(new LiteralBoolean(false), new VarDerivationNode("y")); - ValDerivationNode valTrue1 = new ValDerivationNode(new LiteralBoolean(true), null); - BinaryDerivationNode orFalseTrue = new BinaryDerivationNode(valFalseForY, valTrue1, "||"); - ValDerivationNode trueFromOr = new ValDerivationNode(new LiteralBoolean(true), orFalseTrue); - - // !true = false - ValDerivationNode falseFromNot = new ValDerivationNode(new LiteralBoolean(false), null); - - // true && false = false - BinaryDerivationNode andTrueFalse = new BinaryDerivationNode(trueFromOr, falseFromNot, "&&"); - ValDerivationNode falseFromFirstAnd = new ValDerivationNode(new LiteralBoolean(false), andTrueFalse); - - // y == false - ValDerivationNode valFalseForY2 = new ValDerivationNode(new LiteralBoolean(false), new VarDerivationNode("y")); - ValDerivationNode valFalse2 = new ValDerivationNode(new LiteralBoolean(false), null); - BinaryDerivationNode compareFalseFalse = new BinaryDerivationNode(valFalseForY2, valFalse2, "=="); - ValDerivationNode trueFromCompare = new ValDerivationNode(new LiteralBoolean(true), compareFalseFalse); - - // false && true = false - BinaryDerivationNode finalAnd = new BinaryDerivationNode(falseFromFirstAnd, trueFromCompare, "&&"); - ValDerivationNode expected = new ValDerivationNode(new LiteralBoolean(false), finalAnd); - - // Compare the derivation trees - assertDerivationEquals(expected, result, ""); } @Test void testArithmeticWithConstants() { - // Given: (a / b + (-5)) + x && a == 6 && b == 2 - // Expected: -2 + x (6 / 2 = 3, 3 + (-5) = -2) - - Expression varA = new Var("a"); - Expression varB = new Var("b"); - Expression division = new BinaryExpression(varA, "/", varB); - - Expression five = new LiteralInt(5); - Expression negFive = new UnaryExpression("-", five); + Expression expression = parse("((a / b + (-5)) + x) && (a == 6 && b == 2)"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - Expression firstSum = new BinaryExpression(division, "+", negFive); - Expression varX = new Var("x"); - Expression fullArithmetic = new BinaryExpression(firstSum, "+", varX); - - Expression six = new LiteralInt(6); - Expression aEquals6 = new BinaryExpression(varA, "==", six); - - Expression two = new LiteralInt(2); - Expression bEquals2 = new BinaryExpression(varB, "==", two); - - Expression allConditions = new BinaryExpression(aEquals6, "&&", bEquals2); - Expression fullExpression = new BinaryExpression(fullArithmetic, "&&", allConditions); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then assertNotNull(result, "Result should not be null"); assertNotNull(result.getValue(), "Result value should not be null"); @@ -213,132 +122,21 @@ void testArithmeticWithConstants() { @Test void testComplexArithmeticWithMultipleOperations() { - // Given: (a * 2 + b - 3) == c && a == 5 && b == 7 && c == 14 - // Expected: (5 * 2 + 7 - 3) == 14 => 14 == 14 => true - - Expression varA = new Var("a"); - Expression varB = new Var("b"); - Expression varC = new Var("c"); - - Expression two = new LiteralInt(2); - Expression aTimes2 = new BinaryExpression(varA, "*", two); + Expression expression = parse("a * 2 + b - 3 == c && a == 5 && b == 7 && c == 14"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - Expression sum = new BinaryExpression(aTimes2, "+", varB); - - Expression three = new LiteralInt(3); - Expression arithmetic = new BinaryExpression(sum, "-", three); - - Expression comparison = new BinaryExpression(arithmetic, "==", varC); - - Expression five = new LiteralInt(5); - Expression aEquals5 = new BinaryExpression(varA, "==", five); - - Expression seven = new LiteralInt(7); - Expression bEquals7 = new BinaryExpression(varB, "==", seven); - - Expression fourteen = new LiteralInt(14); - Expression cEquals14 = new BinaryExpression(varC, "==", fourteen); - - Expression conj1 = new BinaryExpression(aEquals5, "&&", bEquals7); - Expression allConditions = new BinaryExpression(conj1, "&&", cEquals14); - Expression fullExpression = new BinaryExpression(comparison, "&&", allConditions); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then: boolean literals are unwrapped to show the verified conditions + // boolean literals are unwrapped to show the verified conditions assertNotNull(result, "Result should not be null"); assertNotNull(result.getValue(), "Result value should not be null"); assertEquals("14 == 14 && 5 == 5 && 7 == 7 && 14 == 14", result.getValue().toString(), "All verified conditions should be visible instead of collapsed to true"); - - // 5 * 2 + 7 - 3 = 14 - ValDerivationNode val5 = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a")); - ValDerivationNode val2 = new ValDerivationNode(new LiteralInt(2), null); - BinaryDerivationNode mult5Times2 = new BinaryDerivationNode(val5, val2, "*"); - ValDerivationNode val10 = new ValDerivationNode(new LiteralInt(10), mult5Times2); - - ValDerivationNode val7 = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("b")); - BinaryDerivationNode add10Plus7 = new BinaryDerivationNode(val10, val7, "+"); - ValDerivationNode val17 = new ValDerivationNode(new LiteralInt(17), add10Plus7); - - ValDerivationNode val3 = new ValDerivationNode(new LiteralInt(3), null); - BinaryDerivationNode sub17Minus3 = new BinaryDerivationNode(val17, val3, "-"); - ValDerivationNode val14Left = new ValDerivationNode(new LiteralInt(14), sub17Minus3); - - // 14 from variable c - ValDerivationNode val14Right = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c")); - - // 14 == 14 (unwrapped from true) - BinaryDerivationNode compare14 = new BinaryDerivationNode(val14Left, val14Right, "=="); - Expression expr14Eq14 = new BinaryExpression(new LiteralInt(14), "==", new LiteralInt(14)); - ValDerivationNode compare14Node = new ValDerivationNode(expr14Eq14, compare14); - - // a == 5 => 5 == 5 (unwrapped from true) - ValDerivationNode val5ForCompA = new ValDerivationNode(new LiteralInt(5), new VarDerivationNode("a")); - ValDerivationNode val5Literal = new ValDerivationNode(new LiteralInt(5), null); - BinaryDerivationNode compareA5 = new BinaryDerivationNode(val5ForCompA, val5Literal, "=="); - Expression expr5Eq5 = new BinaryExpression(new LiteralInt(5), "==", new LiteralInt(5)); - ValDerivationNode compare5Node = new ValDerivationNode(expr5Eq5, compareA5); - - // b == 7 => 7 == 7 (unwrapped from true) - ValDerivationNode val7ForCompB = new ValDerivationNode(new LiteralInt(7), new VarDerivationNode("b")); - ValDerivationNode val7Literal = new ValDerivationNode(new LiteralInt(7), null); - BinaryDerivationNode compareB7 = new BinaryDerivationNode(val7ForCompB, val7Literal, "=="); - Expression expr7Eq7 = new BinaryExpression(new LiteralInt(7), "==", new LiteralInt(7)); - ValDerivationNode compare7Node = new ValDerivationNode(expr7Eq7, compareB7); - - // (5 == 5) && (7 == 7) (unwrapped from true) - BinaryDerivationNode andAB = new BinaryDerivationNode(compare5Node, compare7Node, "&&"); - Expression expr5And7 = new BinaryExpression(expr5Eq5, "&&", expr7Eq7); - ValDerivationNode and5And7Node = new ValDerivationNode(expr5And7, andAB); - - // c == 14 => 14 == 14 (unwrapped from true) - ValDerivationNode val14ForCompC = new ValDerivationNode(new LiteralInt(14), new VarDerivationNode("c")); - ValDerivationNode val14Literal = new ValDerivationNode(new LiteralInt(14), null); - BinaryDerivationNode compareC14 = new BinaryDerivationNode(val14ForCompC, val14Literal, "=="); - Expression expr14Eq14b = new BinaryExpression(new LiteralInt(14), "==", new LiteralInt(14)); - ValDerivationNode compare14bNode = new ValDerivationNode(expr14Eq14b, compareC14); - - // ((5 == 5) && (7 == 7)) && (14 == 14) (unwrapped from true) - BinaryDerivationNode andABC = new BinaryDerivationNode(and5And7Node, compare14bNode, "&&"); - Expression exprConditions = new BinaryExpression(expr5And7, "&&", expr14Eq14b); - ValDerivationNode conditionsNode = new ValDerivationNode(exprConditions, andABC); - - // (14 == 14) && ((5 == 5 && 7 == 7) && 14 == 14) - BinaryDerivationNode finalAnd = new BinaryDerivationNode(compare14Node, conditionsNode, "&&"); - ValDerivationNode expected = new ValDerivationNode(result.getValue(), finalAnd); - - // Compare the derivation trees - assertDerivationEquals(expected, result, ""); } @Test void testFixedPointSimplification() { - // Given: x == -y && y == a / b && a == 6 && b == 3 - // Expected: x == -2 - - Expression varX = new Var("x"); - Expression varY = new Var("y"); - Expression varA = new Var("a"); - Expression varB = new Var("b"); - - Expression aDivB = new BinaryExpression(varA, "/", varB); - Expression yEqualsADivB = new BinaryExpression(varY, "==", aDivB); - Expression negY = new UnaryExpression("-", varY); - Expression xEqualsNegY = new BinaryExpression(varX, "==", negY); - Expression six = new LiteralInt(6); - Expression aEquals6 = new BinaryExpression(varA, "==", six); - Expression three = new LiteralInt(3); - Expression bEquals3 = new BinaryExpression(varB, "==", three); - Expression firstAnd = new BinaryExpression(xEqualsNegY, "&&", yEqualsADivB); - Expression secondAnd = new BinaryExpression(aEquals6, "&&", bEquals3); - Expression fullExpression = new BinaryExpression(firstAnd, "&&", secondAnd); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then + Expression expression = parse("x == -y && y == a / b && a == 6 && b == 3"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); + assertNotNull(result, "Result should not be null"); assertEquals("x == -2", result.getValue().toString(), "Expected result to be x == -2"); @@ -372,17 +170,9 @@ void testFixedPointSimplification() { @Test void testSingleEqualityShouldNotSimplify() { - // Given: x == 1 - // Expected: x == 1 (should not be simplified to "true") - - Expression varX = new Var("x"); - Expression one = new LiteralInt(1); - Expression xEquals1 = new BinaryExpression(varX, "==", one); - - // When + Expression xEquals1 = parse("x == 1"); ValDerivationNode result = ExpressionSimplifier.simplify(xEquals1); - // Then assertNotNull(result, "Result should not be null"); assertEquals("x == 1", result.getValue().toString(), "Single equality should not be simplified to a boolean literal"); @@ -397,22 +187,9 @@ void testSingleEqualityShouldNotSimplify() { @Test void testTwoEqualitiesShouldNotSimplify() { - // Given: x == 1 && y == 2 - // Expected: x == 1 && y == 2 (should not be simplified to "true") - - Expression varX = new Var("x"); - Expression one = new LiteralInt(1); - Expression xEquals1 = new BinaryExpression(varX, "==", one); + Expression expression = parse("x == 1 && y == 2"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - Expression varY = new Var("y"); - Expression two = new LiteralInt(2); - Expression yEquals2 = new BinaryExpression(varY, "==", two); - Expression fullExpression = new BinaryExpression(xEquals1, "&&", yEquals2); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then assertNotNull(result, "Result should not be null"); assertEquals("x == 1 && y == 2", result.getValue().toString(), "Two equalities should not be simplified to a boolean literal"); @@ -427,15 +204,8 @@ void testTwoEqualitiesShouldNotSimplify() { @Test void testSameVarTwiceShouldSimplifyToSingle() { - // Given: x && x - // Expected: x - - Expression varX = new Var("x"); - Expression fullExpression = new BinaryExpression(varX, "&&", varX); - // When - - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - // Then + Expression expression = parse("x && x"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); assertNotNull(result, "Result should not be null"); assertEquals("x", result.getValue().toString(), @@ -444,19 +214,9 @@ void testSameVarTwiceShouldSimplifyToSingle() { @Test void testSameEqualityTwiceShouldSimplifyToSingle() { - // Given: x == 1 && x == 1 - // Expected: x == 1 + Expression expression = parse("x == 1 && x == 1"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - Expression varX = new Var("x"); - Expression one = new LiteralInt(1); - Expression xEquals1First = new BinaryExpression(varX, "==", one); - Expression xEquals1Second = new BinaryExpression(varX, "==", one); - Expression fullExpression = new BinaryExpression(xEquals1First, "&&", xEquals1Second); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then assertNotNull(result, "Result should not be null"); assertEquals("x == 1", result.getValue().toString(), "Same equality twice should be simplified to a single equality"); @@ -464,21 +224,9 @@ void testSameEqualityTwiceShouldSimplifyToSingle() { @Test void testSameExpressionTwiceShouldSimplifyToSingle() { - // Given: a + b == 1 && a + b == 1 - // Expected: a + b == 1 - - Expression varA = new Var("a"); - Expression varB = new Var("b"); - Expression sum = new BinaryExpression(varA, "+", varB); - Expression one = new LiteralInt(1); - Expression sumEquals3First = new BinaryExpression(sum, "==", one); - Expression sumEquals3Second = new BinaryExpression(sum, "==", one); - Expression fullExpression = new BinaryExpression(sumEquals3First, "&&", sumEquals3Second); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + Expression expression = parse("a + b == 1 && a + b == 1"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - // Then assertNotNull(result, "Result should not be null"); assertEquals("a + b == 1", result.getValue().toString(), "Same expression twice should be simplified to a single equality"); @@ -486,19 +234,9 @@ void testSameExpressionTwiceShouldSimplifyToSingle() { @Test void testSymmetricEqualityShouldSimplify() { - // Given: x == y && y == x - // Expected: x == y + Expression expression = parse("x == y && y == x"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - Expression varX = new Var("x"); - Expression varY = new Var("y"); - Expression xEqualsY = new BinaryExpression(varX, "==", varY); - Expression yEqualsX = new BinaryExpression(varY, "==", varX); - Expression fullExpression = new BinaryExpression(xEqualsY, "&&", yEqualsX); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then assertNotNull(result, "Result should not be null"); assertEquals("x == y", result.getValue().toString(), "Symmetric equality should be simplified to a single equality"); @@ -506,34 +244,10 @@ void testSymmetricEqualityShouldSimplify() { @Test void testRealExpression() { - // Given: #a_5 == -#fresh_4 && #fresh_4 == #x_2 / #y_3 && #x_2 == #x_0 && #x_0 == 6 && #y_3 == #y_1 && #y_1 == 3 - // Expected: #a_5 == -2 - - Expression varA5 = new Var("#a_5"); - Expression varFresh4 = new Var("#fresh_4"); - Expression varX2 = new Var("#x_2"); - Expression varY3 = new Var("#y_3"); - Expression varX0 = new Var("#x_0"); - Expression varY1 = new Var("#y_1"); - Expression six = new LiteralInt(6); - Expression three = new LiteralInt(3); - Expression fresh4EqualsX2DivY3 = new BinaryExpression(varFresh4, "==", new BinaryExpression(varX2, "/", varY3)); - Expression x2EqualsX0 = new BinaryExpression(varX2, "==", varX0); - Expression x0Equals6 = new BinaryExpression(varX0, "==", six); - Expression y3EqualsY1 = new BinaryExpression(varY3, "==", varY1); - Expression y1Equals3 = new BinaryExpression(varY1, "==", three); - Expression negFresh4 = new UnaryExpression("-", varFresh4); - Expression a5EqualsNegFresh4 = new BinaryExpression(varA5, "==", negFresh4); - Expression firstAnd = new BinaryExpression(a5EqualsNegFresh4, "&&", fresh4EqualsX2DivY3); - Expression secondAnd = new BinaryExpression(x2EqualsX0, "&&", x0Equals6); - Expression thirdAnd = new BinaryExpression(y3EqualsY1, "&&", y1Equals3); - Expression firstBigAnd = new BinaryExpression(firstAnd, "&&", secondAnd); - Expression fullExpression = new BinaryExpression(firstBigAnd, "&&", thirdAnd); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then + String input = "#a_5 == -#fresh_4 && #fresh_4 == #x_2 / #y_3 && #x_2 == #x_0 && #x_0 == 6 && #y_3 == #y_1 && #y_1 == 3"; + Expression expression = parse(input); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); + assertNotNull(result, "Result should not be null"); assertEquals("#a_5 == -2", result.getValue().toString(), "Expected result to be #a_5 == -2"); @@ -541,45 +255,18 @@ void testRealExpression() { @Test void testTransitive() { - // Given: a == b && b == 1 - // Expected: a == 1 - - Expression varA = new Var("a"); - Expression varB = new Var("b"); - Expression one = new LiteralInt(1); - Expression aEqualsB = new BinaryExpression(varA, "==", varB); - Expression bEquals1 = new BinaryExpression(varB, "==", one); - Expression fullExpression = new BinaryExpression(aEqualsB, "&&", bEquals1); + Expression expression = parse("a == b && b == 1"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then assertNotNull(result, "Result should not be null"); assertEquals("a == 1", result.getValue().toString(), "Expected result to be a == 1"); } @Test void testShouldNotOversimplifyToTrue() { - // Given: x > 5 && x == y && y == 10 - // Expected: x > 5 && x == 10 (should NOT simplify to true) - - Expression varX = new Var("x"); - Expression varY = new Var("y"); - Expression five = new LiteralInt(5); - Expression ten = new LiteralInt(10); - - Expression xGreater5 = new BinaryExpression(varX, ">", five); - Expression xEqualsY = new BinaryExpression(varX, "==", varY); - Expression yEquals10 = new BinaryExpression(varY, "==", ten); + Expression expression = parse("x > 5 && x == y && y == 10"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - Expression firstAnd = new BinaryExpression(xGreater5, "&&", xEqualsY); - Expression fullExpression = new BinaryExpression(firstAnd, "&&", yEquals10); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then assertNotNull(result, "Result should not be null"); assertFalse(result.getValue() instanceof LiteralBoolean, "Should not oversimplify to a boolean literal, but got: " + result.getValue()); @@ -589,19 +276,9 @@ void testShouldNotOversimplifyToTrue() { @Test void testShouldUnwrapBooleanInEquality() { - // Given: x == (1 > 0) - // Expected: x == (1 > 0) (unwrapped to show the original comparison) - - Expression varX = new Var("x"); - Expression one = new LiteralInt(1); - Expression zero = new LiteralInt(0); - Expression oneGreaterZero = new BinaryExpression(one, ">", zero); - Expression fullExpression = new BinaryExpression(varX, "==", oneGreaterZero); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + Expression expression = parse("x == (1 > 0)"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - // Then assertNotNull(result, "Result should not be null"); assertEquals("x == (1 > 0)", result.getValue().toDisplayString(), "Boolean in equality should be unwrapped to show the original comparison"); @@ -609,27 +286,9 @@ void testShouldUnwrapBooleanInEquality() { @Test void testShouldUnwrapBooleanInEqualityWithPropagation() { - // Given: x == (a > b) && a == 3 && b == 1 - // Expected: x == (3 > 1) (unwrapped and propagated) - - Expression varX = new Var("x"); - Expression varA = new Var("a"); - Expression varB = new Var("b"); - Expression aGreaterB = new BinaryExpression(varA, ">", varB); - Expression xEqualsComp = new BinaryExpression(varX, "==", aGreaterB); - - Expression three = new LiteralInt(3); - Expression aEquals3 = new BinaryExpression(varA, "==", three); - Expression one = new LiteralInt(1); - Expression bEquals1 = new BinaryExpression(varB, "==", one); - - Expression conditions = new BinaryExpression(aEquals3, "&&", bEquals1); - Expression fullExpression = new BinaryExpression(xEqualsComp, "&&", conditions); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + Expression expression = parse("x == (a > b) && a == 3 && b == 1"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - // Then assertNotNull(result, "Result should not be null"); assertEquals("x == (3 > 1)", result.getValue().toDisplayString(), "Boolean in equality should be unwrapped after propagation"); @@ -637,23 +296,10 @@ void testShouldUnwrapBooleanInEqualityWithPropagation() { @Test void testShouldNotUnwrapBooleanWithBooleanChildren() { - // Given: (y || true) && !true && y == false - // Expected: false (both children of the fold are boolean, so no unwrapping needed) + Expression expression = parse("(y || true) && !true && y == false"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - Expression varY = new Var("y"); - Expression trueExp = new LiteralBoolean(true); - Expression yOrTrue = new BinaryExpression(varY, "||", trueExp); - Expression notTrue = new UnaryExpression("!", trueExp); - Expression falseExp = new LiteralBoolean(false); - Expression yEqualsFalse = new BinaryExpression(varY, "==", falseExp); - - Expression firstAnd = new BinaryExpression(yOrTrue, "&&", notTrue); - Expression fullExpression = new BinaryExpression(firstAnd, "&&", yEqualsFalse); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then: false stays as false since both sides in the derivation are booleans + // false stays as false since both sides in the derivation are booleans assertNotNull(result, "Result should not be null"); assertInstanceOf(LiteralBoolean.class, result.getValue(), "Result should remain a boolean"); assertFalse(result.getValue().isBooleanTrue(), "Expected result to be false"); @@ -661,29 +307,9 @@ void testShouldNotUnwrapBooleanWithBooleanChildren() { @Test void testShouldUnwrapNestedBooleanInEquality() { - // Given: x == (a + b > 10) && a == 3 && b == 5 - // Expected: x == (8 > 10) (shows the actual comparison that produced the boolean) - - Expression varX = new Var("x"); - Expression varA = new Var("a"); - Expression varB = new Var("b"); - Expression aPlusB = new BinaryExpression(varA, "+", varB); - Expression ten = new LiteralInt(10); - Expression comparison = new BinaryExpression(aPlusB, ">", ten); - Expression xEqualsComp = new BinaryExpression(varX, "==", comparison); + Expression expression = parse("x == (a + b > 10) && a == 3 && b == 5"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - Expression three = new LiteralInt(3); - Expression aEquals3 = new BinaryExpression(varA, "==", three); - Expression five = new LiteralInt(5); - Expression bEquals5 = new BinaryExpression(varB, "==", five); - - Expression conditions = new BinaryExpression(aEquals3, "&&", bEquals5); - Expression fullExpression = new BinaryExpression(xEqualsComp, "&&", conditions); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then assertNotNull(result, "Result should not be null"); assertEquals("x == (8 > 10)", result.getValue().toDisplayString(), "Boolean in equality should be unwrapped to show the computed comparison"); @@ -691,19 +317,9 @@ void testShouldUnwrapNestedBooleanInEquality() { @Test void testVarToVarPropagationWithInternalVariable() { - // Given: #x_0 == a && #x_0 > 5 - // Expected: a > 5 (internal #x_0 substituted with user-facing a) - - Expression varX0 = new Var("#x_0"); - Expression varA = new Var("a"); - Expression x0EqualsA = new BinaryExpression(varX0, "==", varA); - Expression x0Greater5 = new BinaryExpression(varX0, ">", new LiteralInt(5)); - Expression fullExpression = new BinaryExpression(x0EqualsA, "&&", x0Greater5); + Expression expression = parse("#x_0 == a && #x_0 > 5"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then assertNotNull(result, "Result should not be null"); assertEquals("a > 5", result.getValue().toString(), "Internal variable #x_0 should be substituted with user-facing variable a"); @@ -711,24 +327,9 @@ void testVarToVarPropagationWithInternalVariable() { @Test void testVarToVarInternalToInternal() { - // Given: #a_1 == #b_2 && #b_2 == 5 && x == #a_1 + 1 - // Expected: x == 5 + 1 = x == 6 - - Expression varA = new Var("#a_1"); - Expression varB = new Var("#b_2"); - Expression varX = new Var("x"); - Expression five = new LiteralInt(5); - Expression aEqualsB = new BinaryExpression(varA, "==", varB); - Expression bEquals5 = new BinaryExpression(varB, "==", five); - Expression aPlus1 = new BinaryExpression(varA, "+", new LiteralInt(1)); - Expression xEqualsAPlus1 = new BinaryExpression(varX, "==", aPlus1); - Expression firstAnd = new BinaryExpression(aEqualsB, "&&", bEquals5); - Expression fullExpression = new BinaryExpression(firstAnd, "&&", xEqualsAPlus1); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then + Expression expression = parse("#a_1 == #b_2 && #b_2 == 5 && x == #a_1 + 1"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); + assertNotNull(result, "Result should not be null"); assertEquals("x == 6", result.getValue().toString(), "#a should resolve through #b to 5 across passes, then x == 5 + 1 = x == 6"); @@ -736,19 +337,9 @@ void testVarToVarInternalToInternal() { @Test void testVarToVarDoesNotAffectUserFacingVariables() { - // Given: x == y && x > 5 - // Expected: x == y && x > 5 (user-facing var-to-var should not be propagated) - - Expression varX = new Var("x"); - Expression varY = new Var("y"); - Expression xEqualsY = new BinaryExpression(varX, "==", varY); - Expression xGreater5 = new BinaryExpression(varX, ">", new LiteralInt(5)); - Expression fullExpression = new BinaryExpression(xEqualsY, "&&", xGreater5); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + Expression expression = parse("x == y && x > 5"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - // Then assertNotNull(result, "Result should not be null"); assertEquals("x == y && x > 5", result.getValue().toString(), "User-facing variable equalities should not trigger var-to-var propagation"); @@ -756,25 +347,9 @@ void testVarToVarDoesNotAffectUserFacingVariables() { @Test void testVarToVarRemovesRedundantEquality() { - // Given: #ret_1 == #b_0 - 100 && #b_0 == b && b >= -128 && b <= 127 - // Expected: #ret_1 == b - 100 && b >= -128 && b <= 127 (#b_0 replaced with b, #b_0 == b removed) - - Expression ret1 = new Var("#ret_1"); - Expression b0 = new Var("#b_0"); - Expression b = new Var("b"); - Expression ret1EqB0Minus100 = new BinaryExpression(ret1, "==", - new BinaryExpression(b0, "-", new LiteralInt(100))); - Expression b0EqB = new BinaryExpression(b0, "==", b); - Expression bGeMinus128 = new BinaryExpression(b, ">=", new UnaryExpression("-", new LiteralInt(128))); - Expression bLe127 = new BinaryExpression(b, "<=", new LiteralInt(127)); - Expression and1 = new BinaryExpression(ret1EqB0Minus100, "&&", b0EqB); - Expression and2 = new BinaryExpression(bGeMinus128, "&&", bLe127); - Expression fullExpression = new BinaryExpression(and1, "&&", and2); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then + Expression expression = parse("#ret_1 == #b_0 - 100 && #b_0 == b && b >= -128 && b <= 127"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); + assertNotNull(result, "Result should not be null"); assertEquals("#ret_1 == b - 100 && b >= -128 && b <= 127", result.getValue().toString(), "Internal variable #b_0 should be replaced with b and redundant equality removed"); @@ -783,16 +358,8 @@ void testVarToVarRemovesRedundantEquality() { @Test void testInternalToInternalReducesRedundantVariable() { - // Given: #a_3 == #b_7 && #a_3 > 5 - // Expected: #b_7 > 5 (#a_3 has lower counter, so #a_3 -> #b_7) - - Expression a3 = new Var("#a_3"); - Expression b7 = new Var("#b_7"); - Expression a3EqualsB7 = new BinaryExpression(a3, "==", b7); - Expression a3Greater5 = new BinaryExpression(a3, ">", new LiteralInt(5)); - Expression fullExpression = new BinaryExpression(a3EqualsB7, "&&", a3Greater5); - - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + Expression expression = parse("#a_3 == #b_7 && #a_3 > 5"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); assertNotNull(result); assertEquals("#b_7 > 5", result.getValue().toString(), @@ -801,19 +368,8 @@ void testInternalToInternalReducesRedundantVariable() { @Test void testInternalToInternalChainWithUserFacingVariableUserFacingFirst() { - // Given: #b_7 == x && #a_3 == #b_7 && x > 0 - // Expected: x > 0 (#b_7 -> x (user-facing); #a_3 has lower counter so #a_3 -> #b_7) - - Expression a3 = new Var("#a_3"); - Expression b7 = new Var("#b_7"); - Expression x = new Var("x"); - Expression b7EqualsX = new BinaryExpression(b7, "==", x); - Expression a3EqualsB7 = new BinaryExpression(a3, "==", b7); - Expression xGreater0 = new BinaryExpression(x, ">", new LiteralInt(0)); - Expression and1 = new BinaryExpression(b7EqualsX, "&&", a3EqualsB7); - Expression fullExpression = new BinaryExpression(and1, "&&", xGreater0); - - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + Expression expression = parse("#b_7 == x && #a_3 == #b_7 && x > 0"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); assertNotNull(result); assertEquals("x > 0", result.getValue().toString(), @@ -822,19 +378,8 @@ void testInternalToInternalChainWithUserFacingVariableUserFacingFirst() { @Test void testInternalToInternalChainWithUserFacingVariableInternalFirst() { - // Given: #a_3 == #b_7 && #b_7 == x && x > 0 - // Expected: x > 0 (#a_3 has lower counter so #a_3 -> #b_7; #b_7 -> x (user-facing) overwrites) - - Expression a3 = new Var("#a_3"); - Expression b7 = new Var("#b_7"); - Expression x = new Var("x"); - Expression a3EqualsB7 = new BinaryExpression(a3, "==", b7); - Expression b7EqualsX = new BinaryExpression(b7, "==", x); - Expression xGreater0 = new BinaryExpression(x, ">", new LiteralInt(0)); - Expression and1 = new BinaryExpression(a3EqualsB7, "&&", b7EqualsX); - Expression fullExpression = new BinaryExpression(and1, "&&", xGreater0); - - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + Expression expression = parse("#a_3 == #b_7 && #b_7 == x && x > 0"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); assertNotNull(result); assertEquals("x > 0", result.getValue().toString(), @@ -843,17 +388,8 @@ void testInternalToInternalChainWithUserFacingVariableInternalFirst() { @Test void testInternalToInternalBothResolvingToLiteral() { - // Given: #a_3 == #b_7 && #b_7 == 5 - // Expected: 5 == 5 && 5 == 5 (#a_3 has lower counter so #a_3 -> #b_7; #b_7 -> 5) - - Expression a3 = new Var("#a_3"); - Expression b7 = new Var("#b_7"); - Expression five = new LiteralInt(5); - Expression a3EqualsB7 = new BinaryExpression(a3, "==", b7); - Expression b7Equals5 = new BinaryExpression(b7, "==", five); - Expression fullExpression = new BinaryExpression(a3EqualsB7, "&&", b7Equals5); - - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + Expression expression = parse("#a_3 == #b_7 && #b_7 == 5"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); assertNotNull(result); assertEquals("5 == 5 && 5 == 5", result.getValue().toString(), @@ -862,17 +398,8 @@ void testInternalToInternalBothResolvingToLiteral() { @Test void testInternalToInternalNoFurtherResolution() { - // Given: #a_3 == #b_7 && #b_7 + 1 > 0 - // Expected: #b_7 + 1 > 0 (#a_3 has lower counter, so #a_3 -> #b_7) - - Expression a3 = new Var("#a_3"); - Expression b7 = new Var("#b_7"); - Expression a3EqualsB7 = new BinaryExpression(a3, "==", b7); - Expression b7Plus1 = new BinaryExpression(b7, "+", new LiteralInt(1)); - Expression b7Plus1Greater0 = new BinaryExpression(b7Plus1, ">", new LiteralInt(0)); - Expression fullExpression = new BinaryExpression(a3EqualsB7, "&&", b7Plus1Greater0); - - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + Expression expression = parse("#a_3 == #b_7 && #b_7 + 1 > 0"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); assertNotNull(result); assertEquals("#b_7 + 1 > 0", result.getValue().toString(), @@ -881,15 +408,9 @@ void testInternalToInternalNoFurtherResolution() { @Test void testIteTrueConditionSimplifiesToThenBranch() { - // Given: true ? a : b - // Expected: a - - Expression expr = new Ite(new LiteralBoolean(true), new Var("a"), new Var("b")); - - // When + Expression expr = parse("true ? a : b"); ValDerivationNode result = ExpressionSimplifier.simplify(expr); - // Then assertNotNull(result, "Result should not be null"); assertEquals("a", result.getValue().toString(), "Expected result to be a"); @@ -904,15 +425,9 @@ void testIteTrueConditionSimplifiesToThenBranch() { @Test void testIteFalseConditionSimplifiesToElseBranch() { - // Given: false ? a : b - // Expected: b - - Expression expr = new Ite(new LiteralBoolean(false), new Var("a"), new Var("b")); - - // When + Expression expr = parse("false ? a : b"); ValDerivationNode result = ExpressionSimplifier.simplify(expr); - // Then assertNotNull(result, "Result should not be null"); assertEquals("b", result.getValue().toString(), "Expected result to be b"); @@ -927,16 +442,9 @@ void testIteFalseConditionSimplifiesToElseBranch() { @Test void testIteEqualBranchesSimplifiesToBranch() { - // Given: cond ? b : b - // Expected: b - - Expression branch = new Var("b"); - Expression expr = new Ite(new Var("cond"), branch, branch.clone()); - - // When + Expression expr = parse("cond ? b : b"); ValDerivationNode result = ExpressionSimplifier.simplify(expr); - // Then assertNotNull(result, "Result should not be null"); assertEquals("b", result.getValue().toString(), "Expected result to be b"); @@ -951,16 +459,13 @@ void testIteEqualBranchesSimplifiesToBranch() { @Test void testByteAliasExpansion() { - // Given: Byte(b) with alias Byte(int b) { b >= -128 && b <= 127 } + String sut = "Byte(b)"; AliasDTO byteAlias = new AliasDTO("Byte", List.of("int"), List.of("b"), "b >= -128 && b <= 127"); byteAlias.parse(""); Map aliases = Map.of("Byte", byteAlias); - Expression exp = new AliasInvocation("Byte", List.of(new Var("b"))); - - // When + Expression exp = parse(sut); ValDerivationNode result = ExpressionSimplifier.simplify(exp, aliases); - // Then assertEquals("Byte(b)", result.getValue().toString()); assertNotNull(result.getOrigin(), "Origin should contain the expanded body"); ValDerivationNode origin = (ValDerivationNode) result.getOrigin(); @@ -969,16 +474,13 @@ void testByteAliasExpansion() { @Test void testPositiveAliasExpansion() { - // Given: Positive(x) with alias Positive(int v) { v > 0 } + String sut = "Positive(x)"; AliasDTO positiveAlias = new AliasDTO("Positive", List.of("int"), List.of("v"), "v > 0"); positiveAlias.parse(""); Map aliases = Map.of("Positive", positiveAlias); - Expression exp = new AliasInvocation("Positive", List.of(new Var("x"))); - - // When + Expression exp = parse(sut); ValDerivationNode result = ExpressionSimplifier.simplify(exp, aliases); - // Then assertEquals("Positive(x)", result.getValue().toString()); assertNotNull(result.getOrigin(), "Origin should contain the expanded body"); ValDerivationNode origin = (ValDerivationNode) result.getOrigin(); @@ -987,20 +489,13 @@ void testPositiveAliasExpansion() { @Test void testTwoArgAliasWithNormalExpression() { - // Given: Bounded(v, 100) && v > 50 with alias Bounded(int x, int n) { x > 0 && x < n } + String sut = "Bounded(v, 100) && v > 50"; AliasDTO boundedAlias = new AliasDTO("Bounded", List.of("int", "int"), List.of("x", "n"), "x > 0 && x < n"); boundedAlias.parse(""); Map aliases = Map.of("Bounded", boundedAlias); + Expression expression = parse(sut); + ValDerivationNode result = ExpressionSimplifier.simplify(expression, aliases); - Expression varV = new Var("v"); - Expression bounded = new AliasInvocation("Bounded", List.of(varV, new LiteralInt(100))); - Expression vGt50 = new BinaryExpression(varV, ">", new LiteralInt(50)); - Expression fullExpression = new BinaryExpression(bounded, "&&", vGt50); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression, aliases); - - // Then assertEquals("Bounded(v, 100) && v > 50", result.getValue().toString()); assertInstanceOf(BinaryDerivationNode.class, result.getOrigin()); BinaryDerivationNode binOrigin = (BinaryDerivationNode) result.getOrigin(); @@ -1017,21 +512,18 @@ void testTwoArgAliasWithNormalExpression() { @Test void testEntailedConjunctIsRemovedButOriginIsPreserved() { - // Given: b >= 100 && b > 0 - // Expected: b >= 100 (b >= 100 implies b > 0) - + String sut = "b >= 100 && b > 0"; addIntVariableToContext("b"); - Expression b = new Var("b"); - Expression bGe100 = new BinaryExpression(b, ">=", new LiteralInt(100)); - Expression bGt0 = new BinaryExpression(b, ">", new LiteralInt(0)); - Expression fullExpression = new BinaryExpression(bGe100, "&&", bGt0); - - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + Expression expression = parse(sut); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); assertNotNull(result); assertEquals("b >= 100", result.getValue().toString(), "The weaker conjunct should be removed when implied by the stronger one"); + BinaryExpression parsed = (BinaryExpression) expression; + Expression bGe100 = parsed.getFirstOperand(); + Expression bGt0 = parsed.getSecondOperand(); ValDerivationNode expectedLeft = new ValDerivationNode(bGe100, null); ValDerivationNode expectedRight = new ValDerivationNode(bGt0, null); ValDerivationNode expected = new ValDerivationNode(bGe100, @@ -1042,23 +534,19 @@ void testEntailedConjunctIsRemovedButOriginIsPreserved() { @Test void testStrictComparisonImpliesNonStrictComparison() { - // Given: x > y && x >= y - // Expected: x > y (x > y implies x >= y) - + String sut = "x > y && x >= y"; addIntVariableToContext("x"); addIntVariableToContext("y"); - Expression x = new Var("x"); - Expression y = new Var("y"); - Expression xGtY = new BinaryExpression(x, ">", y); - Expression xGeY = new BinaryExpression(x, ">=", y); - Expression fullExpression = new BinaryExpression(xGtY, "&&", xGeY); - - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + Expression expression = parse(sut); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); assertNotNull(result); assertEquals("x > y", result.getValue().toString(), "The stricter comparison should be kept when it implies the weaker one"); + BinaryExpression parsed = (BinaryExpression) expression; + Expression xGtY = parsed.getFirstOperand(); + Expression xGeY = parsed.getSecondOperand(); ValDerivationNode expectedLeft = new ValDerivationNode(xGtY, null); ValDerivationNode expectedRight = new ValDerivationNode(xGeY, null); ValDerivationNode expected = new ValDerivationNode(xGtY, @@ -1069,19 +557,17 @@ void testStrictComparisonImpliesNonStrictComparison() { @Test void testEquivalentBoundsKeepOneSide() { - // Given: i >= 0 && 0 <= i - // Expected: 0 <= i (both conjuncts express the same condition) + String sut = "0 <= i && i >= 0"; addIntVariableToContext("i"); - Expression i = new Var("i"); - Expression zeroLeI = new BinaryExpression(new LiteralInt(0), "<=", i); - Expression iGeZero = new BinaryExpression(i, ">=", new LiteralInt(0)); - Expression fullExpression = new BinaryExpression(zeroLeI, "&&", iGeZero); - - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); + Expression expression = parse(sut); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); assertNotNull(result); assertEquals("0 <= i", result.getValue().toString(), "Equivalent bounds should collapse to a single conjunct"); + BinaryExpression parsed = (BinaryExpression) expression; + Expression zeroLeI = parsed.getFirstOperand(); + Expression iGeZero = parsed.getSecondOperand(); ValDerivationNode expectedLeft = new ValDerivationNode(zeroLeI, null); ValDerivationNode expectedRight = new ValDerivationNode(iGeZero, null); ValDerivationNode expected = new ValDerivationNode(zeroLeI, @@ -1092,45 +578,42 @@ void testEquivalentBoundsKeepOneSide() { @Test void testSubstitutesVariableDefinedByArithmeticExpression() { - // Given: z == y - 2 && y == x + 1 - // Expected: z == x - 1 + Expression expression = parse("z == y - 2 && y == x + 1"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); - Expression z = new Var("z"); - Expression y = new Var("y"); - Expression x = new Var("x"); - - Expression returnExpression = new BinaryExpression(z, "==", new BinaryExpression(y, "-", new LiteralInt(2))); - Expression yDefinition = new BinaryExpression(y, "==", new BinaryExpression(x, "+", new LiteralInt(1))); - Expression fullExpression = new BinaryExpression(returnExpression, "&&", yDefinition); - - // When - ValDerivationNode result = ExpressionSimplifier.simplify(fullExpression); - - // Then assertNotNull(result, "Result should not be null"); assertEquals("z == x - 1", result.getValue().toString(), "Expected variable definition to be substituted"); } @Test void testFoldsAdjacentIntegerConstantsInLeftAssociatedArithmetic() { - // Given: x + 1 - 2, x - 1 + 2, x + 1 + 2, and x + 1 - 1 - // Expected: x - 1, x + 1, x + 3, and x - - Expression x = new Var("x"); - - Expression xPlus1Minus2 = new BinaryExpression(new BinaryExpression(x, "+", new LiteralInt(1)), "-", - new LiteralInt(2)); - Expression xMinus1Plus2 = new BinaryExpression(new BinaryExpression(x, "-", new LiteralInt(1)), "+", - new LiteralInt(2)); - Expression xPlus1Plus2 = new BinaryExpression(new BinaryExpression(x, "+", new LiteralInt(1)), "+", - new LiteralInt(2)); - Expression xPlus1Minus1 = new BinaryExpression(new BinaryExpression(x, "+", new LiteralInt(1)), "-", - new LiteralInt(1)); - - // When / Then - assertEquals("x - 1", ExpressionSimplifier.simplify(xPlus1Minus2).getValue().toString()); - assertEquals("x + 1", ExpressionSimplifier.simplify(xMinus1Plus2).getValue().toString()); - assertEquals("x + 3", ExpressionSimplifier.simplify(xPlus1Plus2).getValue().toString()); - assertEquals("x", ExpressionSimplifier.simplify(xPlus1Minus1).getValue().toString()); + assertEquals("x - 1", ExpressionSimplifier.simplify(parse("x + 1 - 2")).getValue().toString()); + assertEquals("x + 1", ExpressionSimplifier.simplify(parse("x - 1 + 2")).getValue().toString()); + assertEquals("x + 3", ExpressionSimplifier.simplify(parse("x + 1 + 2")).getValue().toString()); + assertEquals("x", ExpressionSimplifier.simplify(parse("x + 1 - 1")).getValue().toString()); + } + + @Test + void testFunctionInvocationEqualitiesPropagateTransitively() { + Expression expression = parse("size(x3) == size(x2) - 1 && size(x2) == size(x1) + 1 && size(x1) == 0"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); + + assertEquals("size(x3) == 0", result.getValue().toString()); + } + + @Test + void testFunctionInvocationOnLeftBehavesLikeVariable() { + Expression expression = parse("func(a) == func(b) && func(b) == 1"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); + + assertEquals("func(a) == 1", result.getValue().toString()); + } + + @Test + void testFunctionInvocationEqualitiesMixWithVariables() { + Expression expression = parse("func(a) + x && func(a) == y && y == 1 && x == 2"); + ValDerivationNode result = ExpressionSimplifier.simplify(expression); + + assertEquals("3", result.getValue().toString()); } } diff --git a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VariableResolverTest.java b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VariableResolverTest.java index 561d6a3e..34e66626 100644 --- a/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VariableResolverTest.java +++ b/liquidjava-verifier/src/test/java/liquidjava/rj_language/opt/VariableResolverTest.java @@ -6,41 +6,28 @@ import org.junit.jupiter.api.Test; -import liquidjava.rj_language.ast.BinaryExpression; import liquidjava.rj_language.ast.Expression; -import liquidjava.rj_language.ast.GroupExpression; -import liquidjava.rj_language.ast.LiteralInt; -import liquidjava.rj_language.ast.UnaryExpression; -import liquidjava.rj_language.ast.Var; +import liquidjava.rj_language.parsing.RefinementsParser; class VariableResolverTest { + private static Expression parse(String refinement) { + return RefinementsParser.createAST(refinement, ""); + } + @Test void testSingleEqualityNotExtracted() { - // x == 1 should not extract because it's a single equality - Expression varX = new Var("x"); - Expression one = new LiteralInt(1); - Expression xEquals1 = new BinaryExpression(varX, "==", one); - Map result = VariableResolver.resolve(xEquals1); + Expression expression = parse("x == 1"); + Map result = VariableResolver.resolve(expression); + assertTrue(result.isEmpty(), "Single equality should not extract variable mapping"); } @Test void testConjunctionExtractsVariables() { - // x + y && x == 1 && y == 2 should extract x -> 1, y -> 2 - Expression varX = new Var("x"); - Expression varY = new Var("y"); - Expression one = new LiteralInt(1); - Expression two = new LiteralInt(2); - - Expression xPlusY = new BinaryExpression(varX, "+", varY); - Expression xEquals1 = new BinaryExpression(varX, "==", one); - Expression yEquals2 = new BinaryExpression(varY, "==", two); + Expression expression = parse("x + y && x == 1 && y == 2"); + Map result = VariableResolver.resolve(expression); - Expression conditions = new BinaryExpression(xEquals1, "&&", yEquals2); - Expression fullExpr = new BinaryExpression(xPlusY, "&&", conditions); - - Map result = VariableResolver.resolve(fullExpr); assertEquals(2, result.size(), "Should extract both variables"); assertEquals("1", result.get("x").toString()); assertEquals("2", result.get("y").toString()); @@ -48,124 +35,119 @@ void testConjunctionExtractsVariables() { @Test void testSingleComparisonNotExtracted() { - // x > 0 should not extract anything - Expression varX = new Var("x"); - Expression zero = new LiteralInt(0); - Expression xGreaterZero = new BinaryExpression(varX, ">", zero); + Expression expression = parse("x > 0"); + Map result = VariableResolver.resolve(expression); - Map result = VariableResolver.resolve(xGreaterZero); assertTrue(result.isEmpty(), "Single comparison should not extract variable mapping"); } @Test void testSingleArithmeticExpression() { - // x + 1 should not extract anything - Expression varX = new Var("x"); - Expression one = new LiteralInt(1); - Expression xPlusOne = new BinaryExpression(varX, "+", one); + Expression expression = parse("x + 1"); + Map result = VariableResolver.resolve(expression); - Map result = VariableResolver.resolve(xPlusOne); assertTrue(result.isEmpty(), "Single arithmetic expression should not extract variable mapping"); } @Test void testDisjunctionWithEqualities() { - // x == 1 || y == 2 should not extract anything - Expression varX = new Var("x"); - Expression varY = new Var("y"); - Expression one = new LiteralInt(1); - Expression two = new LiteralInt(2); - - Expression xEquals1 = new BinaryExpression(varX, "==", one); - Expression yEquals2 = new BinaryExpression(varY, "==", two); - Expression disjunction = new BinaryExpression(xEquals1, "||", yEquals2); + Expression expression = parse("x == 1 || y == 2"); + Map result = VariableResolver.resolve(expression); - Map result = VariableResolver.resolve(disjunction); assertTrue(result.isEmpty(), "Disjunction should not extract variable mappings"); } @Test void testNegatedEquality() { - // !(x == 1) should not extract because it's a single equality - Expression varX = new Var("x"); - Expression one = new LiteralInt(1); - Expression xEquals1 = new BinaryExpression(varX, "==", one); - Expression notXEquals1 = new UnaryExpression("!", xEquals1); + Expression expression = parse("!(x == 1)"); + Map result = VariableResolver.resolve(expression); - Map result = VariableResolver.resolve(notXEquals1); assertTrue(result.isEmpty(), "Negated equality should not extract variable mapping"); } @Test void testGroupedEquality() { - // (x == 1) should not extract because it's a single equality - Expression varX = new Var("x"); - Expression one = new LiteralInt(1); - Expression xEquals1 = new BinaryExpression(varX, "==", one); - Expression grouped = new GroupExpression(xEquals1); + Expression expression = parse("(x == 1)"); + Map result = VariableResolver.resolve(expression); - Map result = VariableResolver.resolve(grouped); assertTrue(result.isEmpty(), "Grouped single equality should not extract variable mapping"); } @Test void testCircularDependency() { - // x == y && y == x should not extract anything due to circular dependency - Expression varX = new Var("x"); - Expression varY = new Var("y"); + Expression expression = parse("x == y && y == x"); + Map result = VariableResolver.resolve(expression); - Expression xEqualsY = new BinaryExpression(varX, "==", varY); - Expression yEqualsX = new BinaryExpression(varY, "==", varX); - Expression conjunction = new BinaryExpression(xEqualsY, "&&", yEqualsX); - - Map result = VariableResolver.resolve(conjunction); assertTrue(result.isEmpty(), "Circular dependency should not extract variable mappings"); } @Test void testUnusedEqualitiesShouldBeIgnored() { - // z > 0 && x == 1 && y == 2 && z == 3 - Expression varX = new Var("x"); - Expression varY = new Var("y"); - Expression varZ = new Var("z"); - Expression one = new LiteralInt(1); - Expression two = new LiteralInt(2); - Expression three = new LiteralInt(3); - Expression zero = new LiteralInt(0); - Expression zGreaterZero = new BinaryExpression(varZ, ">", zero); - Expression xEquals1 = new BinaryExpression(varX, "==", one); - Expression yEquals2 = new BinaryExpression(varY, "==", two); - Expression zEquals3 = new BinaryExpression(varZ, "==", three); - Expression conditions = new BinaryExpression(xEquals1, "&&", new BinaryExpression(yEquals2, "&&", zEquals3)); - Expression fullExpr = new BinaryExpression(zGreaterZero, "&&", conditions); - Map result = VariableResolver.resolve(fullExpr); + Expression expression = parse("z > 0 && x == 1 && y == 2 && z == 3"); + Map result = VariableResolver.resolve(expression); + assertEquals(1, result.size(), "Should only extract used variable z"); assertEquals("3", result.get("z").toString()); } @Test void testReturnVariableIsNotSubstituted() { - // #ret_1 == x && x > 0 should not substitute #ret_1 with x - Expression ret = new Var("#ret_1"); - Expression x = new Var("x"); - Expression xGreaterZero = new BinaryExpression(x, ">", new LiteralInt(0)); - Expression retEqualsX = new BinaryExpression(ret, "==", x); - Expression fullExpr = new BinaryExpression(xGreaterZero, "&&", retEqualsX); - - Map result = VariableResolver.resolve(fullExpr); + Expression expression = parse("x > 0 && #ret_1 == x"); + Map result = VariableResolver.resolve(expression); + assertTrue(result.isEmpty(), "Return variables should not be substituted with another variable"); } @Test void testFreshVariableIsNotUsedAsSubstitutionTarget() { - // #tmp_1 > 0 && #tmp_1 == #fresh_2 should not substitute #tmp_1 with #fresh_2 - Expression internal = new Var("#tmp_1"); - Expression fresh = new Var("#fresh_2"); - Expression internalGreaterZero = new BinaryExpression(internal, ">", new LiteralInt(0)); - Expression internalEqualsFresh = new BinaryExpression(internal, "==", fresh); - Expression fullExpr = new BinaryExpression(internalGreaterZero, "&&", internalEqualsFresh); - - Map result = VariableResolver.resolve(fullExpr); + Expression expression = parse("#tmp_1 > 0 && #tmp_1 == #fresh_2"); + Map result = VariableResolver.resolve(expression); + assertTrue(result.isEmpty(), "Fresh variables should not replace another variable"); } + + @Test + void testFunctionInvocationEqualityExtractsFunctionKey() { + Expression expression = parse("size(stack) > 0 && size(stack) == 1"); + Map result = VariableResolver.resolve(expression); + + assertEquals(1, result.size(), "Should extract the function invocation as a substitution key"); + assertEquals("1", result.get("size(stack)").toString()); + } + + @Test + void testLiteralOnLeftExtractsFunctionInvocationKey() { + Expression expression = parse("size(stack) > 0 && 1 == size(stack)"); + Map result = VariableResolver.resolve(expression); + + assertEquals(1, result.size(), "Should extract function invocation equalities from either side"); + assertEquals("1", result.get("size(stack)").toString()); + } + + @Test + void testFunctionInvocationEqualitiesResolveTransitively() { + Expression expression = parse("func(a) > 0 && func(a) == func(b) && func(b) == 1"); + Map result = VariableResolver.resolve(expression); + + assertEquals(2, result.size(), "Should keep the function invocation chain that was used"); + assertEquals("1", result.get("func(a)").toString()); + assertEquals("1", result.get("func(b)").toString()); + } + + @Test + void testFunctionInvocationNamesAreMatchedStructurally() { + Expression expression = parse("f(a) > 0 && f(a) == ff(a) + b"); + Map result = VariableResolver.resolve(expression); + + assertEquals(1, result.size(), "Should not treat ff(a) as a use of f(a)"); + assertEquals("ff(a) + b", result.get("f(a)").toString()); + } + + @Test + void testUnusedFunctionInvocationEqualityIsIgnored() { + Expression expression = parse("x > 0 && size(stack) == 1"); + Map result = VariableResolver.resolve(expression); + + assertTrue(result.isEmpty(), "Function invocation definitions with no usage should be ignored"); + } }