From 7de27c9b7bd6da4e480d7108a6d0a0005c720c12 Mon Sep 17 00:00:00 2001 From: Timothy Nunn Date: Thu, 21 May 2026 09:51:47 +0100 Subject: [PATCH 1/2] Correct typing errors and add ty to pre-commit --- .github/workflows/quality.yml | 4 ++-- .pre-commit-config.yaml | 7 ++++++ examples/converges.ipynb | 4 ++-- pyproject.toml | 4 ++++ src/pyvmcon/problem.py | 16 +++++++------- src/pyvmcon/vmcon.py | 10 +++++---- tests/test_vmcon_paper.py | 40 +++++++++++++++++------------------ 7 files changed, 49 insertions(+), 36 deletions(-) diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index 9725acd..0bddfa4 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -16,7 +16,7 @@ jobs: with: python-version: "3.10" - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd - - name: Install pre-commit - run: pip install pre-commit + - name: Install PyVMCON with dev dependencies (includes pre-commit) + run: pip install '.[dev]' - name: Run pre-commit run: pre-commit run --all-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e8c029..193c81f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,3 +11,10 @@ repos: - id: ruff-check args: [--fix] - id: ruff-format + - repo: local + hooks: + - id: ty + name: ty check + entry: ty check . + pass_filenames: false + language: python diff --git a/examples/converges.ipynb b/examples/converges.ipynb index 95c20f8..57332af 100644 --- a/examples/converges.ipynb +++ b/examples/converges.ipynb @@ -35,8 +35,8 @@ "from pyvmcon import Problem\n", "\n", "\n", - "def f(x: list):\n", - " return (x[0] - 2) ** 2 + (x[1] - 1) ** 2\n", + "def f(x):\n", + " return np.array((x[0] - 2) ** 2 + (x[1] - 1) ** 2)\n", "\n", "\n", "problem = Problem(\n", diff --git a/pyproject.toml b/pyproject.toml index c4e512d..200c647 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ Source = "https://github.com/ukaea/PyVMCON" [project.optional-dependencies] test = ["pytest"] +dev = ["pre-commit", "ty==0.0.38"] docs = [ "Sphinx>=6.1", "sphinxcontrib-mermaid>=0.9", @@ -60,3 +61,6 @@ ignore = [ [tool.ruff.format] preview = true + +[tool.ty.rules] +unresolved-import = "ignore" diff --git a/src/pyvmcon/problem.py b/src/pyvmcon/problem.py index e42c771..38cc8d5 100644 --- a/src/pyvmcon/problem.py +++ b/src/pyvmcon/problem.py @@ -3,16 +3,16 @@ from abc import ABC, abstractmethod from collections.abc import Callable from dataclasses import dataclass, field -from typing import TypeVar, cast +from typing import TypeAlias import numpy as np from numpy.typing import NDArray -ScalarType = TypeVar("ScalarType", NDArray, np.number, float) +ScalarType: TypeAlias = NDArray | np.number | float """A scalar variable e.g. a single number (which could be a 0D numpy array)""" -VectorType = NDArray +VectorType: TypeAlias = NDArray """A numpy array with only 1 dimension""" -MatrixType = NDArray +MatrixType: TypeAlias = NDArray """A numpy array with 2 dimensions""" @@ -111,10 +111,10 @@ def __call__(self, x: VectorType) -> Result: return Result( self.f(x), self.df(x), - cast("VectorType", np.array([c(x) for c in self.equality_constraints])), - cast("MatrixType", np.array([c(x) for c in self.dequality_constraints])), - cast("VectorType", np.array([c(x) for c in self.inequality_constraints])), - cast("MatrixType", np.array([c(x) for c in self.dinequality_constraints])), + np.array([c(x) for c in self.equality_constraints]), + np.array([c(x) for c in self.dequality_constraints]), + np.array([c(x) for c in self.inequality_constraints]), + np.array([c(x) for c in self.dinequality_constraints]), ) @property diff --git a/src/pyvmcon/vmcon.py b/src/pyvmcon/vmcon.py index 3dce4dd..155432f 100644 --- a/src/pyvmcon/vmcon.py +++ b/src/pyvmcon/vmcon.py @@ -31,7 +31,7 @@ def solve( initial_B: np.ndarray | None = None, callback: Callable[[int, Result, VectorType, float], None] | None = None, additional_convergence: ( - Callable[[Result, VectorType, VectorType, VectorType, VectorType], None] | None + Callable[[Result, VectorType, VectorType, VectorType, VectorType], bool] | None ) = None, overwrite_convergence_criteria: bool = False, ) -> tuple[VectorType, VectorType, VectorType, Result]: @@ -149,8 +149,10 @@ def solve( B = np.identity(n) if initial_B is None else initial_B callback = callback or (lambda _i, _result, _x, _con: None) - additional_convergence = additional_convergence or ( - lambda _result, _x, _delta, _lambda_eq, _lambda_in: True + additional_convergence_function = ( + (lambda _result, _x, _delta, _lambda_eq, _lambda_in: True) + if additional_convergence is None + else additional_convergence ) # These two values being None allows the line @@ -201,7 +203,7 @@ def solve( callback(j, result, x, convergence_info) - if additional_convergence( + if additional_convergence_function( result, x, delta, diff --git a/tests/test_vmcon_paper.py b/tests/test_vmcon_paper.py index 1a8b907..b1e9923 100644 --- a/tests/test_vmcon_paper.py +++ b/tests/test_vmcon_paper.py @@ -19,8 +19,8 @@ class VMCONTestAsset: expected_x: np.ndarray expected_lamda_equality: np.ndarray expected_lamda_inequality: np.ndarray - lbs: np.ndarray = None - ubs: np.ndarray = None + lbs: np.ndarray | None = None + ubs: np.ndarray | None = None max_iter: int = 10 epsilon: float = 1e-8 @@ -39,9 +39,9 @@ class VMCONTestAsset: [lambda x: np.array([-0.5 * x[0], -2 * x[1]])], ), initial_x=np.array([2.0, 2.0]), - expected_x=[8.228756e-1, 9.114378e-1], - expected_lamda_equality=[-1.594491], - expected_lamda_inequality=[1.846591], + expected_x=np.array([8.228756e-1, 9.114378e-1]), + expected_lamda_equality=np.array([-1.594491]), + expected_lamda_inequality=np.array([1.846591]), ), # Test 1 detailed in ANL-80-64 page 25 # with one of the constraints duplicated @@ -57,11 +57,11 @@ class VMCONTestAsset: [lambda x: np.array([-0.5 * x[0], -2 * x[1]])], ), initial_x=np.array([2.0, 2.0]), - expected_x=[8.228756e-1, 9.114378e-1], + expected_x=np.array([8.228756e-1, 9.114378e-1]), # duplicating the constraint is probably expected # to change the Lagrange multipliers - expected_lamda_equality=[-0.7972455591261, -0.7972455591261], - expected_lamda_inequality=[1.846591], + expected_lamda_equality=np.array([-0.7972455591261, -0.7972455591261]), + expected_lamda_inequality=np.array([1.846591]), ), # Test 1 detailed in ANL-80-64 page 25 # with added, unintrusive, bounds @@ -75,9 +75,9 @@ class VMCONTestAsset: [lambda x: np.array([-0.5 * x[0], -2 * x[1]])], ), initial_x=np.array([2.0, 2.0]), - expected_x=[8.228756e-1, 9.114378e-1], - expected_lamda_equality=[-1.594491], - expected_lamda_inequality=[1.846591], + expected_x=np.array([8.228756e-1, 9.114378e-1]), + expected_lamda_equality=np.array([-1.594491]), + expected_lamda_inequality=np.array([1.846591]), lbs=np.array([-10, -10]), ubs=np.array([10, 10]), ), @@ -98,9 +98,9 @@ class VMCONTestAsset: ], ), initial_x=np.array([2.0, 2.0]), - expected_x=[1.6649685472365443, 0.55404867491788852], - expected_lamda_equality=[], - expected_lamda_inequality=[0, 0.80489557193146243], + expected_x=np.array([1.6649685472365443, 0.55404867491788852]), + expected_lamda_equality=np.array([]), + expected_lamda_inequality=np.array([0, 0.80489557193146243]), ), # Example 1a of https://en.wikipedia.org/wiki/Lagrange_multiplier VMCONTestAsset( @@ -114,9 +114,9 @@ class VMCONTestAsset: ), initial_x=np.array([1.0, 1.05]), epsilon=1e-8, - expected_x=[-0.5 * 2**0.5, -0.5 * 2**0.5], - expected_lamda_equality=[-(2 ** (-0.5))], - expected_lamda_inequality=[], + expected_x=np.array([-0.5 * 2**0.5, -0.5 * 2**0.5]), + expected_lamda_equality=np.array([-(2 ** (-0.5))]), + expected_lamda_inequality=np.array([]), max_iter=30, ), ], @@ -161,9 +161,9 @@ def test_vmcon_paper_feasible_examples(vmcon_example: VMCONTestAsset): ), initial_x=np.array([2.0, 2.0]), max_iter=5, - expected_x=[2.3999994310874733, 0.6], - expected_lamda_equality=[0.0], - expected_lamda_inequality=[0.0], + expected_x=np.array([2.3999994310874733, 0.6]), + expected_lamda_equality=np.array([0.0]), + expected_lamda_inequality=np.array([0.0]), ), ], ) From e4b2de07710c9fa15f9d01e2297d54c593ebee94 Mon Sep 17 00:00:00 2001 From: Timothy Nunn Date: Fri, 22 May 2026 09:45:23 +0100 Subject: [PATCH 2/2] Remove casting in vmcon --- src/pyvmcon/vmcon.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pyvmcon/vmcon.py b/src/pyvmcon/vmcon.py index 155432f..427b6bf 100644 --- a/src/pyvmcon/vmcon.py +++ b/src/pyvmcon/vmcon.py @@ -2,7 +2,7 @@ import logging from collections.abc import Callable -from typing import Any, cast +from typing import Any import cvxpy as cp import numpy as np @@ -481,7 +481,7 @@ def phi(result: Result) -> ScalarType: lamda_inequality=lamda_inequality, ) - return cast("ScalarType", alpha), mu_equality, mu_inequality, new_result + return alpha, mu_equality, mu_inequality, new_result def _derivative_lagrangian(