Skip to content

Commit 995a306

Browse files
committed
Update churn notebook
1 parent bae2ab0 commit 995a306

File tree

2 files changed

+87
-17
lines changed

2 files changed

+87
-17
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ server-tests.ipynb
1212
dependencies/
1313
*.bin
1414
*.csv
15+
*.yaml
1516

1617
# Ignore everything in examples/ except the task dirs
1718
!examples

examples/tabular-classification/sklearn/churn-classifier/churn-classifier-sklearn.ipynb

Lines changed: 86 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@
276276
"openlayer.api.STORAGE = openlayer.api.StorageType.ONPREM\n",
277277
"openlayer.api.OPENLAYER_ENDPOINT = \"http://localhost:8080/v1\"\n",
278278
"\n",
279-
"client = openlayer.OpenlayerClient(\"YOUR_API_KEY_HERE\")"
279+
"client = openlayer.OpenlayerClient(\"YOUR_API_KEY\")"
280280
]
281281
},
282282
{
@@ -329,6 +329,78 @@
329329
"training_set['churn'] = y_train.values"
330330
]
331331
},
332+
{
333+
"cell_type": "code",
334+
"execution_count": null,
335+
"id": "c2d842da",
336+
"metadata": {},
337+
"outputs": [],
338+
"source": [
339+
"val_preds_df = pd.DataFrame({\"predictions\": sklearn_model.predict_proba(x_val_one_hot).tolist()})\n",
340+
"validation_set = validation_set.copy().reset_index(drop=True)\n",
341+
"validation_set[\"preds\"] = val_preds_df[\"predictions\"]\n",
342+
"validation_set"
343+
]
344+
},
345+
{
346+
"cell_type": "code",
347+
"execution_count": null,
348+
"id": "62969755",
349+
"metadata": {},
350+
"outputs": [],
351+
"source": [
352+
"train_preds_df = pd.DataFrame({\"predictions\": sklearn_model.predict_proba(x_train_one_hot).tolist()})\n",
353+
"training_set = training_set.copy().reset_index(drop=True)\n",
354+
"training_set[\"preds\"] = train_preds_df\n",
355+
"training_set"
356+
]
357+
},
358+
{
359+
"cell_type": "code",
360+
"execution_count": null,
361+
"id": "03688a2f",
362+
"metadata": {},
363+
"outputs": [],
364+
"source": [
365+
"import yaml \n",
366+
"\n",
367+
"validation_dataset_config = {\n",
368+
" \"label\": \"validation\",\n",
369+
" \"classNames\": class_names,\n",
370+
" \"categoricalFeatureNames\": [\"Gender\", \"Geography\"],\n",
371+
" \"featureNames\":feature_names,\n",
372+
" \"columnNames\":list(validation_set.columns),\n",
373+
" \"labelColumnName\": \"churn\",\n",
374+
" \"predictionsColumnName\": \"preds\",\n",
375+
"}\n",
376+
"\n",
377+
"with open('validation_dataset_config.yaml', 'w') as dataset_config_file:\n",
378+
" yaml.dump(validation_dataset_config, dataset_config_file, default_flow_style=False)"
379+
]
380+
},
381+
{
382+
"cell_type": "code",
383+
"execution_count": null,
384+
"id": "0e7257a3",
385+
"metadata": {},
386+
"outputs": [],
387+
"source": [
388+
"import yaml \n",
389+
"\n",
390+
"training_dataset_config = {\n",
391+
" \"label\": \"training\",\n",
392+
" \"classNames\": class_names,\n",
393+
" \"categoricalFeatureNames\": [\"Gender\", \"Geography\"],\n",
394+
" \"featureNames\":feature_names,\n",
395+
" \"columnNames\":list(training_set.columns),\n",
396+
" \"labelColumnName\": \"churn\",\n",
397+
" \"predictionsColumnName\": \"preds\",\n",
398+
"}\n",
399+
"\n",
400+
"with open('training_dataset_config.yaml', 'w') as dataset_config_file:\n",
401+
" yaml.dump(training_dataset_config, dataset_config_file, default_flow_style=False)"
402+
]
403+
},
332404
{
333405
"cell_type": "code",
334406
"execution_count": null,
@@ -341,11 +413,7 @@
341413
"# Validation set\n",
342414
"project.add_dataframe(\n",
343415
" df=validation_set,\n",
344-
" dataset_type=DatasetType.Validation,\n",
345-
" class_names=class_names,\n",
346-
" label_column_name='churn',\n",
347-
" feature_names=feature_names,\n",
348-
" categorical_feature_names=[\"Gender\", \"Geography\"],\n",
416+
" dataset_config_file_path='validation_dataset_config.yaml',\n",
349417
")"
350418
]
351419
},
@@ -359,11 +427,7 @@
359427
"# Training set\n",
360428
"project.add_dataframe(\n",
361429
" df=training_set,\n",
362-
" dataset_type=DatasetType.Training,\n",
363-
" class_names=class_names,\n",
364-
" label_column_name='churn',\n",
365-
" feature_names=feature_names,\n",
366-
" categorical_feature_names=[\"Gender\", \"Geography\"],\n",
430+
" dataset_config_file_path='training_dataset_config.yaml',\n",
367431
")"
368432
]
369433
},
@@ -538,10 +602,13 @@
538602
"\n",
539603
"model_config = {\n",
540604
" \"name\": \"Churn prediction model\",\n",
541-
" \"model_type\": \"sklearn\",\n",
542-
" \"class_names\": class_names,\n",
543-
" \"categorical_feature_names\": [\"Gender\", \"Geography\"],\n",
544-
" \"feature_names\":feature_names\n",
605+
" \"architectureType\": \"sklearn\",\n",
606+
" \"classNames\": class_names,\n",
607+
" \"categoricalFeatureNames\": [\"Gender\", \"Geography\"],\n",
608+
" \"featureNames\":feature_names,\n",
609+
" \"metadata\": {\n",
610+
" \"test\": \"name\"\n",
611+
" }\n",
545612
"}\n",
546613
"\n",
547614
"with open('model_package/model_config.yaml', 'w') as model_config_file:\n",
@@ -567,7 +634,8 @@
567634
"\n",
568635
"model_validator = ModelValidator(\n",
569636
" model_package_dir=\"model_package\", \n",
570-
" sample_data = x_val.iloc[:10, :]\n",
637+
" model_config_file_path='model_package/model_config.yaml',\n",
638+
" sample_data = x_val.iloc[:10, :],\n",
571639
")\n",
572640
"model_validator.validate()"
573641
]
@@ -589,7 +657,8 @@
589657
"source": [
590658
"project.add_model(\n",
591659
" model_package_dir=\"model_package\",\n",
592-
" sample_data=x_val.iloc[:10, :]\n",
660+
" model_config_file_path='model_package/model_config.yaml',\n",
661+
" sample_data=x_val.iloc[:10, :],\n",
593662
")"
594663
]
595664
},

0 commit comments

Comments
 (0)