diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index fa908ffcd..8ca7d5568 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -1253,8 +1253,9 @@ class EmformerEncoder(nn.Module): super().__init__() assert ( - chunk_length - 1 & chunk_length == 0 - ), "chunk_length should be a power of 2." + chunk_length - 1 + ) & chunk_length == 0, "chunk_length should be a power of 2." + self.shift = int(math.log(chunk_length, 2)) self.use_memory = memory_size > 0 self.init_memory_op = nn.AvgPool1d( @@ -1584,9 +1585,7 @@ class EmformerEncoder(nn.Module): chunk_mask = make_pad_mask(output_lengths).to(x.device) memory_mask = ( ( - ( - num_processed_frames >> int(math.log(self.chunk_length, 2)) - ).view(x.size(1), 1) + (num_processed_frames >> self.shift).view(x.size(1), 1) <= torch.arange(self.memory_size, device=x.device).expand( x.size(1), self.memory_size ) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index 18d2047bd..f16f5acc7 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -1189,8 +1189,9 @@ class EmformerEncoder(nn.Module): super().__init__() assert ( - chunk_length - 1 & chunk_length == 0 - ), "chunk_length should be a power of 2." + chunk_length - 1 + ) & chunk_length == 0, "chunk_length should be a power of 2." + self.shift = int(math.log(chunk_length, 2)) self.use_memory = memory_size > 0 @@ -1492,9 +1493,7 @@ class EmformerEncoder(nn.Module): chunk_mask = make_pad_mask(output_lengths).to(x.device) memory_mask = ( ( - ( - num_processed_frames >> int(math.log(self.chunk_length, 2)) - ).view(x.size(1), 1) + (num_processed_frames >> self.shift).view(x.size(1), 1) <= torch.arange(self.memory_size, device=x.device).expand( x.size(1), self.memory_size )