From 20a23d13ce9354fd4c7bfc9b9a9a1fb038ca93fc Mon Sep 17 00:00:00 2001 From: yaozengwei Date: Sun, 15 May 2022 12:11:05 +0800 Subject: [PATCH] minor fix --- .../conv_emformer_transducer_stateless/emformer.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index f95072970..6f986f163 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -1637,7 +1637,7 @@ class EmformerEncoder(nn.Module): return attention_mask def forward( - self, x: torch.Tensor, lengths: torch.Tensor + self, x: torch.Tensor, lengths: torch.Tensor, warmup: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for training and validation mode. @@ -1684,6 +1684,7 @@ class EmformerEncoder(nn.Module): memory, attention_mask, pos_emb, + warmup=warmup, ) return output, output_lengths @@ -1831,7 +1832,7 @@ class Emformer(EncoderInterface): ) def forward( - self, x: torch.Tensor, x_lens: torch.Tensor + self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 ) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for training and non-streaming inference. @@ -1847,6 +1848,10 @@ class Emformer(EncoderInterface): With shape (B,) and i-th element representing number of valid utterance frames for i-th batch element in x, containing the right_context at the end. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. Returns: (Tensor, Tensor): @@ -1864,7 +1869,9 @@ class Emformer(EncoderInterface): x_lens = ((x_lens - 1) // 2 - 1) // 2 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder(x, x_lens) # (T, N, C) + output, output_lengths = self.encoder( + x, x_lens, warmup=warmup + ) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C)