Skip to content

Commit 3ba2e3f

Browse files
committed
Add tests for solver and remove pdb in TR tests
1 parent 0990ef7 commit 3ba2e3f

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

dfols/tests/test_solver.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ def rosenbrock_jacobian(x):
4141
return np.array([[-20.0*x[0], 10.0], [-1.0, 0.0]])
4242

4343

44+
def p_box(x,l,u):
45+
return np.minimum(np.maximum(x,l), u)
46+
47+
def p_ball(x,c,r):
48+
return c + (r/np.max([np.linalg.norm(x-c),r]))*(x-c)
49+
50+
4451
class TestNans(unittest.TestCase):
4552
# Generic objective that only returns NaNs (like optclim code)
4653
# Verify get a sensible termination
@@ -185,3 +192,21 @@ def runTest(self):
185192
self.assertTrue(array_compare(soln.jacobian, jac(soln.x), thresh=1e-1), "Wrong Jacobian")
186193
self.assertTrue(abs(soln.f) < 1e-10, "Wrong fmin")
187194

195+
196+
class TestRosenbrockBoxBall(unittest.TestCase):
197+
# Minimise the (2d) Rosenbrock function, where x[1] hits the upper bound
198+
def runTest(self):
199+
# n, m = 2, 2
200+
x0 = np.array([-1.2, 0.7]) # standard start point does not satisfy bounds
201+
lower = np.array([0.7, -2.0])
202+
upper = np.array([1.0, 2])
203+
boxproj = lambda x: p_box(x,lower,upper)
204+
ballproj = lambda x: p_ball(x,np.array([0.5,1]),0.25)
205+
xmin = np.array([0.70424386, 0.85583188]) # approximate
206+
fmin = np.dot(rosenbrock(xmin), rosenbrock(xmin))
207+
soln = dfols.solve(rosenbrock, x0, projections=[boxproj,ballproj])
208+
print(soln.x)
209+
self.assertTrue(array_compare(soln.x, xmin, thresh=1e-2), "Wrong xmin")
210+
self.assertTrue(array_compare(soln.resid, rosenbrock(soln.x), thresh=1e-10), "Wrong resid")
211+
self.assertTrue(array_compare(soln.jacobian, rosenbrock_jacobian(soln.x), thresh=1e-2), "Wrong Jacobian")
212+
self.assertTrue(abs(soln.f - fmin) < 1e-4, "Wrong fmin")

dfols/tests/test_trust_region.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@
3030
from dfols.trust_region import ctrsbox, ctrsbox_geometry, trsbox, trsbox_geometry
3131
from dfols.util import model_value
3232

33-
import pdb
34-
3533

3634
def cauchy_pt(g, H, delta):
3735
# General expression for the Cauchy point

0 commit comments

Comments
 (0)