Skip to content

Converting HIFI-GAN decoder to TensorRT #4682

@Amarnath1906

Description

@Amarnath1906

I am trying to convert HIFI GAN decoder to tensorrt.
Here is the script I am trying to use,
`

import tensorrt as trt
import os
  
def build_engine(onnx_file_path):

  logger = trt.Logger(trt.Logger.VERBOSE)
  builder = trt.Builder(logger)

  network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
  network = builder.create_network(network_flags)
  parser = trt.OnnxParser(network, logger)

  if not os.path.exists(onnx_file_path):
      print(f"ONNX file {onnx_file_path} not found.")
      return None

  print(f"Loading ONNX file: {onnx_file_path}")
  with open(onnx_file_path, "rb") as model:
      if not parser.parse(model.read()):
          for i in range(parser.num_errors):
              print(parser.get_error(i))
          return None

  config = builder.create_builder_config()

  # L40S has large memory → allow bigger workspace
  config.set_memory_pool_limit(
      trt.MemoryPoolType.WORKSPACE,
      4 * 1024 * 1024 * 1024   # 4GB workspace (safe on 48GB L40S)
  )

  # FP16 (always good)
  if builder.platform_has_fast_fp16:
      config.set_flag(trt.BuilderFlag.FP16)
      print("FP16 mode enabled (L40S optimized).")

  
  # ───────────────────────────────────────────────
  # Dynamic Shapes (same as your model)
  # ───────────────────────────────────────────────
  profile = builder.create_optimization_profile()

  profile.set_shape("ASR",
      (1, 512, 28),
      (1, 512, 100),
      (1, 512, 1106)
  )

  profile.set_shape("F0_PRED",
      (1, 56),
      (1, 200),
      (1, 2212)
  )

  profile.set_shape("N_PRED",
      (1, 56),
      (1, 200),
      (1, 2212)
  )

  profile.set_shape("REF",
      (1, 128),
      (1, 128),
      (1, 128)
  )

  config.add_optimization_profile(profile)

  print("Building TensorRT engine on L40S...")

  serialized_engine = builder.build_serialized_network(network, config)

  if serialized_engine:
      with open("sample.engine", "wb") as f:
          f.write(serialized_engine)
      print("Success: Engine saved as sample.engine")
      return serialized_engine
  else:
      print("Error: Build failed.")
      return None

if name == "main":
ONNX_PATH = "decoder_v2.onnx"
build_engine(ONNX_PATH)

`

And I am getting the following error,
[01/16/2026-10:05:01] [TRT] [E] IBuilder::buildSerializedNetwork: Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: /Concat: axis 2 dimensions must be equal for concatenation on axis 1. Dimensions are seq_len and (+ (CEIL_DIV (+ seq_len -2) 2) 1). Condition '==' violated: 100 != 50.) Error: Build failed.

This is the Decoder function I have,

`

class Decoder(nn.Module):
      def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80, 
                  resblock_kernel_sizes = [3,7,11],
                  upsample_rates = [10,5,3,2],
                  upsample_initial_channel=512,
                  resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
                  upsample_kernel_sizes=[20,10,6,4]):
          super().__init__()
      
      self.decode = nn.ModuleList()
      
      self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
      
      self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
      self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
      self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
      self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))

      self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
      
      self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
      
      self.asr_res = nn.Sequential(
          weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
      )
      
      
      self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
  
  def match(self,t, ref):
      if t.shape[-1] != ref.shape[-1]:
          t = F.interpolate(t, size=ref.shape[-1], mode="nearest")
      return t


      
  def forward(self, asr, F0_curve, N, s):
      
      if self.training:
          downlist = [0, 3, 7]
          F0_down = downlist[random.randint(0, 2)]
          downlist = [0, 3, 7, 15]
          N_down = downlist[random.randint(0, 3)]
          if F0_down:
              F0_curve = nn.functional.conv1d(F0_curve.unsqueeze(1), torch.ones(1, 1, F0_down).to('cuda'), padding=F0_down//2).squeeze(1) / F0_down
          if N_down:
              N = nn.functional.conv1d(N.unsqueeze(1), torch.ones(1, 1, N_down).to('cuda'), padding=N_down//2).squeeze(1)  / N_down

      
      F0 = self.F0_conv(F0_curve.unsqueeze(1))
      N = self.N_conv(N.unsqueeze(1))
      
      x = torch.cat([asr, F0, N], axis=1)
      x = self.encode(x, s)
  
      
      asr_res = self.asr_res(asr)
       
      res = True
      for block in self.decode:
          if res:
              
              x = torch.cat([x, asr_res, F0, N], axis=1)
          x = block(x, s)
          if block.upsample_type != "none":
              res = False
              
      x = self.generator(x, s, F0_curve)
      return x

`

This is how i am converting the decoder to onnx,

`

def export_decoder_dynamic(model,repo_path):
bmodel = model
decoder = bmodel.decoder.eval().cuda()

    model_path = "decoder_v2.onnx"
  
  # os.makedirs(os.path.dirname(model_path), exist_ok=True)

  batch = 2
  seq_len = 40

  asr = torch.randn(batch, 512, seq_len, dtype=torch.float32).cuda()
  F0_pred = torch.randn(batch, seq_len*2, dtype=torch.float32).cuda()
  # F0 =  torch.randn(batch, 1, seq_len, dtype=torch.float32)
  N_pred = torch.randn(batch, seq_len*2, dtype=torch.float32).cuda()
  # N_pred = torch.randn(batch, 1, seq_len, dtype=torch.float32)
  ref = torch.randn(batch, 128, dtype=torch.float32).cuda()             # [B, S]

  print(asr.shape,F0_pred.shape,N_pred.shape,ref.shape)

  try:
      torch.onnx.export(
          decoder,
          (asr,F0_pred,N_pred,ref),
          model_path,
          input_names=["ASR","F0_PRED","N_PRED","REF"],
          output_names=["AUDIO_OUT"],
          dynamic_axes={
              "ASR": {0: "batch", 2: "seq_len"},       # input [B, C, T]
              "F0_PRED": {0: "batch", 1: "seq_len"},       # input [B, C, T]
              "N_PRED": {0: "batch", 1: "seq_len"},       # input [B, C, T]
              "REF": {0: "batch"},       # input [B, C, T]
              "AUDIO_OUT": {0: "batch", 2: "seq_len"},   # output [B, C, T]
          },
          opset_version=17,
          do_constant_folding=True,
          verbose=False

      )

      print("decoder exported:", model_path)
  except Exception as e:
      print(traceback.print_exc())
      print(f"Failed to export decoder: {e}")


  try:
      model = onnx.load(model_path)
      model_with_shapes = onnx.shape_inference.infer_shapes(model)
      onnx.save(model_with_shapes, model_path)

      # Create ONNX Runtime session
      session = ort.InferenceSession(model_path)

      # Check the input names
      input_names = [inp.name for inp in session.get_inputs()]
      print("ONNX input names:", input_names)

      # Check output names
      output_names = [out.name for out in session.get_outputs()]
      print("ONNX output names:", output_names)

      onnx_out = session.run(None,
          {"ASR": asr.cpu().numpy(), 
           "F0_PRED": F0_pred.cpu().numpy(),
           "N_PRED": N_pred.cpu().numpy(),
           "REF": ref.cpu().numpy()
           }
      )[0]
      
      print("ONNX output shape:", onnx_out.shape)
      
  except Exception as e:
      print(traceback.print_exc())
      print(f"Failed to export decoder: {e}")
      return

`

I am using nvcr.io/nvidia/tensorrt:25.10-py3 container to convert this and here are few more specs about the GPU,
`

ok-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.06 Driver Version: 555.42.06 CUDA Version: 12.5 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA L40S Off | 00000000:01:01.0 Off | 0 |
| N/A 39C P0 81W / 350W | 23923MiB / 46068MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+

`

CUDA Verison: 12.6.

Let me know where I am going wrong.

Thanks in advance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Module:ONNXIssues relating to ONNX usage and import

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions