diff --git a/c/src/ml-api-inference-single.c b/c/src/ml-api-inference-single.c index 03df6326..5c52350d 100644 --- a/c/src/ml-api-inference-single.c +++ b/c/src/ml-api-inference-single.c @@ -1015,6 +1015,33 @@ _ml_nnfw_to_str_prop (const ml_nnfw_hw_e hw) return str_prop; } +/** + * @brief Specifies whether input/output tensor information is required for each neural network framework + * @param nnfw_type The neural network framework type (e.g., TensorFlow, PyTorch) + * @param requires_in_info TRUE if the framework requires input tensors information + * @param requires_out_info TRUE if the framework requires output tensors information + */ +typedef struct +{ + ml_nnfw_type_e nnfw_type; + gboolean requires_in_info; + gboolean requires_out_info; +} nnfw_tensors_info; + +/** + * @brief A table defining frameworks that require tensor info setting. + * @details This table-driven approach simplifies adding or modifying framework-specific logic. + */ +static const nnfw_tensors_info nnfw_tensors_info_table[] = { + {ML_NNFW_TYPE_TENSORFLOW, TRUE, TRUE}, + {ML_NNFW_TYPE_SNAP, TRUE, TRUE}, + {ML_NNFW_TYPE_PYTORCH, TRUE, TRUE}, + {ML_NNFW_TYPE_TRIX_ENGINE, TRUE, TRUE}, + {ML_NNFW_TYPE_NCNN, TRUE, TRUE}, + {ML_NNFW_TYPE_ARMNN, FALSE, FALSE} +}; + + /** * @brief Opens an ML model with the custom options and returns the instance as a handle. */ @@ -1031,6 +1058,9 @@ ml_single_open_custom (ml_single_h * single, ml_single_preset * info) g_autofree gchar *converted_models = NULL; gchar **list_models; guint i, num_models; + gboolean requires_in_info = FALSE; + gboolean requires_out_info = FALSE; + gboolean found_in_table = FALSE; char *hw_name; check_feature_state (ML_FEATURE_INFERENCE); @@ -1097,59 +1127,50 @@ ml_single_open_custom (ml_single_h * single, ml_single_preset * info) * 3. Construct a direct connection with the nnfw. * Note that we do not construct a pipeline since 2019.12. */ - if (nnfw == ML_NNFW_TYPE_TENSORFLOW || nnfw == ML_NNFW_TYPE_SNAP || - nnfw == ML_NNFW_TYPE_PYTORCH || nnfw == ML_NNFW_TYPE_TRIX_ENGINE || - nnfw == ML_NNFW_TYPE_NCNN) { - /* set input and output tensors information */ - if (in_tensors_info && out_tensors_info) { - status = - ml_single_set_inout_tensors_info (filter_obj, TRUE, in_tensors_info); - if (status != ML_ERROR_NONE) { - _ml_error_report_continue - ("Input tensors info is given; however, failed to set input tensors info. Error code: %d", - status); - goto error; - } + /* Look up framework properties from the table */ + for (i = 0; i < G_N_ELEMENTS (nnfw_tensors_info_table); i++) { + if (nnfw_tensors_info_table[i].nnfw_type == nnfw) { + requires_in_info = nnfw_tensors_info_table[i].requires_in_info; + requires_out_info = nnfw_tensors_info_table[i].requires_out_info; + found_in_table = TRUE; + break; + } + } - status = - ml_single_set_inout_tensors_info (filter_obj, FALSE, - out_tensors_info); - if (status != ML_ERROR_NONE) { - _ml_error_report_continue - ("Output tensors info is given; however, failed to set output tensors info. Error code: %d", - status); - goto error; - } - } else { - _ml_error_report - ("To run the given nnfw, '%s', with a neural network model, both input and output information should be provided.", - fw_name); + if (found_in_table) { + if (requires_in_info && !in_tensors_info) { + _ml_error_report ("Framework '%s' requires input information.", fw_name); + status = ML_ERROR_INVALID_PARAMETER; + goto error; + } + + if (requires_out_info && !out_tensors_info) { + _ml_error_report ("Framework '%s' requires output information.", fw_name); status = ML_ERROR_INVALID_PARAMETER; goto error; } - } else if (nnfw == ML_NNFW_TYPE_ARMNN) { - /* set input and output tensors information, if available */ + if (in_tensors_info) { - status = - ml_single_set_inout_tensors_info (filter_obj, TRUE, in_tensors_info); + status = ml_single_set_inout_tensors_info (filter_obj, TRUE, in_tensors_info); if (status != ML_ERROR_NONE) { _ml_error_report_continue - ("With nnfw '%s', input tensors info is optional. However, the user has provided an invalid input tensors info. Error code: %d", - fw_name, status); + ("Input tensors info is given; however, failed to set input tensors info. Error code: %d", + status); goto error; } } + if (out_tensors_info) { - status = - ml_single_set_inout_tensors_info (filter_obj, FALSE, - out_tensors_info); + status = ml_single_set_inout_tensors_info (filter_obj, FALSE, out_tensors_info); if (status != ML_ERROR_NONE) { - _ml_error_report_continue - ("With nnfw '%s', output tensors info is optional. However, the user has provided an invalid output tensors info. Error code: %d", - fw_name, status); + _ml_error_report_continue + ("Output tensors info is given; however, failed to set output tensors info. Error code: %d", + status); goto error; } } + } else { + _ml_logi("Framework '%s' not found in tensors info table, using default behavior", fw_name); } /* set accelerator, framework, model files and custom option */