99from copy import deepcopy
1010
1111import numpy as np
12+ from ax .adapter .base import DataLoaderConfig
13+ from ax .adapter .data_utils import extract_experiment_data
1214from ax .adapter .transforms .unit_x import UnitX
1315from ax .core .observation import ObservationFeatures
1416from ax .core .parameter import ChoiceParameter , ParameterType , RangeParameter
1517from ax .core .parameter_constraint import ParameterConstraint
1618from ax .core .search_space import RobustSearchSpace , SearchSpace
1719from ax .exceptions .core import UnsupportedError , UserInputError
1820from ax .utils .common .testutils import TestCase
19- from ax .utils .testing .core_stubs import get_robust_search_space
21+ from ax .utils .testing .core_stubs import (
22+ get_experiment_with_observations ,
23+ get_robust_search_space ,
24+ )
25+ from pandas import DataFrame
26+ from pandas .testing import assert_frame_equal
2027from pyre_extensions import assert_is_instance
2128
2229
2330class UnitXTransformTest (TestCase ):
24- transform_class = UnitX
25- # pyre-fixme[4]: Attribute must be annotated.
26- expected_c_dicts = [{"x" : - 1.0 , "y" : 1.0 }, {"x" : - 1.0 , "a" : 1.0 }]
27- expected_c_bounds = [0.0 , 1.0 ]
28-
2931 def setUp (self ) -> None :
3032 super ().setUp ()
31- self .target_lb = self .transform_class .target_lb
32- self .target_range = self .transform_class .target_range
33- self .target_ub = self .target_lb + self .target_range
3433 self .search_space = SearchSpace (
3534 parameters = [
3635 RangeParameter (
@@ -56,10 +55,7 @@ def setUp(self) -> None:
5655 ParameterConstraint (constraint_dict = {"x" : - 0.5 , "a" : 1 }, bound = 0.5 ),
5756 ],
5857 )
59- self .t = self .transform_class (
60- search_space = self .search_space ,
61- observations = [],
62- )
58+ self .t = UnitX (search_space = self .search_space )
6359 self .search_space_with_target = SearchSpace (
6460 parameters = [
6561 RangeParameter (
@@ -86,13 +82,7 @@ def test_TransformObservationFeatures(self) -> None:
8682 obs_ft2 ,
8783 [
8884 ObservationFeatures (
89- parameters = {
90- "x" : self .target_lb + self .target_range / 2.0 ,
91- "y" : 1.0 ,
92- "z" : 2 ,
93- "a" : 2 ,
94- "b" : "b" ,
95- }
85+ parameters = {"x" : 0.5 , "y" : 1.0 , "z" : 2 , "a" : 2 , "b" : "b" }
9686 )
9787 ],
9888 )
@@ -103,7 +93,7 @@ def test_TransformObservationFeatures(self) -> None:
10393 obs_ft3 = self .t .transform_observation_features (obs_ft3 )
10494 self .assertEqual (
10595 obs_ft3 [0 ],
106- ObservationFeatures (parameters = {"x" : self . target_ub , "z" : 2 }),
96+ ObservationFeatures (parameters = {"x" : 1.0 , "z" : 2 }),
10797 )
10898 obs_ft5 = self .t .transform_observation_features ([ObservationFeatures ({})])
10999 self .assertEqual (obs_ft5 [0 ], ObservationFeatures ({}))
@@ -114,31 +104,35 @@ def test_TransformSearchSpace(self) -> None:
114104
115105 # Parameters transformed
116106 true_bounds = {
117- "x" : (self . target_lb , 1.0 ),
118- "y" : (self . target_lb , 1.0 ),
107+ "x" : (0.0 , 1.0 ),
108+ "y" : (0.0 , 1.0 ),
119109 "z" : (1.0 , 2.0 ),
120110 "a" : (1.0 , 2.0 ),
121111 }
122112 for p_name , (l , u ) in true_bounds .items ():
123- self .assertEqual (ss2 .parameters [p_name ].lower , l )
124- self .assertEqual (ss2 .parameters [p_name ].upper , u )
125- self .assertEqual (ss2 .parameters ["b" ].values , ["a" , "b" , "c" ])
113+ self .assertEqual (
114+ assert_is_instance (ss2 .parameters [p_name ], RangeParameter ).lower , l
115+ )
116+ self .assertEqual (
117+ assert_is_instance (ss2 .parameters [p_name ], RangeParameter ).upper , u
118+ )
119+ self .assertEqual (
120+ assert_is_instance (ss2 .parameters ["b" ], ChoiceParameter ).values ,
121+ ["a" , "b" , "c" ],
122+ )
126123 self .assertEqual (len (ss2 .parameters ), 5 )
127124 # Constraints transformed
128125 self .assertEqual (
129- ss2 .parameter_constraints [0 ].constraint_dict , self . expected_c_dicts [ 0 ]
126+ ss2 .parameter_constraints [0 ].constraint_dict , { "x" : - 1.0 , "y" : 1.0 }
130127 )
131- self .assertEqual (ss2 .parameter_constraints [0 ].bound , self . expected_c_bounds [ 0 ] )
128+ self .assertEqual (ss2 .parameter_constraints [0 ].bound , 0.0 )
132129 self .assertEqual (
133- ss2 .parameter_constraints [1 ].constraint_dict , self . expected_c_dicts [ 1 ]
130+ ss2 .parameter_constraints [1 ].constraint_dict , { "x" : - 1.0 , "a" : 1.0 }
134131 )
135- self .assertEqual (ss2 .parameter_constraints [1 ].bound , self . expected_c_bounds [ 1 ] )
132+ self .assertEqual (ss2 .parameter_constraints [1 ].bound , 1.0 )
136133
137134 # Test transform of target value
138- t = self .transform_class (
139- search_space = self .search_space_with_target ,
140- observations = [],
141- )
135+ t = UnitX (search_space = self .search_space_with_target )
142136 t .transform_search_space (self .search_space_with_target )
143137 self .assertEqual (
144138 self .search_space_with_target .parameters ["x" ].target_value , 1.0
@@ -175,14 +169,8 @@ def test_TransformNewSearchSpace(self) -> None:
175169 self .t .transform_search_space (new_ss )
176170 # Parameters transformed
177171 true_bounds = {
178- "x" : [
179- 0.25 * self .target_range + self .target_lb ,
180- 0.5 * self .target_range + self .target_lb ,
181- ],
182- "y" : [
183- 0.25 * self .target_range + self .target_lb ,
184- 1.0 * self .target_range + self .target_lb ,
185- ],
172+ "x" : [0.25 , 0.5 ],
173+ "y" : [0.25 , 1.0 ],
186174 "z" : [1.0 , 1.5 ],
187175 "a" : [0 , 2 ],
188176 }
@@ -197,23 +185,16 @@ def test_TransformNewSearchSpace(self) -> None:
197185 self .assertEqual (len (new_ss .parameters ), 5 )
198186 # # Constraints transformed
199187 self .assertEqual (
200- new_ss .parameter_constraints [0 ].constraint_dict , self .expected_c_dicts [0 ]
201- )
202- self .assertEqual (
203- new_ss .parameter_constraints [0 ].bound , self .expected_c_bounds [0 ]
204- )
205- self .assertEqual (
206- new_ss .parameter_constraints [1 ].constraint_dict , self .expected_c_dicts [1 ]
188+ new_ss .parameter_constraints [0 ].constraint_dict , {"x" : - 1.0 , "y" : 1.0 }
207189 )
190+ self .assertEqual (new_ss .parameter_constraints [0 ].bound , 0.0 )
208191 self .assertEqual (
209- new_ss .parameter_constraints [1 ].bound , self . expected_c_bounds [ 1 ]
192+ new_ss .parameter_constraints [1 ].constraint_dict , { "x" : - 1.0 , "a" : 1.0 }
210193 )
194+ self .assertEqual (new_ss .parameter_constraints [1 ].bound , 1.0 )
211195
212196 # Test transform of target value
213- t = self .transform_class (
214- search_space = self .search_space_with_target ,
215- observations = [],
216- )
197+ t = UnitX (search_space = self .search_space_with_target )
217198 new_search_space_with_target = SearchSpace (
218199 parameters = [
219200 RangeParameter (
@@ -227,50 +208,30 @@ def test_TransformNewSearchSpace(self) -> None:
227208 ]
228209 )
229210 t .transform_search_space (new_search_space_with_target )
230- self .assertEqual (
231- new_search_space_with_target .parameters ["x" ].target_value ,
232- 0.5 * self .target_range + self .target_lb ,
233- )
211+ self .assertEqual (new_search_space_with_target .parameters ["x" ].target_value , 0.5 )
234212
235213 def test_w_robust_search_space_univariate (self ) -> None :
236214 # Check that if no transforms are needed, it is untouched.
237215 for multivariate in (True , False ):
238- rss = get_robust_search_space (
239- multivariate = multivariate ,
240- lb = self .target_lb ,
241- ub = self .target_ub ,
242- )
216+ rss = get_robust_search_space (multivariate = multivariate , lb = 0.0 , ub = 1.0 )
243217 expected = str (rss )
244- t = self .transform_class (
245- search_space = rss ,
246- observations = [],
247- )
218+ t = UnitX (search_space = rss )
248219 self .assertEqual (expected , str (t .transform_search_space (rss )))
249220 # Error if distribution is multiplicative.
250221 rss = get_robust_search_space ()
251222 rss .parameter_distributions [0 ].multiplicative = True
252- t = self .transform_class (
253- search_space = rss ,
254- observations = [],
255- )
223+ t = UnitX (search_space = rss )
256224 with self .assertRaisesRegex (NotImplementedError , "multiplicative" ):
257225 t .transform_search_space (rss )
258226 # Correctly transform univariate additive distributions.
259227 rss = get_robust_search_space (lb = 5.0 , ub = 10.0 )
260- t = self .transform_class (
261- search_space = rss ,
262- observations = [],
263- )
228+ t = UnitX (search_space = rss )
264229 t .transform_search_space (rss )
265230 dists = rss .parameter_distributions
266- self .assertEqual (
267- dists [0 ].distribution_parameters ["loc" ], 0.2 * self .target_range
268- )
269- self .assertEqual (dists [0 ].distribution_parameters ["scale" ], self .target_range )
231+ self .assertEqual (dists [0 ].distribution_parameters ["loc" ], 0.2 )
232+ self .assertEqual (dists [0 ].distribution_parameters ["scale" ], 1.0 )
270233 self .assertEqual (dists [1 ].distribution_parameters ["loc" ], 0.0 )
271- self .assertEqual (
272- dists [1 ].distribution_parameters ["scale" ], 0.2 * self .target_range
273- )
234+ self .assertEqual (dists [1 ].distribution_parameters ["scale" ], 0.2 )
274235 # Correctly transform environmental distributions.
275236 rss = get_robust_search_space (lb = 5.0 , ub = 10.0 )
276237 all_parameters = list (rss .parameters .values ())
@@ -286,22 +247,19 @@ def test_w_robust_search_space_univariate(self) -> None:
286247 dist .distribution_parameters ["loc" ],
287248 t ._normalize_value (1.0 , (5.0 , 10.0 )),
288249 )
289- self .assertEqual (dist .distribution_parameters ["scale" ], self . target_range )
250+ self .assertEqual (dist .distribution_parameters ["scale" ], 1.0 )
290251 # Error if transform via loc / scale is not supported.
291252 rss = get_robust_search_space (use_discrete = True )
292253 rss .parameters ["z" ]._parameter_type = ParameterType .FLOAT
293- t = self .transform_class (
294- search_space = rss ,
295- observations = [],
296- )
254+ t = UnitX (search_space = rss )
297255 with self .assertRaisesRegex (UnsupportedError , "`loc` and `scale`" ):
298256 t .transform_search_space (rss )
299257
300258 def test_w_robust_search_space_multivariate (self ) -> None :
301259 # Error if trying to transform non-normal multivariate distributions.
302260 rss = get_robust_search_space (multivariate = True )
303261 rss .parameter_distributions [0 ].distribution_class = "multivariate_t"
304- t = self . transform_class (
262+ t = UnitX (
305263 search_space = rss ,
306264 observations = [],
307265 )
@@ -310,25 +268,16 @@ def test_w_robust_search_space_multivariate(self) -> None:
310268 # Transform multivariate normal.
311269 rss = get_robust_search_space (multivariate = True )
312270 old_params = deepcopy (rss .parameter_distributions [0 ].distribution_parameters )
313- t = self .transform_class (
314- search_space = rss ,
315- observations = [],
316- )
271+ t = UnitX (search_space = rss )
317272 t .transform_search_space (rss )
318273 new_params = rss .parameter_distributions [0 ].distribution_parameters
319274 self .assertIsInstance (new_params ["mean" ], np .ndarray )
320275 self .assertIsInstance (new_params ["cov" ], np .ndarray )
321276 self .assertTrue (
322- np .allclose (
323- new_params ["mean" ],
324- np .asarray (old_params ["mean" ]) / 5.0 * self .target_range ,
325- )
277+ np .allclose (new_params ["mean" ], np .asarray (old_params ["mean" ]) / 5.0 )
326278 )
327279 self .assertTrue (
328- np .allclose (
329- new_params ["cov" ],
330- np .asarray (old_params ["cov" ]) / ((5.0 / self .target_range ) ** 2 ),
331- )
280+ np .allclose (new_params ["cov" ], np .asarray (old_params ["cov" ]) / 25.0 )
332281 )
333282 # Transform multivariate normal environmental distribution.
334283 rss = get_robust_search_space (multivariate = True )
@@ -339,18 +288,11 @@ def test_w_robust_search_space_multivariate(self) -> None:
339288 num_samples = rss .num_samples ,
340289 environmental_variables = rss_params [:2 ],
341290 )
342- t = self .transform_class (
343- search_space = rss ,
344- observations = [],
345- )
291+ t = UnitX (search_space = rss )
346292 t .transform_search_space (rss )
347293 new_params = rss .parameter_distributions [0 ].distribution_parameters
348294 self .assertTrue (
349- np .allclose (
350- new_params ["mean" ],
351- np .asarray (old_params ["mean" ]) / 5.0 * self .target_range
352- + self .target_lb ,
353- )
295+ np .allclose (new_params ["mean" ], np .asarray (old_params ["mean" ]) / 5.0 )
354296 )
355297 # Errors if mean / cov are of wrong shape.
356298 rss .parameter_distributions [0 ].distribution_parameters ["mean" ] = [1.0 ]
@@ -360,3 +302,42 @@ def test_w_robust_search_space_multivariate(self) -> None:
360302 rss .parameter_distributions [0 ].distribution_parameters ["cov" ] = [1.0 ]
361303 with self .assertRaisesRegex (UserInputError , "cov" ):
362304 t .transform_search_space (rss )
305+
306+ def test_transform_experiment_data (self ) -> None :
307+ parameterizations = [
308+ {"x" : 1.0 , "y" : 1.5 , "z" : 1.0 , "a" : 1 , "b" : "b" },
309+ {"x" : 2.0 , "y" : 2.0 , "z" : 2.0 , "a" : 2 , "b" : "b" },
310+ ]
311+ experiment = get_experiment_with_observations (
312+ observations = [[1.0 ], [2.0 ]],
313+ search_space = self .search_space ,
314+ parameterizations = parameterizations ,
315+ )
316+ experiment_data = extract_experiment_data (
317+ experiment = experiment , data_loader_config = DataLoaderConfig ()
318+ )
319+ transformed_data = self .t .transform_experiment_data (
320+ experiment_data = deepcopy (experiment_data )
321+ )
322+
323+ # Check that `x` and `y` have been transformed.
324+ expected = DataFrame (
325+ index = transformed_data .arm_data .index ,
326+ data = {
327+ "x" : [0.0 , 0.5 ],
328+ "y" : [0.5 , 1.0 ],
329+ },
330+ columns = ["x" , "y" ],
331+ )
332+ assert_frame_equal (transformed_data .arm_data [["x" , "y" ]], expected )
333+
334+ # Remaining columns are unchanged.
335+ # "z" is log-scale and "a" is in, so they're not transformed.
336+ cols = ["z" , "a" , "b" , "metadata" ]
337+ assert_frame_equal (
338+ transformed_data .arm_data [cols ], experiment_data .arm_data [cols ]
339+ )
340+ # Observation data is unchanged.
341+ assert_frame_equal (
342+ transformed_data .observation_data , experiment_data .observation_data
343+ )
0 commit comments