modify decode.py pretrained.py test_model.py train.py

This commit is contained in:
yaozengwei 2022-07-17 15:38:53 +08:00
parent b1be6ea475
commit 4a0dea2aa2
4 changed files with 28 additions and 209 deletions

View File

@ -18,36 +18,36 @@
"""
Usage:
(1) greedy search
./pruned_transducer_stateless2/decode.py \
./lstm_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless2/decode.py \
./lstm_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless2/decode.py \
./lstm_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./pruned_transducer_stateless2/decode.py \
./lstm_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
@ -55,10 +55,10 @@ Usage:
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless2/decode.py \
./lstm_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
@ -68,10 +68,10 @@ Usage:
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless2/decode.py \
./lstm_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
@ -81,30 +81,15 @@ Usage:
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless2/decode.py \
./lstm_transducer_stateless/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./lstm_transducer_stateless/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(8) decode in streaming mode (take greedy search as an example)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--simulate-streaming 1 \
--causal-convolution 1 \
--decode-chunk-size 16 \
--left-context 64 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method greedy_search
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
@ -130,7 +115,7 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
from train import add_model_arguments, get_params, get_transducer_model
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
@ -142,7 +127,6 @@ from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
@ -286,29 +270,6 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
parser.add_argument(
"--num-paths",
type=int,
@ -327,7 +288,6 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
add_model_arguments(parser)
return parser
@ -387,18 +347,9 @@ def decode_one_batch(
value=LOG_EPS,
)
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=feature,
x_lens=feature_lens,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
@ -674,10 +625,6 @@ def main():
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.simulate_streaming:
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
params.suffix += f"-left-context-{params.left_context}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
@ -712,11 +659,6 @@ def main():
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params)
logging.info("About to create model")

View File

@ -77,9 +77,7 @@ from beam_search import (
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import str2bool
from train import get_params, get_transducer_model
def get_parser():
@ -180,30 +178,6 @@ def get_parser():
""",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="""Whether to simulate streaming in decoding, this is a good way to
test a streaming model.
""",
)
parser.add_argument(
"--decode-chunk-size",
type=int,
default=16,
help="The chunk size for decoding (in frames after subsampling)",
)
parser.add_argument(
"--left-context",
type=int,
default=64,
help="left context can be seen during decoding (in frames after subsampling)",
)
add_model_arguments(parser)
return parser
@ -248,11 +222,6 @@ def main():
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(f"{params}")
device = torch.device("cpu")
@ -299,18 +268,9 @@ def main():
feature_lengths = torch.tensor(feature_lengths, device=device)
if params.simulate_streaming:
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
x=features,
x_lens=feature_lengths,
chunk_size=params.decode_chunk_size,
left_context=params.left_context,
simulate_streaming=True,
)
else:
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []

View File

@ -34,31 +34,6 @@ def test_model():
params.context_size = 2
params.unk_id = 2
params.dynamic_chunk_training = False
params.short_chunk_size = 25
params.num_left_chunks = 4
params.causal_convolution = False
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def test_model_streaming():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.unk_id = 2
params.dynamic_chunk_training = True
params.short_chunk_size = 25
params.num_left_chunks = 4
params.causal_convolution = True
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
@ -69,7 +44,6 @@ def test_model_streaming():
def main():
test_model()
test_model_streaming()
if __name__ == "__main__":

View File

@ -21,37 +21,24 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless2/train.py \
./lstm_transducer_stateless/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \
--exp-dir lstm_transducer_stateless/exp \
--full-libri 1 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless2/train.py \
./lstm_transducer_stateless/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--exp-dir lstm_transducer_stateless/exp \
--full-libri 1 \
--max-duration 550
# train a streaming model
./pruned_transducer_stateless2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless/exp \
--full-libri 1 \
--dynamic-chunk-training 1 \
--causal-convolution 1 \
--short-chunk-size 25 \
--num-left-chunks 4 \
--max-duration 300
"""
@ -69,7 +56,7 @@ import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lstm import RNN
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
@ -95,42 +82,6 @@ LRSchedulerType = Union[
]
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=False,
help="""Whether to use dynamic_chunk_training, if you want a streaming
model, this requires to be True.
""",
)
parser.add_argument(
"--causal-convolution",
type=str2bool,
default=False,
help="""Whether to use causal convolution, this requires to be True when
using dynamic_chunk_training.
""",
)
parser.add_argument(
"--short-chunk-size",
type=int,
default=25,
help="""Chunk length of dynamic training, the chunk size would be either
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
""",
)
parser.add_argument(
"--num-left-chunks",
type=int,
default=4,
help="How many left context can be seen in chunks when calculating attention.",
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -311,8 +262,6 @@ def get_parser():
help="Whether to use half precision training.",
)
add_model_arguments(parser)
return parser
@ -374,7 +323,6 @@ def get_params() -> AttributeDict:
"feature_dim": 80,
"subsampling_factor": 4,
"encoder_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
# parameters for decoder
@ -392,17 +340,12 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
encoder = RNN(
num_features=params.feature_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.encoder_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks,
causal=params.causal_convolution,
)
return encoder