From 435a01dbdf3c0cefea3d56dbc65222f327151aab Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Wed, 25 May 2022 17:09:49 +0800 Subject: [PATCH] add doc about memory --- .../ASR/conv_emformer_transducer_stateless/emformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index bfb3875ca..57e2c6ee2 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -730,15 +730,14 @@ class EmformerAttention(nn.Module): matrix_bd_utterance = self._rel_shift(matrix_bd_utterance) # (B, nhead, U, U + right_context_length) for training and validation mode; # noqa # (B, nhead, U, tot_left_length + U + right_context_length) for inference mode. # noqa - matrix_bd_utterance = matrix_bd_utterance.contiguous().view( - B * self.nhead, U, -1 - ) + matrix_bd_utterance = matrix_bd_utterance.view(B * self.nhead, U, -1) matrix_bd = torch.zeros_like(matrix_ac) if left_context_key is not None and left_context_val is not None: # inference mode # key: [memory, right context, left context, utterance] # for memory if M > 0: + # take average over the chunk frames for the memory vector matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d( matrix_bd_utterance[:, :, :tot_left_length].unsqueeze(1), kernel_size=(1, self.chunk_length), @@ -758,6 +757,7 @@ class EmformerAttention(nn.Module): # key: [memory, right context, utterance] # for memory if M > 0: + # take average over the chunk frames for the memory vector matrix_bd[:, R : R + U, :M] = torch.nn.functional.avg_pool2d( matrix_bd_utterance[:, :, :U].unsqueeze(1), kernel_size=(1, self.chunk_length),