diff --git a/replicate/exceptions.py b/replicate/exceptions.py index f52f9fb..dbf563f 100644 --- a/replicate/exceptions.py +++ b/replicate/exceptions.py @@ -17,7 +17,9 @@ class ModelError(ReplicateException): def __init__(self, prediction: "Prediction") -> None: self.prediction = prediction - super().__init__(prediction.error) + super().__init__( + f"Prediction {prediction.id} {prediction.status}: {prediction.error}" + ) class ReplicateError(ReplicateException): diff --git a/tests/test_run.py b/tests/test_run.py index 93f7248..fd66645 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -581,7 +581,7 @@ async def test_run_with_model_error(mock_replicate_api_token): }, ) - assert str(excinfo.value) == "OOM" + assert str(excinfo.value) == "Prediction p1 failed: OOM" assert excinfo.value.prediction.error == "OOM" assert excinfo.value.prediction.status == "failed"