diff --git a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py index e5ceb3683..88c58f581 100755 --- a/egs/librispeech/ASR/zipformer/export-onnx-streaming.py +++ b/egs/librispeech/ASR/zipformer/export-onnx-streaming.py @@ -74,7 +74,6 @@ import onnx import torch import torch.nn as nn from decoder import Decoder -from onnxconverter_common import float16 from onnxruntime.quantization import QuantType, quantize_dynamic from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_model, get_params @@ -756,6 +755,7 @@ def main(): logging.info(f"Exported joiner to {joiner_filename}") if(params.fp16) : + from onnxconverter_common import float16 logging.info("Generate fp16 models") encoder = onnx.load(encoder_filename) diff --git a/egs/librispeech/ASR/zipformer/zipformer.py b/egs/librispeech/ASR/zipformer/zipformer.py index 69059287b..2a0ae0129 100644 --- a/egs/librispeech/ASR/zipformer/zipformer.py +++ b/egs/librispeech/ASR/zipformer/zipformer.py @@ -191,6 +191,7 @@ class Zipformer2(EncoderInterface): dim=encoder_dim[i], downsample=downsampling_factor[i], dropout=dropout, + causal=causal, ) encoders.append(encoder) @@ -198,7 +199,10 @@ class Zipformer2(EncoderInterface): self.encoders = nn.ModuleList(encoders) self.downsample_output = SimpleDownsample( - max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout + max(encoder_dim), + downsample=output_downsampling_factor, + dropout=dropout, + causal=causal, ) def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: @@ -1217,11 +1221,16 @@ class DownsampledZipformer2Encoder(nn.Module): """ def __init__( - self, encoder: nn.Module, dim: int, downsample: int, dropout: FloatLike + self, + encoder: nn.Module, + dim: int, + downsample: int, + dropout: FloatLike, + causal: bool, ): super(DownsampledZipformer2Encoder, self).__init__() self.downsample_factor = downsample - self.downsample = SimpleDownsample(dim, downsample, dropout) + self.downsample = SimpleDownsample(dim, downsample, dropout, causal) self.num_layers = encoder.num_layers self.encoder = encoder self.upsample = SimpleUpsample(dim, downsample) @@ -1310,9 +1319,12 @@ class SimpleDownsample(torch.nn.Module): Does downsampling with attention, by weighted sum, and a projection.. """ - def __init__(self, channels: int, downsample: int, dropout: FloatLike): + def __init__( + self, channels: int, downsample: int, dropout: FloatLike, causal: bool + ): super(SimpleDownsample, self).__init__() + self.causal = causal self.bias = nn.Parameter(torch.zeros(downsample)) self.name = None # will be set from training code @@ -1333,9 +1345,18 @@ class SimpleDownsample(torch.nn.Module): # Pad to an exact multiple of self.downsample # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds + + if self.causal and torch.jit.is_tracing(): + assert ( + pad == 0 + ), f"pad should be zero for exporting streaming models. Given {pad}" + + # If we are exporting a streaming model, then we skip the if statement + if not self.causal or not torch.jit.is_tracing(): + src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src = torch.cat((src, src_extra), dim=0) + + assert src.shape[0] == d_seq_len * ds, (src.shape, d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) @@ -1609,7 +1630,11 @@ class RelPositionMultiheadAttentionWeights(nn.Module): k = x[..., query_dim : 2 * query_dim] # p is the position-encoding query p = x[..., 2 * query_dim :] - assert p.shape[-1] == num_heads * pos_head_dim, (p.shape[-1], num_heads, pos_head_dim) + assert p.shape[-1] == num_heads * pos_head_dim, ( + p.shape[-1], + num_heads, + pos_head_dim, + ) q = self.copy_query(q) # for diagnostics only, does nothing. k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass.