diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 61b7dec9c..509eed3c8 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -1252,6 +1252,10 @@ class EmformerEncoder(nn.Module): ): super().__init__() + assert int(math.log(chunk_length, 2)) == math.log( + chunk_length, 2 + ), "chunk_length should be a power of 2." + self.use_memory = memory_size > 0 self.init_memory_op = nn.AvgPool1d( kernel_size=chunk_length, @@ -1580,10 +1584,8 @@ class EmformerEncoder(nn.Module): chunk_mask = make_pad_mask(output_lengths).to(x.device) memory_mask = ( ( - torch.div( - num_processed_frames, - self.chunk_length, - rounding_mode="floor", + ( + num_processed_frames << int(math.log(self.chunk_length, 2)) ).view(x.size(1), 1) <= torch.arange(self.memory_size, device=x.device).expand( x.size(1), self.memory_size diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index 2277710a4..1d16682c6 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -1188,6 +1188,10 @@ class EmformerEncoder(nn.Module): ): super().__init__() + assert int(math.log(chunk_length, 2)) == math.log( + chunk_length, 2 + ), "chunk_length should be a power of 2." + self.use_memory = memory_size > 0 self.emformer_layers = nn.ModuleList( @@ -1488,10 +1492,8 @@ class EmformerEncoder(nn.Module): chunk_mask = make_pad_mask(output_lengths).to(x.device) memory_mask = ( ( - torch.div( - num_processed_frames, - self.chunk_length, - rounding_mode="floor", + ( + num_processed_frames << int(math.log(self.chunk_length, 2)) ).view(x.size(1), 1) <= torch.arange(self.memory_size, device=x.device).expand( x.size(1), self.memory_size