Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion mypyc/irbuild/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,12 @@
is_list_rprimitive,
is_none_rprimitive,
is_object_rprimitive,
is_str_rprimitive,
is_tagged,
is_tuple_rprimitive,
object_rprimitive,
set_rprimitive,
short_int_rprimitive,
vec_api_by_item_type,
)
from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional
Expand All @@ -119,6 +122,7 @@
apply_dunder_specialization,
apply_function_specialization,
apply_method_specialization,
translate_getitem_with_bounds_check,
translate_object_new,
translate_object_setattr,
)
Expand All @@ -137,7 +141,12 @@
from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op
from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op
from mypyc.primitives.set_ops import set_add_op, set_in_op, set_update_op
from mypyc.primitives.str_ops import str_slice_op
from mypyc.primitives.str_ops import (
str_adjust_index_op,
str_get_item_unsafe_as_int_op,
str_range_check_op,
str_slice_op,
)
from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op

# Name and attribute references
Expand Down Expand Up @@ -918,6 +927,16 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value:
return result

if len(e.operators) == 1:
# s[i] == 'x' / s[i] != 'x' (and the symmetric RHS) -> int compare of
# codepoints. Skips the per-iteration 1-char str allocation/lookup and
# generic str equality call.
if first_op in ("==", "!="):
result = try_specialize_str_index_compare(
builder, first_op, e.operands[0], e.operands[1], e.line
)
if result is not None:
return result

# Special some common simple cases
if first_op in ("is", "is not"):
right_expr = e.operands[1]
Expand Down Expand Up @@ -960,6 +979,49 @@ def go(i: int, prev: Value) -> Value:
return go(0, builder.accept(e.operands[0]))


def try_specialize_str_index_compare(
builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int
) -> Value | None:
"""Specialize `s[i] == 'x'` / `s[i] != 'x'` (and the symmetric form with
operands swapped) into an int compare of codepoints.
Returns None if the pattern doesn't match: the indexed base must be str,
the index must be an integer, and the literal must be a 1-character str.
Multi-character or empty literals fall through to the generic str compare
(which still returns False for them, matching today's behavior).
"""
# Normalize so the IndexExpr is on the left.
if isinstance(rhs, IndexExpr) and not isinstance(lhs, IndexExpr):
lhs, rhs = rhs, lhs
Comment on lines +994 to +995
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the errors in the run tests are because of a mypy issue #21586 as it seems rhs is typed as IndexExpr after the swap and assigning lhs to it raises a type error.

you might need to use a temp variable as a work-around as this way it seems to work correctly.

tmp = lhs
lhs, rhs = rhs, tmp

# Shape: s[i] {==, !=} "x" where "x" is exactly one codepoint.
if (
not isinstance(lhs, IndexExpr)
or not isinstance(rhs, StrExpr)
or len(rhs.value) != 1
or not is_str_rprimitive(builder.node_type(lhs.base))
):
return None
index_type = builder.node_type(lhs.index)
if not (is_tagged(index_type) or is_fixed_width_rtype(index_type)):
return None

# ord(s[i]) with bounds check; raises IndexError for out-of-range indices,
# matching the behavior of the generic s[i] path.
codepoint = translate_getitem_with_bounds_check(
builder,
lhs.base,
[lhs.index],
lhs,
str_adjust_index_op,
str_range_check_op,
str_get_item_unsafe_as_int_op,
)
if codepoint is None:
return None
literal_cp = Integer(ord(rhs.value), short_int_rprimitive, line)
return builder.binary_op(codepoint, literal_cp, op, line)


def try_specialize_in_expr(
builder: IRBuilder, op: str, lhs: Expression, rhs: Expression, line: int
) -> Value | None:
Expand Down
150 changes: 150 additions & 0 deletions mypyc/test-data/irbuild-str.test
Original file line number Diff line number Diff line change
Expand Up @@ -1025,3 +1025,153 @@ def is_digit(x):
L0:
r0 = CPyStr_IsDigit(x)
return r0

[case testStrIndexEqLiteral]
def is_comma(s: str, i: int) -> bool:
return s[i] == ","
def is_comma_swapped(s: str, i: int) -> bool:
return "," == s[i]
def is_comma_ne(s: str, i: int) -> bool:
return s[i] != ","
[out]
def is_comma(s, i):
s :: str
i :: int
r0 :: native_int
r1 :: bit
r2, r3 :: i64
r4 :: ptr
r5 :: c_ptr
r6, r7 :: i64
r8, r9 :: bool
r10 :: short_int
r11 :: bit
L0:
r0 = i & 1
r1 = r0 == 0
if r1 goto L1 else goto L2 :: bool
L1:
r2 = i >> 1
r3 = r2
goto L3
L2:
r4 = i ^ 1
r5 = r4
r6 = CPyLong_AsInt64(r5)
r3 = r6
keep_alive i
L3:
r7 = CPyStr_AdjustIndex(s, r3)
r8 = CPyStr_RangeCheck(s, r7)
if r8 goto L5 else goto L4 :: bool
L4:
r9 = raise IndexError('index out of range')
unreachable
L5:
r10 = CPyStr_GetItemUnsafeAsInt(s, r7)
r11 = int_eq r10, 88
return r11
def is_comma_swapped(s, i):
s :: str
i :: int
r0 :: native_int
r1 :: bit
r2, r3 :: i64
r4 :: ptr
r5 :: c_ptr
r6, r7 :: i64
r8, r9 :: bool
r10 :: short_int
r11 :: bit
L0:
r0 = i & 1
r1 = r0 == 0
if r1 goto L1 else goto L2 :: bool
L1:
r2 = i >> 1
r3 = r2
goto L3
L2:
r4 = i ^ 1
r5 = r4
r6 = CPyLong_AsInt64(r5)
r3 = r6
keep_alive i
L3:
r7 = CPyStr_AdjustIndex(s, r3)
r8 = CPyStr_RangeCheck(s, r7)
if r8 goto L5 else goto L4 :: bool
L4:
r9 = raise IndexError('index out of range')
unreachable
L5:
r10 = CPyStr_GetItemUnsafeAsInt(s, r7)
r11 = int_eq r10, 88
return r11
def is_comma_ne(s, i):
s :: str
i :: int
r0 :: native_int
r1 :: bit
r2, r3 :: i64
r4 :: ptr
r5 :: c_ptr
r6, r7 :: i64
r8, r9 :: bool
r10 :: short_int
r11 :: bit
L0:
r0 = i & 1
r1 = r0 == 0
if r1 goto L1 else goto L2 :: bool
L1:
r2 = i >> 1
r3 = r2
goto L3
L2:
r4 = i ^ 1
r5 = r4
r6 = CPyLong_AsInt64(r5)
r3 = r6
keep_alive i
L3:
r7 = CPyStr_AdjustIndex(s, r3)
r8 = CPyStr_RangeCheck(s, r7)
if r8 goto L5 else goto L4 :: bool
L4:
r9 = raise IndexError('index out of range')
unreachable
L5:
r10 = CPyStr_GetItemUnsafeAsInt(s, r7)
r11 = int_ne r10, 88
return r11

[case testStrIndexEqLiteralNoSpecialize]
def two_char_literal(s: str, i: int) -> bool:
# Multi-char literals don't match the specialization; falls through to
# the generic str equality path.
return s[i] == "ab"
def empty_literal(s: str, i: int) -> bool:
# Empty string literals also fall through; the generic path returns False.
return s[i] == ""
[out]
def two_char_literal(s, i):
s :: str
i :: int
r0, r1 :: str
r2 :: bool
L0:
r0 = CPyStr_GetItem(s, i)
r1 = 'ab'
r2 = CPyStr_EqualLiteral(r0, r1, 2)
return r2
def empty_literal(s, i):
s :: str
i :: int
r0, r1 :: str
r2 :: bool
L0:
r0 = CPyStr_GetItem(s, i)
r1 = ''
r2 = CPyStr_EqualLiteral(r0, r1, 0)
return r2
71 changes: 71 additions & 0 deletions mypyc/test-data/run-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -1412,3 +1412,74 @@ def test_isdigit_strings() -> None:
assert not "\u00e9\u00e8".isdigit()
assert not "123\u00e9".isdigit()
assert not "\U0001d7ce!".isdigit()

[case testStrIndexEqLiteralSpecialize]
from typing import Any

from testutil import assertRaises

# The specializer fires on the AST shape `IndexExpr == StrLiteral` (or the
# symmetric swap, and `!=`). The literal has to be a real source-level
# string literal (can't be passed in as a parameter), so each test
# function pins one distinct shape.

def eq_comma(s: str, i: int) -> bool:
# Specialized: s[i] == "x".
return s[i] == ","

def ne_comma(s: str, i: int) -> bool:
# Specialized: s[i] != "x".
return s[i] != ","

def comma_eq(s: str, i: int) -> bool:
# Specialized: "x" == s[i]. Operand-swap is normalized.
return "," == s[i]

def eq_two_chars(s: str, i: int) -> bool:
# Not specialized: literal isn't 1 char. Falls through to the generic
# str compare, which returns False since s[i] is always 1 codepoint.
return s[i] == "ab"

def eq_empty(s: str, i: int) -> bool:
# Not specialized: empty literal. Same fall-through.
return s[i] == ""

def test_specialized_path() -> None:
s = "a,b" # comma at index 1
assert eq_comma(s, 1)
assert not eq_comma(s, 0)
assert not eq_comma(s, 2)
# != inverts.
assert ne_comma(s, 0)
assert not ne_comma(s, 1)
# Literal on the LHS is normalized to the same shape.
assert comma_eq(s, 1)
assert not comma_eq(s, 0)

def test_negative_index_is_adjusted() -> None:
s = "a,b"
assert eq_comma(s, -2) # -2 -> 1 (',')
assert not eq_comma(s, -1) # -1 -> 2 ('b')

def test_non_1char_literal_falls_through() -> None:
s = "a,b"
# Generic str compare answers False because s[i] has length 1.
assert not eq_two_chars(s, 0)
assert not eq_two_chars(s, 1)
assert not eq_empty(s, 0)

def test_out_of_range_raises_indexerror() -> None:
# Bounds-check semantics match the unspecialized s[i] path.
s = "a,b"
with assertRaises(IndexError):
eq_comma(s, 3)
with assertRaises(IndexError):
eq_comma(s, -4)

def test_any_dispatch_uses_generic_path() -> None:
# Going through `Any` routes through the interpreted wrapper, which
# uses the unspecialized lowering. Confirms the str surface still
# works for callers that bypass the specializer.
f: Any = eq_comma
assert f("hello,world", 5) is True
assert f("hello", 0) is False
Comment on lines +1480 to +1485
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think the comment is true, the interpreted wrapper still calls the same generated C function for eq_comma that has the optimization.

to test unspecialized lowering you could add a test case that compares against a one-char str passed as a parameter instead of a literal. i'd imagine we have tests like that already though so i think you could just remove this test case.

Loading