diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index e8d22a051cc4..f953dd3825ef 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -93,9 +93,12 @@ is_list_rprimitive, is_none_rprimitive, is_object_rprimitive, + is_str_rprimitive, + is_tagged, is_tuple_rprimitive, object_rprimitive, set_rprimitive, + short_int_rprimitive, vec_api_by_item_type, ) from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional @@ -119,6 +122,7 @@ apply_dunder_specialization, apply_function_specialization, apply_method_specialization, + translate_getitem_with_bounds_check, translate_object_new, translate_object_setattr, ) @@ -137,7 +141,12 @@ from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op from mypyc.primitives.set_ops import set_add_op, set_in_op, set_update_op -from mypyc.primitives.str_ops import str_slice_op +from mypyc.primitives.str_ops import ( + str_adjust_index_op, + str_get_item_unsafe_as_int_op, + str_range_check_op, + str_slice_op, +) from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op # Name and attribute references @@ -918,6 +927,16 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: return result if len(e.operators) == 1: + # s[i] == 'x' / s[i] != 'x' (and the symmetric RHS) -> int compare of + # codepoints. Skips the per-iteration 1-char str allocation/lookup and + # generic str equality call. + if first_op in ("==", "!="): + result = try_specialize_str_index_compare( + builder, first_op, e.operands[0], e.operands[1], e.line + ) + if result is not None: + return result + # Special some common simple cases if first_op in ("is", "is not"): right_expr = e.operands[1] @@ -960,6 +979,50 @@ def go(i: int, prev: Value) -> Value: return go(0, builder.accept(e.operands[0])) +def try_specialize_str_index_compare( + builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int +) -> Value | None: + """Specialize `s[i] == 'x'` / `s[i] != 'x'` (and the symmetric form with + operands swapped) into an int compare of codepoints. + + Returns None if the pattern doesn't match: the indexed base must be str, + the index must be an integer, and the literal must be a 1-character str. + Multi-character or empty literals fall through to the generic str compare + (which still returns False for them, matching today's behavior). + """ + # Normalize so the IndexExpr is on the left. + if isinstance(rhs, IndexExpr) and not isinstance(lhs, IndexExpr): + tmp = lhs + lhs, rhs = rhs, tmp + # Shape: s[i] {==, !=} "x" where "x" is exactly one codepoint. + if ( + not isinstance(lhs, IndexExpr) + or not isinstance(rhs, StrExpr) + or len(rhs.value) != 1 + or not is_str_rprimitive(builder.node_type(lhs.base)) + ): + return None + index_type = builder.node_type(lhs.index) + if not (is_tagged(index_type) or is_fixed_width_rtype(index_type)): + return None + + # ord(s[i]) with bounds check; raises IndexError for out-of-range indices, + # matching the behavior of the generic s[i] path. + codepoint = translate_getitem_with_bounds_check( + builder, + lhs.base, + [lhs.index], + lhs, + str_adjust_index_op, + str_range_check_op, + str_get_item_unsafe_as_int_op, + ) + if codepoint is None: + return None + literal_cp = Integer(ord(rhs.value), short_int_rprimitive, line) + return builder.binary_op(codepoint, literal_cp, op, line) + + def try_specialize_in_expr( builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int ) -> Value | None: diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index 81cd5bd34c04..b16057f645ba 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -1025,3 +1025,153 @@ def is_digit(x): L0: r0 = CPyStr_IsDigit(x) return r0 + +[case testStrIndexEqLiteral_64bit] +def is_comma(s: str, i: int) -> bool: + return s[i] == "," +def is_comma_swapped(s: str, i: int) -> bool: + return "," == s[i] +def is_comma_ne(s: str, i: int) -> bool: + return s[i] != "," +[out] +def is_comma(s, i): + s :: str + i :: int + r0 :: native_int + r1 :: bit + r2, r3 :: i64 + r4 :: ptr + r5 :: c_ptr + r6, r7 :: i64 + r8, r9 :: bool + r10 :: short_int + r11 :: bit +L0: + r0 = i & 1 + r1 = r0 == 0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = i >> 1 + r3 = r2 + goto L3 +L2: + r4 = i ^ 1 + r5 = r4 + r6 = CPyLong_AsInt64(r5) + r3 = r6 + keep_alive i +L3: + r7 = CPyStr_AdjustIndex(s, r3) + r8 = CPyStr_RangeCheck(s, r7) + if r8 goto L5 else goto L4 :: bool +L4: + r9 = raise IndexError('index out of range') + unreachable +L5: + r10 = CPyStr_GetItemUnsafeAsInt(s, r7) + r11 = int_eq r10, 88 + return r11 +def is_comma_swapped(s, i): + s :: str + i :: int + r0 :: native_int + r1 :: bit + r2, r3 :: i64 + r4 :: ptr + r5 :: c_ptr + r6, r7 :: i64 + r8, r9 :: bool + r10 :: short_int + r11 :: bit +L0: + r0 = i & 1 + r1 = r0 == 0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = i >> 1 + r3 = r2 + goto L3 +L2: + r4 = i ^ 1 + r5 = r4 + r6 = CPyLong_AsInt64(r5) + r3 = r6 + keep_alive i +L3: + r7 = CPyStr_AdjustIndex(s, r3) + r8 = CPyStr_RangeCheck(s, r7) + if r8 goto L5 else goto L4 :: bool +L4: + r9 = raise IndexError('index out of range') + unreachable +L5: + r10 = CPyStr_GetItemUnsafeAsInt(s, r7) + r11 = int_eq r10, 88 + return r11 +def is_comma_ne(s, i): + s :: str + i :: int + r0 :: native_int + r1 :: bit + r2, r3 :: i64 + r4 :: ptr + r5 :: c_ptr + r6, r7 :: i64 + r8, r9 :: bool + r10 :: short_int + r11 :: bit +L0: + r0 = i & 1 + r1 = r0 == 0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = i >> 1 + r3 = r2 + goto L3 +L2: + r4 = i ^ 1 + r5 = r4 + r6 = CPyLong_AsInt64(r5) + r3 = r6 + keep_alive i +L3: + r7 = CPyStr_AdjustIndex(s, r3) + r8 = CPyStr_RangeCheck(s, r7) + if r8 goto L5 else goto L4 :: bool +L4: + r9 = raise IndexError('index out of range') + unreachable +L5: + r10 = CPyStr_GetItemUnsafeAsInt(s, r7) + r11 = int_ne r10, 88 + return r11 + +[case testStrIndexEqLiteralNoSpecialize] +def two_char_literal(s: str, i: int) -> bool: + # Multi-char literals don't match the specialization; falls through to + # the generic str equality path. + return s[i] == "ab" +def empty_literal(s: str, i: int) -> bool: + # Empty string literals also fall through; the generic path returns False. + return s[i] == "" +[out] +def two_char_literal(s, i): + s :: str + i :: int + r0, r1 :: str + r2 :: bool +L0: + r0 = CPyStr_GetItem(s, i) + r1 = 'ab' + r2 = CPyStr_EqualLiteral(r0, r1, 2) + return r2 +def empty_literal(s, i): + s :: str + i :: int + r0, r1 :: str + r2 :: bool +L0: + r0 = CPyStr_GetItem(s, i) + r1 = '' + r2 = CPyStr_EqualLiteral(r0, r1, 0) + return r2 diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index ec662da96986..81b85580d7a7 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -1412,3 +1412,64 @@ def test_isdigit_strings() -> None: assert not "\u00e9\u00e8".isdigit() assert not "123\u00e9".isdigit() assert not "\U0001d7ce!".isdigit() + +[case testStrIndexEqLiteralSpecialize] +from testutil import assertRaises + +# The specializer fires on the AST shape `IndexExpr == StrLiteral` (or the +# symmetric swap, and `!=`). The literal has to be a real source-level +# string literal (can't be passed in as a parameter), so each test +# function pins one distinct shape. + +def eq_comma(s: str, i: int) -> bool: + # Specialized: s[i] == "x". + return s[i] == "," + +def ne_comma(s: str, i: int) -> bool: + # Specialized: s[i] != "x". + return s[i] != "," + +def comma_eq(s: str, i: int) -> bool: + # Specialized: "x" == s[i]. Operand-swap is normalized. + return "," == s[i] + +def eq_two_chars(s: str, i: int) -> bool: + # Not specialized: literal isn't 1 char. Falls through to the generic + # str compare, which returns False since s[i] is always 1 codepoint. + return s[i] == "ab" + +def eq_empty(s: str, i: int) -> bool: + # Not specialized: empty literal. Same fall-through. + return s[i] == "" + +def test_specialized_path() -> None: + s = "a,b" # comma at index 1 + assert eq_comma(s, 1) + assert not eq_comma(s, 0) + assert not eq_comma(s, 2) + # != inverts. + assert ne_comma(s, 0) + assert not ne_comma(s, 1) + # Literal on the LHS is normalized to the same shape. + assert comma_eq(s, 1) + assert not comma_eq(s, 0) + +def test_negative_index_is_adjusted() -> None: + s = "a,b" + assert eq_comma(s, -2) # -2 -> 1 (',') + assert not eq_comma(s, -1) # -1 -> 2 ('b') + +def test_non_1char_literal_falls_through() -> None: + s = "a,b" + # Generic str compare answers False because s[i] has length 1. + assert not eq_two_chars(s, 0) + assert not eq_two_chars(s, 1) + assert not eq_empty(s, 0) + +def test_out_of_range_raises_indexerror() -> None: + # Bounds-check semantics match the unspecialized s[i] path. + s = "a,b" + with assertRaises(IndexError): + eq_comma(s, 3) + with assertRaises(IndexError): + eq_comma(s, -4)