@@ -8,37 +8,42 @@ namespace converters {
88namespace impl {
99namespace {
1010
11- bool relu (ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
12- auto in = args[0 ].ITensor ();
11+ #define convert (act, trt_type ) \
12+ bool act (ConversionCtx* ctx, const torch::jit::Node* n, args& args) { \
13+ auto in = args[0 ].ITensor (); \
14+ \
15+ auto new_layer = \
16+ ctx->net ->addActivation (*in, nvinfer1::ActivationType::trt_type); \
17+ TRTORCH_CHECK (new_layer, \
18+ " Unable to create " #act " layer from node: " << *n); \
19+ \
20+ new_layer->setName (util::node_info (n).c_str ()); \
21+ auto out_value = n->outputs ()[0 ]; \
22+ auto out_tensor = new_layer->getOutput (0 ); \
23+ out_tensor->setName (out_value->debugName ().c_str ()); \
24+ ctx->value_tensor_map [out_value] = out_tensor; \
25+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ()); \
26+ \
27+ return true ; \
28+ } \
29+ \
30+ auto act##_registrations TRTORCH_UNUSED = \
31+ RegisterNodeConversionPatterns () \
32+ .pattern({" aten::" #act " (Tensor input) -> (Tensor)" , \
33+ [](ConversionCtx *ctx, const torch::jit::Node *n, \
34+ args &args) -> bool { return act (ctx, n, args); }}) \
35+ .pattern({" aten::" #act " _(Tensor(a!) self) -> (Tensor(a!))" , \
36+ [](ConversionCtx *ctx, const torch::jit::Node *n, \
37+ args &args) -> bool { return act (ctx, n, args); }});
1338
14- auto new_layer = ctx->net ->addActivation (*in, nvinfer1::ActivationType::kRELU );
15- TRTORCH_CHECK (new_layer, " Unable to create ReLU layer from node: " << *n);
39+ convert (relu, kRELU );
40+ convert (sigmoid, kSIGMOID );
41+ convert (tanh, kTANH );
1642
17- new_layer->setName (util::node_info (n).c_str ());
18- auto out_value = n->outputs ()[0 ];
19- auto out_tensor = new_layer->getOutput (0 );
20- out_tensor->setName (out_value->debugName ().c_str ());
21- ctx->value_tensor_map [out_value] = out_tensor;
22- LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
23-
24- return true ;
25- }
26-
27- auto relu_registrations = RegisterNodeConversionPatterns()
28- .pattern({
29- " aten::relu(Tensor input) -> (Tensor)" ,
30- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
31- return relu (ctx, n, args);
32- }
33- }).pattern({
34- " aten::relu_(Tensor(a!) self) -> (Tensor(a!))" ,
35- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
36- return relu (ctx, n, args);
37- }
38- });
43+ #undef convert
3944} // namespace
4045} // namespace impl
4146} // namespace converters
4247} // namespace conversion
4348} // namespace core
44- } // trtorch
49+ } // namespace trtorch
0 commit comments