diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bab640c5..12c42cd8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,16 +9,16 @@ repos: - id: end-of-file-fixer exclude_types: [json, binary] - repo: https://github.com/psf/black-pre-commit-mirror - rev: "25.11.0" + rev: "26.1.0" hooks: - id: black-jupyter - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.14.7" + rev: "v0.15.4" hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - repo: https://github.com/kynan/nbstripout - rev: "0.8.2" + rev: "0.9.1" hooks: - id: nbstripout exclude: docs/benchmarks.ipynb diff --git a/pyproject.toml b/pyproject.toml index 276767e9..d55864ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ ignore = [ "PLR0915", # Allow many statements "PLR2004", # Allow magic numbers in comparisons "B905", # Allow zip() without explicit `strict=` parameter + "UP047", # PEP 695 type params require Python 3.12+; project supports 3.11+ ] exclude = [] diff --git a/tests/test_gp.py b/tests/test_gp.py index add8ad91..9268b6ad 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -25,9 +25,7 @@ def test_sample(data): X, _ = data with jax.enable_x64(True): - gp = GaussianProcess( - kernels.Matern32(1.5), X, diag=0.01, mean=lambda x: jnp.sum(x) - ) + gp = GaussianProcess(kernels.Matern32(1.5), X, diag=0.01, mean=jnp.sum) y = gp.sample(jax.random.PRNGKey(543)) assert y.shape == (len(X),)