@@ -82,6 +82,57 @@ Status BuildNodeMap(const Graph& graph,
8282 return Status::OK ();
8383}
8484
85+ Status FindExtraConcatInput (const Graph& graph,
86+ const std::vector<std::string>& input_output_names,
87+ std::vector<const Node*>* filter_concat_node) {
88+ std::unordered_set<const Node*> candidate_node;
89+ std::unordered_set<Node*> concat_nodes;
90+ for (auto * node : graph.nodes ()) {
91+ if (node->type_string () == " ConcatV2" ) {
92+ concat_nodes.insert (node);
93+ }
94+ }
95+ std::unordered_set<std::string> in_out_names;
96+ for (auto & name : input_output_names) {
97+ in_out_names.insert (name);
98+ }
99+ for (const Node* c_nodes : concat_nodes) {
100+ std::vector<const Node*> in_placeholder;
101+ ReverseDFSFrom (
102+ graph, {c_nodes},
103+ [&in_placeholder, in_out_names](const Node* node) {
104+ if (in_out_names.find (node->name ()) != in_out_names.end ()) {
105+ in_placeholder.emplace_back (node);
106+ }
107+ },
108+ /* end*/ nullptr );
109+ if (in_placeholder.size () > 1 ) { // verify node in common sub-graph
110+ DataType t_types;
111+ TF_RETURN_IF_ERROR (GetNodeAttr (c_nodes->attrs (), " T" , &t_types));
112+ if (t_types == DT_FLOAT) {
113+ candidate_node.insert (c_nodes);
114+ }
115+ }
116+ }
117+
118+ for (const Node* cnode : candidate_node) {
119+ bool is_admit = true ;
120+ ReverseDFSFrom (graph, {cnode},
121+ [&filter_concat_node, &is_admit, candidate_node,
122+ cnode](const Node* node) {
123+ if ((candidate_node.find (node) != candidate_node.end ()) &&
124+ (cnode->name () != node->name ())) {
125+ is_admit = false ;
126+ }
127+ },
128+ /* end*/ nullptr );
129+ if (is_admit) {
130+ filter_concat_node->emplace_back (cnode);
131+ }
132+ }
133+ return Status::OK ();
134+ }
135+
85136EngineInfo::EngineType GetEngineType (
86137 const TRTOptimizationPass::ConversionParams& params) {
87138 return (params.is_dynamic_op || params.use_calibration )
@@ -773,6 +824,14 @@ Status ConvertGraph(const TRTOptimizationPass::ConversionParams& params,
773824 for (const auto & node : input_output_names) {
774825 segment_options.exclude_node_list .insert (node);
775826 }
827+ std::vector<const Node*> filter_concat_node;
828+ TF_RETURN_IF_ERROR (
829+ FindExtraConcatInput (graph, input_output_names, &filter_concat_node));
830+ for (const auto * node : filter_concat_node) {
831+ for (auto * inode : node->in_nodes ()) {
832+ segment_options.exclude_node_list .insert (inode->name ());
833+ }
834+ }
776835 segment_options.minimum_segment_size = params.minimum_segment_size ;
777836 segment_options.use_implicit_batch = params.use_implicit_batch ;
778837 if (segment_options.use_implicit_batch )
0 commit comments