Minor fixes

This commit is contained in:
pkufool 2022-06-10 19:06:06 +08:00
parent aebe9c22dd
commit 85ddfd958c
4 changed files with 25 additions and 22 deletions

View File

@ -38,7 +38,8 @@ class Joiner(nn.Module):
Returns: Returns:
Return a tensor of shape (N, T, s_range, C). Return a tensor of shape (N, T, s_range, C).
""" """
assert encoder_out.ndim == decoder_out.ndim == 4 assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape assert encoder_out.shape == decoder_out.shape
logit = encoder_out + decoder_out logit = encoder_out + decoder_out

View File

@ -250,12 +250,12 @@ class Conformer(EncoderInterface):
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
states: List[Tensor], states: List[Tensor],
warmup: float = 1.0, processed_lens: Optional[Tensor] = None,
chunk_size: int = 16,
left_context: int = 64, left_context: int = 64,
right_context: int = 4, right_context: int = 4,
chunk_size: int = 16,
simulate_streaming: bool = False, simulate_streaming: bool = False,
processed_lens: Optional[Tensor] = None, warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
""" """
Args: Args:
@ -271,13 +271,8 @@ class Conformer(EncoderInterface):
the second element is the conv_cache which has a shape of the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim). (encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function. Note: states will be modified in this function.
warmup: processed_lens:
A floating point value that gradually increases from 0 throughout How many frames (after subsampling) have been processed for each sequence.
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
chunk_size:
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
left_context: left_context:
How many previous frames the attention can see in current chunk. How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames Note: It's not that each individual frame has `left_context` frames
@ -286,13 +281,18 @@ class Conformer(EncoderInterface):
How many future frames the attention can see in current chunk. How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames Note: It's not that each individual frame has `right_context` frames
of right context, some have more. of right context, some have more.
chunk_size:
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
simulate_streaming: simulate_streaming:
If setting True, it will use a masking strategy to simulate streaming If setting True, it will use a masking strategy to simulate streaming
fashion (i.e. every chunk data only see limited left context and fashion (i.e. every chunk data only see limited left context and
right context). The whole sequence is supposed to be send at a time right context). The whole sequence is supposed to be send at a time
When using simulate_streaming. When using simulate_streaming.
processed_lens: warmup:
How many frames (after subsampling) have been processed for each sequence. A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns: Returns:
Return a tuple containing 2 tensors: Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim) - logits, its shape is (batch_size, output_seq_len, output_dim)

View File

@ -52,8 +52,10 @@ class Joiner(nn.Module):
Returns: Returns:
Return a tensor of shape (N, T, s_range, C). Return a tensor of shape (N, T, s_range, C).
""" """
assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.shape[:-1] == decoder_out.shape[:-1] assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape
if project_input: if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj( logit = self.encoder_proj(encoder_out) + self.decoder_proj(

View File

@ -246,11 +246,11 @@ class Conformer(Transformer):
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
states: List[torch.Tensor], states: List[torch.Tensor],
chunk_size: int = 16, processed_lens: Optional[Tensor] = None,
left_context: int = 64, left_context: int = 64,
right_context: int = 0, right_context: int = 0,
chunk_size: int = 16,
simulate_streaming: bool = False, simulate_streaming: bool = False,
processed_lens: Optional[Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
""" """
Args: Args:
@ -266,9 +266,8 @@ class Conformer(Transformer):
the second element is the conv_cache which has a shape of the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim). (encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function. Note: states will be modified in this function.
chunk_size: processed_lens:
The chunk size for decoding, this will be used to simulate streaming How many frames (after subsampling) have been processed for each sequence.
decoding using masking.
left_context: left_context:
How many previous frames the attention can see in current chunk. How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames Note: It's not that each individual frame has `left_context` frames
@ -277,13 +276,14 @@ class Conformer(Transformer):
How many future frames the attention can see in current chunk. How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames Note: It's not that each individual frame has `right_context` frames
of right context, some have more. of right context, some have more.
chunk_size:
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
simulate_streaming: simulate_streaming:
If setting True, it will use a masking strategy to simulate streaming If setting True, it will use a masking strategy to simulate streaming
fashion (i.e. every chunk data only see limited left context and fashion (i.e. every chunk data only see limited left context and
right context). The whole sequence is supposed to be send at a time right context). The whole sequence is supposed to be send at a time
When using simulate_streaming. When using simulate_streaming.
processed_lens:
How many frames (after subsampling) have been processed for each sequence.
Returns: Returns:
Return a tuple containing 2 tensors: Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim) - logits, its shape is (batch_size, output_seq_len, output_dim)