Skip to content

Commit 3e2e44e

Browse files
Alexandre LandeauAlexandre Landeau
authored andcommitted
Requested changes fix
1 parent a217d59 commit 3e2e44e

File tree

4 files changed

+122
-83
lines changed

4 files changed

+122
-83
lines changed

.editorconfig

Whitespace-only changes.

custom-recipes/nlp-visualization-wordcloud/recipe.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,29 @@
44
import logging
55
from time import perf_counter
66

7-
from dataiku.customrecipe import get_recipe_resource
87
from spacy_tokenizer import MultilingualTokenizer
98
from wordcloud_visualizer import WordcloudVisualizer
10-
from plugin_config_loading import load_plugin_config_wordcloud
9+
from plugin_config_loading import load_config_and_data_wordcloud
1110

1211

1312
# Load config
14-
params = load_plugin_config_wordcloud()
15-
font_folder_path = os.path.join(get_recipe_resource(), "fonts")
16-
output_folder = params["output_folder"]
17-
output_partition_path = params["output_partition_path"]
18-
df = params["df"]
19-
20-
# Instanciate tokenizer
21-
tokenizer = MultilingualTokenizer(
22-
stopwords_folder_path=(params["stopwords_folder_path"] if params["remove_stopwords"] else None)
23-
)
13+
params, df = load_config_and_data_wordcloud()
14+
output_folder = params.output_folder
15+
output_partition_path = params.output_partition_path
16+
2417
# Load wordcloud visualizer
2518
worcloud_visualizer = WordcloudVisualizer(
26-
tokenizer=tokenizer,
27-
text_column=params["text_column"],
28-
font_folder_path=font_folder_path,
29-
language=params["language"],
30-
language_column=params["language_column"],
31-
subchart_column=params["subchart_column"],
32-
remove_stopwords=params["remove_stopwords"],
33-
remove_punctuation=params["remove_punctuation"],
34-
case_insensitive=params["case_insensitive"],
35-
max_words=params["max_words"],
36-
color_list=params["color_list"],
19+
tokenizer=MultilingualTokenizer(stopwords_folder_path=params.stopwords_folder_path),
20+
text_column=params.text_column,
21+
font_folder_path=params.font_folder_path,
22+
language=params.language,
23+
language_column=params.language_column,
24+
subchart_column=params.subchart_column,
25+
remove_stopwords=params.remove_stopwords,
26+
remove_punctuation=params.remove_punctuation,
27+
case_insensitive=params.case_insensitive,
28+
max_words=params.max_words,
29+
color_list=params.color_list,
3730
)
3831

3932
# Prepare data and count tokens for each subchart

python-lib/plugin_config_loading.py

Lines changed: 95 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
import logging
55
import os
6-
from typing import Dict
6+
from typing import Tuple
77

8+
import pandas as pd
89
import matplotlib
910
import dataiku
1011
from dataiku.customrecipe import (
@@ -25,14 +26,40 @@ class PluginParamValidationError(ValueError):
2526
pass
2627

2728

28-
def load_plugin_config_wordcloud() -> Dict:
29-
"""Utility function to validate and load language detection parameters into a clean dictionary
29+
class PluginParams:
30+
"""Class to store recipe parameters"""
31+
32+
def __init__(self):
33+
pass
34+
35+
__slots__ = [
36+
"output_folder",
37+
"output_partition_path",
38+
"text_column",
39+
"language",
40+
"language_column",
41+
"subchart_column",
42+
"remove_stopwords",
43+
"stopwords_folder_path",
44+
"font_folder_path",
45+
"remove_punctuation",
46+
"case_insensitive",
47+
"max_words",
48+
"color_list",
49+
]
3050

31-
Returns:
32-
Dictionary of parameter names (key) and values
3351

52+
def load_config_and_data_wordcloud() -> Tuple[PluginParams, pd.DataFrame]:
53+
"""Utility function to:
54+
- Validate and load wordcloud parameters into a clean class
55+
- Validate input data, keep only necessary columns and drop invalid rows
56+
57+
Returns:
58+
- Class instance with parameter names as attributes and associated values
59+
- Pandas DataFrame with necessary input data
3460
"""
35-
params = {}
61+
62+
params = PluginParams()
3663
# Input dataset
3764
input_dataset_names = get_input_names_for_role("input_dataset")
3865
if len(input_dataset_names) != 1:
@@ -44,97 +71,110 @@ def load_plugin_config_wordcloud() -> Dict:
4471
output_folder_names = get_output_names_for_role("output_folder")
4572
if len(output_folder_names) != 1:
4673
raise PluginParamValidationError("Please specify one output folder")
47-
params["output_folder"] = dataiku.Folder(output_folder_names[0])
74+
params.output_folder = dataiku.Folder(output_folder_names[0])
4875

4976
# Partition handling
50-
params["output_partition_path"] = get_folder_partition_root(params["output_folder"])
77+
params.output_partition_path = get_folder_partition_root(params.output_folder)
5178

5279
# Recipe parameters
5380
recipe_config = get_recipe_config()
5481

5582
# Text column
56-
params["text_column"] = recipe_config.get("text_column")
57-
if params["text_column"] not in input_dataset_columns:
58-
raise PluginParamValidationError(f"Invalid text column selection: {params['text_column']}")
59-
logging.info(f"Text column: {params['text_column']}")
83+
if recipe_config.get("text_column") not in input_dataset_columns:
84+
raise PluginParamValidationError(f"Invalid text column selection: {recipe_config.get('text_column')}")
85+
params.text_column = recipe_config.get("text_column")
86+
logging.info(f"Text column: {params.text_column}")
6087
# Language selection
61-
params["language"] = recipe_config.get("language")
62-
if params["language"] == "language_column":
63-
params["language_column"] = recipe_config.get("language_column")
64-
if params["language_column"] not in input_dataset_columns:
65-
raise PluginParamValidationError(f"Invalid language column selection: {params['language_column']}")
66-
logging.info(f"Language column: {params['language_column']}")
88+
89+
if recipe_config.get("language") == "language_column":
90+
if recipe_config.get("language_column") not in input_dataset_columns:
91+
raise PluginParamValidationError(
92+
f"Invalid language column selection: {recipe_config.get('language_column')}"
93+
)
94+
params.language = recipe_config.get("language")
95+
params.language_column = recipe_config.get("language_column")
96+
logging.info(f"Language column: {params.language_column}")
6797
else:
68-
if not params["language"]:
98+
if not recipe_config.get("language"):
6999
raise PluginParamValidationError("Empty language selection")
70-
if params["language"] not in SUPPORTED_LANGUAGES_SPACY:
71-
raise PluginParamValidationError(f"Unsupported language code: {params['language']}")
72-
params["language_column"] = None
73-
logging.info(f"Language: {params['language']}")
100+
if recipe_config.get("language") not in SUPPORTED_LANGUAGES_SPACY:
101+
raise PluginParamValidationError(f"Unsupported language code: {recipe_config.get('language')}")
102+
params.language = recipe_config.get("language")
103+
params.language_column = None
104+
logging.info(f"Language: {params.language}")
74105

75106
# Subcharts
76-
params["subchart_column"] = recipe_config.get("subchart_column")
107+
subchart_column = recipe_config.get("subchart_column")
77108
# If parameter is saved then cleared, config retrieves ""
78-
params["subchart_column"] = None if not params["subchart_column"] else params["subchart_column"]
79-
if params["subchart_column"] and ((params["subchart_column"] not in input_dataset_columns + ["order66"])):
80-
raise PluginParamValidationError(f"Invalid categorical column selection: {params['subchart_column']}")
81-
logging.info(f"Subcharts column: {params['subchart_column']}")
109+
subchart_column = None if not subchart_column else subchart_column
110+
if subchart_column and ((subchart_column not in input_dataset_columns + ["order66"])):
111+
raise PluginParamValidationError(f"Invalid categorical column selection: {subchart_column}")
112+
params.subchart_column = subchart_column
113+
logging.info(f"Subcharts column: {params.subchart_column}")
82114

83115
# Input dataframe
84116
necessary_columns = [
85117
column
86-
for column in set([params["text_column"], params["language_column"], params["subchart_column"]])
118+
for column in set(
119+
[
120+
params.text_column,
121+
params.language_column,
122+
params.subchart_column,
123+
]
124+
)
87125
if (column not in [None, "order66"])
88126
]
89-
params["df"] = input_dataset.get_dataframe(columns=necessary_columns).dropna(subset=necessary_columns)
90-
if params["df"].empty:
127+
df = input_dataset.get_dataframe(columns=necessary_columns).dropna(subset=necessary_columns)
128+
if df.empty:
91129
raise PluginParamValidationError("Dataframe is empty")
92130
# Check if unsupported languages in multilingual case
93-
elif params["language_column"]:
94-
languages = set(params["df"][params["language_column"]].unique())
131+
elif params.language_column:
132+
languages = set(df[params.language_column].unique())
95133
unsupported_lang = languages - SUPPORTED_LANGUAGES_SPACY.keys()
96134
if unsupported_lang:
97135
raise PluginParamValidationError(
98136
f"Found {len(unsupported_lang)} unsupported languages: {', '.join(sorted(unsupported_lang))}"
99137
)
100138

101-
logging.info(f"Read dataset of shape: {params['df'].shape}")
139+
logging.info(f"Read dataset of shape: {df.shape}")
102140

103141
# Text simplification parameters
104-
params["remove_stopwords"] = recipe_config.get("remove_stopwords")
105-
params["stopwords_folder_path"] = os.path.join(get_recipe_resource(), "stopwords")
106-
params["remove_punctuation"] = recipe_config.get("remove_punctuation")
107-
params["case_insensitive"] = recipe_config.get("case_insensitive")
108-
logging.info(f"Remove stopwords: {params['remove_stopwords']}")
109-
logging.info(f"Remove punctuation: {params['remove_punctuation']}")
110-
logging.info(f"Case-insensitive: {params['case_insensitive']}")
142+
params.remove_stopwords = recipe_config.get("remove_stopwords")
143+
params.stopwords_folder_path = os.path.join(get_recipe_resource(), "stopwords") if params.remove_stopwords else None
144+
params.font_folder_path = os.path.join(get_recipe_resource(), "fonts")
145+
params.remove_punctuation = recipe_config.get("remove_punctuation")
146+
params.case_insensitive = recipe_config.get("case_insensitive")
147+
logging.info(f"Remove stopwords: {params.remove_stopwords}")
148+
logging.info(f"Stopwords folder path: {params.stopwords_folder_path}")
149+
logging.info(f"Fonts folder path: {params.font_folder_path}")
150+
logging.info(f"Remove punctuation: {params.remove_punctuation}")
151+
logging.info(f"Case-insensitive: {params.case_insensitive}")
111152

112153
# Display parameters:
113-
params["max_words"] = recipe_config.get("max_words")
114-
if (not params["max_words"]) or not ((isinstance(params["max_words"], int)) & (params["max_words"] >= 1)):
154+
max_words = recipe_config.get("max_words")
155+
if (not max_words) or not ((isinstance(max_words, int)) and (max_words >= 1)):
115156
raise PluginParamValidationError("Maximum number of words is not a positive integer")
116-
logging.info(f"Max number of words: {params['max_words']}")
157+
params.max_words = max_words
158+
logging.info(f"Max number of words: {params.max_words}")
117159

118160
color_palette = recipe_config.get("color_palette")
119161
if not color_palette:
120162
raise PluginParamValidationError("Empty color palette selection")
121163
if color_palette == "custom":
122-
params["color_list"] = recipe_config.get("color_list")
123-
if not (isinstance(params["color_list"], list) & (len(params["color_list"]) >= 1)):
164+
color_list = recipe_config.get("color_list")
165+
if not (isinstance(color_list, list) and (len(color_list) >= 1)):
124166
raise PluginParamValidationError("Empty custom palette")
125-
if not all([matplotlib.colors.is_color_like(color) for color in params["color_list"]]):
126-
raise PluginParamValidationError(f"Invalid custom palette: {params['color_list']}")
127-
params["color_list"] = [matplotlib.colors.to_hex(color) for color in params["color_list"]]
128-
logging.info(f"Custom palette: {params['color_list']}")
167+
if not all([matplotlib.colors.is_color_like(color) for color in color_list]):
168+
raise PluginParamValidationError(f"Invalid custom palette: {color_list}")
169+
params.color_list = [matplotlib.colors.to_hex(color) for color in color_list]
170+
logging.info(f"Custom palette: {params.color_list}")
129171
else:
130172
if color_palette not in {builtin_palette["id"] for builtin_palette in DSS_BUILTIN_COLOR_PALETTES}:
131173
raise PluginParamValidationError(f"Unsupported color palette: {color_palette}")
132174
selected_palette_dict = [
133175
builtin_palette for builtin_palette in DSS_BUILTIN_COLOR_PALETTES if builtin_palette["id"] == color_palette
134176
][0]
135-
params["color_list"] = selected_palette_dict["colors"]
136-
logging.info(
137-
f"Using built-in DSS palette: '{selected_palette_dict['name']}' with colors: {params['color_list']}"
138-
)
177+
params.color_list = selected_palette_dict["colors"]
178+
logging.info(f"Using built-in DSS palette: '{selected_palette_dict['name']}' with colors: {params.color_list}")
139179

140-
return params
180+
return params, df

python-lib/wordcloud_visualizer.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,12 @@ def _count_tokens(self, docs: List[Doc]) -> List[Tuple[AnyStr, Dict]]:
247247
for doc in docs:
248248
counter = Counter()
249249
for token in doc:
250-
token_is_stopwords = token.is_stop if self.remove_stopwords else False
251-
token_is_punctuation = token.is_punct if self.remove_punctuation else False
252-
if not token_is_stopwords and not token_is_punctuation and not token.is_space:
253-
counter[token.text] += 1 # Equivalently, token.lemma_
250+
if not token.is_space:
251+
token_is_stopwords = token.is_stop if self.remove_stopwords else False
252+
if not token_is_stopwords:
253+
token_is_punctuation = token.is_punct if self.remove_punctuation else False
254+
if not token_is_punctuation:
255+
counter[token.text] += 1 # Equivalently, token.lemma_
254256
counters.append(counter)
255257

256258
if not self.subchart_column:
@@ -293,7 +295,11 @@ def generate_wordclouds(self, counts: List[Tuple[AnyStr, Dict]]) -> Generator[Tu
293295
wordcloud_title = f"{self.subchart_column}: {name}"
294296
# Generate chart
295297
if self.language_as_subchart:
296-
fig = self._generate_wordcloud(frequencies=count, language=name, title=wordcloud_title,)
298+
fig = self._generate_wordcloud(
299+
frequencies=count,
300+
language=name,
301+
title=wordcloud_title,
302+
)
297303
else:
298304
fig = self._generate_wordcloud(frequencies=count, language=self.language, title=wordcloud_title)
299305
# Return chart

0 commit comments

Comments
 (0)