Skip to content

Commit ded1788

Browse files
committed
#23 Add distortion calculator
1 parent 1819412 commit ded1788

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

src/utils/distortion_calculator.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
Utilities for dataset distortion calculation
3+
"""
4+
import enum
5+
from typing import TypeVar
6+
from src.utils.numeric_distance_type import NumericDistanceType
7+
from src.utils.numeric_distance_calculator import NumericDistanceCalculator
8+
from src.utils.string_distance_calculator import StringDistanceType, TextDistanceCalculator
9+
from src.exceptions.exceptions import InvalidParamValue
10+
11+
Vector = TypeVar('Vector')
12+
13+
14+
class DistortionCalculationType(enum.IntEnum):
15+
"""
16+
17+
"""
18+
19+
INVALID = -1
20+
SUM = 0
21+
AVG = 1
22+
23+
24+
class DistortionCalculator(object):
25+
26+
def __init__(self, numeric_column_distortion_metric_type: NumericDistanceType,
27+
string_column_distortion_metric_type: StringDistanceType,
28+
dataset_distortion_type: DistortionCalculationType):
29+
self.numeric_column_distortion_metric_type = numeric_column_distortion_metric_type
30+
self.string_column_distortion_metric_type = string_column_distortion_metric_type
31+
self.dataset_distortion_type = dataset_distortion_type
32+
33+
def calculate(self, vec1: Vector, vec2: Vector, datatype: str) -> float:
34+
35+
if datatype == 'str':
36+
return TextDistanceCalculator(dist_type=self.string_column_distortion_metric_type).calculate(txt1=vec1,
37+
txt2=vec2)
38+
elif datatype == 'float' or datatype == 'int':
39+
return NumericDistanceCalculator(dist_type=self.numeric_column_distortion_metric_type).calculate(state1=vec1,
40+
state2=vec2)
41+
raise InvalidParamValue(param_name='datatype', param_value=datatype)
42+
43+
def total_distortion(self, distortions: Vector) -> float:
44+
45+
if self.dataset_distortion_type == DistortionCalculationType.SUM:
46+
return float(sum(distortions))
47+
elif self.dataset_distortion_type == DistortionCalculationType.AVG:
48+
return float(sum(distortions) / len(distortions))
49+
50+
raise InvalidParamValue(param_name='dataset_distortion_type', param_value=self.dataset_distortion_type.name)
51+

0 commit comments

Comments
 (0)