@@ -9,37 +9,42 @@ namespace converters {
99namespace impl {
1010namespace {
1111
12- bool relu (ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
13- auto in = args[0 ].ITensor ();
12+ #define convert (act, trt_type ) \
13+ bool act (ConversionCtx* ctx, const torch::jit::Node* n, args& args) { \
14+ auto in = args[0 ].ITensor (); \
15+ \
16+ auto new_layer = \
17+ ctx->net ->addActivation (*in, nvinfer1::ActivationType::trt_type); \
18+ TRTORCH_CHECK (new_layer, \
19+ " Unable to create " #act " layer from node: " << *n); \
20+ \
21+ new_layer->setName (util::node_info (n).c_str ()); \
22+ auto out_value = n->outputs ()[0 ]; \
23+ auto out_tensor = new_layer->getOutput (0 ); \
24+ out_tensor->setName (out_value->debugName ().c_str ()); \
25+ ctx->value_tensor_map [out_value] = out_tensor; \
26+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ()); \
27+ \
28+ return true ; \
29+ } \
30+ \
31+ auto act##_registrations TRTORCH_UNUSED = \
32+ RegisterNodeConversionPatterns () \
33+ .pattern({" aten::" #act " (Tensor input) -> (Tensor)" , \
34+ [](ConversionCtx *ctx, const torch::jit::Node *n, \
35+ args &args) -> bool { return act (ctx, n, args); }}) \
36+ .pattern({" aten::" #act " _(Tensor(a!) self) -> (Tensor(a!))" , \
37+ [](ConversionCtx *ctx, const torch::jit::Node *n, \
38+ args &args) -> bool { return act (ctx, n, args); }});
1439
15- auto new_layer = ctx->net ->addActivation (*in, nvinfer1::ActivationType::kRELU );
16- TRTORCH_CHECK (new_layer, " Unable to create ReLU layer from node: " << *n);
40+ convert (relu, kRELU );
41+ convert (sigmoid, kSIGMOID );
42+ convert (tanh, kTANH );
1743
18- new_layer->setName (util::node_info (n).c_str ());
19- auto out_value = n->outputs ()[0 ];
20- auto out_tensor = new_layer->getOutput (0 );
21- out_tensor->setName (out_value->debugName ().c_str ());
22- ctx->value_tensor_map [out_value] = out_tensor;
23- LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
24-
25- return true ;
26- }
27-
28- auto relu_registrations = RegisterNodeConversionPatterns()
29- .pattern({
30- " aten::relu(Tensor input) -> (Tensor)" ,
31- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
32- return relu (ctx, n, args);
33- }
34- }).pattern({
35- " aten::relu_(Tensor(a!) self) -> (Tensor(a!))" ,
36- [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
37- return relu (ctx, n, args);
38- }
39- });
44+ #undef convert
4045} // namespace
4146} // namespace impl
4247} // namespace converters
4348} // namespace conversion
4449} // namespace core
45- } // trtorch
50+ } // namespace trtorch
0 commit comments