44"""
55
66import copy
7- import enum
87import numpy as np
98from pathlib import Path
10- import pandas as pd
11- import torch
12- from typing import NamedTuple , Generic , Optional , TypeVar , List
9+ from typing import TypeVar , List
1310import multiprocessing as mp
1411
1512from src .spaces .actions import ActionBase , ActionType
16- from src .utils .string_distance_calculator import StringDistanceType , TextDistanceCalculator
17- from src .utils .numeric_distance_type import NumericDistanceType
18- from src .utils .numeric_distance_calculator import NumericDistanceCalculator
13+ from src .spaces .time_step import TimeStep , StepType
1914
2015DataSet = TypeVar ("DataSet" )
2116RewardManager = TypeVar ("RewardManager" )
2217ActionSpace = TypeVar ("ActionSpace" )
2318DistortionCalculator = TypeVar ('DistortionCalculator' )
2419
25- _Reward = TypeVar ('_Reward' )
26- _Discount = TypeVar ('_Discount' )
27- _Observation = TypeVar ('_Observation' )
28-
29-
30- class StepType (enum .IntEnum ):
31- """
32- Defines the status of a `TimeStep` within a sequence.
33- """
34-
35- # Denotes the first `TimeStep` in a sequence.
36- FIRST = 0
37-
38- # Denotes any `TimeStep` in a sequence that is not FIRST or LAST.
39- MID = 1
40-
41- # Denotes the last `TimeStep` in a sequence.
42- LAST = 2
43-
44- def first (self ) -> bool :
45- return self is StepType .FIRST
46-
47- def mid (self ) -> bool :
48- return self is StepType .MID
49-
50- def last (self ) -> bool :
51- return self is StepType .LAST
52-
53-
54- class TimeStep (NamedTuple , Generic [_Reward , _Discount , _Observation ]):
55- step_type : StepType
56- info : dict
57- reward : Optional [_Reward ]
58- discount : Optional [_Discount ]
59- observation : _Observation
60-
61- def first (self ) -> bool :
62- return self .step_type == StepType .FIRST
63-
64- def mid (self ) -> bool :
65- return self .step_type == StepType .MID
66-
67- def last (self ) -> bool :
68- return self .step_type == StepType .LAST
69-
7020
7121class DiscreteEnvConfig (object ):
7222 """
@@ -79,8 +29,6 @@ def __init__(self) -> None:
7929 self .reward_manager : RewardManager = None
8030 self .average_distortion_constraint : float = 0.0
8131 self .gamma : float = 0.99
82- # self.string_column_distortion_type: StringDistanceType = StringDistanceType.INVALID
83- # self.numeric_column_distortion_metric_type: NumericDistanceType = NumericDistanceType.INVALID
8432 self .n_states : int = 10
8533 self .min_distortion : float = 0.4
8634 self .max_distortion : float = 0.7
@@ -115,6 +63,10 @@ def __init__(self, env_config: DiscreteEnvConfig) -> None:
11563 self .column_visits = {}
11664 self .create_bins ()
11765
66+ @property
67+ def columns_attribute_types (self ) -> dict :
68+ return self .config .data_set .columns_attribute_types
69+
11870 @property
11971 def action_space (self ):
12072 return self .config .action_space
0 commit comments