Skip to content

Commit f3257c9

Browse files
authored
Merge pull request #33 from pockerman/add_scala_utilities
Add scala utilities
2 parents 10cba59 + 019ee50 commit f3257c9

File tree

8 files changed

+175
-28
lines changed

8 files changed

+175
-28
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/**
2+
* Investigate various output quality measures supplied by ARX
3+
*
4+
*/
5+
package examples.example_3
6+
7+
8+
import base.DefaultConfiguration
9+
import org.deidentifier.arx.Data
10+
import postprocessor.ResultPrinter.printHandleTop
11+
12+
//import scala.collection.JavaConversions._
13+
//import collection.convert.ImplicitConversionsToScala.map AsScala
14+
import collection.JavaConverters.* // asScala
15+
import collection.convert.ImplicitConversions.*
16+
import java.io.File
17+
import java.nio.charset.Charset
18+
19+
object MeasureDataQuality extends App{
20+
21+
def loadData: Tuple2[Data, Data] = {
22+
23+
val dataFileOrg: File = new File("/home/alex/qi3/drl_anonymity/src/examples/q_learn_distorted_sets/distorted_set_-1")
24+
val dataOrg: Data = Data.create(dataFileOrg, Charset.defaultCharset, ',')
25+
26+
val dataFileDist: File = new File("/home/alex/qi3/drl_anonymity/src/examples/q_learn_distorted_sets/distorted_set_-2")
27+
val dataDist: Data = Data.create(dataFileDist, Charset.defaultCharset, ',')
28+
29+
require(dataOrg.getHandle.getNumRows == dataDist.getHandle.getNumRows)
30+
require(dataOrg.getHandle.getNumColumns == dataDist.getHandle.getNumColumns)
31+
32+
// define the attribute types
33+
System.out.println(s"Number of rows ${dataOrg.getHandle.getNumRows}")
34+
System.out.println(s"Number of cols ${dataOrg.getHandle.getNumColumns}")
35+
36+
printHandleTop(handle = dataOrg.getHandle, n = 5)
37+
System.out.println("Done...")
38+
39+
(dataOrg, dataDist)
40+
}
41+
42+
def experiment1: Unit = {
43+
44+
val data = loadData
45+
46+
val dataHandleOrg = data._1.getHandle
47+
val dataHandleDist = data._2.getHandle
48+
49+
val summaryStatsDist = dataHandleDist.getStatistics().getSummaryStatistics(true)
50+
val summaryStatsOrg = dataHandleOrg.getStatistics().getSummaryStatistics(true)
51+
// getEquivalenceClassStatistics(); //getEquivalenceClassStatistics();
52+
53+
for((key, value) <- summaryStatsDist){
54+
println(s"Column: ${key}")
55+
println("-----------------------Distorted/Original")
56+
println(s"distinctNumberOfValues ${value.getNumberOfDistinctValuesAsString}/${summaryStatsOrg.get(key).getNumberOfDistinctValuesAsString}")
57+
println(s"Mode ${value.getModeAsString}/${summaryStatsOrg.get(key).getModeAsString}")
58+
if(value.isMaxAvailable) {
59+
println(s"Max ${value.getMaxAsString}/${summaryStatsOrg.get(key).getMaxAsString}")
60+
println(s"Min ${value.getMinAsString}/${summaryStatsOrg.get(key).getMinAsString}")
61+
}
62+
}
63+
}
64+
65+
def runKAnonimity: Unit = {
66+
67+
val data = loadData
68+
69+
// create the hierarchies for the ethnicity and
70+
// salary
71+
72+
}
73+
74+
// execute Experiment 1
75+
experiment1
76+
77+
}

scala_helpers/build.sbt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
name := "data_anonymizer_scala"
2+
3+
version := "0.1"
4+
5+
scalaVersion := "3.0.2"
6+
7+
libraryDependencies += "org.scalactic" %% "scalactic" % "3.2.10"
8+
libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.10" % "test"

src/algorithms/q_learning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def play(self, env: Env, stop_criterion: Criterion) -> None:
7474

7575
# set the q_table for the policy
7676
self.config.policy.q_table = self.q_table
77-
total_dist = env.total_average_current_distortion()
77+
total_dist = env.total_current_distortion()
7878
while stop_criterion.continue_itr(total_dist):
7979

8080
if stop_criterion.iteration_counter == 12:
@@ -87,7 +87,7 @@ def play(self, env: Env, stop_criterion: Criterion) -> None:
8787
print("{0} At state={1} with distortion={2} select action={3}".format("INFO: ", state_idx, total_dist,
8888
action.column_name + "-" + action.action_type.name))
8989
env.step(action=action)
90-
total_dist = env.total_average_current_distortion()
90+
total_dist = env.total_current_distortion()
9191

9292
def train(self, env: Env, **options) -> tuple:
9393

src/datasets/dataset_wrapper.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def read(self, filename: Path, **options) -> None:
2929

3030

3131
class PandasDSWrapper(DSWrapper[pd.DataFrame]):
32-
3332
"""
3433
Simple wrapper to a pandas DataFrame object.
3534
Facilitates various actions on the original dataset
@@ -60,15 +59,15 @@ def n_columns(self) -> int:
6059
def schema(self) -> dict:
6160
return pd.io.json.build_table_schema(self.ds)
6261

63-
def save_to_csv(self, filename: Path) -> None:
62+
def save_to_csv(self, filename: Path, save_index: bool) -> None:
6463
"""
6564
Save the underlying dataset in a csv format
6665
:param filename:
6766
:return:
6867
"""
69-
self.ds.to_csv(filename)
68+
self.ds.to_csv(filename, index=save_index)
7069

71-
def read(self, filename: Path, **options) -> None:
70+
def read(self, filename: Path, **options) -> None:
7271
"""
7372
Load a data set from a file
7473
:param filename:
@@ -145,14 +144,14 @@ def get_column(self, col_name: str):
145144
return self.ds.loc[:, col_name]
146145

147146
def get_column_unique_values(self, col_name: str):
148-
"""
147+
"""
149148
Returns the unique values for the column
150149
:param col_name:
151150
:return:
152151
"""
153-
col = self.get_column(col_name=col_name)
154-
vals = col.values.ravel()
155-
return pd.unique(vals)
152+
col = self.get_column(col_name=col_name)
153+
vals = col.values.ravel()
154+
return pd.unique(vals)
156155

157156
def get_columns_types(self):
158157
return list(self.ds.dtypes)
@@ -181,8 +180,3 @@ def apply_column_transform(self, column_name: str, transform: Transform) -> None
181180
column = self.get_column(col_name=column_name)
182181
column = transform.act(**{"data": column.values})
183182
self.ds[transform.column_name] = column
184-
185-
186-
187-
188-

src/examples/__init__.py

Whitespace-only changes.
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
This example shows how to create hierarchies suitable to
3+
be loaded in the ARX tool
4+
"""
5+
import csv
6+
from src.datasets.datasets_loaders import MockSubjectsLoader
7+
8+
9+
def get_ethnicity_hierarchy():
10+
11+
ethnicity_hierarchy = {}
12+
13+
ethnicity_hierarchy["Mixed White/Asian"] = ["White/Asian", "Mixed"]
14+
ethnicity_hierarchy["Chinese"] = ["Asian", "Asian"]
15+
ethnicity_hierarchy["Indian"] = ["Asian", "Asian"]
16+
ethnicity_hierarchy["Mixed White/Black African"] = ["White/Black", "Mixed"]
17+
ethnicity_hierarchy["Black African"] = ["Black", "African"]
18+
ethnicity_hierarchy["Asian other"] = ["Asian", "Other"]
19+
ethnicity_hierarchy["Black other"] = ["Black", "Other"]
20+
ethnicity_hierarchy["Mixed White/Black Caribbean"] = ["White/Black", "Mixed"]
21+
ethnicity_hierarchy["Mixed other"] = ["Mixed", "Mixe"]
22+
ethnicity_hierarchy["Arab"] = ["Asian", "Asian"]
23+
ethnicity_hierarchy["White Irish"] = ["Irish", "European"]
24+
ethnicity_hierarchy["Not stated"] = ["Not stated", "Not stated"]
25+
ethnicity_hierarchy["White Gypsy/Traveller"] = ["White", "White"]
26+
ethnicity_hierarchy["White British"] = ["British", "European"]
27+
ethnicity_hierarchy["Bangladeshi"] = ["Asian", "Asian"]
28+
ethnicity_hierarchy["White other"] = ["White", "White"]
29+
ethnicity_hierarchy["Black Caribbean"] = ["Black", "Caribbean"]
30+
ethnicity_hierarchy["Pakistani"] = ["Asian", "Asian"]
31+
32+
return ethnicity_hierarchy
33+
34+
35+
if __name__ == '__main__':
36+
37+
# specify the columns to drop
38+
drop_columns = MockSubjectsLoader.FEATURES_DROP_NAMES + ["preventative_treatment", "gender",
39+
"education", "mutation_status"]
40+
MockSubjectsLoader.FEATURES_DROP_NAMES = drop_columns
41+
42+
# do a salary normalization
43+
MockSubjectsLoader.NORMALIZED_COLUMNS = ["salary"]
44+
45+
# specify the columns to use
46+
MockSubjectsLoader.COLUMNS_TYPES = {"ethnicity": str, "salary": float, "diagnosis": int}
47+
ds = MockSubjectsLoader()
48+
49+
ehnicity_map = get_ethnicity_hierarchy()
50+
# get the ethincity column loop over
51+
# the values and create the hierarchy file
52+
filename = "/home/alex/qi3/drl_anonymity/data/hierarchies/ethnicity_hierarchy.csv"
53+
with open(filename, 'w') as fh:
54+
writer = csv.writer(fh, delimiter=",")
55+
56+
ethnicity_column = ds.get_column(col_name="ethnicity").values
57+
58+
for val in ethnicity_column:
59+
60+
if val not in ehnicity_map:
61+
raise ValueError("Value {0} not in ethnicity map")
62+
63+
row = [val]
64+
row.extend(ehnicity_map[val])
65+
writer.writerow(row)
66+

src/examples/qlearning_three_columns.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,10 @@ def get_ethinicity_hierarchy():
151151
# create the environment
152152
env = DiscreteStateEnvironment(env_config=env_config)
153153
env.reset()
154-
env.save_current_dataset(episode_index=-1)
155154

156-
# save the original dataset for comparison
157-
env.save_current_dataset(episode_index=-1)
158-
env.reset()
155+
# save the data before distortion so that we can
156+
# later load it on ARX
157+
env.save_current_dataset(episode_index=-1, save_index=False)
159158

160159
# configuration for the Q-learner
161160
algo_config = QLearnConfig()
@@ -195,7 +194,8 @@ def get_ethinicity_hierarchy():
195194

196195
stop_criterion = IterationControl(n_itrs=10, min_dist=MIN_DISTORTION, max_dist=MAX_DISTORTION)
197196
agent.play(env=env, stop_criterion=stop_criterion)
198-
env.save_current_dataset(episode_index=-2)
199-
197+
env.save_current_dataset(episode_index=-2, save_index=False)
198+
print("{0} Done....".format(INFO))
199+
print("=============================================")
200200

201201

src/spaces/discrete_state_environment.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,16 @@ def n_states(self) -> int:
128128
def get_action(self, aidx: int) -> ActionBase:
129129
return self.config.action_space[aidx]
130130

131-
def save_current_dataset(self, episode_index: int) -> None:
131+
def save_current_dataset(self, episode_index: int, save_index: bool = False) -> None:
132132
"""
133133
Save the current distorted datase for the given episode index
134134
:param episode_index:
135+
:param save_index:
135136
:return:
136137
"""
137138
self.distorted_data_set.save_to_csv(
138-
filename=Path(str(self.config.distorted_set_path) + "_" + str(episode_index)))
139+
filename=Path(str(self.config.distorted_set_path) + "_" + str(episode_index)),
140+
save_index=save_index)
139141

140142
def create_bins(self) -> None:
141143
"""
@@ -216,15 +218,14 @@ def apply_action(self, action: ActionBase):
216218

217219
self.column_distances[action.column_name] = distance
218220

219-
def total_average_current_distortion(self) -> float:
221+
def total_current_distortion(self) -> float:
220222
"""
221-
Calculates the average total distortion of the dataset
222-
by summing over the current computed distances for each column
223+
Calculates the current total distortion of the dataset.
223224
:return:
224225
"""
225226

226227
return self.config.distortion_calculator.total_distortion(
227-
list(self.column_distances.values())) # float(np.mean(list(self.column_distances.values())))
228+
list(self.column_distances.values()))
228229

229230
def reset(self, **options) -> TimeStep:
230231
"""
@@ -270,7 +271,7 @@ def step(self, action: ActionBase) -> TimeStep:
270271
self.apply_action(action=action)
271272

272273
# calculate the distortion of the dataset
273-
current_distortion = self.total_average_current_distortion()
274+
current_distortion = self.total_current_distortion()
274275

275276
# get the reward for the current distortion
276277
reward = self.config.reward_manager.get_reward_for_state(state=current_distortion, **{"action": action})
@@ -312,6 +313,7 @@ def step(self, action: ActionBase) -> TimeStep:
312313

313314
# TODO: these modifications will cause the agent to always
314315
# move close to transition points
316+
# TODO: Remove the magic constants
315317
if next_state is not None and self.current_time_step.observation is not None:
316318
if next_state < min_dist_bin <= self.current_time_step.observation:
317319
# the agent chose to step into the chaos again

0 commit comments

Comments
 (0)