modify decode.py pretrained.py test_model.py train.py
This commit is contained in:
parent
b1be6ea475
commit
4a0dea2aa2
@ -18,36 +18,36 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./lstm_transducer_stateless/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search (not recommended)
|
(2) beam search (not recommended)
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./lstm_transducer_stateless/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./lstm_transducer_stateless/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(4) fast beam search (one best)
|
(4) fast beam search (one best)
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./lstm_transducer_stateless/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 20.0 \
|
--beam 20.0 \
|
||||||
@ -55,10 +55,10 @@ Usage:
|
|||||||
--max-states 64
|
--max-states 64
|
||||||
|
|
||||||
(5) fast beam search (nbest)
|
(5) fast beam search (nbest)
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./lstm_transducer_stateless/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search_nbest \
|
--decoding-method fast_beam_search_nbest \
|
||||||
--beam 20.0 \
|
--beam 20.0 \
|
||||||
@ -68,10 +68,10 @@ Usage:
|
|||||||
--nbest-scale 0.5
|
--nbest-scale 0.5
|
||||||
|
|
||||||
(6) fast beam search (nbest oracle WER)
|
(6) fast beam search (nbest oracle WER)
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./lstm_transducer_stateless/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search_nbest_oracle \
|
--decoding-method fast_beam_search_nbest_oracle \
|
||||||
--beam 20.0 \
|
--beam 20.0 \
|
||||||
@ -81,30 +81,15 @@ Usage:
|
|||||||
--nbest-scale 0.5
|
--nbest-scale 0.5
|
||||||
|
|
||||||
(7) fast beam search (with LG)
|
(7) fast beam search (with LG)
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./lstm_transducer_stateless/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search_nbest_LG \
|
--decoding-method fast_beam_search_nbest_LG \
|
||||||
--beam 20.0 \
|
--beam 20.0 \
|
||||||
--max-contexts 8 \
|
--max-contexts 8 \
|
||||||
--max-states 64
|
--max-states 64
|
||||||
|
|
||||||
(8) decode in streaming mode (take greedy search as an example)
|
|
||||||
./pruned_transducer_stateless2/decode.py \
|
|
||||||
--epoch 28 \
|
|
||||||
--avg 15 \
|
|
||||||
--simulate-streaming 1 \
|
|
||||||
--causal-convolution 1 \
|
|
||||||
--decode-chunk-size 16 \
|
|
||||||
--left-context 64 \
|
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
|
||||||
--max-duration 600 \
|
|
||||||
--decoding-method greedy_search
|
|
||||||
--beam 20.0 \
|
|
||||||
--max-contexts 8 \
|
|
||||||
--max-states 64
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -130,7 +115,7 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -142,7 +127,6 @@ from icefall.utils import (
|
|||||||
AttributeDict,
|
AttributeDict,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
str2bool,
|
|
||||||
write_error_stats,
|
write_error_stats,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -286,29 +270,6 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--simulate-streaming",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to simulate streaming in decoding, this is a good way to
|
|
||||||
test a streaming model.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--decode-chunk-size",
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--left-context",
|
|
||||||
type=int,
|
|
||||||
default=64,
|
|
||||||
help="left context can be seen during decoding (in frames after subsampling)",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-paths",
|
"--num-paths",
|
||||||
type=int,
|
type=int,
|
||||||
@ -327,7 +288,6 @@ def get_parser():
|
|||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -387,18 +347,9 @@ def decode_one_batch(
|
|||||||
value=LOG_EPS,
|
value=LOG_EPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
x=feature, x_lens=feature_lens
|
||||||
x=feature,
|
)
|
||||||
x_lens=feature_lens,
|
|
||||||
chunk_size=params.decode_chunk_size,
|
|
||||||
left_context=params.left_context,
|
|
||||||
simulate_streaming=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
|
||||||
x=feature, x_lens=feature_lens
|
|
||||||
)
|
|
||||||
|
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
@ -674,10 +625,6 @@ def main():
|
|||||||
else:
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if params.simulate_streaming:
|
|
||||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
|
||||||
params.suffix += f"-left-context-{params.left_context}"
|
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
@ -712,11 +659,6 @@ def main():
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.simulate_streaming:
|
|
||||||
assert (
|
|
||||||
params.causal_convolution
|
|
||||||
), "Decoding in streaming requires causal convolution"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
|||||||
@ -77,9 +77,7 @@ from beam_search import (
|
|||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -180,30 +178,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--simulate-streaming",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to simulate streaming in decoding, this is a good way to
|
|
||||||
test a streaming model.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--decode-chunk-size",
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help="The chunk size for decoding (in frames after subsampling)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--left-context",
|
|
||||||
type=int,
|
|
||||||
default=64,
|
|
||||||
help="left context can be seen during decoding (in frames after subsampling)",
|
|
||||||
)
|
|
||||||
|
|
||||||
add_model_arguments(parser)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -248,11 +222,6 @@ def main():
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.simulate_streaming:
|
|
||||||
assert (
|
|
||||||
params.causal_convolution
|
|
||||||
), "Decoding in streaming requires causal convolution"
|
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
@ -299,18 +268,9 @@ def main():
|
|||||||
|
|
||||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
if params.simulate_streaming:
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
x=features, x_lens=feature_lengths
|
||||||
x=features,
|
)
|
||||||
x_lens=feature_lengths,
|
|
||||||
chunk_size=params.decode_chunk_size,
|
|
||||||
left_context=params.left_context,
|
|
||||||
simulate_streaming=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
|
||||||
x=features, x_lens=feature_lengths
|
|
||||||
)
|
|
||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
num_waves = encoder_out.size(0)
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|||||||
@ -34,31 +34,6 @@ def test_model():
|
|||||||
params.context_size = 2
|
params.context_size = 2
|
||||||
params.unk_id = 2
|
params.unk_id = 2
|
||||||
|
|
||||||
params.dynamic_chunk_training = False
|
|
||||||
params.short_chunk_size = 25
|
|
||||||
params.num_left_chunks = 4
|
|
||||||
params.causal_convolution = False
|
|
||||||
|
|
||||||
model = get_transducer_model(params)
|
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
|
||||||
print(f"Number of model parameters: {num_param}")
|
|
||||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
|
||||||
torch.jit.script(model)
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_streaming():
|
|
||||||
params = get_params()
|
|
||||||
params.vocab_size = 500
|
|
||||||
params.blank_id = 0
|
|
||||||
params.context_size = 2
|
|
||||||
params.unk_id = 2
|
|
||||||
|
|
||||||
params.dynamic_chunk_training = True
|
|
||||||
params.short_chunk_size = 25
|
|
||||||
params.num_left_chunks = 4
|
|
||||||
params.causal_convolution = True
|
|
||||||
|
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
@ -69,7 +44,6 @@ def test_model_streaming():
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
test_model()
|
test_model()
|
||||||
test_model_streaming()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -21,37 +21,24 @@ Usage:
|
|||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
./pruned_transducer_stateless2/train.py \
|
./lstm_transducer_stateless/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--exp-dir pruned_transducer_stateless2/exp \
|
--exp-dir lstm_transducer_stateless/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 300
|
--max-duration 300
|
||||||
|
|
||||||
# For mix precision training:
|
# For mix precision training:
|
||||||
|
|
||||||
./pruned_transducer_stateless2/train.py \
|
./lstm_transducer_stateless/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir pruned_transducer_stateless2/exp \
|
--exp-dir lstm_transducer_stateless/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 550
|
--max-duration 550
|
||||||
|
|
||||||
# train a streaming model
|
|
||||||
./pruned_transducer_stateless2/train.py \
|
|
||||||
--world-size 4 \
|
|
||||||
--num-epochs 30 \
|
|
||||||
--start-epoch 0 \
|
|
||||||
--exp-dir pruned_transducer_stateless/exp \
|
|
||||||
--full-libri 1 \
|
|
||||||
--dynamic-chunk-training 1 \
|
|
||||||
--causal-convolution 1 \
|
|
||||||
--short-chunk-size 25 \
|
|
||||||
--num-left-chunks 4 \
|
|
||||||
--max-duration 300
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -69,7 +56,7 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from lstm import RNN
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
@ -95,42 +82,6 @@ LRSchedulerType = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
|
||||||
parser.add_argument(
|
|
||||||
"--dynamic-chunk-training",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
|
||||||
model, this requires to be True.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--causal-convolution",
|
|
||||||
type=str2bool,
|
|
||||||
default=False,
|
|
||||||
help="""Whether to use causal convolution, this requires to be True when
|
|
||||||
using dynamic_chunk_training.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--short-chunk-size",
|
|
||||||
type=int,
|
|
||||||
default=25,
|
|
||||||
help="""Chunk length of dynamic training, the chunk size would be either
|
|
||||||
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-left-chunks",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="How many left context can be seen in chunks when calculating attention.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -311,8 +262,6 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -374,7 +323,6 @@ def get_params() -> AttributeDict:
|
|||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"encoder_dim": 512,
|
"encoder_dim": 512,
|
||||||
"nhead": 8,
|
|
||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
# parameters for decoder
|
# parameters for decoder
|
||||||
@ -392,17 +340,12 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
# TODO: We can add an option to switch between Conformer and Transformer
|
# TODO: We can add an option to switch between Conformer and Transformer
|
||||||
encoder = Conformer(
|
encoder = RNN(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
d_model=params.encoder_dim,
|
d_model=params.encoder_dim,
|
||||||
nhead=params.nhead,
|
|
||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
dynamic_chunk_training=params.dynamic_chunk_training,
|
|
||||||
short_chunk_size=params.short_chunk_size,
|
|
||||||
num_left_chunks=params.num_left_chunks,
|
|
||||||
causal=params.causal_convolution,
|
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user