Minor fixes

This commit is contained in:
pkufool 2022-06-10 19:06:06 +08:00
parent aebe9c22dd
commit 85ddfd958c
4 changed files with 25 additions and 22 deletions

View File

@ -38,7 +38,8 @@ class Joiner(nn.Module):
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape
logit = encoder_out + decoder_out

View File

@ -250,12 +250,12 @@ class Conformer(EncoderInterface):
x: torch.Tensor,
x_lens: torch.Tensor,
states: List[Tensor],
warmup: float = 1.0,
chunk_size: int = 16,
processed_lens: Optional[Tensor] = None,
left_context: int = 64,
right_context: int = 4,
chunk_size: int = 16,
simulate_streaming: bool = False,
processed_lens: Optional[Tensor] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""
Args:
@ -271,13 +271,8 @@ class Conformer(EncoderInterface):
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
chunk_size:
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
processed_lens:
How many frames (after subsampling) have been processed for each sequence.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
@ -286,13 +281,18 @@ class Conformer(EncoderInterface):
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
chunk_size:
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
simulate_streaming:
If setting True, it will use a masking strategy to simulate streaming
fashion (i.e. every chunk data only see limited left context and
right context). The whole sequence is supposed to be send at a time
When using simulate_streaming.
processed_lens:
How many frames (after subsampling) have been processed for each sequence.
warmup:
A floating point value that gradually increases from 0 throughout
training; when it is >= 1.0 we are "fully warmed up". It is used
to turn modules on sequentially.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)

View File

@ -52,8 +52,10 @@ class Joiner(nn.Module):
Returns:
Return a tensor of shape (N, T, s_range, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
assert encoder_out.ndim == decoder_out.ndim
assert encoder_out.ndim in (2, 4)
assert encoder_out.shape == decoder_out.shape
if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(

View File

@ -246,11 +246,11 @@ class Conformer(Transformer):
x: torch.Tensor,
x_lens: torch.Tensor,
states: List[torch.Tensor],
chunk_size: int = 16,
processed_lens: Optional[Tensor] = None,
left_context: int = 64,
right_context: int = 0,
chunk_size: int = 16,
simulate_streaming: bool = False,
processed_lens: Optional[Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""
Args:
@ -266,9 +266,8 @@ class Conformer(Transformer):
the second element is the conv_cache which has a shape of
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
Note: states will be modified in this function.
chunk_size:
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
processed_lens:
How many frames (after subsampling) have been processed for each sequence.
left_context:
How many previous frames the attention can see in current chunk.
Note: It's not that each individual frame has `left_context` frames
@ -277,13 +276,14 @@ class Conformer(Transformer):
How many future frames the attention can see in current chunk.
Note: It's not that each individual frame has `right_context` frames
of right context, some have more.
chunk_size:
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
simulate_streaming:
If setting True, it will use a masking strategy to simulate streaming
fashion (i.e. every chunk data only see limited left context and
right context). The whole sequence is supposed to be send at a time
When using simulate_streaming.
processed_lens:
How many frames (after subsampling) have been processed for each sequence.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)