Skip to content

Commit 960fdda

Browse files
authored
Merge pull request #2530 from advikkabra/new-str-attributes
Improved str.count() algorithm using the Knuth-Morris-Pratt algorithm
2 parents 496f29e + 12c88e7 commit 960fdda

File tree

4 files changed

+138
-11
lines changed

4 files changed

+138
-11
lines changed

integration_tests/test_str_attributes.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,26 @@ def find():
7171
assert s2.find("we") == -1
7272
assert "".find("") == 0
7373

74+
def count():
75+
s: str
76+
sub: str
77+
s = "ABC ABCDAB ABCDABCDABDE"
78+
sub = "ABC"
79+
assert s.count(sub) == 4
80+
assert s.count("ABC") == 4
81+
82+
sub = "AB"
83+
assert s.count(sub) == 6
84+
assert s.count("AB") == 6
85+
86+
sub = "ABC"
87+
assert "ABC ABCDAB ABCDABCDABDE".count(sub) == 4
88+
assert "ABC ABCDAB ABCDABCDABDE".count("ABC") == 4
89+
90+
sub = "AB"
91+
assert "ABC ABCDAB ABCDABCDABDE".count(sub) == 6
92+
assert "ABC ABCDAB ABCDABCDABDE".count("AB") == 6
93+
7494

7595
def startswith():
7696
s: str
@@ -307,6 +327,7 @@ def check():
307327
strip()
308328
swapcase()
309329
find()
330+
count()
310331
startswith()
311332
endswith()
312333
partition()

src/libasr/asr_utils.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4242,6 +4242,44 @@ static inline int KMP_string_match(std::string &s_var, std::string &sub) {
42424242
return res;
42434243
}
42444244

4245+
static inline int KMP_string_match_count(std::string &s_var, std::string &sub) {
4246+
int str_len = s_var.size();
4247+
int sub_len = sub.size();
4248+
int count = 0;
4249+
std::vector<int> lps(sub_len, 0);
4250+
if (sub_len == 0) {
4251+
count = str_len + 1;
4252+
} else {
4253+
for(int i = 1, len = 0; i < sub_len;) {
4254+
if (sub[i] == sub[len]) {
4255+
lps[i++] = ++len;
4256+
} else {
4257+
if (len != 0) {
4258+
len = lps[len - 1];
4259+
} else {
4260+
lps[i++] = 0;
4261+
}
4262+
}
4263+
}
4264+
for (int i = 0, j = 0; (str_len - i) >= (sub_len - j);) {
4265+
if (sub[j] == s_var[i]) {
4266+
j++, i++;
4267+
}
4268+
if (j == sub_len) {
4269+
count++;
4270+
j = lps[j - 1];
4271+
} else if (i < str_len && sub[j] != s_var[i]) {
4272+
if (j != 0) {
4273+
j = lps[j - 1];
4274+
} else {
4275+
i = i + 1;
4276+
}
4277+
}
4278+
}
4279+
}
4280+
return count;
4281+
}
4282+
42454283
static inline void visit_expr_list(Allocator &al, Vec<ASR::call_arg_t>& exprs,
42464284
Vec<ASR::expr_t*>& exprs_vec) {
42474285
LCOMPILERS_ASSERT(exprs_vec.reserve_called);

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6872,13 +6872,13 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
68726872
}
68736873
} else if (attr_name == "find") {
68746874
if (args.size() != 1) {
6875-
throw SemanticError("str.find() takes one arguments",
6875+
throw SemanticError("str.find() takes one argument",
68766876
loc);
68776877
}
68786878
ASR::expr_t *arg = args[0].m_value;
68796879
ASR::ttype_t *type = ASRUtils::expr_type(arg);
68806880
if (!ASRUtils::is_character(*type)) {
6881-
throw SemanticError("str.find() takes one arguments of type: str",
6881+
throw SemanticError("str.find() takes one argument of type: str",
68826882
arg->base.loc);
68836883
}
68846884
if (ASRUtils::expr_value(arg) != nullptr) {
@@ -6905,6 +6905,41 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
69056905
tmp = make_call_helper(al, fn_div, current_scope, args, "_lpython_str_find", loc);
69066906
}
69076907
return;
6908+
} else if (attr_name == "count") {
6909+
if (args.size() != 1) {
6910+
throw SemanticError("str.count() takes one argument",
6911+
loc);
6912+
}
6913+
ASR::expr_t *arg = args[0].m_value;
6914+
ASR::ttype_t *type = ASRUtils::expr_type(arg);
6915+
if (!ASRUtils::is_character(*type)) {
6916+
throw SemanticError("str.count() takes one argument of type: str",
6917+
arg->base.loc);
6918+
}
6919+
if (ASRUtils::expr_value(arg) != nullptr) {
6920+
ASR::StringConstant_t* sub_str_con = ASR::down_cast<ASR::StringConstant_t>(arg);
6921+
std::string sub = sub_str_con->m_s;
6922+
int res = ASRUtils::KMP_string_match_count(s_var, sub);
6923+
tmp = ASR::make_IntegerConstant_t(al, loc, res,
6924+
ASRUtils::TYPE(ASR::make_Integer_t(al,loc, 4)));
6925+
} else {
6926+
ASR::symbol_t *fn_div = resolve_intrinsic_function(loc, "_lpython_str_count");
6927+
Vec<ASR::call_arg_t> args;
6928+
args.reserve(al, 1);
6929+
ASR::call_arg_t str_arg;
6930+
str_arg.loc = loc;
6931+
ASR::ttype_t *str_type = ASRUtils::TYPE(ASR::make_Character_t(al, loc,
6932+
1, s_var.size(), nullptr));
6933+
str_arg.m_value = ASRUtils::EXPR(
6934+
ASR::make_StringConstant_t(al, loc, s2c(al, s_var), str_type));
6935+
ASR::call_arg_t sub_arg;
6936+
sub_arg.loc = loc;
6937+
sub_arg.m_value = arg;
6938+
args.push_back(al, str_arg);
6939+
args.push_back(al, sub_arg);
6940+
tmp = make_call_helper(al, fn_div, current_scope, args, "_lpython_str_count", loc);
6941+
}
6942+
return;
69086943
} else if (attr_name == "rstrip") {
69096944
if (args.size() != 0) {
69106945
throw SemanticError("str.rstrip() takes no arguments",

src/runtime/lpython_builtin.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -637,17 +637,50 @@ def _lpython_str_capitalize(x: str) -> str:
637637

638638

639639
@overload
640-
def _lpython_str_count(x: str, y: str) -> i32:
641-
if(len(y) == 0): return len(x) + 1
640+
def _lpython_str_count(s: str, sub: str) -> i32:
641+
s_len :i32; sub_len :i32; flag: bool; _len: i32;
642+
count: i32; i: i32;
643+
lps: list[i32] = []
644+
s_len = len(s)
645+
sub_len = len(sub)
642646

643-
count: i32 = 0
644-
curr_char: str
645-
i: i32
647+
if sub_len == 0:
648+
return s_len + 1
649+
650+
count = 0
651+
652+
for i in range(sub_len):
653+
lps.append(0)
654+
655+
i = 1
656+
_len = 0
657+
while i < sub_len:
658+
if sub[i] == sub[_len]:
659+
_len += 1
660+
lps[i] = _len
661+
i += 1
662+
else:
663+
if _len != 0:
664+
_len = lps[_len - 1]
665+
else:
666+
lps[i] = 0
667+
i += 1
646668

647-
for i in range(len(x)):
648-
curr_char = x[i]
649-
if curr_char == y[0]:
650-
count += i32(x[i:i+len(y)] == y)
669+
j: i32
670+
j = 0
671+
i = 0
672+
while (s_len - i) >= (sub_len - j):
673+
if sub[j] == s[i]:
674+
i += 1
675+
j += 1
676+
if j == sub_len:
677+
count += 1
678+
j = lps[j - 1]
679+
elif i < s_len and sub[j] != s[i]:
680+
if j != 0:
681+
j = lps[j - 1]
682+
else:
683+
i = i + 1
651684

652685
return count
653686

0 commit comments

Comments
 (0)