Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/14324.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix ``pytest.RaisesGroup`` incorrectly calling the ``check`` callback with contained exceptions instead of only the exception group.
76 changes: 40 additions & 36 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,42 +371,46 @@ def _read_pyc(
fp = open(pyc, "rb")
except OSError:
return None
with fp:
try:
stat_result = os.stat(source)
mtime = int(stat_result.st_mtime)
size = stat_result.st_size
data = fp.read(16)
except OSError as e:
trace(f"_read_pyc({source}): OSError {e}")
return None
# Check for invalid or out of date pyc file.
if len(data) != (16):
trace(f"_read_pyc({source}): invalid pyc (too short)")
return None
if data[:4] != importlib.util.MAGIC_NUMBER:
trace(f"_read_pyc({source}): invalid pyc (bad magic number)")
return None
if data[4:8] != b"\x00\x00\x00\x00":
trace(f"_read_pyc({source}): invalid pyc (unsupported flags)")
return None
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace(f"_read_pyc({source}): out of date")
return None
size_data = data[12:16]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace(f"_read_pyc({source}): invalid pyc (incorrect size)")
return None
try:
co = marshal.load(fp)
except Exception as e:
trace(f"_read_pyc({source}): marshal.load error {e}")
return None
if not isinstance(co, types.CodeType):
trace(f"_read_pyc({source}): not a code object")
return None
return co
try:
with fp:
try:
stat_result = os.stat(source)
mtime = int(stat_result.st_mtime)
size = stat_result.st_size
data = fp.read(16)
except OSError as e:
trace(f"_read_pyc({source}): OSError {e}")
return None
# Check for invalid or out of date pyc file.
if len(data) != (16):
trace(f"_read_pyc({source}): invalid pyc (too short)")
return None
if data[:4] != importlib.util.MAGIC_NUMBER:
trace(f"_read_pyc({source}): invalid pyc (bad magic number)")
return None
if data[4:8] != b"\x00\x00\x00\x00":
trace(f"_read_pyc({source}): invalid pyc (unsupported flags)")
return None
mtime_data = data[8:12]
if int.from_bytes(mtime_data, "little") != mtime & 0xFFFFFFFF:
trace(f"_read_pyc({source}): out of date")
return None
size_data = data[12:16]
if int.from_bytes(size_data, "little") != size & 0xFFFFFFFF:
trace(f"_read_pyc({source}): invalid pyc (incorrect size)")
return None
try:
co = marshal.load(fp)
except Exception as e:
trace(f"_read_pyc({source}): marshal.load error {e}")
return None
if not isinstance(co, types.CodeType):
trace(f"_read_pyc({source}): not a code object")
return None
return co
except OSError as e:
trace(f"_read_pyc({source}): OSError {e}")
return None


def rewrite_asserts(
Expand Down
29 changes: 25 additions & 4 deletions src/_pytest/raises.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,14 +1207,35 @@ def matches(
reason = (
cast(str, self._fail_reason) + f" on the {type(exception).__name__}"
)

suggest_subexception_check = False
if (
len(actual_exceptions) == len(self.expected_exceptions) == 1
self.check is not None
and len(actual_exceptions) == len(self.expected_exceptions) == 1
and isinstance(expected := self.expected_exceptions[0], type)
# we explicitly break typing here :)
and self._check_check(actual_exceptions[0]) # type: ignore[arg-type]
and isinstance(actual_exceptions[0], expected)
):
annotations = getattr(self.check, "__annotations__", {})
param_names = [name for name in annotations if name != "return"]
if param_names:
param_annotation = annotations[param_names[0]]

if isinstance(param_annotation, str):
suggest_subexception_check = (
"ExceptionGroup" not in param_annotation
and "BaseExceptionGroup" not in param_annotation
)
else:
origin = get_origin(param_annotation) or param_annotation
if isinstance(origin, type):
suggest_subexception_check = not issubclass(
origin, BaseExceptionGroup
)

if suggest_subexception_check:
self._fail_reason = reason + (
f", but did return True for the expected {self._repr_expected(expected)}."
f", but the single contained exception matches the expected "
f"{self._repr_expected(expected)}."
f" You might want RaisesGroup(RaisesExc({expected.__name__}, check=<...>))"
)
else:
Expand Down
20 changes: 18 additions & 2 deletions testing/python/raises_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,12 @@ def is_exc(e: ExceptionGroup[ValueError]) -> bool:
return e is exc

is_exc_repr = repr_callable(is_exc)

# this should pass (same object)
with RaisesGroup(ValueError, check=is_exc):
raise exc

# this should fail WITHOUT suggestion
with (
fails_raises_group(
f"check {is_exc_repr} did not return True on the ExceptionGroup"
Expand All @@ -426,16 +429,29 @@ def is_exc(e: ExceptionGroup[ValueError]) -> bool:
def is_value_error(e: BaseException) -> bool:
return isinstance(e, ValueError)

# helpful suggestion if the user thinks the check is for the sub-exception
# this should fail WITH suggestion (because check looks like it's for inner exception)
with (
fails_raises_group(
f"check {is_value_error} did not return True on the ExceptionGroup, but did return True for the expected ValueError. You might want RaisesGroup(RaisesExc(ValueError, check=<...>))"
f"check {is_value_error} did not return True on the ExceptionGroup, but the single contained exception matches the expected ValueError. You might want RaisesGroup(RaisesExc(ValueError, check=<...>))"
),
RaisesGroup(ValueError, check=is_value_error),
):
raise ExceptionGroup("", (ValueError(),))


def test_check_called_only_with_group() -> None:
seen = []

def check(exc_group: ExceptionGroup[ValueError]) -> bool:
seen.append(type(exc_group))
return len(exc_group.exceptions) == 1

with RaisesGroup(ValueError, match="Main message", check=check):
raise ExceptionGroup("Main message", [ValueError("foo")])

assert seen == [ExceptionGroup]


def test_unwrapped_match_check() -> None:
def my_check(e: object) -> bool: # pragma: no cover
return True
Expand Down
33 changes: 33 additions & 0 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,39 @@ def test_read_pyc(self, tmp_path: Path) -> None:

assert _read_pyc(source, pyc) is None # no error

def test_read_pyc_handles_context_manager_oserror(self, tmp_path: Path) -> None:
from _pytest.assertion.rewrite import _read_pyc

source = tmp_path / "source.py"
pyc = Path(str(source) + "c")
source.write_text("def test(): pass", encoding="utf-8")
py_compile.compile(str(source), str(pyc))

real_open = open

class FailingContextManager:
def __init__(self, fp) -> None:
self.fp = fp

def __enter__(self):
return self.fp

def __exit__(self, exc_type, exc, tb) -> None:
self.fp.close()
raise OSError(errno.EIO, "Input/output error")

def __getattr__(self, name):
return getattr(self.fp, name)

def mock_open(file, mode="r", *args, **kwargs):
fp = real_open(file, mode, *args, **kwargs)
if Path(file) == pyc and mode == "rb":
return FailingContextManager(fp)
return fp

with mock.patch("builtins.open", mock_open):
assert _read_pyc(source, pyc) is None

def test_read_pyc_success(self, tmp_path: Path, pytester: Pytester) -> None:
"""
Ensure that the _rewrite_test() -> _write_pyc() produces a pyc file
Expand Down
Loading