mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
Add more documents
This commit is contained in:
parent
5bd2490b44
commit
118b09463d
@ -66,6 +66,18 @@ Usage:
|
||||
--beam 8 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(6) decode in streaming mode (take greedy search as an example)
|
||||
./pruned_transducer_stateless/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--simulate-streaming 1 \
|
||||
--causal-convolution 1 \
|
||||
--right-chunk-size 16 \
|
||||
--left-context 64 \
|
||||
--exp-dir ./pruned_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
"""
|
||||
|
||||
|
||||
@ -249,24 +261,66 @@ def get_parser():
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--streaming-mode",
|
||||
"--dynamic-chunk-training",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""
|
||||
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||
model, this requires to be True.
|
||||
Note: not needed for decoding, adding it here to construct transducer model,
|
||||
as we reuse the code in train.py.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--short-chunk-size",
|
||||
type=int,
|
||||
default=25,
|
||||
help="""Chunk length of dynamic training, the chunk size would be either
|
||||
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||
Note: not needed for decoding, adding it here to construct transducer model,
|
||||
as we reuse the code in train.py.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-left-chunks",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""How many left context can be seen in chunks when calculating attention.
|
||||
Note: not needed for decoding, adding it here to construct transducer model,
|
||||
as we reuse the code in train.py.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simulate-streaming",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||
test a streaming model.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--causal-convolution",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use causal convolution, this requires to be True when
|
||||
using dynamic_chunk_training.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--right-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="right context to attend during decoding",
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context to attend during decoding",
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
return parser
|
||||
@ -320,13 +374,13 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
if params.streaming_mode:
|
||||
if params.simulate_streaming:
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
chunk_size=params.right_chunk_size,
|
||||
left_context=params.left_context,
|
||||
streaming_data=False
|
||||
simulate_streaming=True
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
@ -554,7 +608,7 @@ def main():
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.streaming_mode:
|
||||
if params.simulate_streaming:
|
||||
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context}"
|
||||
|
||||
@ -590,13 +644,14 @@ def main():
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.simulate_streaming:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
), "Decoding in streaming requires causal convolution"
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
# TODO(wei kang): make following config more elegant
|
||||
params.dynamic_chunk_training=params.streaming_mode
|
||||
params.short_chunk_size=25
|
||||
params.num_left_chunks=params.left_context // params.right_chunk_size
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if params.iter > 0:
|
||||
|
@ -28,6 +28,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--exp-dir pruned_transducer_stateless/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300
|
||||
|
||||
# train a streaming model
|
||||
./pruned_transducer_stateless/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir pruned_transducer_stateless/exp \
|
||||
--full-libri 1 \
|
||||
--dynamic-chunk-training 1 \
|
||||
--causal-convolution 1 \
|
||||
--short-chunk-size 25 \
|
||||
--num-left-chunks 4 \
|
||||
--max-duration 300
|
||||
"""
|
||||
|
||||
|
||||
@ -222,28 +235,40 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dynamic-chunk-training",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||
model, this requires to be True.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--causal-convolution",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use causal convolution, this requires to be True when
|
||||
using dynamic_chunk_training.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--short-chunk-size",
|
||||
type=int,
|
||||
default=25,
|
||||
help="chunk length of dynamic training",
|
||||
help="""Chunk length of dynamic training, the chunk size would be either
|
||||
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-left-chunks",
|
||||
type=int,
|
||||
default=4,
|
||||
help="chunk length of dynamic training",
|
||||
help="How many left context can be seen in chunks when calculating attention.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dynamic-chunk-training",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||
model, this requires to be True
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@ -336,7 +361,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
dynamic_chunk_training=params.dynamic_chunk_training,
|
||||
short_chunk_size=params.short_chunk_size,
|
||||
num_left_chunks=params.num_left_chunks,
|
||||
causal=True if params.dynamic_chunk_training else False,
|
||||
causal=params.causal_convolution,
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -789,6 +814,11 @@ def run(rank, world_size, args):
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.dynamic_chunk_training:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
), "dynamic_chunk_training requires causal convolution"
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
@ -105,6 +105,26 @@ class Conformer(Transformer):
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
normalize_before (bool): whether to use layer_norm before the first block.
|
||||
vgg_frontend (bool): whether to use vgg frontend.
|
||||
dynamic_chunk_training (bool): whether to use dynamic chunk training, if
|
||||
you want to train a streaming model, this is expected to be True.
|
||||
When setting True, it will use a masking strategy to make the attention
|
||||
see only limited left and right context.
|
||||
short_chunk_threshold (float): a threshold to determinize the chunk size
|
||||
to be used in masking training, if the randomly generated chunk size
|
||||
is greater than ``max_len * short_chunk_threshold`` (max_len is the
|
||||
max sequence length of current batch) then it will use
|
||||
full context in training (i.e. with chunk size equals to max_len).
|
||||
This will be used only when dynamic_chunk_training is True.
|
||||
short_chunk_size (int): see docs above, if the randomly generated chunk
|
||||
size equals to or less than ``max_len * short_chunk_threshold``, the
|
||||
chunk size will be sampled uniformly from 1 to short_chunk_size.
|
||||
This also will be used only when dynamic_chunk_training is True.
|
||||
num_left_chunks (int): the left context attention can see in chunks, the
|
||||
chunk size is decided by short_chunk_threshold and short_chunk_size.
|
||||
A minus value means seeing full left context.
|
||||
This also will be used only when dynamic_chunk_training is True.
|
||||
causal (bool): Whether to use causal convolution in conformer encoder
|
||||
layer. This MUST be True when using dynamic_chunk_training.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -229,16 +249,45 @@ class Conformer(Transformer):
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
decode_states: Optional[DecodeStates] = None,
|
||||
chunk_size: int = 32,
|
||||
chunk_size: int = 16,
|
||||
left_context: int = 64,
|
||||
streaming_data: bool = True,
|
||||
simulate_streaming: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, DecodeStates]:
|
||||
# x: [N, T, C]
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`x` before padding.
|
||||
decode_states:
|
||||
The decode states for previous frames which contains the cached data
|
||||
and the offset of current chunk in the whole sequence.
|
||||
chunk_size:
|
||||
The chunk size for decoding, this will be used to simulate streaming
|
||||
decoding using masking.
|
||||
left_context:
|
||||
How many old frames the attention can see in current chunk, it MUST
|
||||
be equal to left_context in decode_states.
|
||||
simulate_streaming:
|
||||
If setting True, it will use a masking strategy to simulate streaming
|
||||
fashion (i.e. every chunk data only see limited left context and
|
||||
right context). The whole sequence is supposed to be send at a time
|
||||
When using simulate_streaming.
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `logits` before padding.
|
||||
- decode_states, the updated DecodeStates including the information
|
||||
of current chunk.
|
||||
"""
|
||||
|
||||
# x: [N, T, C]
|
||||
# Caution: We assume the subsampling factor is 4!
|
||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||
|
||||
if streaming_data:
|
||||
if not simulate_streaming:
|
||||
assert (
|
||||
decode_states is not None
|
||||
), "Require cache when sending data in streaming mode"
|
||||
@ -309,7 +358,9 @@ class ConformerEncoderLayer(nn.Module):
|
||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
cnn_module_kernel (int): Kernel size of convolution module.
|
||||
normalize_before: whether to use layer_norm before the first block.
|
||||
normalize_before (bool): whether to use layer_norm before the first block.
|
||||
causal (bool): Whether to use causal convolution in conformer encoder
|
||||
layer. This MUST be True when using dynamic_chunk_training and streaming decoding.
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||
@ -386,10 +437,14 @@ class ConformerEncoderLayer(nn.Module):
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
attn_cache: attention cache for previous frames.
|
||||
conv_cache: convolution cache for previous frames.
|
||||
left_context: left context in frames used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
@ -502,10 +557,16 @@ class ConformerEncoder(nn.Module):
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
attn_cache: attention cache for previous frames.
|
||||
conv_cache: convolution cache for previous frames.
|
||||
left_context: left context in frames used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
Shape:
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
@ -606,7 +667,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
context (int): left context in frames used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
@ -699,6 +762,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
left_context (int): left context in frames used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
Shape:
|
||||
- Inputs:
|
||||
@ -753,6 +819,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
Args:
|
||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||
time1 means the length of query vector.
|
||||
left_context (int): left context in frames used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
Returns:
|
||||
Tensor: tensor of shape (batch, head, time1, time2)
|
||||
@ -810,6 +879,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
left_context (int): left context in frames used during streaming decoding.
|
||||
this is used only in real streaming decoding, in other circumstances,
|
||||
it MUST be 0.
|
||||
|
||||
Shape:
|
||||
Inputs:
|
||||
@ -1083,7 +1155,7 @@ class ConvolutionModule(nn.Module):
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
bias (bool): Whether to use bias in conv layers (default=True).
|
||||
|
||||
causal (bool): Whether to use causal convolution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
Loading…
x
Reference in New Issue
Block a user