Skip to content

Commit fa090da

Browse files
committed
add string calculator
1 parent a0c610c commit fa090da

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import numpy as np
2+
import textdistance
3+
import enum
4+
from src.exceptions import Error
5+
6+
7+
class DistanceType(enum.IntEnum):
8+
"""
9+
Defines the status of a `TimeStep` within a sequence.
10+
"""
11+
12+
# Denotes the first `TimeStep` in a sequence.
13+
COSINE = 0
14+
15+
16+
class TextDistanceCalculator(object):
17+
"""
18+
Wrapper class for text distance calculation
19+
"""
20+
21+
DISTANCE_TYPES = [DistanceType.COSINE, ]
22+
23+
@staticmethod
24+
def build_calculator(dist_type: DistanceType):
25+
26+
if dist_type not in TextDistanceCalculator.DISTANCE_TYPES:
27+
raise Error("Distance type '{0}' is invalid".format(str(dist_type)))
28+
29+
if dist_type == DistanceType.COSINE:
30+
return textdistance.Cosine()
31+
32+
def __init__(self, dist_type):
33+
34+
if dist_type not in TextDistanceCalculator.DISTANCE_TYPES:
35+
raise Error("Distance type '{0}' is invalid".format(dist_type))
36+
37+
self._dist_type = dist_type
38+
39+
def calculate(self, txt1, txt2, **options):
40+
41+
# build a calculator
42+
calculator = TextDistanceCalculator.build_calculator(dist_type=self._dist_type)
43+
44+
set_options = getattr(calculator, "set_options", None)
45+
46+
if set_options is not None:
47+
calculator.set_options(**options)
48+
49+
return calculator.similarity(txt1, txt2)
50+
51+
52+
53+
54+
55+
56+
57+
58+
59+

0 commit comments

Comments
 (0)