mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
streaming for pruned_transducer_stateless4
This commit is contained in:
parent
e923b1b336
commit
4f2ef237eb
@ -54,6 +54,18 @@ Usage:
|
|||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
--max-states 8
|
--max-states 8
|
||||||
|
|
||||||
|
(5) decode in streaming mode (take greedy search as an example)
|
||||||
|
./pruned_transducer_stateless4/decode.py \
|
||||||
|
--epoch 30 \
|
||||||
|
--avg 15 \
|
||||||
|
--simulate-streaming 1 \
|
||||||
|
--causal-convolution 1 \
|
||||||
|
--right-chunk-size 16 \
|
||||||
|
--left-context 64 \
|
||||||
|
--exp-dir ./pruned_transducer_stateless4/exp \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -91,6 +103,7 @@ from icefall.utils import (
|
|||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -212,6 +225,70 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dynamic-chunk-training",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||||
|
model, this requires to be True.
|
||||||
|
Note: not needed for decoding, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--short-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=25,
|
||||||
|
help="""Chunk length of dynamic training, the chunk size would be either
|
||||||
|
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||||
|
Note: not needed for decoding, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-left-chunks",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="""How many left context can be seen in chunks when calculating attention.
|
||||||
|
Note: not needed for decoding, adding it here to construct transducer model,
|
||||||
|
as we reuse the code in train.py.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--simulate-streaming",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||||
|
test a streaming model.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--causal-convolution",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use causal convolution, this requires to be True when
|
||||||
|
using dynamic_chunk_training.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--right-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help="The chunk size for decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--left-context",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
help="left context can be seen during decoding (in frames after subsampling)",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -260,9 +337,26 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
feature_lens += params.left_context
|
||||||
x=feature, x_lens=feature_lens
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, params.left_context),
|
||||||
|
value=LOG_EPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||||
|
x=feature,
|
||||||
|
x_lens=feature_lens,
|
||||||
|
chunk_size=params.right_chunk_size,
|
||||||
|
left_context=params.left_context,
|
||||||
|
simulate_streaming=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=feature, x_lens=feature_lens
|
||||||
|
)
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
@ -475,6 +569,10 @@ def main():
|
|||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
params.suffix += f"-streaming-chunk-size-{params.right_chunk_size}"
|
||||||
|
params.suffix += f"-left-context-{params.left_context}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
@ -507,6 +605,11 @@ def main():
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.simulate_streaming:
|
||||||
|
assert (
|
||||||
|
params.causal_convolution
|
||||||
|
), "Decoding in streaming requires causal convolution"
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
@ -41,6 +41,19 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 550
|
--max-duration 550
|
||||||
|
|
||||||
|
# train a streaming model
|
||||||
|
./pruned_transducer_stateless4/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--exp-dir pruned_transducer_stateless4/exp \
|
||||||
|
--full-libri 1 \
|
||||||
|
--dynamic-chunk-training 1 \
|
||||||
|
--causal-convolution 1 \
|
||||||
|
--short-chunk-size 25 \
|
||||||
|
--num-left-chunks 4 \
|
||||||
|
--max-duration 300
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -281,6 +294,59 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dynamic-chunk-training",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||||
|
model, this requires to be True.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--causal-convolution",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to use causal convolution, this requires to be True when
|
||||||
|
using dynamic_chunk_training.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--short-chunk-size",
|
||||||
|
type=int,
|
||||||
|
default=25,
|
||||||
|
help="""Chunk length of dynamic training, the chunk size would be either
|
||||||
|
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-left-chunks",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="How many left context can be seen in chunks when calculating attention.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--delay-penalty",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="""A constant value to penalize symbol delay, this may be
|
||||||
|
needed when training with time masking, to avoid the time masking
|
||||||
|
encouraging the network to delay symbols.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--return-sym-delay",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="""Whether to return `sym_delay` during training, this is a stat
|
||||||
|
to measure symbols emission delay, especially for time masking training.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -367,6 +433,10 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
|
dynamic_chunk_training=params.dynamic_chunk_training,
|
||||||
|
short_chunk_size=params.short_chunk_size,
|
||||||
|
num_left_chunks=params.num_left_chunks,
|
||||||
|
causal=params.causal_convolution,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -571,7 +641,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss, sym_delay = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -579,6 +649,8 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
|
delay_penalty=params.delay_penalty,
|
||||||
|
return_sym_delay=params.return_sym_delay,
|
||||||
)
|
)
|
||||||
# after the main warmup step, we keep pruned_loss_scale small
|
# after the main warmup step, we keep pruned_loss_scale small
|
||||||
# for the same amount of time (model_warm_step), to avoid
|
# for the same amount of time (model_warm_step), to avoid
|
||||||
@ -608,6 +680,9 @@ def compute_loss(
|
|||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
|
|
||||||
|
if params.return_sym_delay:
|
||||||
|
info["sym_delay"] = sym_delay.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
|
|
||||||
@ -847,6 +922,15 @@ def run(rank, world_size, args):
|
|||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
if params.dynamic_chunk_training:
|
||||||
|
assert (
|
||||||
|
params.causal_convolution
|
||||||
|
), "dynamic_chunk_training requires causal convolution"
|
||||||
|
else:
|
||||||
|
assert (
|
||||||
|
params.delay_penalty == 0.0
|
||||||
|
), "delay_penalty is intended for dynamic_chunk_training"
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user