Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
with:
python-version: "3.10"
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
- name: Install pre-commit
run: pip install pre-commit
- name: Install PyVMCON with dev dependencies (includes pre-commit)
run: pip install '.[dev]'
- name: Run pre-commit
run: pre-commit run --all-files
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ repos:
- id: ruff-check
args: [--fix]
- id: ruff-format
- repo: local
hooks:
- id: ty
name: ty check
entry: ty check .
pass_filenames: false
language: python
4 changes: 2 additions & 2 deletions examples/converges.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
"from pyvmcon import Problem\n",
"\n",
"\n",
"def f(x: list):\n",
" return (x[0] - 2) ** 2 + (x[1] - 1) ** 2\n",
"def f(x):\n",
" return np.array((x[0] - 2) ** 2 + (x[1] - 1) ** 2)\n",
"\n",
"\n",
"problem = Problem(\n",
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Source = "https://github.com/ukaea/PyVMCON"

[project.optional-dependencies]
test = ["pytest"]
dev = ["pre-commit", "ty==0.0.38"]
docs = [
"Sphinx>=6.1",
"sphinxcontrib-mermaid>=0.9",
Expand Down Expand Up @@ -60,3 +61,6 @@ ignore = [

[tool.ruff.format]
preview = true

[tool.ty.rules]
unresolved-import = "ignore"
16 changes: 8 additions & 8 deletions src/pyvmcon/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import TypeVar, cast
from typing import TypeAlias

import numpy as np
from numpy.typing import NDArray

ScalarType = TypeVar("ScalarType", NDArray, np.number, float)
ScalarType: TypeAlias = NDArray | np.number | float
"""A scalar variable e.g. a single number (which could be a 0D numpy array)"""
VectorType = NDArray
VectorType: TypeAlias = NDArray
"""A numpy array with only 1 dimension"""
MatrixType = NDArray
MatrixType: TypeAlias = NDArray
"""A numpy array with 2 dimensions"""


Expand Down Expand Up @@ -111,10 +111,10 @@ def __call__(self, x: VectorType) -> Result:
return Result(
self.f(x),
self.df(x),
cast("VectorType", np.array([c(x) for c in self.equality_constraints])),
cast("MatrixType", np.array([c(x) for c in self.dequality_constraints])),
cast("VectorType", np.array([c(x) for c in self.inequality_constraints])),
cast("MatrixType", np.array([c(x) for c in self.dinequality_constraints])),
np.array([c(x) for c in self.equality_constraints]),
np.array([c(x) for c in self.dequality_constraints]),
np.array([c(x) for c in self.inequality_constraints]),
np.array([c(x) for c in self.dinequality_constraints]),
)

@property
Expand Down
14 changes: 8 additions & 6 deletions src/pyvmcon/vmcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from collections.abc import Callable
from typing import Any, cast
from typing import Any

import cvxpy as cp
import numpy as np
Expand Down Expand Up @@ -31,7 +31,7 @@ def solve(
initial_B: np.ndarray | None = None,
callback: Callable[[int, Result, VectorType, float], None] | None = None,
additional_convergence: (
Callable[[Result, VectorType, VectorType, VectorType, VectorType], None] | None
Callable[[Result, VectorType, VectorType, VectorType, VectorType], bool] | None
) = None,
overwrite_convergence_criteria: bool = False,
) -> tuple[VectorType, VectorType, VectorType, Result]:
Expand Down Expand Up @@ -149,8 +149,10 @@ def solve(
B = np.identity(n) if initial_B is None else initial_B

callback = callback or (lambda _i, _result, _x, _con: None)
additional_convergence = additional_convergence or (
lambda _result, _x, _delta, _lambda_eq, _lambda_in: True
additional_convergence_function = (
(lambda _result, _x, _delta, _lambda_eq, _lambda_in: True)
if additional_convergence is None
else additional_convergence
)

# These two values being None allows the line
Expand Down Expand Up @@ -201,7 +203,7 @@ def solve(

callback(j, result, x, convergence_info)

if additional_convergence(
if additional_convergence_function(
result,
x,
delta,
Expand Down Expand Up @@ -479,7 +481,7 @@ def phi(result: Result) -> ScalarType:
lamda_inequality=lamda_inequality,
)

return cast("ScalarType", alpha), mu_equality, mu_inequality, new_result
return alpha, mu_equality, mu_inequality, new_result


def _derivative_lagrangian(
Expand Down
40 changes: 20 additions & 20 deletions tests/test_vmcon_paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class VMCONTestAsset:
expected_x: np.ndarray
expected_lamda_equality: np.ndarray
expected_lamda_inequality: np.ndarray
lbs: np.ndarray = None
ubs: np.ndarray = None
lbs: np.ndarray | None = None
ubs: np.ndarray | None = None
max_iter: int = 10
epsilon: float = 1e-8

Expand All @@ -39,9 +39,9 @@ class VMCONTestAsset:
[lambda x: np.array([-0.5 * x[0], -2 * x[1]])],
),
initial_x=np.array([2.0, 2.0]),
expected_x=[8.228756e-1, 9.114378e-1],
expected_lamda_equality=[-1.594491],
expected_lamda_inequality=[1.846591],
expected_x=np.array([8.228756e-1, 9.114378e-1]),
expected_lamda_equality=np.array([-1.594491]),
expected_lamda_inequality=np.array([1.846591]),
),
# Test 1 detailed in ANL-80-64 page 25
# with one of the constraints duplicated
Expand All @@ -57,11 +57,11 @@ class VMCONTestAsset:
[lambda x: np.array([-0.5 * x[0], -2 * x[1]])],
),
initial_x=np.array([2.0, 2.0]),
expected_x=[8.228756e-1, 9.114378e-1],
expected_x=np.array([8.228756e-1, 9.114378e-1]),
# duplicating the constraint is probably expected
# to change the Lagrange multipliers
expected_lamda_equality=[-0.7972455591261, -0.7972455591261],
expected_lamda_inequality=[1.846591],
expected_lamda_equality=np.array([-0.7972455591261, -0.7972455591261]),
expected_lamda_inequality=np.array([1.846591]),
),
# Test 1 detailed in ANL-80-64 page 25
# with added, unintrusive, bounds
Expand All @@ -75,9 +75,9 @@ class VMCONTestAsset:
[lambda x: np.array([-0.5 * x[0], -2 * x[1]])],
),
initial_x=np.array([2.0, 2.0]),
expected_x=[8.228756e-1, 9.114378e-1],
expected_lamda_equality=[-1.594491],
expected_lamda_inequality=[1.846591],
expected_x=np.array([8.228756e-1, 9.114378e-1]),
expected_lamda_equality=np.array([-1.594491]),
expected_lamda_inequality=np.array([1.846591]),
lbs=np.array([-10, -10]),
ubs=np.array([10, 10]),
),
Expand All @@ -98,9 +98,9 @@ class VMCONTestAsset:
],
),
initial_x=np.array([2.0, 2.0]),
expected_x=[1.6649685472365443, 0.55404867491788852],
expected_lamda_equality=[],
expected_lamda_inequality=[0, 0.80489557193146243],
expected_x=np.array([1.6649685472365443, 0.55404867491788852]),
expected_lamda_equality=np.array([]),
expected_lamda_inequality=np.array([0, 0.80489557193146243]),
),
# Example 1a of https://en.wikipedia.org/wiki/Lagrange_multiplier
VMCONTestAsset(
Expand All @@ -114,9 +114,9 @@ class VMCONTestAsset:
),
initial_x=np.array([1.0, 1.05]),
epsilon=1e-8,
expected_x=[-0.5 * 2**0.5, -0.5 * 2**0.5],
expected_lamda_equality=[-(2 ** (-0.5))],
expected_lamda_inequality=[],
expected_x=np.array([-0.5 * 2**0.5, -0.5 * 2**0.5]),
expected_lamda_equality=np.array([-(2 ** (-0.5))]),
expected_lamda_inequality=np.array([]),
max_iter=30,
),
],
Expand Down Expand Up @@ -161,9 +161,9 @@ def test_vmcon_paper_feasible_examples(vmcon_example: VMCONTestAsset):
),
initial_x=np.array([2.0, 2.0]),
max_iter=5,
expected_x=[2.3999994310874733, 0.6],
expected_lamda_equality=[0.0],
expected_lamda_inequality=[0.0],
expected_x=np.array([2.3999994310874733, 0.6]),
expected_lamda_equality=np.array([0.0]),
expected_lamda_inequality=np.array([0.0]),
),
],
)
Expand Down
Loading