modify decode.py pretrained.py test_model.py train.py

This commit is contained in:
yaozengwei 2022-07-17 15:38:53 +08:00
parent b1be6ea475
commit 4a0dea2aa2
4 changed files with 28 additions and 209 deletions

View File

@ -18,36 +18,36 @@
""" """
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless2/decode.py \ ./lstm_transducer_stateless/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (not recommended) (2) beam search (not recommended)
./pruned_transducer_stateless2/decode.py \ ./lstm_transducer_stateless/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless2/decode.py \ ./lstm_transducer_stateless/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (one best) (4) fast beam search (one best)
./pruned_transducer_stateless2/decode.py \ ./lstm_transducer_stateless/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 20.0 \ --beam 20.0 \
@ -55,10 +55,10 @@ Usage:
--max-states 64 --max-states 64
(5) fast beam search (nbest) (5) fast beam search (nbest)
./pruned_transducer_stateless2/decode.py \ ./lstm_transducer_stateless/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest \ --decoding-method fast_beam_search_nbest \
--beam 20.0 \ --beam 20.0 \
@ -68,10 +68,10 @@ Usage:
--nbest-scale 0.5 --nbest-scale 0.5
(6) fast beam search (nbest oracle WER) (6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless2/decode.py \ ./lstm_transducer_stateless/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \ --decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \ --beam 20.0 \
@ -81,30 +81,15 @@ Usage:
--nbest-scale 0.5 --nbest-scale 0.5
(7) fast beam search (with LG) (7) fast beam search (with LG)
./pruned_transducer_stateless2/decode.py \ ./lstm_transducer_stateless/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \ --decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \ --beam 20.0 \
--max-contexts 8 \ --max-contexts 8 \
--max-states 64 --max-states 64
(8) decode in streaming mode (take greedy search as an example)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--simulate-streaming 1 \
--causal-convolution 1 \
--decode-chunk-size 16 \
--left-context 64 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method greedy_search
--beam 20.0 \
--max-contexts 8 \
--max-states 64
""" """
@ -130,7 +115,7 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from train import add_model_arguments, get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -142,7 +127,6 @@ from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
str2bool,
write_error_stats, write_error_stats,
) )
@ -286,29 +270,6 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument( parser.add_argument(
"--num-paths", "--num-paths",
type=int, type=int,
@ -327,7 +288,6 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
) )
add_model_arguments(parser)
return parser return parser
@ -387,18 +347,9 @@ def decode_one_batch(
value=LOG_EPS, value=LOG_EPS,
) )
if params.simulate_streaming: encoder_out, encoder_out_lens = model.encoder(
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=feature, x_lens=feature_lens
x=feature, )
x_lens=feature_lens,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
@ -674,10 +625,6 @@ def main():
else: else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
@ -712,11 +659,6 @@ def main():
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")

View File

@ -77,9 +77,7 @@ from beam_search import (
modified_beam_search, modified_beam_search,
) )
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.utils import str2bool
def get_parser(): def get_parser():
@ -180,30 +178,6 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
add_model_arguments(parser)
return parser return parser
@ -248,11 +222,6 @@ def main():
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")
@ -299,18 +268,9 @@ def main():
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
if params.simulate_streaming: encoder_out, encoder_out_lens = model.encoder(
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( x=features, x_lens=feature_lengths
x=features, )
x_lens=feature_lengths,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyps = []

View File

@ -34,31 +34,6 @@ def test_model():
params.context_size = 2 params.context_size = 2
params.unk_id = 2 params.unk_id = 2
params.dynamic_chunk_training = False
params.short_chunk_size = 25
params.num_left_chunks = 4
params.causal_convolution = False
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def test_model_streaming():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.dynamic_chunk_training = True
params.short_chunk_size = 25
params.num_left_chunks = 4
params.causal_convolution = True
model = get_transducer_model(params) model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
@ -69,7 +44,6 @@ def test_model_streaming():
def main(): def main():
test_model() test_model()
test_model_streaming()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -21,37 +21,24 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless2/train.py \ ./lstm_transducer_stateless/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir lstm_transducer_stateless/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300
# For mix precision training: # For mix precision training:
./pruned_transducer_stateless2/train.py \ ./lstm_transducer_stateless/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--use-fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir lstm_transducer_stateless/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 550 --max-duration 550
# train a streaming model
./pruned_transducer_stateless2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 25 \
--num-left-chunks 4 \
--max-duration 300
""" """
@ -69,7 +56,7 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from lstm import RNN
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
@ -95,42 +82,6 @@ LRSchedulerType = Union[
] ]
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="How many left context can be seen in chunks when calculating attention.",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -311,8 +262,6 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
add_model_arguments(parser)
return parser return parser
@ -374,7 +323,6 @@ def get_params() -> AttributeDict:
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
"encoder_dim": 512, "encoder_dim": 512,
"nhead": 8,
"dim_feedforward": 2048, "dim_feedforward": 2048,
"num_encoder_layers": 12, "num_encoder_layers": 12,
# parameters for decoder # parameters for decoder
@ -392,17 +340,12 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer # TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer( encoder = RNN(
num_features=params.feature_dim, num_features=params.feature_dim,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
d_model=params.encoder_dim, d_model=params.encoder_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks,
causal=params.causal_convolution,
) )
return encoder return encoder