From 10662c5c3897800c9de89545a6521bbc7b61f8fa Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Mon, 4 Jul 2022 17:05:47 +0800 Subject: [PATCH] fix bug about memory mask when memory_size==0 --- .flake8 | 3 +-- .../emformer.py | 20 +++++++++----- .../emformer.py | 26 +++++++++++++------ 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/.flake8 b/.flake8 index 9dd8d6207..b2eb2e943 100644 --- a/.flake8 +++ b/.flake8 @@ -9,8 +9,7 @@ per-file-ignores = egs/*/ASR/pruned_transducer_stateless*/*.py: E501, egs/*/ASR/*/optim.py: E501, egs/*/ASR/*/scaling.py: E501, - egs/librispeech/ASR/conv_emformer_transducer_stateless/*.py: E501, E203 - egs/librispeech/ASR/conv_emformer_transducer_stateless2/*.py: E501, E203 + egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203 # invalid escape sequence (cause by tex formular), W605 icefall/utils.py: E501, W605 diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 753e5c473..8d1a56736 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -1579,13 +1579,19 @@ class EmformerEncoder(nn.Module): # calcualte padding mask to mask out initial zero caches chunk_mask = make_pad_mask(output_lengths).to(x.device) memory_mask = ( - torch.div( - num_processed_frames, self.chunk_length, rounding_mode="floor" - ).view(x.size(1), 1) - <= torch.arange(self.memory_size, device=x.device).expand( - x.size(1), self.memory_size - ) - ).flip(1) + ( + torch.div( + num_processed_frames, + self.chunk_length, + rounding_mode="floor", + ).view(x.size(1), 1) + <= torch.arange(self.memory_size, device=x.device).expand( + x.size(1), self.memory_size + ) + ).flip(1) + if self.use_memory + else torch.empty(0).to(dtype=torch.bool, device=x.device) + ) left_context_mask = ( num_processed_frames.view(x.size(1), 1) <= torch.arange(self.left_context_length, device=x.device).expand( diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index e3a598b0e..015ce9b9e 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -1388,7 +1388,11 @@ class EmformerEncoder(nn.Module): output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) - M = right_context.size(0) // self.right_context_length - 1 + M = ( + right_context.size(0) // self.right_context_length - 1 + if self.use_memory + else 0 + ) padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths) output = utterance @@ -1480,13 +1484,19 @@ class EmformerEncoder(nn.Module): # calcualte padding mask to mask out initial zero caches chunk_mask = make_pad_mask(output_lengths).to(x.device) memory_mask = ( - torch.div( - num_processed_frames, self.chunk_length, rounding_mode="floor" - ).view(x.size(1), 1) - <= torch.arange(self.memory_size, device=x.device).expand( - x.size(1), self.memory_size - ) - ).flip(1) + ( + torch.div( + num_processed_frames, + self.chunk_length, + rounding_mode="floor", + ).view(x.size(1), 1) + <= torch.arange(self.memory_size, device=x.device).expand( + x.size(1), self.memory_size + ) + ).flip(1) + if self.use_memory + else torch.empty(0).to(dtype=torch.bool, device=x.device) + ) left_context_mask = ( num_processed_frames.view(x.size(1), 1) <= torch.arange(self.left_context_length, device=x.device).expand(