11#include < string>
22#include " gtest/gtest.h"
3- #include " torch/csrc/jit/irparser .h"
3+ #include " torch/script .h"
44#include " tests/util/util.h"
5- #include " cpp /trtorch.h"
5+ #include " trtorch /trtorch.h"
66
77TEST (ModuleTests, CanRunMultipleEngines) {
88 torch::jit::script::Module mod1;
@@ -16,7 +16,7 @@ TEST(ModuleTests, CanRunMultipleEngines) {
1616 return ;
1717 }
1818
19- const std::vector<int64_t > input_shape = {1 ,3 ,224 ,224 };
19+ const std::vector<std::vector< int64_t >> input_shapes = {{ 1 ,3 ,224 ,224 } };
2020
2121 std::vector<torch::jit::IValue> jit1_inputs_ivalues;
2222 std::vector<torch::jit::IValue> trt1_inputs_ivalues;
@@ -38,18 +38,18 @@ TEST(ModuleTests, CanRunMultipleEngines) {
3838 std::vector<at::Tensor> jit1_results;
3939 jit1_results.push_back (jit1_results_ivalues.toTensor ());
4040
41- torch::jit::IValue jit2_results_ivalues = trtorch::tests::util::RunModuleForward (mod2, jit2_inputs_ivalues);
41+ torch::jit::IValue jit2_results_ivalues = trtorch::tests::util::RunModuleForward (mod2, jit2_inputs_ivalues);
4242 std::vector<at::Tensor> jit2_results;
4343 jit2_results.push_back (jit2_results_ivalues.toTensor ());
4444
4545
4646 auto trt_mod1 = trtorch::CompileGraph (mod1, input_shapes);
47- torch::jit::IValue trt1_results_ivalues = trtorch::tests::util::RunModuleForward (trt1_mod , trt1_inputs_ivalues);
47+ torch::jit::IValue trt1_results_ivalues = trtorch::tests::util::RunModuleForward (trt_mod1 , trt1_inputs_ivalues);
4848 std::vector<at::Tensor> trt1_results;
4949 trt1_results.push_back (trt1_results_ivalues.toTensor ());
5050
5151 auto trt_mod2 = trtorch::CompileGraph (mod2, input_shapes);
52- torch::jit::IValue trt2_results_ivalues = trtorch::tests::util::RunModuleForward (trt2_mod , trt2_inputs_ivalues);
52+ torch::jit::IValue trt2_results_ivalues = trtorch::tests::util::RunModuleForward (trt_mod2 , trt2_inputs_ivalues);
5353 std::vector<at::Tensor> trt2_results;
5454 trt2_results.push_back (trt2_results_ivalues.toTensor ());
5555
0 commit comments