diff --git a/src/Bounds.cpp b/src/Bounds.cpp index f4493474c49f..9087516b54b8 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -1827,6 +1827,18 @@ Interval bounds_of_expr_in_scope(const Expr &expr, const Scope &scope, return bounds_of_expr_in_scope_with_indent(expr, scope, fb, const_bound, 0); } +Expr and_condition_over_domain(const Expr &e, const Scope &varying) { + internal_assert(e.type().is_bool()) << "Expr provided to and_condition_over_domain is not boolean: " << e << "\n"; + Interval bounds = bounds_of_expr_in_scope(e, varying); + internal_assert(bounds.has_lower_bound()) << "Failed to produce bound on boolean value in and_condition_over_domain" << e << "\n"; + // Minimum of a boolean value is sufficient condition, implies expression. + return simplify(bounds.min); +} + +Expr or_condition_over_domain(const Expr &c, const Scope &varying) { + return simplify(!and_condition_over_domain(simplify(!c), varying)); +} + void merge_boxes(Box &a, const Box &b) { if (b.empty()) { return; diff --git a/src/Bounds.h b/src/Bounds.h index bafa42ecda1a..a06980492d60 100644 --- a/src/Bounds.h +++ b/src/Bounds.h @@ -48,6 +48,22 @@ Expr find_constant_bound(const Expr &e, Direction d, * +/-inf. */ Interval find_constant_bounds(const Expr &e, const Scope &scope); +/** Take a conditional that includes variables that vary over some + * domain, and convert it to a more conservative (less frequently + * true) condition that doesn't depend on those variables. Formally, + * the output expr implies the input expr. + * + * The condition may be a vector condition, in which case we also + * 'and' over the vector lanes, and return a scalar result. */ +Expr and_condition_over_domain(const Expr &c, const Scope &varying); + +/** Take a conditional that includes variables that vary over some + * domain, and convert it to a weaker (less frequently false) condition + * that doesn't depend on those variables. Formally, the input expr + * implies the output expr. Note that this function might be unable to + * provide a better response than simply const_true(). */ +Expr or_condition_over_domain(const Expr &c, const Scope &varying); + /** Represents the bounds of a region of arbitrary dimension. Zero * dimensions corresponds to a scalar region. */ struct Box { diff --git a/src/Simplify_Cast.cpp b/src/Simplify_Cast.cpp index 631686ee0bfc..0d1b278a1e85 100644 --- a/src/Simplify_Cast.cpp +++ b/src/Simplify_Cast.cpp @@ -110,13 +110,20 @@ Expr Simplify::visit(const Cast *op, ExprInfo *info) { } else if (cast && op->type.is_int_or_uint() && cast->type.is_int_or_uint() && + cast->value.type().is_int_or_uint() && op->type.bits() <= cast->type.bits() && op->type.bits() <= op->value.type().bits()) { // If this is a cast between integer types, where the // outer cast is narrower than the inner cast and the // inner cast's argument, the inner cast can be // eliminated. The inner cast is either a sign extend - // or a zero extend, and the outer cast truncates the extended bits + // or a zero extend, and the outer cast truncates the extended bits. + // The requirement that cast->value is itself int-or-uint is crucial: + // a float source makes `cast` an fp-to-int conversion, whose low + // bits are not the same as an fp-to-int conversion of a narrower + // type. For example, int32(uint64(float64(-21))) evaluates to 0 + // (float-to-uint of a negative value saturates to 0 in Halide), + // while the stripped form int32(float64(-21)) evaluates to -21. if (op->type == cast->value.type()) { return mutate(cast->value, info); } else { diff --git a/src/Solve.cpp b/src/Solve.cpp index a0b91be7a287..5bb742917809 100644 --- a/src/Solve.cpp +++ b/src/Solve.cpp @@ -156,7 +156,12 @@ class SolveExpression : public IRMutator { } } else if (a_uses_var && b_uses_var) { if (equal(a, b)) { - expr = mutate(a * 2); + // Use Mul::make + make_const rather than operator*(Expr, int) + // because the latter rejects constants that don't fit in a's + // type (e.g. `2` in UInt(1)). make_const truncates modulo the + // width, so `UInt(1) * 2` becomes `UInt(1) * 0`, which is + // the correct modular result of `a + a` for UInt(1). + expr = mutate(Mul::make(a, make_const(a.type(), 2))); } else if (add_a && !a_failed) { // (f(x) + a) + g(x) -> (f(x) + g(x)) + a expr = mutate((add_a->a + b) + add_a->b); @@ -181,11 +186,15 @@ class SolveExpression : public IRMutator { } else if (mul_b && equal(mul_b->a, a)) { // f(x) + f(x)*a -> f(x) * (a + 1) expr = mutate(a * (mul_b->b + 1)); - } else if (div_a && !a_failed) { - // f(x)/a + g(x) -> (f(x) + g(x) * a) / b + } else if (div_a && !a_failed && no_overflow_int(op->type)) { + // f(x)/a + g(x) -> (f(x) + g(x) * a) / a + // Only valid when multiplication and division don't wrap: + // under modular arithmetic g(x)*a can overflow and the + // rewrite changes the value. Gated to Int(32)+. expr = mutate((div_a->a + b * div_a->b) / div_a->b); - } else if (div_b && !b_failed) { + } else if (div_b && !b_failed && no_overflow_int(op->type)) { // f(x) + g(x)/b -> (f(x) * b + g(x)) / b + // Same overflow concern as above. expr = mutate((a * div_b->b + div_b->a) / div_b->b); } else { expr = fail(a + b); @@ -269,8 +278,9 @@ class SolveExpression : public IRMutator { } else if (mul_a && mul_b && equal(mul_a->b, mul_b->b)) { // f(x)*a - g(x)*a -> (f(x) - g(x))*a; expr = mutate((mul_a->a - mul_b->a) * mul_a->b); - } else if (div_a && !a_failed) { - // f(x)/a - g(x) -> (f(x) - g(x) * a) / b + } else if (div_a && !a_failed && no_overflow_int(op->type)) { + // f(x)/a - g(x) -> (f(x) - g(x) * a) / a + // Same overflow concern as the analogous Add rewrite. expr = mutate((div_a->a - b * div_a->b) / div_a->b); } else { expr = fail(a - b); @@ -358,7 +368,10 @@ class SolveExpression : public IRMutator { const Sub *sub_a = a.as(); const Mul *mul_a = a.as(); Expr expr; - if (a_uses_var && !b_uses_var) { + if (a_uses_var && !b_uses_var && no_overflow_int(op->type)) { + // Distributing division across +/-/* is only sound under + // non-wrapping integer arithmetic. Floats lose precision per + // operation, and narrower / unsigned ints can wrap. auto ib = as_const_int(b); auto is_multiple_of_b = [&](const Expr &e) { if (ib && op->type.is_scalar()) { @@ -376,7 +389,7 @@ class SolveExpression : public IRMutator { is_multiple_of_b(sub_a->a)) { // (f(x) - a) / b -> f(x) / b - a / b expr = mutate(simplify(sub_a->a / b) - sub_a->b / b); - } else if (mul_a && !a_failed && no_overflow_int(op->type) && + } else if (mul_a && !a_failed && is_multiple_of_b(mul_a->b)) { // (f(x) * a) / b -> f(x) * (a / b) expr = mutate(mul_a->a * (mul_a->b / b)); @@ -609,11 +622,20 @@ class SolveExpression : public IRMutator { Expr expr; if (a_uses_var && !b_uses_var) { - // We have f(x) < y. Try to unwrap f(x) - if (add_a && !a_failed) { + // We have f(x) < y. Try to unwrap f(x). + // + // Several of these rewrites rearrange the comparison by adding + // or subtracting on both sides. That's only sound under an + // assumption of no integer overflow -- for types that wrap + // (unsigned and narrow signed), ordering comparisons flip + // under wrap even though equality is preserved. So gate the + // rewrite on no_overflow_int for LT/LE/GT/GE but allow EQ/NE + // for all types (modular arithmetic preserves equality). + const bool safe_to_rearrange = no_overflow_int(a.type()) || is_eq || is_ne; + if (add_a && !a_failed && safe_to_rearrange) { // f(x) + b < c -> f(x) < c - b expr = mutate(Cmp::make(add_a->a, (b - add_a->b))); - } else if (sub_a && !a_failed) { + } else if (sub_a && !a_failed && safe_to_rearrange) { // f(x) - b < c -> f(x) < c + b expr = mutate(Cmp::make(sub_a->a, (b + sub_a->b))); } else if (mul_a) { @@ -631,11 +653,14 @@ class SolveExpression : public IRMutator { // check is true, but put an assertion anyway. internal_assert(!b.type().is_uint()) << "Negating unsigned is not legal\n"; expr = mutate(Opp::make(mul_a->a * negate(mul_a->b), negate(b))); - } else { - // Don't use operator/ and operator % to sneak - // past the division-by-zero check. We'll only - // actually use these when mul_a->b is a positive - // or negative constant. + } else if (is_positive_const(mul_a->b) && no_overflow_int(a.type())) { + // The rewrites below divide by mul_a->b, so require + // it to be a nonzero constant of known sign. + // no_overflow_int also rules out unsigned and narrow + // signed types, for which `a*c == b <=> a == b/c && + // b%c == 0` fails under modular arithmetic (consider + // uint8 with c = 3, b = 7: the rewrite misses the + // solutions that arise from wrap). Expr div = Div::make(b, mul_a->b); Expr rem = Mod::make(b, mul_a->b); if (is_eq) { @@ -644,16 +669,14 @@ class SolveExpression : public IRMutator { } else if (is_ne) { // f(x) * c != b -> f(x) != b/c || b%c != 0 expr = mutate((mul_a->a != div) || (rem != 0)); - } else if (is_positive_const(mul_a->b)) { - if (is_le) { - expr = mutate(mul_a->a <= div); - } else if (is_lt) { - expr = mutate(mul_a->a <= (b - 1) / mul_a->b); - } else if (is_gt) { - expr = mutate(mul_a->a > div); - } else if (is_ge) { - expr = mutate(mul_a->a > (b - 1) / mul_a->b); - } + } else if (is_le) { + expr = mutate(mul_a->a <= div); + } else if (is_lt) { + expr = mutate(mul_a->a <= (b - 1) / mul_a->b); + } else if (is_gt) { + expr = mutate(mul_a->a > div); + } else if (is_ge) { + expr = mutate(mul_a->a > (b - 1) / mul_a->b); } } } else if (div_a) { @@ -663,7 +686,7 @@ class SolveExpression : public IRMutator { } else if (is_negative_const(div_a->b)) { expr = mutate(Opp::make(div_a->a, b * div_a->b)); } - } else if (a.type().is_int() && a.type().bits() >= 32) { + } else if (no_overflow_int(a.type())) { if (is_eq || is_ne) { // Can't do anything with this } else if (is_negative_const(div_a->b)) { @@ -689,7 +712,7 @@ class SolveExpression : public IRMutator { } } } - } else if (a_uses_var && b_uses_var && a.type().is_int() && a.type().bits() >= 32) { + } else if (a_uses_var && b_uses_var && no_overflow_int(a.type())) { // Convert to f(x) - g(x) == 0 and let the subtract mutator clean up. // Only safe if the type is not subject to overflow. expr = mutate(Cmp::make(a - b, make_zero(a.type()))); @@ -1173,312 +1196,5 @@ Interval solve_for_outer_interval(const Expr &c, const std::string &var) { return s.result; } -Expr and_condition_over_domain(const Expr &e, const Scope &varying) { - internal_assert(e.type().is_bool()) << "Expr provided to and_condition_over_domain is not boolean: " << e << "\n"; - Interval bounds = bounds_of_expr_in_scope(e, varying); - internal_assert(bounds.has_lower_bound()) << "Failed to produce bound on boolean value in and_condition_over_domain" << e << "\n"; - // Minimum of a boolean value is sufficient condition, implies expression. - return simplify(bounds.min); -} - -Expr or_condition_over_domain(const Expr &c, const Scope &varying) { - return simplify(!and_condition_over_domain(simplify(!c), varying)); -} - -// Testing code - -namespace { - -void check_solve(const Expr &a, const Expr &b) { - SolverResult solved = solve_expression(a, "x"); - internal_assert(equal(solved.result, b)) - << "Expression: " << a << "\n" - << " solved to " << solved.result << "\n" - << " instead of " << b << "\n"; -} - -void check_interval(const Expr &a, const Interval &i, bool outer) { - Interval result = - outer ? solve_for_outer_interval(a, "x") : solve_for_inner_interval(a, "x"); - result.min = simplify(result.min); - result.max = simplify(result.max); - internal_assert(equal(result.min, i.min) && equal(result.max, i.max)) - << "Expression " << a << " solved to the interval:\n" - << " min: " << result.min << "\n" - << " max: " << result.max << "\n" - << " instead of:\n" - << " min: " << i.min << "\n" - << " max: " << i.max << "\n"; -} - -void check_outer_interval(const Expr &a, const Expr &min, const Expr &max) { - check_interval(a, Interval(min, max), true); -} - -void check_inner_interval(const Expr &a, const Expr &min, const Expr &max) { - check_interval(a, Interval(min, max), false); -} - -void check_and_condition(const Expr &orig, const Expr &result, const Interval &i) { - Scope s; - s.push("x", i); - Expr cond = and_condition_over_domain(orig, s); - internal_assert(equal(cond, result)) - << "Expression " << orig - << " reduced to " << cond - << " instead of " << result << "\n"; -} -} // namespace - -void solve_test() { - using ConciseCasts::i16; - - Expr x = Variable::make(Int(32), "x"); - Expr y = Variable::make(Int(32), "y"); - Expr z = Variable::make(Int(32), "z"); - - // Check some simple cases - check_solve(3 - 4 * x, x * (-4) + 3); - check_solve(min(5, x), min(x, 5)); - check_solve(max(5, (5 + x) * y), max(x * y + 5 * y, 5)); - check_solve(5 * y + 3 * x == 2, ((x == ((2 - (5 * y)) / 3)) && (((2 - (5 * y)) % 3) == 0))); - check_solve(min(min(z, x), min(x, y)), min(x, min(y, z))); - check_solve(min(x + y, x + 5), x + min(y, 5)); - - // Check solver with expressions containing division - check_solve(x + (x * 2) / 2, x * 2); - check_solve(x + (x * 2 + y) / 2, x * 2 + (y / 2)); - check_solve(x + (x * 2 - y) / 2, x * 2 - (y / 2)); - check_solve(x + (-(x * 2) / 2), x * 0 + 0); - check_solve(x + (-(x * 2 + -3)) / 2, x * 0 + 1); - check_solve(x + (z - (x * 2 + -3)) / 2, x * 0 + (z - (-3)) / 2); - check_solve(x + (y * 16 + (z - (x * 2 + -1))) / 2, - (x * 0) + (((z - -1) + (y * 16)) / 2)); - - check_solve((x * 9 + 3) / 4 - x * 2, (x * 1 + 3) / 4); - check_solve((x * 9 + 3) / 4 + x * 2, (x * 17 + 3) / 4); - check_solve(x * 2 + (x * 9 + 3) / 4, (x * 17 + 3) / 4); - - // Check the solver doesn't perform transformations that change integer overflow behavior. - check_solve(i16(x + y) * i16(2) / i16(2), i16(x + y) * i16(2) / i16(2)); - - // A let statement - check_solve(Let::make("z", 3 + 5 * x, y + z < 8), - x <= (((8 - (3 + y)) - 1) / 5)); - - // A let statement where the variable gets used twice. - check_solve(Let::make("z", 3 + 5 * x, y + (z + z) < 8), - x <= (((8 - (6 + y)) - 1) / 10)); - - // Something where we expect a let in the output. - { - Expr e = y + 1; - for (int i = 0; i < 10; i++) { - e *= (e + 1); - } - SolverResult solved = solve_expression(x + e < e * e, "x"); - internal_assert(solved.fully_solved && solved.result.as()); - } - - // Solving inequalities for integers is a pain to get right with - // all the rounding rules. Check we didn't make a mistake with - // brute force. - for (int den = -3; den <= 3; den++) { - if (den == 0) { - continue; - } - for (int num = 5; num <= 10; num++) { - Expr in[] = { - {x * den < num}, - {x * den <= num}, - {x * den == num}, - {x * den != num}, - {x * den >= num}, - {x * den > num}, - {x / den < num}, - {x / den <= num}, - {x / den == num}, - {x / den != num}, - {x / den >= num}, - {x / den > num}, - }; - for (const auto &e : in) { - SolverResult solved = solve_expression(e, "x"); - internal_assert(solved.fully_solved) << "Error: failed to solve for x in " << e << "\n"; - Expr out = simplify(solved.result); - for (int i = -10; i < 10; i++) { - Expr in_val = substitute("x", i, e); - Expr out_val = substitute("x", i, out); - in_val = simplify(in_val); - out_val = simplify(out_val); - internal_assert(equal(in_val, out_val)) - << "Error: " - << e << " is not equivalent to " - << out << " when x == " << i << "\n"; - } - } - } - } - - // Check for combinatorial explosion - Expr e = x + y; - for (int i = 0; i < 20; i++) { - e += (e + 1) * y; - } - SolverResult solved = solve_expression(e, "x"); - internal_assert(solved.fully_solved && solved.result.defined()); - - // Check some things that we don't expect to work. - - // Quadratics: - internal_assert(!solve_expression(x * x < 4, "x").fully_solved); - - // Function calls, cast nodes, or multiplications by unknown sign - // don't get inverted, but the bit containing x still gets moved - // leftwards. - check_solve(4.0f > sqrt(x), sqrt(x) < 4.0f); - - check_solve(4 > y * x, x * y < 4); - - // Now test solving for an interval - check_inner_interval(x > 0, 1, Interval::pos_inf()); - check_inner_interval(x < 100, Interval::neg_inf(), 99); - check_outer_interval(x > 0 && x < 100, 1, 99); - check_inner_interval(x > 0 && x < 100, 1, 99); - - Expr c = Variable::make(Bool(), "c"); - check_outer_interval(Let::make("y", 0, x > y && x < 100), 1, 99); - check_outer_interval(Let::make("c", x > 0, c && x < 100), 1, 99); - - check_outer_interval((x >= 10 && x <= 90) && sin(x) > 0.5f, 10, 90); - check_inner_interval((x >= 10 && x <= 90) && sin(x) > 0.6f, Interval::pos_inf(), Interval::neg_inf()); - - check_inner_interval(x == 10, 10, 10); - check_outer_interval(x == 10, 10, 10); - - check_inner_interval(!(x != 10), 10, 10); - check_outer_interval(!(x != 10), 10, 10); - - check_inner_interval(3 * x + 4 < 27, Interval::neg_inf(), 7); - check_outer_interval(3 * x + 4 < 27, Interval::neg_inf(), 7); - - check_inner_interval(min(x, y) > 17, 18, y); - check_outer_interval(min(x, y) > 17, 18, Interval::pos_inf()); - - check_inner_interval(x / 5 < 17, Interval::neg_inf(), 84); - check_outer_interval(x / 5 < 17, Interval::neg_inf(), 84); - - // Test anding a condition over a domain - check_and_condition(x > 0, const_true(), Interval(1, y)); - check_and_condition(x > 0, const_true(), Interval(5, y)); - check_and_condition(x > 0, const_false(), Interval(-5, y)); - check_and_condition(x > 0 && x < 10, const_true(), Interval(1, 9)); - check_and_condition(x > 0 || sin(x) == 0.5f, const_true(), Interval(100, 200)); - - check_and_condition(x <= 0, const_true(), Interval(-100, 0)); - check_and_condition(x <= 0, const_false(), Interval(-100, 1)); - - check_and_condition(x <= 0 || y > 2, const_true(), Interval(-100, 0)); - check_and_condition(x > 0 || y > 2, 2 < y, Interval(-100, 0)); - - check_and_condition(x == 0, const_true(), Interval(0, 0)); - check_and_condition(x == 0, const_false(), Interval(-10, 10)); - check_and_condition(x != 0, const_false(), Interval(-10, 10)); - check_and_condition(x != 0, const_true(), Interval(-20, -10)); - - check_and_condition(y == 0, y == 0, Interval(-10, 10)); - check_and_condition(y != 0, y != 0, Interval(-10, 10)); - check_and_condition((x == 5) && (y != 0), const_false(), Interval(-10, 10)); - check_and_condition((x == 5) && (y != 3), y != 3, Interval(5, 5)); - check_and_condition((x != 0) && (y != 0), const_false(), Interval(-10, 10)); - check_and_condition((x != 0) && (y != 0), y != 0, Interval(-20, -10)); - - { - // This case used to break due to signed integer overflow in - // the simplifier. - Expr a16 = Load::make(Int(16), "a", {x}, Buffer<>(), Parameter(), const_true(), ModulusRemainder()); - Expr b16 = Load::make(Int(16), "b", {x}, Buffer<>(), Parameter(), const_true(), ModulusRemainder()); - Expr lhs = pow(cast(a16), 2) + pow(cast(b16), 2); - - Scope s; - s.push("x", Interval(-10, 10)); - Expr cond = and_condition_over_domain(lhs < 0, s); - internal_assert(!is_const_one(simplify(cond))); - } - - { - // This cause use to cause infinite recursion: - Expr t = Variable::make(Int(32), "t"); - Expr test = (x <= min(max((y - min(((z * x) + t), t)), 1), 0)); - Interval result = solve_for_outer_interval(test, "z"); - } - - { - // This case caused exponential behavior - Expr t = Variable::make(Int(32), "t"); - for (int i = 0; i < 50; i++) { - t = min(t, Variable::make(Int(32), unique_name('v'))); - t = max(t, Variable::make(Int(32), unique_name('v'))); - } - solve_for_outer_interval(t <= 5, "t"); - solve_for_inner_interval(t <= 5, "t"); - } - - // Check for partial results - check_solve(max(min(y, x), x), max(min(x, y), x)); - check_solve(min(y, x) + max(y, 2 * x), min(x, y) + max(x * 2, y)); - check_solve((min(x, y) + min(y, x)) * max(y, x), (min(x, y) * 2) * max(x, y)); - check_solve(max((min((y * x), x) + min((1 + y), x)), (y + 2 * x)), - max((min((x * y), x) + min(x, (1 + y))), (x * 2 + y))); - - { - Expr x = Variable::make(UInt(32), "x"); - Expr y = Variable::make(UInt(32), "y"); - Expr z = Variable::make(UInt(32), "z"); - check_solve(5 - (4 - 4 * x), x * (4) + 1); - check_solve(z - (y - x), x + (z - y)); - check_solve(z - (y - x) == 2, x == 2 - (z - y)); - - check_solve(x - (x - y), (x - x) + y); - - // This is used to cause infinite recursion - Expr expr = Add::make(z, Sub::make(x, y)); - SolverResult solved = solve_expression(expr, "y"); - } - - // This case was incorrect due to canonicalization of the multiply - // occurring after unpacking the LHS. - check_solve((y - z) * x, x * (y - z)); - - // These cases were incorrectly not flipping min/max when moving - // it out of the RHS of a subtract. - check_solve(min(x - y, x - z), x - max(y, z)); - check_solve(min(x - y, x), x - max(y, 0)); - check_solve(min(x, x - y), x - max(y, 0)); - check_solve(max(x - y, x - z), x - min(y, z)); - check_solve(max(x - y, x), x - min(y, 0)); - check_solve(max(x, x - y), x - min(y, 0)); - - // Check mixed add/sub - check_solve(min(x - y, x + z), x + min(0 - y, z)); - check_solve(max(x - y, x + z), x + max(0 - y, z)); - check_solve(min(x + y, x - z), x + min(y, 0 - z)); - check_solve(max(x + y, x - z), x + max(y, 0 - z)); - - check_solve((5 * Broadcast::make(x, 4) + y) / 5, - Broadcast::make(x, 4) + (Broadcast::make(y, 4) / 5)); - - // Select negates the condition to move x leftward - check_solve(select(y < z, z, x), - select(z <= y, x, z)); - - // Select negates the condition and then mutates it, moving x - // leftward (despite the simplifier preferring < to >). - check_solve(select(x < 10, 10, x), - select(x >= 10, x, 10)); - - std::cout << "Solve test passed\n"; -} - } // namespace Internal } // namespace Halide diff --git a/src/Solve.h b/src/Solve.h index 4d06fda47d6b..c47462329adc 100644 --- a/src/Solve.h +++ b/src/Solve.h @@ -38,24 +38,6 @@ Interval solve_for_outer_interval(const Expr &c, const std::string &variable); * true inside of it, and might be true or false outside of it. */ Interval solve_for_inner_interval(const Expr &c, const std::string &variable); -/** Take a conditional that includes variables that vary over some - * domain, and convert it to a more conservative (less frequently - * true) condition that doesn't depend on those variables. Formally, - * the output expr implies the input expr. - * - * The condition may be a vector condition, in which case we also - * 'and' over the vector lanes, and return a scalar result. */ -Expr and_condition_over_domain(const Expr &c, const Scope &varying); - -/** Take a conditional that includes variables that vary over some - * domain, and convert it to a weaker (less frequently false) condition - * that doesn't depend on those variables. Formally, the input expr - * implies the output expr. Note that this function might be unable to - * provide a better response than simply const_true(). */ -Expr or_condition_over_domain(const Expr &c, const Scope &varying); - -void solve_test(); - } // namespace Internal } // namespace Halide diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index c6c90f833db0..913cfceb4281 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -294,6 +294,7 @@ tests(GROUPS correctness sliding_over_guard_with_if.cpp sliding_reduction.cpp sliding_window.cpp + solve.cpp sort_exprs.cpp specialize.cpp specialize_to_gpu.cpp diff --git a/test/correctness/solve.cpp b/test/correctness/solve.cpp new file mode 100644 index 000000000000..f52c49f9edb1 --- /dev/null +++ b/test/correctness/solve.cpp @@ -0,0 +1,526 @@ +#include "Halide.h" + +#include +#include +#include +#include + +using namespace Halide; +using namespace Halide::Internal; + +namespace { + +// Assert that solve_expression produces exactly the given expected expression. +void check_solve(const Expr &in, const Expr &expected) { + SolverResult solved = solve_expression(in, "x"); + if (!equal(solved.result, expected)) { + std::cerr << "solve_expression produced unexpected result:\n" + << " input: " << in << "\n" + << " expected: " << expected << "\n" + << " actual: " << solved.result << "\n"; + std::abort(); + } +} + +void check_interval(const Expr &a, const Interval &i, bool outer) { + Interval result = + outer ? solve_for_outer_interval(a, "x") : solve_for_inner_interval(a, "x"); + result.min = simplify(result.min); + result.max = simplify(result.max); + if (!equal(result.min, i.min) || !equal(result.max, i.max)) { + std::cerr << "Expression " << a << " solved to the interval:\n" + << " min: " << result.min << "\n" + << " max: " << result.max << "\n" + << " instead of:\n" + << " min: " << i.min << "\n" + << " max: " << i.max << "\n"; + std::abort(); + } +} + +void check_outer_interval(const Expr &a, const Expr &min, const Expr &max) { + check_interval(a, Interval(min, max), true); +} + +void check_inner_interval(const Expr &a, const Expr &min, const Expr &max) { + check_interval(a, Interval(min, max), false); +} + +void check_and_condition(const Expr &orig, const Expr &result, const Interval &i) { + Scope s; + s.push("x", i); + Expr cond = and_condition_over_domain(orig, s); + if (!equal(cond, result)) { + std::cerr << "Expression " << orig + << " reduced to " << cond + << " instead of " << result << "\n"; + std::abort(); + } +} + +// Assert that solve_expression produces a result that is semantically +// equivalent to the input under the given substitution. This is used for +// cases where we care about preserved meaning, not exact syntactic form. +void check_solve_equivalent(const Expr &in, const std::map &vars) { + SolverResult solved = solve_expression(in, "x"); + Expr in_v = simplify(substitute(vars, in)); + Expr out_v = simplify(substitute(vars, solved.result)); + if (!equal(in_v, out_v)) { + std::cerr << "solve_expression changed value under substitution:\n" + << " input: " << in << "\n" + << " solved: " << solved.result << "\n"; + for (const auto &[name, val] : vars) { + std::cerr << " " << name << " = " << val << "\n"; + } + std::cerr << " input evaluated: " << in_v << "\n" + << " solved evaluated: " << out_v << "\n"; + std::abort(); + } +} + +// Bug #1: the solver was rewriting `f(x) + b @ c` to `f(x) @ c - b` for +// every comparison @, but for unsigned types the subtraction wraps, which +// does not preserve the *ordering* comparisons LT/LE/GT/GE (the EQ/NE +// rewrite is still valid under modular arithmetic, so those stay). +void test_unsigned_ordering_not_rearranged() { + Expr x = Variable::make(UInt(32), "x"); + Expr y = Variable::make(UInt(32), "y"); + + // A concrete substitution that demonstrates the wrap: with x = 4 and + // y = (uint32_t)-14 = 4294967282, x + y = 4294967286, so + // `x + y < 1641646169` is false. The buggy rewrite `x < 1641646169 - y` + // underflows 1641646169 - 4294967282 to 1641646183, making it true. + std::map vars{ + {"x", UIntImm::make(UInt(32), 4)}, + {"y", UIntImm::make(UInt(32), 4294967282u)}, + }; + + Expr c = UIntImm::make(UInt(32), 1641646169u); + check_solve_equivalent(x + y < c, vars); + check_solve_equivalent(x + y <= c, vars); + check_solve_equivalent(x + y > c, vars); + check_solve_equivalent(x + y >= c, vars); + + // The symmetric subtraction form must be preserved too. + check_solve_equivalent(x - y < c, vars); + check_solve_equivalent(x - y <= c, vars); + check_solve_equivalent(x - y > c, vars); + check_solve_equivalent(x - y >= c, vars); +} + +// Bug #1 corollary: EQ/NE rewrites are still safe under modular arithmetic +// (modular equivalence preserves equality), so these should continue to be +// rewritten to isolate x on the left. +void test_unsigned_equality_still_rearranged() { + Expr x = Variable::make(UInt(32), "x"); + Expr y = Variable::make(UInt(32), "y"); + Expr c = UIntImm::make(UInt(32), 2u); + + // `x + y == c` should solve to `x == c - y`, matching existing tests + // in src/Solve.cpp's solve_test() for unsigned rewrites. + check_solve(x + y == c, x == (c - y)); + check_solve(x + y != c, x != (c - y)); +} + +// Bug #2: the solver was rewriting `f(x) * y @ b` to forms involving `b / y` +// and `b % y` even when `y` was a non-constant expression. When `y` evaluates +// to zero the rewrite changes the expression's value even though Halide +// defines div/mod-by-zero to return zero -- `a * 0 == b` becomes `a == b/0 && +// b%0 == 0` which collapses to `a == 0 && b == 0`, losing the original +// "always false when b != 0" semantics. +void test_nonconstant_multiplier_not_rewritten() { + Expr x = Variable::make(Int(32), "x"); + Expr y = Variable::make(Int(32), "y"); + + // At y = 0, `x * y == 1` is the well-defined `0 == 1 == false`. + // The buggy rewrite `x == 1/y && 1%y == 0` evaluates to + // `x == 0 && true == true`, which is true at x = 0 -- changing the + // value of the expression. + std::map vars_zero{ + {"x", Expr(7)}, + {"y", Expr(0)}, + }; + check_solve_equivalent(x * y == 1, vars_zero); + check_solve_equivalent(x * y != 1, vars_zero); + + // Non-zero y: must still be semantically preserved. + std::map vars_nonzero{ + {"x", Expr(7)}, + {"y", Expr(3)}, + }; + check_solve_equivalent(x * y == 1, vars_nonzero); + check_solve_equivalent(x * y != 1, vars_nonzero); +} + +// The guarded form of the Mul rewrite -- a positive constant multiplier -- +// must continue to work after the fix. visit(Div) constant-folds when both +// operands are const, so `Div::make(7, 3)` reduces to 2 during mutation; +// there is no analogous fold for Mod so the Mod node stays. +void test_positive_const_multiplier_still_rewritten() { + Expr x = Variable::make(Int(32), "x"); + Expr seven = Expr(7); + Expr three = Expr(3); + check_solve(3 * x == 7, + (x == 2) && (Mod::make(seven, three) == 0)); + check_solve(3 * x != 7, + (x != 2) || (Mod::make(seven, three) != 0)); +} + +// Solver used to rewrite `f(x) + f(x) -> f(x) * 2` via `operator*(Expr, int)`, +// which rejects constants that don't fit in the expression type. For UInt(1), +// the literal 2 isn't representable, aborting the whole solve. Use Mul::make +// directly with make_const (which truncates modulo width) so the rewrite +// applies soundly for every integer type -- for UInt(1), `a * 2` correctly +// becomes `a * 0`, matching the modular value of `a + a`. +void test_solve_does_not_abort_on_narrow_self_add() { + Expr x = Variable::make(UInt(1), "x"); + // This used to abort with + // "Integer constant 2 will be implicitly coerced to type uint1..." + SolverResult s = solve_expression(x + x, "x"); + // The actual rewritten form is unimportant here -- the test just locks + // in that solve_expression doesn't abort on this shape. + if (!s.result.defined()) { + std::cerr << "solve_expression returned undefined on `x + x` (UInt(1))\n"; + std::abort(); + } +} + +// Solver's `f(x)/a + g(x) -> (f(x) + g(x) * a) / a` rewrite is only valid +// under non-wrapping arithmetic: modularly, g(x)*a can overflow the width +// and the rewrite changes the computed value. Guard it on no_overflow_int. +void test_narrow_div_add_equivalence() { + // Reproduced from the fuzzer (seed 9414558261169807111, minimized): + // `(uint8(a4)/137) + uint8(a4)` at a4=-13 (uint8 243) is + // 243/137 + 243 = 1 + 243 = 244 (uint8, no wrap). + // The previous rewrite would convert this to + // (uint8(a4) * 138) / 137 + // which at uint8 243 gives (243*138 mod 256)/137 = 254/137 = 1. + Expr a = Variable::make(Int(32), "a"); + Expr u = Cast::make(UInt(8), a); + Expr input = u / UIntImm::make(UInt(8), 137) + u; + SolverResult s = solve_expression(input, "a"); + // Verify by concrete substitution: the solved expression must evaluate + // to the same value as the input at a = -13. + std::map subst{{"a", Expr(-13)}}; + Expr in_v = simplify(substitute(subst, input)); + Expr out_v = simplify(substitute(subst, s.result)); + if (!equal(in_v, out_v)) { + std::cerr << "solve_expression changed value on narrow div+add:\n" + << " input: " << input << " -> " << in_v << "\n" + << " solved: " << s.result << " -> " << out_v << "\n"; + std::abort(); + } +} + +// Simplify_Cast was applying a cast-chain simplification +// int32(uint64(X)) -> int32(X) +// whenever widths and the two outer types all lined up for the +// "sign-extend then truncate" shape. The rule's correctness depends on +// the *inner* cast actually being a sign/zero extend, which only holds +// when its source is an integer. For `int32(uint64(float64(a)))` the +// inner cast is an fp-to-uint conversion, which has entirely different +// semantics -- so the stripped form `int32(float64(a))` evaluates to a +// different value (fp-to-int vs fp-to-uint-then-truncate). +void test_simplify_preserves_float_to_uint_cast_chain() { + Expr a = Variable::make(Int(32), "a"); + Expr chained = Cast::make(Int(32), + Cast::make(UInt(64), + Cast::make(Float(64), a))); + Expr simplified = simplify(chained); + + // At a = -21, the two forms must agree. + std::map subst{{"a", Expr(-21)}}; + Expr v1 = simplify(substitute(subst, chained)); + Expr v2 = simplify(substitute(subst, simplified)); + if (!equal(v1, v2)) { + std::cerr << "simplify changed the value of a cast chain:\n" + << " original: " << chained << " -> " << v1 << "\n" + << " simplified: " << simplified << " -> " << v2 << "\n"; + std::abort(); + } +} + +// Previously lived as `solve_test()` at the bottom of src/Solve.cpp and +// was invoked from test/internal.cpp. Moved here so all solver tests are +// in one place. +void test_original_solve_test_cases() { + using ConciseCasts::i16; + + Expr x = Variable::make(Int(32), "x"); + Expr y = Variable::make(Int(32), "y"); + Expr z = Variable::make(Int(32), "z"); + + // Check some simple cases + check_solve(3 - 4 * x, x * (-4) + 3); + check_solve(min(5, x), min(x, 5)); + check_solve(max(5, (5 + x) * y), max(x * y + 5 * y, 5)); + check_solve(5 * y + 3 * x == 2, ((x == ((2 - (5 * y)) / 3)) && (((2 - (5 * y)) % 3) == 0))); + check_solve(min(min(z, x), min(x, y)), min(x, min(y, z))); + check_solve(min(x + y, x + 5), x + min(y, 5)); + + // Check solver with expressions containing division + check_solve(x + (x * 2) / 2, x * 2); + check_solve(x + (x * 2 + y) / 2, x * 2 + (y / 2)); + check_solve(x + (x * 2 - y) / 2, x * 2 - (y / 2)); + check_solve(x + (-(x * 2) / 2), x * 0 + 0); + check_solve(x + (-(x * 2 + -3)) / 2, x * 0 + 1); + check_solve(x + (z - (x * 2 + -3)) / 2, x * 0 + (z - (-3)) / 2); + check_solve(x + (y * 16 + (z - (x * 2 + -1))) / 2, + (x * 0) + (((z - -1) + (y * 16)) / 2)); + + check_solve((x * 9 + 3) / 4 - x * 2, (x * 1 + 3) / 4); + check_solve((x * 9 + 3) / 4 + x * 2, (x * 17 + 3) / 4); + check_solve(x * 2 + (x * 9 + 3) / 4, (x * 17 + 3) / 4); + + // Check the solver doesn't perform transformations that change integer overflow behavior. + check_solve(i16(x + y) * i16(2) / i16(2), i16(x + y) * i16(2) / i16(2)); + + // A let statement + check_solve(Let::make("z", 3 + 5 * x, y + z < 8), + x <= (((8 - (3 + y)) - 1) / 5)); + + // A let statement where the variable gets used twice. + check_solve(Let::make("z", 3 + 5 * x, y + (z + z) < 8), + x <= (((8 - (6 + y)) - 1) / 10)); + + // Something where we expect a let in the output. + { + Expr e = y + 1; + for (int i = 0; i < 10; i++) { + e *= (e + 1); + } + SolverResult solved = solve_expression(x + e < e * e, "x"); + if (!(solved.fully_solved && solved.result.as())) { + std::cerr << "Expected fully-solved Let-bearing result\n"; + std::abort(); + } + } + + // Solving inequalities for integers is a pain to get right with + // all the rounding rules. Check we didn't make a mistake with + // brute force. + for (int den = -3; den <= 3; den++) { + if (den == 0) { + continue; + } + for (int num = 5; num <= 10; num++) { + Expr in[] = { + {x * den < num}, + {x * den <= num}, + {x * den == num}, + {x * den != num}, + {x * den >= num}, + {x * den > num}, + {x / den < num}, + {x / den <= num}, + {x / den == num}, + {x / den != num}, + {x / den >= num}, + {x / den > num}, + }; + for (const auto &e : in) { + SolverResult solved = solve_expression(e, "x"); + if (!solved.fully_solved) { + std::cerr << "Error: failed to solve for x in " << e << "\n"; + std::abort(); + } + Expr out = simplify(solved.result); + for (int i = -10; i < 10; i++) { + Expr in_val = substitute("x", i, e); + Expr out_val = substitute("x", i, out); + in_val = simplify(in_val); + out_val = simplify(out_val); + if (!equal(in_val, out_val)) { + std::cerr << "Error: " + << e << " is not equivalent to " + << out << " when x == " << i << "\n"; + std::abort(); + } + } + } + } + } + + // Check for combinatorial explosion + { + Expr e = x + y; + for (int i = 0; i < 20; i++) { + e += (e + 1) * y; + } + SolverResult solved = solve_expression(e, "x"); + if (!(solved.fully_solved && solved.result.defined())) { + std::cerr << "Expected fully-solved defined result for combinatorial case\n"; + std::abort(); + } + } + + // Check some things that we don't expect to work. + + // Quadratics: + if (solve_expression(x * x < 4, "x").fully_solved) { + std::cerr << "Expected quadratic to not be fully solved\n"; + std::abort(); + } + + // Function calls, cast nodes, or multiplications by unknown sign + // don't get inverted, but the bit containing x still gets moved + // leftwards. + check_solve(4.0f > sqrt(x), sqrt(x) < 4.0f); + + check_solve(4 > y * x, x * y < 4); + + // Now test solving for an interval + check_inner_interval(x > 0, 1, Interval::pos_inf()); + check_inner_interval(x < 100, Interval::neg_inf(), 99); + check_outer_interval(x > 0 && x < 100, 1, 99); + check_inner_interval(x > 0 && x < 100, 1, 99); + + Expr c = Variable::make(Bool(), "c"); + check_outer_interval(Let::make("y", 0, x > y && x < 100), 1, 99); + check_outer_interval(Let::make("c", x > 0, c && x < 100), 1, 99); + + check_outer_interval((x >= 10 && x <= 90) && sin(x) > 0.5f, 10, 90); + check_inner_interval((x >= 10 && x <= 90) && sin(x) > 0.6f, Interval::pos_inf(), Interval::neg_inf()); + + check_inner_interval(x == 10, 10, 10); + check_outer_interval(x == 10, 10, 10); + + check_inner_interval(!(x != 10), 10, 10); + check_outer_interval(!(x != 10), 10, 10); + + check_inner_interval(3 * x + 4 < 27, Interval::neg_inf(), 7); + check_outer_interval(3 * x + 4 < 27, Interval::neg_inf(), 7); + + check_inner_interval(min(x, y) > 17, 18, y); + check_outer_interval(min(x, y) > 17, 18, Interval::pos_inf()); + + check_inner_interval(x / 5 < 17, Interval::neg_inf(), 84); + check_outer_interval(x / 5 < 17, Interval::neg_inf(), 84); + + // Test anding a condition over a domain + check_and_condition(x > 0, const_true(), Interval(1, y)); + check_and_condition(x > 0, const_true(), Interval(5, y)); + check_and_condition(x > 0, const_false(), Interval(-5, y)); + check_and_condition(x > 0 && x < 10, const_true(), Interval(1, 9)); + check_and_condition(x > 0 || sin(x) == 0.5f, const_true(), Interval(100, 200)); + + check_and_condition(x <= 0, const_true(), Interval(-100, 0)); + check_and_condition(x <= 0, const_false(), Interval(-100, 1)); + + check_and_condition(x <= 0 || y > 2, const_true(), Interval(-100, 0)); + check_and_condition(x > 0 || y > 2, 2 < y, Interval(-100, 0)); + + check_and_condition(x == 0, const_true(), Interval(0, 0)); + check_and_condition(x == 0, const_false(), Interval(-10, 10)); + check_and_condition(x != 0, const_false(), Interval(-10, 10)); + check_and_condition(x != 0, const_true(), Interval(-20, -10)); + + check_and_condition(y == 0, y == 0, Interval(-10, 10)); + check_and_condition(y != 0, y != 0, Interval(-10, 10)); + check_and_condition((x == 5) && (y != 0), const_false(), Interval(-10, 10)); + check_and_condition((x == 5) && (y != 3), y != 3, Interval(5, 5)); + check_and_condition((x != 0) && (y != 0), const_false(), Interval(-10, 10)); + check_and_condition((x != 0) && (y != 0), y != 0, Interval(-20, -10)); + + { + // This case used to break due to signed integer overflow in + // the simplifier. + Expr a16 = Load::make(Int(16), "a", {x}, Buffer<>(), Parameter(), const_true(), ModulusRemainder()); + Expr b16 = Load::make(Int(16), "b", {x}, Buffer<>(), Parameter(), const_true(), ModulusRemainder()); + Expr lhs = pow(cast(a16), 2) + pow(cast(b16), 2); + + Scope s; + s.push("x", Interval(-10, 10)); + Expr cond = and_condition_over_domain(lhs < 0, s); + if (is_const_one(simplify(cond))) { + std::cerr << "Expected cond to not simplify to const_one\n"; + std::abort(); + } + } + + { + // This cause use to cause infinite recursion: + Expr t = Variable::make(Int(32), "t"); + Expr test = (x <= min(max((y - min(((z * x) + t), t)), 1), 0)); + Interval result = solve_for_outer_interval(test, "z"); + } + + { + // This case caused exponential behavior + Expr t = Variable::make(Int(32), "t"); + for (int i = 0; i < 50; i++) { + t = min(t, Variable::make(Int(32), unique_name('v'))); + t = max(t, Variable::make(Int(32), unique_name('v'))); + } + solve_for_outer_interval(t <= 5, "t"); + solve_for_inner_interval(t <= 5, "t"); + } + + // Check for partial results + check_solve(max(min(y, x), x), max(min(x, y), x)); + check_solve(min(y, x) + max(y, 2 * x), min(x, y) + max(x * 2, y)); + check_solve((min(x, y) + min(y, x)) * max(y, x), (min(x, y) * 2) * max(x, y)); + check_solve(max((min((y * x), x) + min((1 + y), x)), (y + 2 * x)), + max((min((x * y), x) + min(x, (1 + y))), (x * 2 + y))); + + { + Expr x = Variable::make(UInt(32), "x"); + Expr y = Variable::make(UInt(32), "y"); + Expr z = Variable::make(UInt(32), "z"); + check_solve(5 - (4 - 4 * x), x * (4) + 1); + check_solve(z - (y - x), x + (z - y)); + check_solve(z - (y - x) == 2, x == 2 - (z - y)); + + check_solve(x - (x - y), (x - x) + y); + + // This is used to cause infinite recursion + Expr expr = Add::make(z, Sub::make(x, y)); + SolverResult solved = solve_expression(expr, "y"); + } + + // This case was incorrect due to canonicalization of the multiply + // occurring after unpacking the LHS. + check_solve((y - z) * x, x * (y - z)); + + // These cases were incorrectly not flipping min/max when moving + // it out of the RHS of a subtract. + check_solve(min(x - y, x - z), x - max(y, z)); + check_solve(min(x - y, x), x - max(y, 0)); + check_solve(min(x, x - y), x - max(y, 0)); + check_solve(max(x - y, x - z), x - min(y, z)); + check_solve(max(x - y, x), x - min(y, 0)); + check_solve(max(x, x - y), x - min(y, 0)); + + // Check mixed add/sub + check_solve(min(x - y, x + z), x + min(0 - y, z)); + check_solve(max(x - y, x + z), x + max(0 - y, z)); + check_solve(min(x + y, x - z), x + min(y, 0 - z)); + check_solve(max(x + y, x - z), x + max(y, 0 - z)); + + check_solve((5 * Broadcast::make(x, 4) + y) / 5, + Broadcast::make(x, 4) + (Broadcast::make(y, 4) / 5)); + + // Select negates the condition to move x leftward + check_solve(select(y < z, z, x), + select(z <= y, x, z)); + + // Select negates the condition and then mutates it, moving x + // leftward (despite the simplifier preferring < to >). + check_solve(select(x < 10, 10, x), + select(x >= 10, x, 10)); +} + +} // namespace + +int main(int argc, char **argv) { + test_original_solve_test_cases(); + test_unsigned_ordering_not_rearranged(); + test_unsigned_equality_still_rearranged(); + test_nonconstant_multiplier_not_rewritten(); + test_positive_const_multiplier_still_rewritten(); + test_solve_does_not_abort_on_narrow_self_add(); + test_narrow_div_add_equivalence(); + test_simplify_preserves_float_to_uint_cast_chain(); + std::printf("Success!\n"); + return 0; +} diff --git a/test/fuzz/CMakeLists.txt b/test/fuzz/CMakeLists.txt index af4ba4fad84d..95716c2c134c 100644 --- a/test/fuzz/CMakeLists.txt +++ b/test/fuzz/CMakeLists.txt @@ -11,6 +11,7 @@ tests(GROUPS fuzz cse.cpp lossless_cast.cpp simplify.cpp + solve.cpp widening_lerp.cpp # By default, the libfuzzer harness runs with a timeout of 1200 seconds. # Let's dial that back: diff --git a/test/fuzz/random_expr_generator.h b/test/fuzz/random_expr_generator.h index cc196ba76f5a..a87bf7c7205f 100644 --- a/test/fuzz/random_expr_generator.h +++ b/test/fuzz/random_expr_generator.h @@ -200,8 +200,8 @@ class RandomExpressionGenerator { return fuzz.PickValueInArray(make_bin_op)(a, b); }); } - if (gen_bitwise) { - // Bitwise + if (gen_bitwise && !t.is_float()) { + // Bitwise -- not valid on float types, so skip when t is float. ops.push_back([&] { static make_bin_op_fn make_bin_op[] = { make_bitwise_or, diff --git a/test/fuzz/solve.cpp b/test/fuzz/solve.cpp new file mode 100644 index 000000000000..20197c5d7a54 --- /dev/null +++ b/test/fuzz/solve.cpp @@ -0,0 +1,450 @@ +#include "Halide.h" +#include + +#include "IRGraphCXXPrinter.h" +#include "fuzz_helpers.h" +#include "random_expr_generator.h" + +// Test the solver in Halide by generating random expressions and verifying that +// solve_expression, solve_for_inner_interval, and solve_for_outer_interval +// satisfy their respective contracts under random concrete substitutions. +namespace { + +using std::map; +using std::string; +using namespace Halide; +using namespace Halide::Internal; + +// Wrap a call that may throw InternalError in an std::variant so callers can +// report the failure with context rather than aborting the whole fuzzer. +template +struct SafeResult : std::variant { + using std::variant::variant; + bool ok() const { + return this->index() == 0; + } + bool failed() const { + return this->index() == 1; + } + const T &value() const { + return std::get(*this); + } +}; + +SafeResult safe_simplify(const Expr &e) { + try { + return simplify(e); + } catch (InternalError &err) { + std::cerr << "simplify threw on:\n" + << e << "\n" + << err.what() << "\n"; + return err; + } +} + +SafeResult safe_solve_expression(const Expr &e, const string &var) { + try { + return solve_expression(e, var); + } catch (InternalError &err) { + std::cerr << "solve_expression threw on:\n" + << e << "\n solving for \"" << var << "\"\n" + << err.what() << "\n"; + return err; + } +} + +SafeResult safe_solve_for_inner_interval(const Expr &c, const string &var) { + try { + return solve_for_inner_interval(c, var); + } catch (InternalError &err) { + std::cerr << "solve_for_inner_interval threw on:\n" + << c << "\n solving for \"" << var << "\"\n" + << err.what() << "\n"; + return err; + } +} + +SafeResult safe_solve_for_outer_interval(const Expr &c, const string &var) { + try { + return solve_for_outer_interval(c, var); + } catch (InternalError &err) { + std::cerr << "solve_for_outer_interval threw on:\n" + << c << "\n solving for \"" << var << "\"\n" + << err.what() << "\n"; + return err; + } +} + +Expr random_int_val(FuzzingContext &fuzz, int lo, int hi) { + return cast(Int(32), fuzz.ConsumeIntegralInRange(lo, hi)); +} + +// Returns true if the expression, under the given substitution, contains a +// division or modulo whose divisor simplifies to zero. Halide defines +// div/mod-by-zero to return zero, but the simplifier doesn't always fold +// that consistently across syntactically-different forms -- so solve can +// rearrange an expression into an equivalent shape whose simplified value +// at a concrete substitution differs only because one side gets the +// "returns zero" fold applied while the other doesn't. Skip those samples +// when checking equivalence. Solve often emits Let bindings, so inline +// them first (otherwise Div::b is a variable reference and we can't see +// whether it's zero). +bool has_div_or_mod_by_zero(const Expr &e, const map &vars) { + Expr inlined = substitute_in_all_lets(e); + bool found = false; + auto check_denom = [&](const Expr &denom) { + if (found) return; + if (SafeResult r = safe_simplify(substitute(vars, denom)); r.ok()) { + if (Internal::is_const_zero(r.value())) { + found = true; + } + } + }; + visit_with( + inlined, + [&](auto *self, const Div *op) { + check_denom(op->b); + self->visit_base(op); + }, + [&](auto *self, const Mod *op) { + check_denom(op->b); + self->visit_base(op); + }); + return found; +} + +// Returns true if the expression, under the given substitution, contains a +// narrowing cast whose source value doesn't fit in the destination type. +// Halide's bounds analysis assumes such casts don't overflow (see PR #7814 +// discussion) -- that's a programmer-level contract that the fuzzer's +// random value substitutions can easily violate, and the resulting +// runtime wrap then disagrees with bounds_of's "assumed-fits" prediction. +// Skip those samples when checking contracts that rely on bounds_of. +bool has_overflowing_cast(const Expr &e, const map &vars) { + Expr inlined = substitute_in_all_lets(e); + bool found = false; + auto check_cast = [&](const Cast *op) { + if (found) return; + Type to = op->type; + Type from = op->value.type(); + // Only care about casts between integer/unsigned types that could + // overflow the destination. + if (!(to.is_int_or_uint() && from.is_int_or_uint())) return; + if (to.can_represent(from)) return; + SafeResult r = safe_simplify(substitute(vars, op->value)); + if (!r.ok()) return; + if (auto iv = as_const_int(r.value())) { + if (!to.can_represent(*iv)) found = true; + } else if (auto uv = as_const_uint(r.value())) { + if (!to.can_represent(*uv)) found = true; + } + }; + visit_with( + inlined, + [&](auto *self, const Cast *op) { + check_cast(op); + self->visit_base(op); + }); + return found; +} + +// Test that solve_expression(test, var) produces an expression equivalent to +// `test` under random concrete substitutions of all variables. Modeled after +// the brute-force check at the bottom of Solve.cpp's solve_test(). +bool test_solve_expression_equivalence(RandomExpressionGenerator ®, + const Expr &test, + const string &var, + int samples) { + SafeResult res = safe_solve_expression(test, var); + if (res.failed()) { + return false; + } + Expr solved = res.value().result; + if (!solved.defined()) { + std::cerr << "solve_expression returned an undefined Expr for:\n" + << test << "\n"; + return false; + } + + // Solving again should not throw. + if (safe_solve_expression(solved, var).failed()) { + return false; + } + + map vars; + for (const auto &v : reg.fuzz_vars) { + vars[v.name()] = Expr(); + } + + for (int i = 0; i < samples; i++) { + for (auto &[name, val] : vars) { + val = random_int_val(reg.fuzz, -32, 32); + } + + // Skip samples that invoke div/mod-by-zero in the input: Halide + // defines the result as zero, but the simplifier may apply the + // fold asymmetrically across two syntactically-distinct forms + // that are otherwise semantically equivalent. We don't skip + // based on the *solved* form -- solve must never introduce new + // div/mod-by-zero that wasn't already in the input. + if (has_div_or_mod_by_zero(test, vars) || + has_overflowing_cast(test, vars)) { + continue; + } + + SafeResult test_v = safe_simplify(substitute(vars, test)); + SafeResult solved_v = safe_simplify(substitute(vars, solved)); + if (test_v.failed() || solved_v.failed()) { + return false; + } + + // If either side didn't simplify to a constant, there's likely UB + // (e.g. signed integer overflow) somewhere -- skip this sample. + if (!Internal::is_const(test_v.value()) || !Internal::is_const(solved_v.value())) { + continue; + } + + if (!equal(test_v.value(), solved_v.value())) { + std::cerr << "solve_expression produced a non-equivalent result:\n"; + for (const auto &[name, val] : vars) { + std::cerr << " " << name << " = " << val << "\n"; + } + std::cerr << " variable being solved: " << var << "\n"; + std::cerr << " original: " << test << " -> " << test_v.value() << "\n"; + std::cerr << " solved: " << solved << " -> " << solved_v.value() << "\n"; + return false; + } + } + return true; +} + +// Substitute the given variables and simplify. +Expr subst_and_simplify(const map &vars, const Expr &e) { + return simplify(substitute(vars, e)); +} + +// Returns 1 if `c` simplifies to a true constant, 0 if a false constant, -1 +// otherwise. Used to handle partial results from the simplifier safely. +int try_resolve_bool(const Expr &c) { + Expr s; + if (SafeResult r = safe_simplify(c); r.ok()) { + s = r.value(); + } else { + return -1; + } + if (is_const_one(s)) { + return 1; + } + if (is_const_zero(s)) { + return 0; + } + return -1; +} + +// Test the contracts of solve_for_inner_interval and solve_for_outer_interval +// by sampling values of `var` and checking: +// - if sample is inside the inner interval, the condition must be true +// - if sample is outside the outer interval, the condition must be false +// Non-solving variables are given concrete random values before sampling. +bool test_solve_intervals(RandomExpressionGenerator ®, + const Expr &cond, + const string &var, + int samples) { + internal_assert(cond.type().is_bool()); + + SafeResult inner_res = safe_solve_for_inner_interval(cond, var); + SafeResult outer_res = safe_solve_for_outer_interval(cond, var); + if (inner_res.failed() || outer_res.failed()) { + return false; + } + Interval inner = inner_res.value(); + Interval outer = outer_res.value(); + + map other_vars; + for (const auto &v : reg.fuzz_vars) { + if (v.name() != var) { + other_vars[v.name()] = Expr(); + } + } + + for (int i = 0; i < samples; i++) { + for (auto &[name, val] : other_vars) { + val = random_int_val(reg.fuzz, -16, 16); + } + // Skip substitutions that violate the "assumed not to overflow" + // contract for narrowing int casts. + if (has_overflowing_cast(cond, other_vars)) { + continue; + } + + Expr inner_min_v, inner_max_v, outer_min_v, outer_max_v; + if (inner.has_lower_bound()) inner_min_v = subst_and_simplify(other_vars, inner.min); + if (inner.has_upper_bound()) inner_max_v = subst_and_simplify(other_vars, inner.max); + if (outer.has_lower_bound()) outer_min_v = subst_and_simplify(other_vars, outer.min); + if (outer.has_upper_bound()) outer_max_v = subst_and_simplify(other_vars, outer.max); + Expr cond_sub = substitute(other_vars, cond); + + int val = reg.fuzz.ConsumeIntegralInRange(-64, 64); + Expr var_val = cast(Int(32), val); + int cond_truth = try_resolve_bool(substitute(var, var_val, cond_sub)); + if (cond_truth < 0) { + // Can't resolve (symbolic leftover or UB) -- skip. + continue; + } + + // Inner interval: var_val in [inner.min, inner.max] => cond is true. + // An empty inner interval is a trivial (vacuously true) claim. + int in_inner = inner.is_empty() ? 0 : 1; + if (in_inner == 1 && inner.has_lower_bound()) { + int r = try_resolve_bool(var_val >= inner_min_v); + if (r < 0) { + in_inner = -1; + } else if (r == 0) { + in_inner = 0; + } + } + if (in_inner == 1 && inner.has_upper_bound()) { + int r = try_resolve_bool(var_val <= inner_max_v); + if (r < 0) { + in_inner = -1; + } else if (r == 0) { + in_inner = 0; + } + } + if (in_inner == 1 && cond_truth == 0) { + std::cerr << "solve_for_inner_interval violation\n" + << " cond: " << cond << "\n" + << " var: " << var << " = " << val << "\n" + << " inner interval: [" << inner.min << ", " << inner.max << "]\n"; + for (const auto &[name, v] : other_vars) { + std::cerr << " " << name << " = " << v << "\n"; + } + return false; + } + + // Outer interval: var_val NOT in [outer.min, outer.max] => cond is false. + // An empty outer interval means cond is unsatisfiable, so any sample + // that evaluates to true is a violation. + int out_lb = 0, out_ub = 0; + if (outer.is_empty()) { + out_lb = 1; + } + if (outer.has_lower_bound()) { + int r = try_resolve_bool(var_val < outer_min_v); + if (r < 0) { + out_lb = -1; + } else { + out_lb = r; + } + } + if (outer.has_upper_bound()) { + int r = try_resolve_bool(var_val > outer_max_v); + if (r < 0) { + out_ub = -1; + } else { + out_ub = r; + } + } + if ((out_lb == 1 || out_ub == 1) && cond_truth == 1) { + std::cerr << "solve_for_outer_interval violation\n" + << " cond: " << cond << "\n" + << " var: " << var << " = " << val << "\n" + << " outer interval: [" << outer.min << ", " << outer.max << "]\n"; + for (const auto &[name, v] : other_vars) { + std::cerr << " " << name << " = " << v << "\n"; + } + return false; + } + } + return true; +} + +Expr random_comparison(RandomExpressionGenerator ®, int depth) { + using make_bin_op_fn = Expr (*)(Expr, Expr); + static make_bin_op_fn ops[] = { + EQ::make, + NE::make, + LT::make, + LE::make, + GT::make, + GE::make, + }; + Expr a = reg.random_expr(Int(32), depth); + Expr b = reg.random_expr(Int(32), depth); + return reg.fuzz.PickValueInArray(ops)(a, b); +} + +} // namespace + +FUZZ_TEST(solve, FuzzingContext &fuzz) { + // Depth of the randomly generated expression trees. + constexpr int depth = 6; + // Number of samples to test each invariant at. + constexpr int samples = 20; + + RandomExpressionGenerator reg{fuzz}; + reg.fuzz_types = {Int(8), Int(16), Int(32), Int(64), + UInt(1), UInt(8), UInt(16), UInt(32), UInt(64), + Float(32), Float(64)}; + // Leave gen_shuffles / gen_vector_reduce / gen_reinterpret off for now + // -- those exercise Deinterleaver / shuffle lowering more than solve + // proper. gen_broadcast_of_vector and gen_ramp_of_vector are on so the + // solver sees vector-typed expressions. + reg.gen_shuffles = false; + reg.gen_vector_reduce = false; + reg.gen_reinterpret = false; + + // Pick one of the generator's variables to solve for. + const string var = reg.fuzz_vars[fuzz.ConsumeIntegralInRange(0, reg.fuzz_vars.size() - 1)].name(); + + // solve_expression: arithmetic equivalence. Pick a random width so the + // generator's Broadcast/Ramp lambdas actually fire (they're no-ops on + // scalar types). Vector subtrees containing scalar variables exercise + // the solver's vector-aware handling (see e.g. the Broadcast case in + // src/Solve.cpp's solve_test). + int width = fuzz.PickValueInArray({1, 2, 3, 4, 6, 8}); + Expr test_expr = reg.random_expr(Int(32).with_lanes(width), depth); + if (!test_solve_expression_equivalence(reg, test_expr, var, samples)) { + std::cerr << "Failing expression (C++):\n"; + IRGraphCXXPrinter printer(std::cerr); + printer.print(test_expr); + std::cerr << "Expr final_expr = " << printer.node_names[test_expr.get()] << ";\n"; + std::cerr << " solving for \"" << var << "\"\n"; + return 1; + } + + // solve_expression: also handle comparisons (the solver inverts these). + Expr cmp = random_comparison(reg, depth); + if (!test_solve_expression_equivalence(reg, cmp, var, samples)) { + std::cerr << "Failing comparison (C++):\n"; + IRGraphCXXPrinter printer(std::cerr); + printer.print(cmp); + std::cerr << "Expr final_expr = " << printer.node_names[cmp.get()] << ";\n"; + std::cerr << " solving for \"" << var << "\"\n"; + return 1; + } + + // solve_for_inner_interval / solve_for_outer_interval. + if (!test_solve_intervals(reg, cmp, var, samples)) { + std::cerr << "Failing condition (C++):\n"; + IRGraphCXXPrinter printer(std::cerr); + printer.print(cmp); + std::cerr << "Expr final_expr = " << printer.node_names[cmp.get()] << ";\n"; + std::cerr << " solving for \"" << var << "\"\n"; + return 1; + } + + // Also exercise solve_for_*_interval with compound boolean conditions. + Expr cmp2 = random_comparison(reg, depth); + Expr compound = fuzz.ConsumeBool() ? (cmp && cmp2) : (cmp || cmp2); + if (!test_solve_intervals(reg, compound, var, samples)) { + std::cerr << "Failing compound condition (C++):\n"; + IRGraphCXXPrinter printer(std::cerr); + printer.print(compound); + std::cerr << "Expr final_expr = " << printer.node_names[compound.get()] << ";\n"; + std::cerr << " solving for \"" << var << "\"\n"; + return 1; + } + + return 0; +} diff --git a/test/internal.cpp b/test/internal.cpp index 08283fa9cf54..f64bdfbca1a8 100644 --- a/test/internal.cpp +++ b/test/internal.cpp @@ -31,7 +31,6 @@ int main(int argc, const char **argv) { deinterleave_vector_test(); modulus_remainder_test(); cse_test(); - solve_test(); target_test(); cplusplus_mangle_test(); is_monotonic_test();