mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 13:34:20 +00:00
Minor fixes
This commit is contained in:
parent
aebe9c22dd
commit
85ddfd958c
@ -38,7 +38,8 @@ class Joiner(nn.Module):
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||
assert encoder_out.ndim == decoder_out.ndim
|
||||
assert encoder_out.ndim in (2, 4)
|
||||
assert encoder_out.shape == decoder_out.shape
|
||||
|
||||
logit = encoder_out + decoder_out
|
||||
|
@ -250,12 +250,12 @@ class Conformer(EncoderInterface):
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
states: List[Tensor],
|
||||
warmup: float = 1.0,
|
||||
chunk_size: int = 16,
|
||||
processed_lens: Optional[Tensor] = None,
|
||||
left_context: int = 64,
|
||||
right_context: int = 4,
|
||||
chunk_size: int = 16,
|
||||
simulate_streaming: bool = False,
|
||||
processed_lens: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
@ -271,13 +271,8 @@ class Conformer(EncoderInterface):
|
||||
the second element is the conv_cache which has a shape of
|
||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: states will be modified in this function.
|
||||
warmup:
|
||||
A floating point value that gradually increases from 0 throughout
|
||||
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||
to turn modules on sequentially.
|
||||
chunk_size:
|
||||
The chunk size for decoding, this will be used to simulate streaming
|
||||
decoding using masking.
|
||||
processed_lens:
|
||||
How many frames (after subsampling) have been processed for each sequence.
|
||||
left_context:
|
||||
How many previous frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `left_context` frames
|
||||
@ -286,13 +281,18 @@ class Conformer(EncoderInterface):
|
||||
How many future frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `right_context` frames
|
||||
of right context, some have more.
|
||||
chunk_size:
|
||||
The chunk size for decoding, this will be used to simulate streaming
|
||||
decoding using masking.
|
||||
simulate_streaming:
|
||||
If setting True, it will use a masking strategy to simulate streaming
|
||||
fashion (i.e. every chunk data only see limited left context and
|
||||
right context). The whole sequence is supposed to be send at a time
|
||||
When using simulate_streaming.
|
||||
processed_lens:
|
||||
How many frames (after subsampling) have been processed for each sequence.
|
||||
warmup:
|
||||
A floating point value that gradually increases from 0 throughout
|
||||
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||
to turn modules on sequentially.
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||
|
@ -52,8 +52,10 @@ class Joiner(nn.Module):
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||
assert encoder_out.shape[:-1] == decoder_out.shape[:-1]
|
||||
|
||||
assert encoder_out.ndim == decoder_out.ndim
|
||||
assert encoder_out.ndim in (2, 4)
|
||||
assert encoder_out.shape == decoder_out.shape
|
||||
|
||||
if project_input:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
||||
|
@ -246,11 +246,11 @@ class Conformer(Transformer):
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
states: List[torch.Tensor],
|
||||
chunk_size: int = 16,
|
||||
processed_lens: Optional[Tensor] = None,
|
||||
left_context: int = 64,
|
||||
right_context: int = 0,
|
||||
chunk_size: int = 16,
|
||||
simulate_streaming: bool = False,
|
||||
processed_lens: Optional[Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
@ -266,9 +266,8 @@ class Conformer(Transformer):
|
||||
the second element is the conv_cache which has a shape of
|
||||
(encoder_layers, cnn_module_kernel-1, batch, conv_dim).
|
||||
Note: states will be modified in this function.
|
||||
chunk_size:
|
||||
The chunk size for decoding, this will be used to simulate streaming
|
||||
decoding using masking.
|
||||
processed_lens:
|
||||
How many frames (after subsampling) have been processed for each sequence.
|
||||
left_context:
|
||||
How many previous frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `left_context` frames
|
||||
@ -277,13 +276,14 @@ class Conformer(Transformer):
|
||||
How many future frames the attention can see in current chunk.
|
||||
Note: It's not that each individual frame has `right_context` frames
|
||||
of right context, some have more.
|
||||
chunk_size:
|
||||
The chunk size for decoding, this will be used to simulate streaming
|
||||
decoding using masking.
|
||||
simulate_streaming:
|
||||
If setting True, it will use a masking strategy to simulate streaming
|
||||
fashion (i.e. every chunk data only see limited left context and
|
||||
right context). The whole sequence is supposed to be send at a time
|
||||
When using simulate_streaming.
|
||||
processed_lens:
|
||||
How many frames (after subsampling) have been processed for each sequence.
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||
|
Loading…
x
Reference in New Issue
Block a user