Add more documents

This commit is contained in:
pkufool 2022-05-20 15:26:45 +08:00
parent 5bd2490b44
commit 118b09463d
3 changed files with 189 additions and 32 deletions

View File

@ -66,6 +66,18 @@ Usage:
--beam 8 \ --beam 8 \
--max-contexts 8 \ --max-contexts 8 \
--max-states 64 --max-states 64
(6) decode in streaming mode (take greedy search as an example)
./pruned_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--simulate-streaming 1 \
--causal-convolution 1 \
--right-chunk-size 16 \
--left-context 64 \
--exp-dir ./pruned_transducer_stateless/exp \
--max-duration 600 \
--decoding-method greedy_search
""" """
@ -249,24 +261,66 @@ def get_parser():
help="""Maximum number of symbols per frame. help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument( parser.add_argument(
"--streaming-mode", "--dynamic-chunk-training",
type=str2bool, type=str2bool,
default=False, default=False,
help=""" help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="""How many left context can be seen in chunks when calculating attention.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""", """,
) )
parser.add_argument( parser.add_argument(
"--right-chunk-size", "--right-chunk-size",
type=int, type=int,
default=16, default=16,
help="right context to attend during decoding", help="The chunk size for decoding (in frames after subsampling)",
) )
parser.add_argument( parser.add_argument(
"--left-context", "--left-context",
type=int, type=int,
default=64, default=64,
help="left context to attend during decoding", help="left context can be seen during decoding (in frames after subsampling)",
) )
return parser return parser
@ -320,13 +374,13 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
if params.streaming_mode: if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
chunk_size=params.right_chunk_size, chunk_size=params.right_chunk_size,
left_context=params.left_context, left_context=params.left_context,
streaming_data=False simulate_streaming=True
) )
else: else:
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
@ -554,7 +608,7 @@ def main():
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.streaming_mode: if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}" params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
params.suffix += f"-left-context-{params.left_context}" params.suffix += f"-left-context-{params.left_context}"
@ -590,13 +644,14 @@ def main():
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
# TODO(wei kang): make following config more elegant
params.dynamic_chunk_training=params.streaming_mode
params.short_chunk_size=25
params.num_left_chunks=params.left_context // params.right_chunk_size
model = get_transducer_model(params) model = get_transducer_model(params)
if params.iter > 0: if params.iter > 0:

View File

@ -28,6 +28,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir pruned_transducer_stateless/exp \ --exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300
# train a streaming model
./pruned_transducer_stateless/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 25 \
--num-left-chunks 4 \
--max-duration 300
""" """
@ -222,28 +235,40 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument( parser.add_argument(
"--short-chunk-size", "--short-chunk-size",
type=int, type=int,
default=25, default=25,
help="chunk length of dynamic training", help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
""",
) )
parser.add_argument( parser.add_argument(
"--num-left-chunks", "--num-left-chunks",
type=int, type=int,
default=4, default=4,
help="chunk length of dynamic training", help="How many left context can be seen in chunks when calculating attention.",
) )
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True
""",
)
return parser return parser
@ -336,7 +361,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
dynamic_chunk_training=params.dynamic_chunk_training, dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size, short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks, num_left_chunks=params.num_left_chunks,
causal=True if params.dynamic_chunk_training else False, causal=params.causal_convolution,
) )
return encoder return encoder
@ -789,6 +814,11 @@ def run(rank, world_size, args):
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training:
assert (
params.causal_convolution
), "dynamic_chunk_training requires causal convolution"
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")

View File

@ -105,6 +105,26 @@ class Conformer(Transformer):
cnn_module_kernel (int): Kernel size of convolution module cnn_module_kernel (int): Kernel size of convolution module
normalize_before (bool): whether to use layer_norm before the first block. normalize_before (bool): whether to use layer_norm before the first block.
vgg_frontend (bool): whether to use vgg frontend. vgg_frontend (bool): whether to use vgg frontend.
dynamic_chunk_training (bool): whether to use dynamic chunk training, if
you want to train a streaming model, this is expected to be True.
When setting True, it will use a masking strategy to make the attention
see only limited left and right context.
short_chunk_threshold (float): a threshold to determinize the chunk size
to be used in masking training, if the randomly generated chunk size
is greater than ``max_len * short_chunk_threshold`` (max_len is the
max sequence length of current batch) then it will use
full context in training (i.e. with chunk size equals to max_len).
This will be used only when dynamic_chunk_training is True.
short_chunk_size (int): see docs above, if the randomly generated chunk
size equals to or less than ``max_len * short_chunk_threshold``, the
chunk size will be sampled uniformly from 1 to short_chunk_size.
This also will be used only when dynamic_chunk_training is True.
num_left_chunks (int): the left context attention can see in chunks, the
chunk size is decided by short_chunk_threshold and short_chunk_size.
A minus value means seeing full left context.
This also will be used only when dynamic_chunk_training is True.
causal (bool): Whether to use causal convolution in conformer encoder
layer. This MUST be True when using dynamic_chunk_training.
""" """
def __init__( def __init__(
@ -229,16 +249,45 @@ class Conformer(Transformer):
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
decode_states: Optional[DecodeStates] = None, decode_states: Optional[DecodeStates] = None,
chunk_size: int = 32, chunk_size: int = 16,
left_context: int = 64, left_context: int = 64,
streaming_data: bool = True, simulate_streaming: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, DecodeStates]: ) -> Tuple[torch.Tensor, torch.Tensor, DecodeStates]:
# x: [N, T, C] """
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
decode_states:
The decode states for previous frames which contains the cached data
and the offset of current chunk in the whole sequence.
chunk_size:
The chunk size for decoding, this will be used to simulate streaming
decoding using masking.
left_context:
How many old frames the attention can see in current chunk, it MUST
be equal to left_context in decode_states.
simulate_streaming:
If setting True, it will use a masking strategy to simulate streaming
fashion (i.e. every chunk data only see limited left context and
right context). The whole sequence is supposed to be send at a time
When using simulate_streaming.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
- decode_states, the updated DecodeStates including the information
of current chunk.
"""
# x: [N, T, C]
# Caution: We assume the subsampling factor is 4! # Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2 lengths = ((x_lens - 1) // 2 - 1) // 2
if streaming_data: if not simulate_streaming:
assert ( assert (
decode_states is not None decode_states is not None
), "Require cache when sending data in streaming mode" ), "Require cache when sending data in streaming mode"
@ -309,7 +358,9 @@ class ConformerEncoderLayer(nn.Module):
dim_feedforward: the dimension of the feedforward network model (default=2048). dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1). dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module. cnn_module_kernel (int): Kernel size of convolution module.
normalize_before: whether to use layer_norm before the first block. normalize_before (bool): whether to use layer_norm before the first block.
causal (bool): Whether to use causal convolution in conformer encoder
layer. This MUST be True when using dynamic_chunk_training and streaming decoding.
Examples:: Examples::
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
@ -386,10 +437,14 @@ class ConformerEncoderLayer(nn.Module):
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional). src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
attn_cache: attention cache for previous frames.
conv_cache: convolution cache for previous frames.
left_context: left context in frames used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape: Shape:
src: (S, N, E). src: (S, N, E).
pos_emb: (N, 2*S-1, E) pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
src_mask: (S, S). src_mask: (S, S).
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number S is the source sequence length, N is the batch size, E is the feature number
@ -502,10 +557,16 @@ class ConformerEncoder(nn.Module):
pos_emb: Positional embedding tensor (required). pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional). mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
attn_cache: attention cache for previous frames.
conv_cache: convolution cache for previous frames.
left_context: left context in frames used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape:
Shape: Shape:
src: (S, N, E). src: (S, N, E).
pos_emb: (N, 2*S-1, E) pos_emb: (N, 2*S-1, E), for streaming decoding it is (N, 2*(S+left_context)-1, E).
mask: (S, S). mask: (S, S).
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
@ -606,7 +667,9 @@ class RelPositionalEncoding(torch.nn.Module):
Args: Args:
x (torch.Tensor): Input tensor (batch, time, `*`). x (torch.Tensor): Input tensor (batch, time, `*`).
context (int): left context in frames used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Returns: Returns:
torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
@ -699,6 +762,9 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights: output attn_output_weights. need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch. the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context in frames used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape: Shape:
- Inputs: - Inputs:
@ -753,6 +819,9 @@ class RelPositionMultiheadAttention(nn.Module):
Args: Args:
x: Input tensor (batch, head, time1, 2*time1-1). x: Input tensor (batch, head, time1, 2*time1-1).
time1 means the length of query vector. time1 means the length of query vector.
left_context (int): left context in frames used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Returns: Returns:
Tensor: tensor of shape (batch, head, time1, time2) Tensor: tensor of shape (batch, head, time1, time2)
@ -810,6 +879,9 @@ class RelPositionMultiheadAttention(nn.Module):
need_weights: output attn_output_weights. need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch. the batches while a 3D mask allows to specify a different mask for the entries of each batch.
left_context (int): left context in frames used during streaming decoding.
this is used only in real streaming decoding, in other circumstances,
it MUST be 0.
Shape: Shape:
Inputs: Inputs:
@ -1083,7 +1155,7 @@ class ConvolutionModule(nn.Module):
channels (int): The number of channels of conv layers. channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers. kernel_size (int): Kernerl size of conv layers.
bias (bool): Whether to use bias in conv layers (default=True). bias (bool): Whether to use bias in conv layers (default=True).
causal (bool): Whether to use causal convolution.
""" """
def __init__( def __init__(