mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
Add more documents
This commit is contained in:
parent
5bd2490b44
commit
118b09463d
@ -66,6 +66,18 @@ Usage:
|
|||||||
--beam 8 \
|
--beam 8 \
|
||||||
--max-contexts 8 \
|
--max-contexts 8 \
|
||||||
--max-states 64
|
--max-states 64
|
||||||
|
|
||||||
|
(6) decode in streaming mode (take greedy search as an example)
|
||||||
|
./pruned_transducer_stateless/decode.py \
|
||||||
|
--epoch 28 \
|
||||||
|
--avg 15 \
|
||||||
|
--simulate-streaming 1 \
|
||||||
|
--causal-convolution 1 \
|
||||||
|
--right-chunk-size 16 \
|
||||||
|
--left-context 64 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -249,24 +261,66 @@ def get_parser():
|
|||||||
help="""Maximum number of symbols per frame.
|
help="""Maximum number of symbols per frame.
|
||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--streaming-mode",
|
"--dynamic-chunk-training",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
default=False,
|
default=False,
|
||||||
help="""
|
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||||
|
model, this requires to be True.
|
||||||
|
Note: not needed for decoding, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--short-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=25,
|
||||||
|
help="""Chunk length of dynamic training, the chunk size would be either
|
||||||
|
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||||
|
Note: not needed for decoding, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-left-chunks",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""How many left context can be seen in chunks when calculating attention.
|
||||||
|
Note: not needed for decoding, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--simulate-streaming",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||||
|
test a streaming model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--causal-convolution",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use causal convolution, this requires to be True when
|
||||||
|
using dynamic_chunk_training.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--right-chunk-size",
|
"--right-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=16,
|
default=16,
|
||||||
help="right context to attend during decoding",
|
help="The chunk size for decoding (in frames after subsampling)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--left-context",
|
"--left-context",
|
||||||
type=int,
|
type=int,
|
||||||
default=64,
|
default=64,
|
||||||
help="left context to attend during decoding",
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@ -320,13 +374,13 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
if params.streaming_mode:
|
if params.simulate_streaming:
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
chunk_size=params.right_chunk_size,
|
chunk_size=params.right_chunk_size,
|
||||||
left_context=params.left_context,
|
left_context=params.left_context,
|
||||||
streaming_data=False
|
simulate_streaming=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
@ -554,7 +608,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if params.streaming_mode:
|
if params.simulate_streaming:
|
||||||
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
|
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
|
||||||
params.suffix += f"-left-context-{params.left_context}"
|
params.suffix += f"-left-context-{params.left_context}"
|
||||||
|
|
||||||
@ -590,13 +644,14 @@ def main():
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
assert (
|
||||||
|
params.causal_convolution
|
||||||
|
), "Decoding in streaming requires causal convolution"
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
# TODO(wei kang): make following config more elegant
|
|
||||||
params.dynamic_chunk_training=params.streaming_mode
|
|
||||||
params.short_chunk_size=25
|
|
||||||
params.num_left_chunks=params.left_context // params.right_chunk_size
|
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
|
@ -28,6 +28,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--exp-dir pruned_transducer_stateless/exp \
|
--exp-dir pruned_transducer_stateless/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 300
|
--max-duration 300
|
||||||
|
|
||||||
|
# train a streaming model
|
||||||
|
./pruned_transducer_stateless/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--start-epoch 0 \
|
||||||
|
--exp-dir pruned_transducer_stateless/exp \
|
||||||
|
--full-libri 1 \
|
||||||
|
--dynamic-chunk-training 1 \
|
||||||
|
--causal-convolution 1 \
|
||||||
|
--short-chunk-size 25 \
|
||||||
|
--num-left-chunks 4 \
|
||||||
|
--max-duration 300
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -222,28 +235,40 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dynamic-chunk-training",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||||
|
model, this requires to be True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--causal-convolution",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use causal convolution, this requires to be True when
|
||||||
|
using dynamic_chunk_training.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--short-chunk-size",
|
"--short-chunk-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=25,
|
default=25,
|
||||||
help="chunk length of dynamic training",
|
help="""Chunk length of dynamic training, the chunk size would be either
|
||||||
|
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-left-chunks",
|
"--num-left-chunks",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="chunk length of dynamic training",
|
help="How many left context can be seen in chunks when calculating attention.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--dynamic-chunk-training",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
|
||||||
model, this requires to be True
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -336,7 +361,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
dynamic_chunk_training=params.dynamic_chunk_training,
|
dynamic_chunk_training=params.dynamic_chunk_training,
|
||||||
short_chunk_size=params.short_chunk_size,
|
short_chunk_size=params.short_chunk_size,
|
||||||
num_left_chunks=params.num_left_chunks,
|
num_left_chunks=params.num_left_chunks,
|
||||||
causal=True if params.dynamic_chunk_training else False,
|
causal=params.causal_convolution,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -789,6 +814,11 @@ def run(rank, world_size, args):
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.dynamic_chunk_training:
|
||||||
|
assert (
|
||||||
|
params.causal_convolution
|
||||||
|
), "dynamic_chunk_training requires causal convolution"
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
@ -105,6 +105,26 @@ class Conformer(Transformer):
|
|||||||
cnn_module_kernel (int): Kernel size of convolution module
|
cnn_module_kernel (int): Kernel size of convolution module
|
||||||
normalize_before (bool): whether to use layer_norm before the first block.
|
normalize_before (bool): whether to use layer_norm before the first block.
|
||||||
vgg_frontend (bool): whether to use vgg frontend.
|
vgg_frontend (bool): whether to use vgg frontend.
|
||||||
|
dynamic_chunk_training (bool): whether to use dynamic chunk training, if
|
||||||
|
you want to train a streaming model, this is expected to be True.
|
||||||
|
When setting True, it will use a masking strategy to make the attention
|
||||||
|
see only limited left and right context.
|
||||||
|
short_chunk_threshold (float): a threshold to determinize the chunk size
|
||||||
|
to be used in masking training, if the randomly generated chunk size
|
||||||
|
is greater than ``max_len * short_chunk_threshold`` (max_len is the
|
||||||
|
max sequence length of current batch) then it will use
|
||||||
|
full context in training (i.e. with chunk size equals to max_len).
|
||||||
|
This will be used only when dynamic_chunk_training is True.
|
||||||
|
short_chunk_size (int): see docs above, if the randomly generated chunk
|
||||||
|
size equals to or less than ``max_len * short_chunk_threshold``, the
|
||||||
|
chunk size will be sampled uniformly from 1 to short_chunk_size.
|
||||||
|
This also will be used only when dynamic_chunk_training is True.
|
||||||
|
num_left_chunks (int): the left context attention can see in chunks, the
|
||||||
|
chunk size is decided by short_chunk_threshold and short_chunk_size.
|
||||||
|
A minus value means seeing full left context.
|
||||||
|
This also will be used only when dynamic_chunk_training is True.
|
||||||
|
causal (bool): Whether to use causal convolution in conformer encoder
|
||||||
|
layer. This MUST be True when using dynamic_chunk_training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -229,16 +249,45 @@ class Conformer(Transformer):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
decode_states: Optional[DecodeStates] = None,
|
decode_states: Optional[DecodeStates] = None,
|
||||||
chunk_size: int = 32,
|
chunk_size: int = 16,
|
||||||
left_context: int = 64,
|
left_context: int = 64,
|
||||||
streaming_data: bool = True,
|
simulate_streaming: bool = False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, DecodeStates]:
|
) -> Tuple[torch.Tensor, torch.Tensor, DecodeStates]:
|
||||||
# x: [N, T, C]
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||||
|
x_lens:
|
||||||
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
|
`x` before padding.
|
||||||
|
decode_states:
|
||||||
|
The decode states for previous frames which contains the cached data
|
||||||
|
and the offset of current chunk in the whole sequence.
|
||||||
|
chunk_size:
|
||||||
|
The chunk size for decoding, this will be used to simulate streaming
|
||||||
|
decoding using masking.
|
||||||
|
left_context:
|
||||||
|
How many old frames the attention can see in current chunk, it MUST
|
||||||
|
be equal to left_context in decode_states.
|
||||||
|
simulate_streaming:
|
||||||
|
If setting True, it will use a masking strategy to simulate streaming
|
||||||
|
fashion (i.e. every chunk data only see limited left context and
|
||||||
|
right context). The whole sequence is supposed to be send at a time
|
||||||
|
When using simulate_streaming.
|
||||||
|
Returns:
|
||||||
|
Return a tuple containing 2 tensors:
|
||||||
|
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||||
|
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||||
|
of frames in `logits` before padding.
|
||||||
|
- decode_states, the updated DecodeStates including the information
|
||||||
|
of current chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# x: [N, T, C]
|
||||||
# Caution: We assume the subsampling factor is 4!
|
# Caution: We assume the subsampling factor is 4!
|
||||||
lengths = ((x_lens - 1) // 2 - 1) // 2
|
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||||
|
|
||||||
if streaming_data:
|
if not simulate_streaming:
|
||||||
assert (
|
assert (
|
||||||
decode_states is not None
|
decode_states is not None
|
||||||
), "Require cache when sending data in streaming mode"
|
), "Require cache when sending data in streaming mode"
|
||||||
@ -309,7 +358,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||||
dropout: the dropout value (default=0.1).
|
dropout: the dropout value (default=0.1).
|
||||||
cnn_module_kernel (int): Kernel size of convolution module.
|
cnn_module_kernel (int): Kernel size of convolution module.
|
||||||
normalize_before: whether to use layer_norm before the first block.
|
normalize_before (bool): whether to use layer_norm before the first block.
|
||||||
|
causal (bool): Whether to use causal convolution in conformer encoder
|
||||||
|
layer. This MUST be True when using dynamic_chunk_training and streaming decoding.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||||
@ -386,10 +437,14 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
pos_emb: Positional embedding tensor (required).
|
pos_emb: Positional embedding tensor (required).
|
||||||
src_mask: the mask for the src sequence (optional).
|
src_mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
attn_cache: attention cache for previous frames.
|
||||||
|
conv_cache: convolution cache for previous frames.
|
||||||
|
left_context: left context in frames used during streaming decoding.
|
||||||
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
|
it MUST be 0.
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
pos_emb: (N, 2*S-1, E)
|
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
|
||||||
src_mask: (S, S).
|
src_mask: (S, S).
|
||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
@ -502,10 +557,16 @@ class ConformerEncoder(nn.Module):
|
|||||||
pos_emb: Positional embedding tensor (required).
|
pos_emb: Positional embedding tensor (required).
|
||||||
mask: the mask for the src sequence (optional).
|
mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
attn_cache: attention cache for previous frames.
|
||||||
|
conv_cache: convolution cache for previous frames.
|
||||||
|
left_context: left context in frames used during streaming decoding.
|
||||||
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
|
it MUST be 0.
|
||||||
|
Shape:
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
pos_emb: (N, 2*S-1, E)
|
pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
|
||||||
mask: (S, S).
|
mask: (S, S).
|
||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||||
@ -606,7 +667,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||||
|
context (int): left context in frames used during streaming decoding.
|
||||||
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
|
it MUST be 0.
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||||
@ -699,6 +762,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
need_weights: output attn_output_weights.
|
need_weights: output attn_output_weights.
|
||||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||||
|
left_context (int): left context in frames used during streaming decoding.
|
||||||
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
|
it MUST be 0.
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
- Inputs:
|
- Inputs:
|
||||||
@ -753,6 +819,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||||
time1 means the length of query vector.
|
time1 means the length of query vector.
|
||||||
|
left_context (int): left context in frames used during streaming decoding.
|
||||||
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
|
it MUST be 0.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: tensor of shape (batch, head, time1, time2)
|
Tensor: tensor of shape (batch, head, time1, time2)
|
||||||
@ -810,6 +879,9 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
need_weights: output attn_output_weights.
|
need_weights: output attn_output_weights.
|
||||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||||
|
left_context (int): left context in frames used during streaming decoding.
|
||||||
|
this is used only in real streaming decoding, in other circumstances,
|
||||||
|
it MUST be 0.
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
Inputs:
|
Inputs:
|
||||||
@ -1083,7 +1155,7 @@ class ConvolutionModule(nn.Module):
|
|||||||
channels (int): The number of channels of conv layers.
|
channels (int): The number of channels of conv layers.
|
||||||
kernel_size (int): Kernerl size of conv layers.
|
kernel_size (int): Kernerl size of conv layers.
|
||||||
bias (bool): Whether to use bias in conv layers (default=True).
|
bias (bool): Whether to use bias in conv layers (default=True).
|
||||||
|
causal (bool): Whether to use causal convolution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user