11import logging
22from smdebug .xgboost import Hook
3- import xgboost as xgb
43
54
65logger = logging .getLogger (__name__ )
76
87
9- class SmeDebugHook (xgb .callback .TrainingCallback , Hook ):
10- """Mix-in callback class. smedebug.xgboost.Hook uses legacy callback style
11- and since XGB-3.0.0 mixing legacy callback instances with new TrainingCallback
12- instances is not allowed.
13- See: https://github.com/dmlc/xgboost/blob/v1.3.0/python-package/xgboost/training.py#L92-L93
14- :param hyperparameters: Dict of hyperparamters.
15- Same as `params` in xgb.train(params, dtrain).
16- :param train_dmatrix: Training data set.
17- :param val_dmatrix: Validation data set.
18- """
19- def __init__ (self , json_config_path , hyperparameters ,
20- train_dmatrix , val_dmatrix ):
21- self = self .hook_from_config (json_config_path )
22- self .hyperparameters = hyperparameters
23- self .train_data = train_dmatrix
24- if val_dmatrix is not None :
25- self .validation_data = val_dmatrix
26-
27-
288def add_debugging (callbacks , hyperparameters , train_dmatrix ,
299 val_dmatrix = None , json_config_path = None ):
3010 """Add a sagemaker debug hook to a list of callbacks.
@@ -37,9 +17,13 @@ def add_debugging(callbacks, hyperparameters, train_dmatrix,
3717 instead of default config file.
3818 """
3919 try :
40- hook = SmeDebugHook (json_config_path , hyperparameters , train_dmatrix , val_dmatrix )
20+ hook = Hook .hook_from_config (json_config_path )
21+ hook .hyperparameters = hyperparameters
22+ hook .train_data = train_dmatrix
23+ if val_dmatrix is not None :
24+ hook .validation_data = val_dmatrix
25+ callbacks .append (hook )
4126 logging .info ("Debug hook created from config" )
4227 except Exception as e :
4328 logging .debug ("Failed to create debug hook" , e )
44- else :
45- callbacks .append (hook )
29+ return
0 commit comments