From 60abe8210a485e68577666ba6ba8661bc4b10ded Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Tue, 2 Jun 2026 10:44:49 +0300 Subject: [PATCH 1/3] [mypyc] Specialize `s[i] == 'x'` to a codepoint int compare Recognizes the AST shape `IndexExpr(str) == StrLiteral` (and the symmetric `StrLiteral == IndexExpr(str)`, plus the `!=` variants) and lowers it to an int compare of codepoints reusing the existing CPyStr_GetItemUnsafeAsInt primitive. Today the pattern lowers to CPyStr_GetItem + CPyStr_EqualLiteral, which allocates or looks up a 1-character PyUnicode object per iteration and goes through a generic string-equality call. After specialization it becomes an inlined PyUnicode_READ plus an int compare -- about 4x faster on bench_str_compare with a 3-compares-per-iteration workload, and closer to ~9x with the more typical 1-compare-per-iteration shape. No annotations required; benefits any code that compares a string index against a 1-character literal. Multi-character / empty literals fall through to the generic path (which still correctly returns False). Bounds checking is preserved -- the helper raises IndexError for out-of-range indices, same as the unspecialized path. Stack: builds on the `ord(s[i])` primitive (#20578) and the librt.strings codepoint helpers (#21462, #21504, #21509, #21521, #21522, #21553). --- mypyc/irbuild/expression.py | 64 ++++++++++++- mypyc/test-data/irbuild-str.test | 150 +++++++++++++++++++++++++++++++ mypyc/test-data/run-strings.test | 71 +++++++++++++++ 3 files changed, 284 insertions(+), 1 deletion(-) diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index e8d22a051cc4d..a1a296220f53b 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -93,9 +93,12 @@ is_list_rprimitive, is_none_rprimitive, is_object_rprimitive, + is_str_rprimitive, + is_tagged, is_tuple_rprimitive, object_rprimitive, set_rprimitive, + short_int_rprimitive, vec_api_by_item_type, ) from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional @@ -119,6 +122,7 @@ apply_dunder_specialization, apply_function_specialization, apply_method_specialization, + translate_getitem_with_bounds_check, translate_object_new, translate_object_setattr, ) @@ -137,7 +141,12 @@ from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op from mypyc.primitives.set_ops import set_add_op, set_in_op, set_update_op -from mypyc.primitives.str_ops import str_slice_op +from mypyc.primitives.str_ops import ( + str_adjust_index_op, + str_get_item_unsafe_as_int_op, + str_range_check_op, + str_slice_op, +) from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op # Name and attribute references @@ -918,6 +927,16 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: return result if len(e.operators) == 1: + # s[i] == 'x' / s[i] != 'x' (and the symmetric RHS) -> int compare of + # codepoints. Skips the per-iteration 1-char str allocation/lookup and + # generic str equality call. + if first_op in ("==", "!="): + result = try_specialize_str_index_compare( + builder, first_op, e.operands[0], e.operands[1], e.line + ) + if result is not None: + return result + # Special some common simple cases if first_op in ("is", "is not"): right_expr = e.operands[1] @@ -960,6 +979,49 @@ def go(i: int, prev: Value) -> Value: return go(0, builder.accept(e.operands[0])) +def try_specialize_str_index_compare( + builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int +) -> Value | None: + """Specialize `s[i] == 'x'` / `s[i] != 'x'` (and the symmetric form with + operands swapped) into an int compare of codepoints. + + Returns None if the pattern doesn't match: the indexed base must be str, + the index must be an integer, and the literal must be a 1-character str. + Multi-character or empty literals fall through to the generic str compare + (which still returns False for them, matching today's behavior). + """ + # Normalize so the IndexExpr is on the left. + if isinstance(rhs, IndexExpr) and not isinstance(lhs, IndexExpr): + lhs, rhs = rhs, lhs + # Shape: s[i] {==, !=} "x" where "x" is exactly one codepoint. + if ( + not isinstance(lhs, IndexExpr) + or not isinstance(rhs, StrExpr) + or len(rhs.value) != 1 + or not is_str_rprimitive(builder.node_type(lhs.base)) + ): + return None + index_type = builder.node_type(lhs.index) + if not (is_tagged(index_type) or is_fixed_width_rtype(index_type)): + return None + + # ord(s[i]) with bounds check; raises IndexError for out-of-range indices, + # matching the behavior of the generic s[i] path. + codepoint = translate_getitem_with_bounds_check( + builder, + lhs.base, + [lhs.index], + lhs, + str_adjust_index_op, + str_range_check_op, + str_get_item_unsafe_as_int_op, + ) + if codepoint is None: + return None + literal_cp = Integer(ord(rhs.value), short_int_rprimitive, line) + return builder.binary_op(codepoint, literal_cp, op, line) + + def try_specialize_in_expr( builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int ) -> Value | None: diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index 81cd5bd34c046..09b836deb2b84 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -1025,3 +1025,153 @@ def is_digit(x): L0: r0 = CPyStr_IsDigit(x) return r0 + +[case testStrIndexEqLiteral] +def is_comma(s: str, i: int) -> bool: + return s[i] == "," +def is_comma_swapped(s: str, i: int) -> bool: + return "," == s[i] +def is_comma_ne(s: str, i: int) -> bool: + return s[i] != "," +[out] +def is_comma(s, i): + s :: str + i :: int + r0 :: native_int + r1 :: bit + r2, r3 :: i64 + r4 :: ptr + r5 :: c_ptr + r6, r7 :: i64 + r8, r9 :: bool + r10 :: short_int + r11 :: bit +L0: + r0 = i & 1 + r1 = r0 == 0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = i >> 1 + r3 = r2 + goto L3 +L2: + r4 = i ^ 1 + r5 = r4 + r6 = CPyLong_AsInt64(r5) + r3 = r6 + keep_alive i +L3: + r7 = CPyStr_AdjustIndex(s, r3) + r8 = CPyStr_RangeCheck(s, r7) + if r8 goto L5 else goto L4 :: bool +L4: + r9 = raise IndexError('index out of range') + unreachable +L5: + r10 = CPyStr_GetItemUnsafeAsInt(s, r7) + r11 = int_eq r10, 88 + return r11 +def is_comma_swapped(s, i): + s :: str + i :: int + r0 :: native_int + r1 :: bit + r2, r3 :: i64 + r4 :: ptr + r5 :: c_ptr + r6, r7 :: i64 + r8, r9 :: bool + r10 :: short_int + r11 :: bit +L0: + r0 = i & 1 + r1 = r0 == 0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = i >> 1 + r3 = r2 + goto L3 +L2: + r4 = i ^ 1 + r5 = r4 + r6 = CPyLong_AsInt64(r5) + r3 = r6 + keep_alive i +L3: + r7 = CPyStr_AdjustIndex(s, r3) + r8 = CPyStr_RangeCheck(s, r7) + if r8 goto L5 else goto L4 :: bool +L4: + r9 = raise IndexError('index out of range') + unreachable +L5: + r10 = CPyStr_GetItemUnsafeAsInt(s, r7) + r11 = int_eq r10, 88 + return r11 +def is_comma_ne(s, i): + s :: str + i :: int + r0 :: native_int + r1 :: bit + r2, r3 :: i64 + r4 :: ptr + r5 :: c_ptr + r6, r7 :: i64 + r8, r9 :: bool + r10 :: short_int + r11 :: bit +L0: + r0 = i & 1 + r1 = r0 == 0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = i >> 1 + r3 = r2 + goto L3 +L2: + r4 = i ^ 1 + r5 = r4 + r6 = CPyLong_AsInt64(r5) + r3 = r6 + keep_alive i +L3: + r7 = CPyStr_AdjustIndex(s, r3) + r8 = CPyStr_RangeCheck(s, r7) + if r8 goto L5 else goto L4 :: bool +L4: + r9 = raise IndexError('index out of range') + unreachable +L5: + r10 = CPyStr_GetItemUnsafeAsInt(s, r7) + r11 = int_ne r10, 88 + return r11 + +[case testStrIndexEqLiteralNoSpecialize] +def two_char_literal(s: str, i: int) -> bool: + # Multi-char literals don't match the specialization; falls through to + # the generic str equality path. + return s[i] == "ab" +def empty_literal(s: str, i: int) -> bool: + # Empty string literals also fall through; the generic path returns False. + return s[i] == "" +[out] +def two_char_literal(s, i): + s :: str + i :: int + r0, r1 :: str + r2 :: bool +L0: + r0 = CPyStr_GetItem(s, i) + r1 = 'ab' + r2 = CPyStr_EqualLiteral(r0, r1, 2) + return r2 +def empty_literal(s, i): + s :: str + i :: int + r0, r1 :: str + r2 :: bool +L0: + r0 = CPyStr_GetItem(s, i) + r1 = '' + r2 = CPyStr_EqualLiteral(r0, r1, 0) + return r2 diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index ec662da969865..18d582849f443 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -1412,3 +1412,74 @@ def test_isdigit_strings() -> None: assert not "\u00e9\u00e8".isdigit() assert not "123\u00e9".isdigit() assert not "\U0001d7ce!".isdigit() + +[case testStrIndexEqLiteralSpecialize] +from typing import Any + +from testutil import assertRaises + +# The specializer fires on the AST shape `IndexExpr == StrLiteral` (or the +# symmetric swap, and `!=`). The literal has to be a real source-level +# string literal (can't be passed in as a parameter), so each test +# function pins one distinct shape. + +def eq_comma(s: str, i: int) -> bool: + # Specialized: s[i] == "x". + return s[i] == "," + +def ne_comma(s: str, i: int) -> bool: + # Specialized: s[i] != "x". + return s[i] != "," + +def comma_eq(s: str, i: int) -> bool: + # Specialized: "x" == s[i]. Operand-swap is normalized. + return "," == s[i] + +def eq_two_chars(s: str, i: int) -> bool: + # Not specialized: literal isn't 1 char. Falls through to the generic + # str compare, which returns False since s[i] is always 1 codepoint. + return s[i] == "ab" + +def eq_empty(s: str, i: int) -> bool: + # Not specialized: empty literal. Same fall-through. + return s[i] == "" + +def test_specialized_path() -> None: + s = "a,b" # comma at index 1 + assert eq_comma(s, 1) + assert not eq_comma(s, 0) + assert not eq_comma(s, 2) + # != inverts. + assert ne_comma(s, 0) + assert not ne_comma(s, 1) + # Literal on the LHS is normalized to the same shape. + assert comma_eq(s, 1) + assert not comma_eq(s, 0) + +def test_negative_index_is_adjusted() -> None: + s = "a,b" + assert eq_comma(s, -2) # -2 -> 1 (',') + assert not eq_comma(s, -1) # -1 -> 2 ('b') + +def test_non_1char_literal_falls_through() -> None: + s = "a,b" + # Generic str compare answers False because s[i] has length 1. + assert not eq_two_chars(s, 0) + assert not eq_two_chars(s, 1) + assert not eq_empty(s, 0) + +def test_out_of_range_raises_indexerror() -> None: + # Bounds-check semantics match the unspecialized s[i] path. + s = "a,b" + with assertRaises(IndexError): + eq_comma(s, 3) + with assertRaises(IndexError): + eq_comma(s, -4) + +def test_any_dispatch_uses_generic_path() -> None: + # Going through `Any` routes through the interpreted wrapper, which + # uses the unspecialized lowering. Confirms the str surface still + # works for callers that bypass the specializer. + f: Any = eq_comma + assert f("hello,world", 5) is True + assert f("hello", 0) is False From 1a2e2de649e99c7a88966bf393fd0e7cb9f7d31b Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Thu, 4 Jun 2026 10:02:55 +0300 Subject: [PATCH 2/3] Address review feedback - Use a temp variable for the swap normalization. Tuple-unpack form (lhs, rhs = rhs, lhs) interacted badly with mypy's narrowing in mypyc-compiled mypy, producing a runtime IndexExpr-vs-StrExpr cast failure (mypy#21586). Workaround per @p-sawicki on PR #21579. - Drop test_any_dispatch_uses_generic_path. The 'Any' dispatch still calls the mypyc-compiled eq_comma, which has the specialization, so this test was not exercising the unspecialized path as claimed. The IR golden pins the specialized lowering, and eq_two_chars / eq_empty cover the fall-through behavior. --- mypyc/irbuild/expression.py | 3 ++- mypyc/test-data/run-strings.test | 10 ---------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index a1a296220f53b..f953dd3825ef2 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -992,7 +992,8 @@ def try_specialize_str_index_compare( """ # Normalize so the IndexExpr is on the left. if isinstance(rhs, IndexExpr) and not isinstance(lhs, IndexExpr): - lhs, rhs = rhs, lhs + tmp = lhs + lhs, rhs = rhs, tmp # Shape: s[i] {==, !=} "x" where "x" is exactly one codepoint. if ( not isinstance(lhs, IndexExpr) diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index 18d582849f443..81b85580d7a7b 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -1414,8 +1414,6 @@ def test_isdigit_strings() -> None: assert not "\U0001d7ce!".isdigit() [case testStrIndexEqLiteralSpecialize] -from typing import Any - from testutil import assertRaises # The specializer fires on the AST shape `IndexExpr == StrLiteral` (or the @@ -1475,11 +1473,3 @@ def test_out_of_range_raises_indexerror() -> None: eq_comma(s, 3) with assertRaises(IndexError): eq_comma(s, -4) - -def test_any_dispatch_uses_generic_path() -> None: - # Going through `Any` routes through the interpreted wrapper, which - # uses the unspecialized lowering. Confirms the str surface still - # works for callers that bypass the specializer. - f: Any = eq_comma - assert f("hello,world", 5) is True - assert f("hello", 0) is False From 72b6cc053526ec606302ca9291b483aaf810845f Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Thu, 4 Jun 2026 10:28:00 +0300 Subject: [PATCH 3/3] Skip the specialized IR golden on 32-bit Python Rename testStrIndexEqLiteral -> testStrIndexEqLiteral_64bit so it skips on 32-bit. The golden output captures the int-unbox-to-i64 path emitted by translate_getitem_with_bounds_check, which differs on 32-bit (extra 'extend signed i: builtins.int to i64' op shifts register numbering). testOrdOfStrIndex_64bit uses the same primitives and follows the same convention. The fall-through golden (testStrIndexEqLiteralNoSpecialize) keeps no suffix; its IR uses CPyStr_GetItem directly with no unboxing. --- mypyc/test-data/irbuild-str.test | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index 09b836deb2b84..b16057f645ba7 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -1026,7 +1026,7 @@ L0: r0 = CPyStr_IsDigit(x) return r0 -[case testStrIndexEqLiteral] +[case testStrIndexEqLiteral_64bit] def is_comma(s: str, i: int) -> bool: return s[i] == "," def is_comma_swapped(s: str, i: int) -> bool: