Skip to content

Commit 855bade

Browse files
committed
Add the function of concatenating to crops after detection.
1 parent bfb030d commit 855bade

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

deploy/py_infer/src/infer_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ def get_args():
119119
"--show_log", type=str2bool, default=False, required=False, help="Whether show log when inferring."
120120
)
121121
parser.add_argument("--save_log_dir", type=str, required=False, help="Log saving dir.")
122+
parser.add_argument(
123+
"--is_concat", type=str2bool, default=False, help="Whether to concatenate crops after the detection."
124+
)
122125

123126
args = parser.parse_args()
124127
setup_logger(args)

deploy/py_infer/src/parallel/module/detection/det_post_node.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import cv2
12
import numpy as np
23

34
from ....data_process.utils import cv_utils
@@ -10,19 +11,44 @@ def __init__(self, args, msg_queue):
1011
super(DetPostNode, self).__init__(args, msg_queue)
1112
self.text_detector = None
1213
self.task_type = self.args.task_type
14+
self.is_concat = self.args.is_concat
1315

1416
def init_self_args(self):
1517
self.text_detector = TextDetector(self.args)
1618
self.text_detector.init(preprocess=False, model=False, postprocess=True)
1719
super().init_self_args()
1820

21+
def concat_crops(self, crops: list):
22+
"""
23+
Concatenates the list of cropped images horizontally after resizing them to have the same height.
24+
25+
Args:
26+
crops (list): A list of cropped images represented as numpy arrays.
27+
28+
Returns:
29+
numpy.ndarray: A horizontally concatenated image array.
30+
"""
31+
max_height = max(crop.shape[0] for crop in crops)
32+
resized_crops = []
33+
for crop in crops:
34+
h, w, c = crop.shape
35+
new_h = max_height
36+
new_w = int((w / h) * new_h)
37+
38+
resized_img = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
39+
resized_crops.append(resized_img)
40+
crops_concated = np.concatenate(resized_crops, axis=1)
41+
return crops_concated
42+
1943
def process(self, input_data):
2044
if input_data.skip:
2145
self.send_to_next_module(input_data)
2246
return
2347

2448
data = input_data.data
2549
boxes = self.text_detector.postprocess(data["pred"], data["shape_list"])
50+
if self.is_concat:
51+
boxes = sorted(boxes, key=lambda points: (points[0][1], points[0][0]))
2652

2753
infer_res_list = []
2854
for box in boxes:
@@ -39,6 +65,8 @@ def process(self, input_data):
3965
for box in infer_res_list:
4066
sub_image = cv_utils.crop_box_from_image(image, np.array(box))
4167
sub_image_list.append(sub_image)
68+
if self.is_concat:
69+
sub_image_list = [self.concat_crops(sub_image_list)]
4270
input_data.sub_image_list = sub_image_list
4371

4472
input_data.data = None

deploy/py_infer/src/parallel/module/recognition/rec_post_node.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ def __init__(self, args, msg_queue):
77
super(RecPostNode, self).__init__(args, msg_queue)
88
self.text_recognizer = None
99
self.task_type = self.args.task_type
10+
self.is_concat = self.args.is_concat
1011

1112
def init_self_args(self):
1213
self.text_recognizer = TextRecognizer(self.args)
@@ -28,9 +29,13 @@ def process(self, input_data):
2829
else:
2930
texts = output["texts"]
3031
confs = output["confs"]
31-
for result, text, conf in zip(input_data.infer_result, texts, confs):
32-
result.append(text)
33-
result.append(conf)
32+
for i, result in enumerate(input_data.infer_result):
33+
if self.is_concat:
34+
result.append(texts[0])
35+
result.append(confs[0])
36+
else:
37+
result.append(texts[i])
38+
result.append(confs[i])
3439

3540
input_data.data = None
3641

0 commit comments

Comments
 (0)