1313import ast
1414import importlib
1515import os
16- import traceback
1716import warnings
18- from typing import Dict , List , Optional
17+ from typing import Any , Dict , List , Optional
1918
2019import marshmallow as ma
2120import pandas as pd
@@ -30,12 +29,13 @@ class CommitBundleValidator:
3029
3130 Parameters
3231 ----------
33- commit_bundle_path : str
32+ bundle_path : str
3433 The path to the commit bundle (staging area, if for the Python API).
3534 """
3635
37- def __init__ (self , commit_bundle_path : str ):
38- self .commit_bundle_path = commit_bundle_path
36+ def __init__ (self , bundle_path : str ):
37+ self .bundle_path = bundle_path
38+ self ._bundle_resources = self ._list_resources_in_bundle ()
3939 self .failed_validations = []
4040
4141 def _validate_bundle_state (self ):
@@ -51,32 +51,30 @@ def _validate_bundle_state(self):
5151 """
5252 bundle_state_failed_validations = []
5353
54- bundle_resources = os .listdir (self .commit_bundle_path )
55-
5654 # Defining which datasets contain predictions
5755 training_predictions_column_name = None
5856 validation_predictions_column_name = None
59- if "training" in bundle_resources :
57+ if "training" in self . _bundle_resources :
6058 with open (
61- f"{ self .commit_bundle_path } /training/dataset_config.yaml" , "r"
59+ f"{ self .bundle_path } /training/dataset_config.yaml" , "r"
6260 ) as stream :
6361 training_dataset_config = yaml .safe_load (stream )
6462
6563 training_predictions_column_name = training_dataset_config .get (
6664 "predictionsColumnName"
6765 )
6866
69- if "validation" in bundle_resources :
67+ if "validation" in self . _bundle_resources :
7068 with open (
71- f"{ self .commit_bundle_path } /validation/dataset_config.yaml" , "r"
69+ f"{ self .bundle_path } /validation/dataset_config.yaml" , "r"
7270 ) as stream :
7371 validation_dataset_config = yaml .safe_load (stream )
7472
7573 validation_predictions_column_name = validation_dataset_config .get (
7674 "predictionsColumnName"
7775 )
7876
79- if "model" in bundle_resources :
77+ if "model" in self . _bundle_resources :
8078 if (
8179 training_predictions_column_name is None
8280 or validation_predictions_column_name is None
@@ -88,7 +86,7 @@ def _validate_bundle_state(self):
8886 )
8987 else :
9088 if (
91- "training" in bundle_resources
89+ "training" in self . _bundle_resources
9290 and validation_predictions_column_name is not None
9391 ):
9492 bundle_state_failed_validations .append (
@@ -111,6 +109,120 @@ def _validate_bundle_state(self):
111109 # Add the bundle state failed validations to the list of all failed validations
112110 self .failed_validations .extend (bundle_state_failed_validations )
113111
112+ def _validate_bundle_resources (self ):
113+ """Runs the corresponding validations for each resource in the bundle."""
114+ bundle_resources_failed_validations = []
115+
116+ if "training" in self ._bundle_resources :
117+ training_set_validator = DatasetValidator (
118+ dataset_config_file_path = f"{ self .bundle_path } /training/dataset_config.yaml" ,
119+ dataset_file_path = f"{ self .bundle_path } /training/dataset.csv" ,
120+ )
121+ bundle_resources_failed_validations .extend (
122+ training_set_validator .validate ()
123+ )
124+
125+ if "validation" in self ._bundle_resources :
126+ validation_set_validator = DatasetValidator (
127+ dataset_config_file_path = f"{ self .bundle_path } /validation/dataset_config.yaml" ,
128+ dataset_file_path = f"{ self .bundle_path } /training/dataset.csv" ,
129+ )
130+ bundle_resources_failed_validations .extend (
131+ validation_set_validator .validate ()
132+ )
133+
134+ if "model" in self ._bundle_resources :
135+ model_files = os .listdir (f"{ self .bundle_path } /model" )
136+ # Shell model
137+ if len (model_files ) == 1 :
138+ model_validator = ModelValidator (
139+ model_config_file_path = f"{ self .bundle_path } /model/model_config.yaml"
140+ )
141+ # Model package
142+ else :
143+ # Use data from the validation as test data
144+ validation_dataset_df = self ._load_dataset_from_bundle ("validation" )
145+ validation_dataset_config = self ._load_dataset_config_from_bundle (
146+ "validation"
147+ )
148+
149+ sample_data = None
150+ if "textColumnName" in validation_dataset_config :
151+ sample_data = validation_dataset_df [
152+ validation_dataset_config ["textColumnName" ]
153+ ].head ()
154+
155+ else :
156+ sample_data = validation_dataset_df [
157+ validation_dataset_config ["featureNames" ]
158+ ].head ()
159+
160+ model_validator = ModelValidator (
161+ model_config_file_path = f"{ self .bundle_path } /model/model_config.yaml" ,
162+ model_package_dir = f"{ self .bundle_path } /model" ,
163+ sample_data = sample_data ,
164+ )
165+ bundle_resources_failed_validations .extend (model_validator .validate ())
166+
167+ # Print results of the validation
168+ if bundle_resources_failed_validations :
169+ print ("Push failed validations: \n " )
170+ _list_failed_validation_messages (bundle_resources_failed_validations )
171+
172+ # Add the bundle resources failed validations to the list of all failed validations
173+ self .failed_validations .extend (bundle_resources_failed_validations )
174+
175+ def _list_resources_in_bundle (self ) -> List [str ]:
176+ """Lists the resources in a commit bundle."""
177+ # TODO: factor out list of valid resources
178+ VALID_RESOURCES = ["model" , "training" , "validation" ]
179+
180+ resources = []
181+
182+ for resource in os .listdir (self .bundle_path ):
183+ if resource in VALID_RESOURCES :
184+ resources .append (resource )
185+ return resources
186+
187+ def _load_dataset_from_bundle (self , label : str ) -> pd .DataFrame :
188+ """Loads a dataset from a commit bundle.
189+
190+ Parameters
191+ ----------
192+ label : str
193+ The type of the dataset. Can be either "training" or "validation".
194+
195+ Returns
196+ -------
197+ pd.DataFrame
198+ The dataset.
199+ """
200+ dataset_file_path = f"{ self .bundle_path } /{ label } /dataset.csv"
201+
202+ dataset_df = pd .read_csv (dataset_file_path )
203+
204+ return dataset_df
205+
206+ def _load_dataset_config_from_bundle (self , label : str ) -> Dict [str , Any ]:
207+ """Loads a dataset config from a commit bundle.
208+
209+ Parameters
210+ ----------
211+ label : str
212+ The type of the dataset. Can be either "training" or "validation".
213+
214+ Returns
215+ -------
216+ Dict[str, Any]
217+ The dataset config.
218+ """
219+ dataset_config_file_path = f"{ self .bundle_path } /{ label } /dataset_config.yaml"
220+
221+ with open (dataset_config_file_path , "r" ) as stream :
222+ dataset_config = yaml .safe_load (stream )
223+
224+ return dataset_config
225+
114226 def validate (self ) -> List [str ]:
115227 """Validates the commit bundle.
116228
@@ -120,6 +232,7 @@ def validate(self) -> List[str]:
120232 A list of failed validations.
121233 """
122234 self ._validate_bundle_state ()
235+ self ._validate_bundle_resources ()
123236
124237 if not self .failed_validations :
125238 print ("All validations passed!" )
@@ -873,11 +986,7 @@ def _validate_prediction_interface(self):
873986 with utils .HidePrints ():
874987 ml_model .predict_proba (self .sample_data )
875988 except Exception as err :
876- exception_stack = "" .join (
877- traceback .format_exception (
878- type (err ), err , err .__traceback__
879- )
880- )
989+ exception_stack = utils .get_exception_stacktrace (err )
881990 prediction_interface_failed_validations .append (
882991 "The `predict_proba` function failed while running the test data. "
883992 "It is failing with the following error message: \n "
0 commit comments