streaming for pruned_transducer_stateless4

This commit is contained in:
pkufool 2022-05-26 10:22:05 +08:00
parent e923b1b336
commit 4f2ef237eb
2 changed files with 190 additions and 3 deletions

View File

@ -54,6 +54,18 @@ Usage:
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
--max-states 8 --max-states 8
(5) decode in streaming mode (take greedy search as an example)
./pruned_transducer_stateless4/decode.py \
--epoch 30 \
--avg 15 \
--simulate-streaming 1 \
--causal-convolution 1 \
--right-chunk-size 16 \
--left-context 64 \
--exp-dir ./pruned_transducer_stateless4/exp \
--max-duration 600 \
--decoding-method greedy_search
""" """
@ -91,6 +103,7 @@ from icefall.utils import (
write_error_stats, write_error_stats,
) )
LOG_EPS = math.log(1e-10)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -212,6 +225,70 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="""How many left context can be seen in chunks when calculating attention.
Note: not needed for decoding, adding it here to construct transducer model,
as we reuse the code in train.py.
""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--right-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
return parser return parser
@ -260,9 +337,26 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder( feature_lens += params.left_context
x=feature, x_lens=feature_lens feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, params.left_context),
value=LOG_EPS,
) )
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
chunk_size=params.right_chunk_size,
left_context=params.left_context,
simulate_streaming=True
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -475,6 +569,10 @@ def main():
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
@ -507,6 +605,11 @@ def main():
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")

View File

@ -41,6 +41,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--full-libri 1 \ --full-libri 1 \
--max-duration 550 --max-duration 550
# train a streaming model
./pruned_transducer_stateless4/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless4/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 25 \
--num-left-chunks 4 \
--max-duration 300
""" """
@ -281,6 +294,59 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="How many left context can be seen in chunks when calculating attention.",
)
parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value to penalize symbol delay, this may be
needed when training with time masking, to avoid the time masking
encouraging the network to delay symbols.
""",
)
parser.add_argument(
"--return-sym-delay",
type=str2bool,
default=False,
help="""Whether to return `sym_delay` during training, this is a stat
to measure symbols emission delay, especially for time masking training.
""",
)
return parser return parser
@ -367,6 +433,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
nhead=params.nhead, nhead=params.nhead,
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks,
causal=params.causal_convolution,
) )
return encoder return encoder
@ -571,7 +641,7 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss, sym_delay = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -579,6 +649,8 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
delay_penalty=params.delay_penalty,
return_sym_delay=params.return_sym_delay,
) )
# after the main warmup step, we keep pruned_loss_scale small # after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid # for the same amount of time (model_warm_step), to avoid
@ -608,6 +680,9 @@ def compute_loss(
info["simple_loss"] = simple_loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.return_sym_delay:
info["sym_delay"] = sym_delay.detach().cpu().item()
return loss, info return loss, info
@ -847,6 +922,15 @@ def run(rank, world_size, args):
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.dynamic_chunk_training:
assert (
params.causal_convolution
), "dynamic_chunk_training requires causal convolution"
else:
assert (
params.delay_penalty == 0.0
), "delay_penalty is intended for dynamic_chunk_training"
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")