diff --git a/tests/functional/config.py b/tests/functional/config.py index 9083668f7c..f2429d6f83 100644 --- a/tests/functional/config.py +++ b/tests/functional/config.py @@ -186,6 +186,9 @@ def get_uses_mapping(): infer_timeouts = { TargetDevice.CPU: default_infer_timeout, TargetDevice.GPU: default_gpu_infer_timeout, + TargetDevice.GPU_0: default_gpu_infer_timeout, + TargetDevice.GPU_1: default_gpu_infer_timeout, + TargetDevice.GPU_2: default_gpu_infer_timeout, TargetDevice.NPU: default_npu_infer_timeout, TargetDevice.AUTO: default_gpu_infer_timeout, TargetDevice.HETERO: default_gpu_infer_timeout, diff --git a/tests/functional/constants/target_device.py b/tests/functional/constants/target_device.py index 8d06e0f0df..55d83c11b7 100644 --- a/tests/functional/constants/target_device.py +++ b/tests/functional/constants/target_device.py @@ -20,6 +20,9 @@ class TargetDevice: CPU = "CPU" GPU = "GPU" + GPU_0 = "GPU:0" + GPU_1 = "GPU:1" + GPU_2 = "GPU:2" NPU = "NPU" AUTO = "AUTO:GPU,CPU" HETERO = "HETERO:GPU,CPU" diff --git a/tests/functional/utils/helpers.py b/tests/functional/utils/helpers.py index 5e9ffbae1d..81b1d93833 100644 --- a/tests/functional/utils/helpers.py +++ b/tests/functional/utils/helpers.py @@ -95,9 +95,27 @@ def get_multi_target_devices(target_devices_list, separator): return result +def _is_device_with_index(device_str): + """ Check if device string is a device with numeric index, e.g. GPU:0, GPU:1. """ + if ":" in device_str: + _, suffix = device_str.split(":", 1) + return suffix.isdigit() + return False + + def validate_supported_values(detected_list, supported_list): supported_list += ALL_AVAILABLE_OPTIONS # 'starred expression' will be evaluated during pytest_configure - check = all(_elem in supported_list for _elem in detected_list) + + def _is_supported(device): + if device in supported_list: + return True + # Accept indexed devices like GPU:0, GPU:1 if base device (GPU) is supported + if _is_device_with_index(device): + base_device = device.split(":", 1)[0] + return base_device in supported_list + return False + + check = all(_is_supported(_elem) for _elem in detected_list) assert check, f"Not supported target devices in {detected_list}" return detected_list @@ -106,7 +124,12 @@ def get_target_devices(): """ Convert comma separated string of devices into list """ target_devices_list = get_list("TT_TARGET_DEVICE", fallback=[TargetDevice.CPU]) separator_multi = ":" - if any(separator_multi in _target_device for _target_device in target_devices_list): + # Only treat as multi-target if ':' is followed by a device name, not a numeric index (GPU:0, GPU:1) + has_multi_target = any( + separator_multi in _td and not _is_device_with_index(_td) + for _td in target_devices_list + ) + if has_multi_target: target_devices_list = get_multi_target_devices(target_devices_list, separator_multi) ov_target_devices = [value for key, value in vars(TargetDevice).items() if not key.startswith("__")] target_devices_list = validate_supported_values(detected_list=target_devices_list, supported_list=ov_target_devices)