From 85ddfd958c50f183c27c392453c712684216e7ad Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 10 Jun 2022 19:06:06 +0800 Subject: [PATCH] Minor fixes --- .../ASR/pruned_transducer_stateless/joiner.py | 3 ++- .../pruned_transducer_stateless2/conformer.py | 24 +++++++++---------- .../pruned_transducer_stateless2/joiner.py | 6 +++-- .../ASR/transducer_stateless/conformer.py | 14 +++++------ 4 files changed, 25 insertions(+), 22 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py index 7c5a93a86..d1af29e6e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/joiner.py @@ -38,7 +38,8 @@ class Joiner(nn.Module): Returns: Return a tensor of shape (N, T, s_range, C). """ - assert encoder_out.ndim == decoder_out.ndim == 4 + assert encoder_out.ndim == decoder_out.ndim + assert encoder_out.ndim in (2, 4) assert encoder_out.shape == decoder_out.shape logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index e28b5034d..33ff03623 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -250,12 +250,12 @@ class Conformer(EncoderInterface): x: torch.Tensor, x_lens: torch.Tensor, states: List[Tensor], - warmup: float = 1.0, - chunk_size: int = 16, + processed_lens: Optional[Tensor] = None, left_context: int = 64, right_context: int = 4, + chunk_size: int = 16, simulate_streaming: bool = False, - processed_lens: Optional[Tensor] = None, + warmup: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: """ Args: @@ -271,13 +271,8 @@ class Conformer(EncoderInterface): the second element is the conv_cache which has a shape of (encoder_layers, cnn_module_kernel-1, batch, conv_dim). Note: states will be modified in this function. - warmup: - A floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. - chunk_size: - The chunk size for decoding, this will be used to simulate streaming - decoding using masking. + processed_lens: + How many frames (after subsampling) have been processed for each sequence. left_context: How many previous frames the attention can see in current chunk. Note: It's not that each individual frame has `left_context` frames @@ -286,13 +281,18 @@ class Conformer(EncoderInterface): How many future frames the attention can see in current chunk. Note: It's not that each individual frame has `right_context` frames of right context, some have more. + chunk_size: + The chunk size for decoding, this will be used to simulate streaming + decoding using masking. simulate_streaming: If setting True, it will use a masking strategy to simulate streaming fashion (i.e. every chunk data only see limited left context and right context). The whole sequence is supposed to be send at a time When using simulate_streaming. - processed_lens: - How many frames (after subsampling) have been processed for each sequence. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. Returns: Return a tuple containing 2 tensors: - logits, its shape is (batch_size, output_seq_len, output_dim) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 35f75ed2a..b916addf0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -52,8 +52,10 @@ class Joiner(nn.Module): Returns: Return a tensor of shape (N, T, s_range, C). """ - assert encoder_out.ndim == decoder_out.ndim == 4 - assert encoder_out.shape[:-1] == decoder_out.shape[:-1] + + assert encoder_out.ndim == decoder_out.ndim + assert encoder_out.ndim in (2, 4) + assert encoder_out.shape == decoder_out.shape if project_input: logit = self.encoder_proj(encoder_out) + self.decoder_proj( diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 61409a3a7..f460bffe5 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -246,11 +246,11 @@ class Conformer(Transformer): x: torch.Tensor, x_lens: torch.Tensor, states: List[torch.Tensor], - chunk_size: int = 16, + processed_lens: Optional[Tensor] = None, left_context: int = 64, right_context: int = 0, + chunk_size: int = 16, simulate_streaming: bool = False, - processed_lens: Optional[Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: """ Args: @@ -266,9 +266,8 @@ class Conformer(Transformer): the second element is the conv_cache which has a shape of (encoder_layers, cnn_module_kernel-1, batch, conv_dim). Note: states will be modified in this function. - chunk_size: - The chunk size for decoding, this will be used to simulate streaming - decoding using masking. + processed_lens: + How many frames (after subsampling) have been processed for each sequence. left_context: How many previous frames the attention can see in current chunk. Note: It's not that each individual frame has `left_context` frames @@ -277,13 +276,14 @@ class Conformer(Transformer): How many future frames the attention can see in current chunk. Note: It's not that each individual frame has `right_context` frames of right context, some have more. + chunk_size: + The chunk size for decoding, this will be used to simulate streaming + decoding using masking. simulate_streaming: If setting True, it will use a masking strategy to simulate streaming fashion (i.e. every chunk data only see limited left context and right context). The whole sequence is supposed to be send at a time When using simulate_streaming. - processed_lens: - How many frames (after subsampling) have been processed for each sequence. Returns: Return a tuple containing 2 tensors: - logits, its shape is (batch_size, output_seq_len, output_dim)