diff --git a/.gitignore b/.gitignore index 8c79b505c..95d6ccf58 100644 --- a/.gitignore +++ b/.gitignore @@ -215,6 +215,9 @@ log/ .pyenv2/* .pyenv3/* +# Generated parameter library files +**/vision_parameters.py + ansible_robots/* doc_internal/* doku/* diff --git a/bitbots_vision/bitbots_vision/params.py b/bitbots_vision/bitbots_vision/params.py deleted file mode 100755 index 3fc386011..000000000 --- a/bitbots_vision/bitbots_vision/params.py +++ /dev/null @@ -1,111 +0,0 @@ -from rcl_interfaces.msg import FloatingPointRange, IntegerRange, ParameterDescriptor, ParameterType - - -class ParameterGenerator: # TODO own file - def __init__(self): - self.param_cache = [] - - def declare_params(self, node): - for param in self.param_cache: - node.declare_parameter(*param) - - def add(self, param_name, param_type=None, default=None, description=None, min=None, max=None, step=None): - describtor = ParameterDescriptor() - describtor.name = param_name - if description is None: - describtor.description = param_name - else: - describtor.description = description - - if param_type is None and default is not None: - param_type = type(default) - - py2ros_param_type = { - None: ParameterType.PARAMETER_NOT_SET, - bool: ParameterType.PARAMETER_BOOL, - int: ParameterType.PARAMETER_INTEGER, - float: ParameterType.PARAMETER_DOUBLE, - str: ParameterType.PARAMETER_STRING, - } - - param_type = py2ros_param_type.get(param_type, param_type) - - describtor.type = param_type - - if param_type == ParameterType.PARAMETER_INTEGER: - if step is None: - step = 1 - if all(x is not None or isinstance(x, int) for x in [min, max, step]): - param_range = IntegerRange() - param_range.from_value = min - param_range.to_value = max - param_range.step = step - describtor.integer_range = [param_range] - - if param_type == ParameterType.PARAMETER_DOUBLE: - if step is None: - step = 0.01 - if all(x is not None for x in [min, max]): - param_range = FloatingPointRange() - param_range.from_value = float(min) - param_range.to_value = float(max) - param_range.step = float(step) - describtor.floating_point_range = [param_range] - - type2default_default = { - ParameterType.PARAMETER_NOT_SET: 0, - ParameterType.PARAMETER_BOOL: False, - ParameterType.PARAMETER_INTEGER: 0, - ParameterType.PARAMETER_DOUBLE: 0.0, - ParameterType.PARAMETER_STRING: "", - } - - if default is None: - default = type2default_default[param_type] - - self.param_cache.append((param_name, default, describtor)) - - -gen = ParameterGenerator() - -########## -# Params # -########## - -gen.add("component_ball_detection_active", bool, description="Activate/Deactivate the ball detection component") -gen.add("component_debug_image_active", bool, description="Activate/Deactivate the debug image component") -gen.add("component_field_detection_active", bool, description="Activate/Deactivate the field detection component") -gen.add("component_goalpost_detection_active", bool, description="Activate/Deactivate the goalpost detection component") -gen.add("component_line_detection_active", bool, description="Activate/Deactivate the line detection component") -gen.add("component_robot_detection_active", bool, description="Activate/Deactivate the robot detection component") - -gen.add("ROS_img_msg_topic", str, description="ROS topic of the image message") -gen.add("ROS_ball_msg_topic", str, description="ROS topic of the ball message") -gen.add("ROS_goal_posts_msg_topic", str, description="ROS topic of the goal posts message") -gen.add("ROS_robot_msg_topic", str, description="ROS topic of the robots message") -gen.add("ROS_line_msg_topic", str, description="ROS topic of the line message") -gen.add("ROS_line_mask_msg_topic", str, description="ROS topic of the line mask message") -gen.add("ROS_debug_image_msg_topic", str, description="ROS topic of the debug image message") -gen.add("ROS_field_mask_image_msg_topic", str, description="ROS topic of the field mask debug image message") - -gen.add("yoeo_model_path", str, description="Name of YOEO model") -gen.add("yoeo_nms_threshold", float, description="YOEO Non-maximum suppression threshold", min=0.0, max=1.0) -gen.add("yoeo_conf_threshold", float, description="YOEO confidence threshold", min=0.0, max=1.0) -gen.add( - "yoeo_framework", - str, - description="The neural network framework that should be used ['pytorch', 'openvino', 'onnx', 'tvm']", -) - -gen.add( - "ball_candidate_rating_threshold", - float, - description="A threshold for the minimum candidate rating", - min=0.0, - max=1.0, -) -gen.add( - "ball_candidate_max_count", int, description="The maximum number of balls that should be published", min=0, max=50 -) - -gen.add("caching", bool, description="Used to deactivate caching for profiling reasons") diff --git a/bitbots_vision/bitbots_vision/vision.py b/bitbots_vision/bitbots_vision/vision.py index 9cc31131a..cb4a26b56 100755 --- a/bitbots_vision/bitbots_vision/vision.py +++ b/bitbots_vision/bitbots_vision/vision.py @@ -1,18 +1,15 @@ #! /usr/bin/env python3 -from copy import deepcopy from typing import Optional import rclpy from ament_index_python.packages import get_package_share_directory from cv_bridge import CvBridge from rcl_interfaces.msg import SetParametersResult -from rclpy.experimental.events_executor import EventsExecutor from rclpy.node import Node from sensor_msgs.msg import Image from bitbots_vision.vision_modules import debug, ros_utils, yoeo - -from .params import gen +from bitbots_vision.vision_parameters import bitbots_vision as parameters logger = rclpy.logging.get_logger("bitbots_vision") @@ -39,11 +36,14 @@ def __init__(self) -> None: logger.debug(f"Entering {self.__class__.__name__} constructor") + # Setup parameter listener directly + self.param_listener = parameters.ParamListener(self) + self.config = self.param_listener.get_params() + self._package_path = get_package_share_directory("bitbots_vision") yoeo.YOEOObjectManager.set_package_directory(self._package_path) - self._config: dict = {} self._cv_bridge = CvBridge() self._sub_image = None @@ -51,8 +51,7 @@ def __init__(self) -> None: self._vision_components: list[yoeo.AbstractVisionComponent] = [] self._debug_image: Optional[debug.DebugImage] = None - # Setup reconfiguration - gen.declare_params(self) + # Setup reconfiguration callback self.add_on_set_parameters_callback(self._dynamic_reconfigure_callback) # Add general params @@ -61,7 +60,8 @@ def __init__(self) -> None: # Update team color ros_utils.update_own_team_color(self) - self._dynamic_reconfigure_callback(self.get_parameters_by_prefix("").values()) + # Configure vision with initial parameters + self._configure_vision(self.config) logger.debug(f"Leaving {self.__class__.__name__} constructor") @@ -69,24 +69,20 @@ def _dynamic_reconfigure_callback(self, params) -> SetParametersResult: """ Callback for the dynamic reconfigure configuration. - :param dict params: new config + :param params: list of changed parameters """ - new_config = self._get_updated_config_with(params) - self._configure_vision(new_config) - self._config = new_config + # Update the config from the parameter listener + self.config = self.param_listener.get_params() + + # Configure vision with the updated config + self._configure_vision(self.config) return SetParametersResult(successful=True) - def _get_updated_config_with(self, params) -> dict: - new_config = deepcopy(self._config) - for param in params: - new_config[param.name] = param.value - return new_config - - def _configure_vision(self, new_config: dict) -> None: - yoeo.YOEOObjectManager.configure(new_config) + def _configure_vision(self, config) -> None: + yoeo.YOEOObjectManager.configure(config.yoeo) - debug_image = debug.DebugImage(new_config["component_debug_image_active"]) + debug_image = debug.DebugImage(config.component_debug_image_active) self._debug_image = debug_image def make_vision_component( @@ -96,39 +92,45 @@ def make_vision_component( node=self, yoeo_handler=yoeo.YOEOObjectManager.get(), debug_image=debug_image, - config=new_config, + config=config, # Now passing config object directly **kwargs, ) self._vision_components = [make_vision_component(yoeo.YOEOComponent)] - if new_config["component_ball_detection_active"]: + if config.component_ball_detection_active: self._vision_components.append(make_vision_component(yoeo.BallDetectionComponent)) - if new_config["component_robot_detection_active"]: + if config.component_robot_detection_active: self._vision_components.append( make_vision_component( yoeo.RobotDetectionComponent, team_color_detection_supported=yoeo.YOEOObjectManager.is_team_color_detection_supported(), ) ) - if new_config["component_goalpost_detection_active"]: + if config.component_goalpost_detection_active: self._vision_components.append(make_vision_component(yoeo.GoalpostDetectionComponent)) - if new_config["component_line_detection_active"]: + if config.component_line_detection_active: self._vision_components.append(make_vision_component(yoeo.LineDetectionComponent)) - if new_config["component_field_detection_active"]: + if config.component_field_detection_active: self._vision_components.append(make_vision_component(yoeo.FieldDetectionComponent)) - if new_config["component_debug_image_active"]: + if config.component_debug_image_active: self._vision_components.append(make_vision_component(yoeo.DebugImageComponent)) - self._sub_image = ros_utils.create_or_update_subscriber( + # For the subscriber update, use the improved ros_utils function + old_topic = getattr(self, '_last_img_topic', None) + current_topic = config.ROS_img_msg_topic + + self._sub_image = ros_utils.create_or_update_subscriber_with_config( self, - self._config, - new_config, + old_topic, + current_topic, self._sub_image, - "ROS_img_msg_topic", Image, - callback=self._run_vision_pipeline, + self._run_vision_pipeline, ) + + # Remember this topic for next time + self._last_img_topic = current_topic @profile def _run_vision_pipeline(self, image_msg: Image) -> None: diff --git a/bitbots_vision/bitbots_vision/vision_modules/ros_utils.py b/bitbots_vision/bitbots_vision/vision_modules/ros_utils.py index 53a41fcad..5ee1b3b3c 100644 --- a/bitbots_vision/bitbots_vision/vision_modules/ros_utils.py +++ b/bitbots_vision/bitbots_vision/vision_modules/ros_utils.py @@ -46,6 +46,32 @@ class RobotColor(Enum): own_team_color: RobotColor = RobotColor.UNKNOWN +def create_or_update_subscriber_with_config( + node, old_topic, new_topic, subscriber_object, data_class, callback, qos_profile=1, callback_group=None +): + """ + Creates or updates a subscriber using direct topic names instead of config dicts + + :param node: ROS node to which the publisher is bound + :param old_topic: Previous topic name + :param new_topic: New topic name + :param subscriber_object: The python object, that represents the subscriber + :param data_class: Data type class for ROS messages of the topic we want to subscribe + :param callback: The subscriber callback function + :param qos_profile: A QoSProfile or a history depth to apply to the subscription. + :param callback_group: The callback group for the subscription. + :return: adjusted subscriber object + """ + # Check if topic has changed + if old_topic != new_topic: + # Create the new subscriber + subscriber_object = node.create_subscription( + data_class, new_topic, callback, qos_profile, callback_group=callback_group + ) + logger.debug("Registered new subscriber at " + str(new_topic)) + return subscriber_object + + def create_or_update_subscriber( node, old_config, new_config, subscriber_object, topic_key, data_class, callback, qos_profile=1, callback_group=None ): diff --git a/bitbots_vision/bitbots_vision/vision_modules/yoeo/object_manager.py b/bitbots_vision/bitbots_vision/vision_modules/yoeo/object_manager.py index 3aadfb974..cca167f8f 100644 --- a/bitbots_vision/bitbots_vision/vision_modules/yoeo/object_manager.py +++ b/bitbots_vision/bitbots_vision/vision_modules/yoeo/object_manager.py @@ -4,6 +4,7 @@ import rclpy from bitbots_vision.vision_modules import ros_utils +from bitbots_vision.vision_parameters import bitbots_vision as parameters from . import yoeo_handlers from .model_config import ModelConfig, ModelConfigLoader @@ -72,19 +73,19 @@ def is_team_color_detection_supported(cls) -> bool: return cls._model_config.team_colors_are_provided() @classmethod - def configure(cls, config: dict) -> None: + def configure(cls, yoeo_config) -> None: if not cls._package_directory_set: logger.error("Package directory not set!") - framework = config["yoeo_framework"] + framework = yoeo_config.framework cls._verify_framework_parameter(framework) - model_path = cls._get_full_model_path(config["yoeo_model_path"]) + model_path = cls._get_full_model_path(yoeo_config.model_path) cls._verify_required_neural_network_files_exist(framework, model_path) - cls._configure_yoeo_instance(config, framework, model_path) + cls._configure_yoeo_instance(yoeo_config, framework, model_path) - cls._config = config + cls._config = yoeo_config cls._framework = framework cls._model_path = model_path @@ -107,13 +108,13 @@ def _model_files_exist(cls, framework: str, model_path: str) -> bool: return cls._HANDLERS_BY_NAME[framework].model_files_exist(model_path) @classmethod - def _configure_yoeo_instance(cls, config: dict, framework: str, model_path: str) -> None: + def _configure_yoeo_instance(cls, yoeo_config, framework: str, model_path: str) -> None: if cls._new_yoeo_handler_is_needed(framework, model_path): cls._load_model_config(model_path) cls._instantiate_new_yoeo_handler(config, framework, model_path) - elif cls._yoeo_parameters_have_changed(config): + elif cls._yoeo_parameters_have_changed(yoeo_config): assert cls._yoeo_instance is not None, "YOEO handler instance not set!" - cls._yoeo_instance.configure(config) + cls._yoeo_instance.configure(yoeo_config) @classmethod def _new_yoeo_handler_is_needed(cls, framework: str, model_path: str) -> bool: @@ -124,9 +125,9 @@ def _load_model_config(cls, model_path: str) -> None: cls._model_config = ModelConfigLoader.load_from(model_path) @classmethod - def _instantiate_new_yoeo_handler(cls, config: dict, framework: str, model_path: str) -> None: + def _instantiate_new_yoeo_handler(cls, yoeo_config, framework: str, model_path: str) -> None: cls._yoeo_instance = cls._HANDLERS_BY_NAME[framework]( - config, + yoeo_config, model_path, cls._model_config.get_detection_classes(), cls._model_config.get_robot_class_ids(), @@ -135,5 +136,9 @@ def _instantiate_new_yoeo_handler(cls, config: dict, framework: str, model_path: logger.info(f"Using {cls._yoeo_instance.__class__.__name__}") @classmethod - def _yoeo_parameters_have_changed(cls, new_config: dict) -> bool: - return ros_utils.config_param_change(cls._config, new_config, r"yoeo_") + def _yoeo_parameters_have_changed(cls, new_yoeo_config) -> bool: + if cls._config is None: + return True + + # Compare YOEO parameters using the hierarchical structure + return cls._config != new_yoeo_config diff --git a/bitbots_vision/bitbots_vision/vision_modules/yoeo/vision_components.py b/bitbots_vision/bitbots_vision/vision_modules/yoeo/vision_components.py index 2fd5d0a1f..74cc7663a 100644 --- a/bitbots_vision/bitbots_vision/vision_modules/yoeo/vision_components.py +++ b/bitbots_vision/bitbots_vision/vision_modules/yoeo/vision_components.py @@ -57,7 +57,7 @@ def __init__( ): super().__init__(node, yoeo_handler, debug_image, config) - self._publisher = self._node.create_publisher(BallArray, self._config["ROS_ball_msg_topic"], qos_profile=1) + self._publisher = self._node.create_publisher(BallArray, self._config.ROS_ball_msg_topic, qos_profile=1) def run(self, image: np.ndarray, header: Header) -> None: # Get all ball candidates from YOEO @@ -65,9 +65,9 @@ def run(self, image: np.ndarray, header: Header) -> None: # Filter candidates by rating and count candidates = candidate.Candidate.sort_candidates(candidates) - top_candidates = candidates[: self._config["ball_candidate_max_count"]] + top_candidates = candidates[: self._config.ball_candidate_max_count] final_candidates = candidate.Candidate.rating_threshold( - top_candidates, self._config["ball_candidate_rating_threshold"] + top_candidates, self._config.ball_candidate_rating_threshold ) # Publish ball candidates @@ -95,7 +95,7 @@ def __init__( super().__init__(node, yoeo_handler, debug_image, config) self._publisher = self._node.create_publisher( - GoalpostArray, self._config["ROS_goal_posts_msg_topic"], qos_profile=1 + GoalpostArray, self._config.ROS_goal_posts_msg_topic, qos_profile=1 ) def run(self, image: np.ndarray, header: Header) -> None: @@ -125,7 +125,7 @@ def __init__( self, node: Node, yoeo_handler: yoeo_handlers.IYOEOHandler, debug_image: debug.DebugImage, config: dict ): super().__init__(node, yoeo_handler, debug_image, config) - self._publisher = self._node.create_publisher(Image, self._config["ROS_line_mask_msg_topic"], qos_profile=1) + self._publisher = self._node.create_publisher(Image, self._config.ROS_line_mask_msg_topic, qos_profile=1) def run(self, image: np.ndarray, header: Header) -> None: # Get line mask from YOEO @@ -153,7 +153,7 @@ def __init__( ): super().__init__(node, yoeo_handler, debug_image, config) self._publisher = self._node.create_publisher( - Image, self._config["ROS_field_mask_image_msg_topic"], qos_profile=1 + Image, self._config.ROS_field_mask_image_msg_topic, qos_profile=1 ) def run(self, image: np.ndarray, header: Header) -> None: @@ -185,7 +185,7 @@ def __init__( super().__init__(node, yoeo_handler, debug_image, config) self._team_color_detection_supported = team_color_detection_supported - self._publisher = self._node.create_publisher(RobotArray, self._config["ROS_robot_msg_topic"], qos_profile=1) + self._publisher = self._node.create_publisher(RobotArray, self._config.ROS_robot_msg_topic, qos_profile=1) def run(self, image: np.ndarray, header: Header) -> None: robot_msgs: list[Robot] = [] @@ -282,7 +282,7 @@ def __init__( ): super().__init__(node, yoeo_handler, debug_image, config) - self._publisher = self._node.create_publisher(Image, self._config["ROS_debug_image_msg_topic"], qos_profile=1) + self._publisher = self._node.create_publisher(Image, self._config.ROS_debug_image_msg_topic, qos_profile=1) def run(self, image: np.ndarray, header: Header) -> None: debug_image_msg = ros_utils.build_image_msg(header, self._debug_image.get_image(), "bgr8") diff --git a/bitbots_vision/bitbots_vision/vision_modules/yoeo/yoeo_handlers.py b/bitbots_vision/bitbots_vision/vision_modules/yoeo/yoeo_handlers.py index 9409dc3ea..50061f285 100644 --- a/bitbots_vision/bitbots_vision/vision_modules/yoeo/yoeo_handlers.py +++ b/bitbots_vision/bitbots_vision/vision_modules/yoeo/yoeo_handlers.py @@ -9,6 +9,8 @@ import numpy as np import rclpy +from bitbots_vision.vision_parameters import bitbots_vision as parameters + from bitbots_vision.vision_modules.candidate import Candidate from . import utils @@ -22,7 +24,7 @@ class IYOEOHandler(ABC): """ @abstractmethod - def configure(self, config: dict) -> None: + def configure(self, yoeo_config) -> None: """ Allows to (re-) configure the YOEO handler. """ @@ -98,7 +100,7 @@ class YOEOHandlerTemplate(IYOEOHandler): def __init__( self, - config: dict, + yoeo_config, model_directory: str, det_class_names: list[str], det_robot_class_ids: list[int], @@ -117,12 +119,12 @@ def __init__( self._seg_class_names: list[str] = seg_class_names self._seg_masks: dict = dict() - self._use_caching: bool = config["caching"] + self._use_caching: bool = yoeo_config.caching logger.debug("Leaving YOEOHandlerTemplate constructor") - def configure(self, config: dict) -> None: - self._use_caching = config["caching"] + def configure(self, yoeo_config) -> None: + self._use_caching = yoeo_config.caching def get_available_detection_class_names(self) -> list[str]: return self._det_class_names @@ -211,7 +213,7 @@ class YOEOHandlerONNX(YOEOHandlerTemplate): def __init__( self, - config: dict, + yoeo_config, model_directory: str, det_class_names: list[str], det_robot_class_ids: list[int], @@ -238,8 +240,8 @@ def __init__( self._det_postprocessor: utils.IDetectionPostProcessor = utils.DefaultDetectionPostProcessor( image_preprocessor=self._img_preprocessor, output_img_size=self._input_layer.shape[2], - conf_thresh=config["yoeo_conf_threshold"], - nms_thresh=config["yoeo_nms_threshold"], + conf_thresh=yoeo_config.conf_threshold, + nms_thresh=yoeo_config.nms_threshold, robot_class_ids=self.get_robot_class_ids(), ) self._seg_postprocessor: utils.ISegmentationPostProcessor = utils.DefaultSegmentationPostProcessor( @@ -248,13 +250,13 @@ def __init__( logger.debug(f"Leaving {self.__class__.__name__} constructor") - def configure(self, config: dict) -> None: + def configure(self, yoeo_config) -> None: super().configure(config) self._det_postprocessor.configure( image_preprocessor=self._img_preprocessor, output_img_size=self._input_layer.shape[2], - conf_thresh=config["yoeo_conf_threshold"], - nms_thresh=config["yoeo_nms_threshold"], + conf_thresh=yoeo_config.conf_threshold, + nms_thresh=yoeo_config.nms_threshold, robot_class_ids=self.get_robot_class_ids(), ) @@ -284,7 +286,7 @@ class YOEOHandlerOpenVino(YOEOHandlerTemplate): def __init__( self, - config: dict, + yoeo_config, model_directory: str, det_class_names: list[str], det_robot_class_ids: list[int], @@ -320,8 +322,8 @@ def __init__( self._det_postprocessor: utils.IDetectionPostProcessor = utils.DefaultDetectionPostProcessor( image_preprocessor=self._img_preprocessor, output_img_size=self._input_layer.shape[2], - conf_thresh=config["yoeo_conf_threshold"], - nms_thresh=config["yoeo_nms_threshold"], + conf_thresh=yoeo_config.conf_threshold, + nms_thresh=yoeo_config.nms_threshold, robot_class_ids=self.get_robot_class_ids(), ) self._seg_postprocessor: utils.ISegmentationPostProcessor = utils.DefaultSegmentationPostProcessor( @@ -337,13 +339,13 @@ def _select_device(self) -> str: device = "CPU" return device - def configure(self, config: dict) -> None: + def configure(self, yoeo_config) -> None: super().configure(config) self._det_postprocessor.configure( image_preprocessor=self._img_preprocessor, output_img_size=self._input_layer.shape[2], - conf_thresh=config["yoeo_conf_threshold"], - nms_thresh=config["yoeo_nms_threshold"], + conf_thresh=yoeo_config.conf_threshold, + nms_thresh=yoeo_config.nms_threshold, robot_class_ids=self.get_robot_class_ids(), ) @@ -372,7 +374,7 @@ class YOEOHandlerPytorch(YOEOHandlerTemplate): def __init__( self, - config: dict, + yoeo_config, model_directory: str, det_class_names: list[str], det_robot_class_ids: list[int], @@ -398,8 +400,8 @@ def __init__( logger.debug(f"Loading files...\n\t{config_path}\n\t{weights_path}") self._model = torch_models.load_model(config_path, weights_path) - self._conf_thresh: float = config["yoeo_conf_threshold"] - self._nms_thresh: float = config["yoeo_nms_threshold"] + self._conf_thresh: float = yoeo_config.conf_threshold + self._nms_thresh: float = yoeo_config.nms_threshold self._group_config: torch_GroupConfig = self._update_group_config() logger.debug(f"Leaving {self.__class__.__name__} constructor") @@ -409,10 +411,10 @@ def _update_group_config(self): return self.torch_group_config(group_ids=robot_class_ids, surrogate_id=robot_class_ids[0]) - def configure(self, config: dict) -> None: + def configure(self, yoeo_config) -> None: super().configure(config) - self._conf_thresh = config["yoeo_conf_threshold"] - self._nms_thresh = config["yoeo_nms_threshold"] + self._conf_thresh = yoeo_config.conf_threshold + self._nms_thresh = yoeo_config.nms_threshold self._group_config = self._update_group_config() @staticmethod @@ -446,7 +448,7 @@ class YOEOHandlerTVM(YOEOHandlerTemplate): def __init__( self, - config: dict, + yoeo_config, model_directory: str, det_class_names: list[str], det_robot_class_ids: list[int], @@ -485,8 +487,8 @@ def __init__( self._det_postprocessor: utils.IDetectionPostProcessor = utils.DefaultDetectionPostProcessor( image_preprocessor=self._img_preprocessor, output_img_size=self._input_layer_shape[2], - conf_thresh=config["yoeo_conf_threshold"], - nms_thresh=config["yoeo_nms_threshold"], + conf_thresh=yoeo_config.conf_threshold, + nms_thresh=yoeo_config.nms_threshold, robot_class_ids=self.get_robot_class_ids(), ) self._seg_postprocessor: utils.ISegmentationPostProcessor = utils.DefaultSegmentationPostProcessor( @@ -495,13 +497,13 @@ def __init__( logger.debug(f"Leaving {self.__class__.__name__} constructor") - def configure(self, config: dict) -> None: + def configure(self, yoeo_config) -> None: super().configure(config) self._det_postprocessor.configure( image_preprocessor=self._img_preprocessor, output_img_size=self._input_layer_shape[2], - conf_thresh=config["yoeo_conf_threshold"], - nms_thresh=config["yoeo_nms_threshold"], + conf_thresh=yoeo_config.conf_threshold, + nms_thresh=yoeo_config.nms_threshold, robot_class_ids=self.get_robot_class_ids(), ) diff --git a/bitbots_vision/config/vision_parameters.yaml b/bitbots_vision/config/vision_parameters.yaml new file mode 100644 index 000000000..e6106bc7d --- /dev/null +++ b/bitbots_vision/config/vision_parameters.yaml @@ -0,0 +1,128 @@ +bitbots_vision: + # Component activation parameters + component_ball_detection_active: + type: bool + default_value: true + description: "Activate/Deactivate the ball detection component" + + component_debug_image_active: + type: bool + default_value: false + description: "Activate/Deactivate the debug image component" + + component_field_detection_active: + type: bool + default_value: true + description: "Activate/Deactivate the field detection component" + + component_goalpost_detection_active: + type: bool + default_value: false + description: "Activate/Deactivate the goalpost detection component" + + component_line_detection_active: + type: bool + default_value: true + description: "Activate/Deactivate the line detection component" + + component_robot_detection_active: + type: bool + default_value: true + description: "Activate/Deactivate the robot detection component" + + # ROS topic parameters + ROS_img_msg_topic: + type: string + default_value: "camera/image_proc" + description: "ROS topic of the image message" + read_only: true + + ROS_ball_msg_topic: + type: string + default_value: "balls_in_image" + description: "ROS topic of the ball message" + read_only: true + + ROS_goal_posts_msg_topic: + type: string + default_value: "goal_posts_in_image" + description: "ROS topic of the goal posts message" + read_only: true + + ROS_robot_msg_topic: + type: string + default_value: "robots_in_image" + description: "ROS topic of the robots message" + read_only: true + + ROS_line_msg_topic: + type: string + default_value: "line_in_image" + description: "ROS topic of the line message" + read_only: true + + ROS_line_mask_msg_topic: + type: string + default_value: "line_mask_in_image" + description: "ROS topic of the line mask message" + read_only: true + + ROS_debug_image_msg_topic: + type: string + default_value: "debug_image" + description: "ROS topic of the debug image message" + read_only: true + + ROS_field_mask_image_msg_topic: + type: string + default_value: "field_mask" + description: "ROS topic of the field mask debug image message" + read_only: true + + # YOEO model parameters with hierarchy + yoeo: + model_path: + type: string + default_value: "2022_10_07_flo_torso21_yoeox" + description: "Name of YOEO model" + + nms_threshold: + type: double + default_value: 0.4 + description: "YOEO Non-maximum suppression threshold" + validation: + bounds<>: [0.0, 1.0] + + conf_threshold: + type: double + default_value: 0.5 + description: "YOEO confidence threshold" + validation: + bounds<>: [0.0, 1.0] + + framework: + type: string + default_value: "tvm" + description: "The neural network framework that should be used" + validation: + one_of<>: [["pytorch", "openvino", "onnx", "tvm"]] + + caching: + type: bool + default_value: true + description: "Used to deactivate caching for profiling reasons" + + # Ball detection parameters + ball_candidate_rating_threshold: + type: double + default_value: 0.5 + description: "A threshold for the minimum candidate rating" + validation: + bounds<>: [0.0, 1.0] + + ball_candidate_max_count: + type: int + default_value: 1 + description: "The maximum number of balls that should be published" + validation: + bounds<>: [0, 50] \ No newline at end of file diff --git a/bitbots_vision/package.xml b/bitbots_vision/package.xml index 6f5b75f5e..794cb8a77 100644 --- a/bitbots_vision/package.xml +++ b/bitbots_vision/package.xml @@ -29,6 +29,7 @@ rosidl_default_runtime bitbots_utils game_controller_hl_interfaces + generate_parameter_library geometry_msgs image_transport python3-numpy diff --git a/bitbots_vision/setup.py b/bitbots_vision/setup.py index 96b053d04..d7288708e 100755 --- a/bitbots_vision/setup.py +++ b/bitbots_vision/setup.py @@ -1,8 +1,14 @@ import glob import os +from generate_parameter_library_py.setup_helper import generate_parameter_module from setuptools import find_packages, setup +generate_parameter_module( + "vision_parameters", # python module name for parameter library + "config/vision_parameters.yaml", # path to input yaml file +) + package_name = "bitbots_vision"