Skip to content

Commit cc6fd76

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Completes OPEN-3490 Create a method inside the CommitBundleValidator that validates all resources individually
1 parent 995a306 commit cc6fd76

File tree

3 files changed

+138
-19
lines changed

3 files changed

+138
-19
lines changed

openlayer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,7 @@ def push(self, project_id: int):
974974

975975
# Validate bundle resources
976976
commit_bundle_validator = validators.CommitBundleValidator(
977-
commit_bundle_path=project_dir
977+
bundle_path=project_dir
978978
)
979979
failed_validations = commit_bundle_validator.validate()
980980

openlayer/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import sys
3+
import traceback
34
import warnings
45

56
import yaml
@@ -79,3 +80,12 @@ def write_yaml(dictionary: dict, filename: str):
7980
"""
8081
with open(filename, "w") as stream:
8182
yaml.dump(dictionary, stream)
83+
84+
85+
def get_exception_stacktrace(err: Exception):
86+
"""Returns the stacktrace of the most recent exception.
87+
88+
Returns:
89+
str: the stacktrace of the most recent exception.
90+
"""
91+
return "".join(traceback.format_exception(type(err), err, err.__traceback__))

openlayer/validators.py

Lines changed: 127 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
import ast
1414
import importlib
1515
import os
16-
import traceback
1716
import warnings
18-
from typing import Dict, List, Optional
17+
from typing import Any, Dict, List, Optional
1918

2019
import marshmallow as ma
2120
import pandas as pd
@@ -30,12 +29,13 @@ class CommitBundleValidator:
3029
3130
Parameters
3231
----------
33-
commit_bundle_path : str
32+
bundle_path : str
3433
The path to the commit bundle (staging area, if for the Python API).
3534
"""
3635

37-
def __init__(self, commit_bundle_path: str):
38-
self.commit_bundle_path = commit_bundle_path
36+
def __init__(self, bundle_path: str):
37+
self.bundle_path = bundle_path
38+
self._bundle_resources = self._list_resources_in_bundle()
3939
self.failed_validations = []
4040

4141
def _validate_bundle_state(self):
@@ -51,32 +51,30 @@ def _validate_bundle_state(self):
5151
"""
5252
bundle_state_failed_validations = []
5353

54-
bundle_resources = os.listdir(self.commit_bundle_path)
55-
5654
# Defining which datasets contain predictions
5755
training_predictions_column_name = None
5856
validation_predictions_column_name = None
59-
if "training" in bundle_resources:
57+
if "training" in self._bundle_resources:
6058
with open(
61-
f"{self.commit_bundle_path}/training/dataset_config.yaml", "r"
59+
f"{self.bundle_path}/training/dataset_config.yaml", "r"
6260
) as stream:
6361
training_dataset_config = yaml.safe_load(stream)
6462

6563
training_predictions_column_name = training_dataset_config.get(
6664
"predictionsColumnName"
6765
)
6866

69-
if "validation" in bundle_resources:
67+
if "validation" in self._bundle_resources:
7068
with open(
71-
f"{self.commit_bundle_path}/validation/dataset_config.yaml", "r"
69+
f"{self.bundle_path}/validation/dataset_config.yaml", "r"
7270
) as stream:
7371
validation_dataset_config = yaml.safe_load(stream)
7472

7573
validation_predictions_column_name = validation_dataset_config.get(
7674
"predictionsColumnName"
7775
)
7876

79-
if "model" in bundle_resources:
77+
if "model" in self._bundle_resources:
8078
if (
8179
training_predictions_column_name is None
8280
or validation_predictions_column_name is None
@@ -88,7 +86,7 @@ def _validate_bundle_state(self):
8886
)
8987
else:
9088
if (
91-
"training" in bundle_resources
89+
"training" in self._bundle_resources
9290
and validation_predictions_column_name is not None
9391
):
9492
bundle_state_failed_validations.append(
@@ -111,6 +109,120 @@ def _validate_bundle_state(self):
111109
# Add the bundle state failed validations to the list of all failed validations
112110
self.failed_validations.extend(bundle_state_failed_validations)
113111

112+
def _validate_bundle_resources(self):
113+
"""Runs the corresponding validations for each resource in the bundle."""
114+
bundle_resources_failed_validations = []
115+
116+
if "training" in self._bundle_resources:
117+
training_set_validator = DatasetValidator(
118+
dataset_config_file_path=f"{self.bundle_path}/training/dataset_config.yaml",
119+
dataset_file_path=f"{self.bundle_path}/training/dataset.csv",
120+
)
121+
bundle_resources_failed_validations.extend(
122+
training_set_validator.validate()
123+
)
124+
125+
if "validation" in self._bundle_resources:
126+
validation_set_validator = DatasetValidator(
127+
dataset_config_file_path=f"{self.bundle_path}/validation/dataset_config.yaml",
128+
dataset_file_path=f"{self.bundle_path}/training/dataset.csv",
129+
)
130+
bundle_resources_failed_validations.extend(
131+
validation_set_validator.validate()
132+
)
133+
134+
if "model" in self._bundle_resources:
135+
model_files = os.listdir(f"{self.bundle_path}/model")
136+
# Shell model
137+
if len(model_files) == 1:
138+
model_validator = ModelValidator(
139+
model_config_file_path=f"{self.bundle_path}/model/model_config.yaml"
140+
)
141+
# Model package
142+
else:
143+
# Use data from the validation as test data
144+
validation_dataset_df = self._load_dataset_from_bundle("validation")
145+
validation_dataset_config = self._load_dataset_config_from_bundle(
146+
"validation"
147+
)
148+
149+
sample_data = None
150+
if "textColumnName" in validation_dataset_config:
151+
sample_data = validation_dataset_df[
152+
validation_dataset_config["textColumnName"]
153+
].head()
154+
155+
else:
156+
sample_data = validation_dataset_df[
157+
validation_dataset_config["featureNames"]
158+
].head()
159+
160+
model_validator = ModelValidator(
161+
model_config_file_path=f"{self.bundle_path}/model/model_config.yaml",
162+
model_package_dir=f"{self.bundle_path}/model",
163+
sample_data=sample_data,
164+
)
165+
bundle_resources_failed_validations.extend(model_validator.validate())
166+
167+
# Print results of the validation
168+
if bundle_resources_failed_validations:
169+
print("Push failed validations: \n")
170+
_list_failed_validation_messages(bundle_resources_failed_validations)
171+
172+
# Add the bundle resources failed validations to the list of all failed validations
173+
self.failed_validations.extend(bundle_resources_failed_validations)
174+
175+
def _list_resources_in_bundle(self) -> List[str]:
176+
"""Lists the resources in a commit bundle."""
177+
# TODO: factor out list of valid resources
178+
VALID_RESOURCES = ["model", "training", "validation"]
179+
180+
resources = []
181+
182+
for resource in os.listdir(self.bundle_path):
183+
if resource in VALID_RESOURCES:
184+
resources.append(resource)
185+
return resources
186+
187+
def _load_dataset_from_bundle(self, label: str) -> pd.DataFrame:
188+
"""Loads a dataset from a commit bundle.
189+
190+
Parameters
191+
----------
192+
label : str
193+
The type of the dataset. Can be either "training" or "validation".
194+
195+
Returns
196+
-------
197+
pd.DataFrame
198+
The dataset.
199+
"""
200+
dataset_file_path = f"{self.bundle_path}/{label}/dataset.csv"
201+
202+
dataset_df = pd.read_csv(dataset_file_path)
203+
204+
return dataset_df
205+
206+
def _load_dataset_config_from_bundle(self, label: str) -> Dict[str, Any]:
207+
"""Loads a dataset config from a commit bundle.
208+
209+
Parameters
210+
----------
211+
label : str
212+
The type of the dataset. Can be either "training" or "validation".
213+
214+
Returns
215+
-------
216+
Dict[str, Any]
217+
The dataset config.
218+
"""
219+
dataset_config_file_path = f"{self.bundle_path}/{label}/dataset_config.yaml"
220+
221+
with open(dataset_config_file_path, "r") as stream:
222+
dataset_config = yaml.safe_load(stream)
223+
224+
return dataset_config
225+
114226
def validate(self) -> List[str]:
115227
"""Validates the commit bundle.
116228
@@ -120,6 +232,7 @@ def validate(self) -> List[str]:
120232
A list of failed validations.
121233
"""
122234
self._validate_bundle_state()
235+
self._validate_bundle_resources()
123236

124237
if not self.failed_validations:
125238
print("All validations passed!")
@@ -873,11 +986,7 @@ def _validate_prediction_interface(self):
873986
with utils.HidePrints():
874987
ml_model.predict_proba(self.sample_data)
875988
except Exception as err:
876-
exception_stack = "".join(
877-
traceback.format_exception(
878-
type(err), err, err.__traceback__
879-
)
880-
)
989+
exception_stack = utils.get_exception_stacktrace(err)
881990
prediction_interface_failed_validations.append(
882991
"The `predict_proba` function failed while running the test data. "
883992
"It is failing with the following error message: \n"

0 commit comments

Comments
 (0)