diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index a30391d11..7d7def879 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -19,7 +19,6 @@ # 2) https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py # noqa import math -import warnings from typing import List, Optional, Tuple import torch @@ -1558,12 +1557,6 @@ class EmformerEncoder(nn.Module): self.cnn_module_kernel - 1, ), conv_caches[i].shape - # assert x.size(0) == self.chunk_length + self.right_context_length, ( - # "Per configured chunk_length and right_context_length, " - # f"expected size of {self.chunk_length + self.right_context_length} " - # f"for dimension 1 of x, but got {x.size(0)}." - # ) - right_context = x[-self.right_context_length :] utterance = x[: -self.right_context_length] output_lengths = torch.clamp(lengths - self.right_context_length, min=0) @@ -1576,7 +1569,9 @@ class EmformerEncoder(nn.Module): # calcualte padding mask to mask out initial zero caches chunk_mask = make_pad_mask(output_lengths).to(x.device) memory_mask = ( - (num_processed_frames // self.chunk_length).view(x.size(1), 1) + torch.div( + num_processed_frames, self.chunk_length, rounding_mode="floor" + ).view(x.size(1), 1) <= torch.arange(self.memory_size, device=x.device).expand( x.size(1), self.memory_size )