mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
modify decode.py pretrained.py test_model.py train.py
This commit is contained in:
parent
b1be6ea475
commit
4a0dea2aa2
@ -18,36 +18,36 @@
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search (one best)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 20.0 \
|
||||
@ -55,10 +55,10 @@ Usage:
|
||||
--max-states 64
|
||||
|
||||
(5) fast beam search (nbest)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
@ -68,10 +68,10 @@ Usage:
|
||||
--nbest-scale 0.5
|
||||
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
@ -81,30 +81,15 @@ Usage:
|
||||
--nbest-scale 0.5
|
||||
|
||||
(7) fast beam search (with LG)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
./lstm_transducer_stateless/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./lstm_transducer_stateless/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
|
||||
(8) decode in streaming mode (take greedy search as an example)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--simulate-streaming 1 \
|
||||
--causal-convolution 1 \
|
||||
--decode-chunk-size 16 \
|
||||
--left-context 64 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
|
||||
@ -130,7 +115,7 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
@ -142,7 +127,6 @@ from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
@ -286,29 +270,6 @@ def get_parser():
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simulate-streaming",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||
test a streaming model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
@ -327,7 +288,6 @@ def get_parser():
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
return parser
|
||||
|
||||
|
||||
@ -387,18 +347,9 @@ def decode_one_batch(
|
||||
value=LOG_EPS,
|
||||
)
|
||||
|
||||
if params.simulate_streaming:
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
chunk_size=params.decode_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True,
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
|
||||
hyps = []
|
||||
|
||||
@ -674,10 +625,6 @@ def main():
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if params.simulate_streaming:
|
||||
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}"
|
||||
params.suffix += f"-left-context-{params.left_context}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
@ -712,11 +659,6 @@ def main():
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.simulate_streaming:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
), "Decoding in streaming requires causal convolution"
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
@ -77,9 +77,7 @@ from beam_search import (
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import str2bool
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -180,30 +178,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--simulate-streaming",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to simulate streaming in decoding, this is a good way to
|
||||
test a streaming model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decode-chunk-size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="The chunk size for decoding (in frames after subsampling)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--left-context",
|
||||
type=int,
|
||||
default=64,
|
||||
help="left context can be seen during decoding (in frames after subsampling)",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -248,11 +222,6 @@ def main():
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
if params.simulate_streaming:
|
||||
assert (
|
||||
params.causal_convolution
|
||||
), "Decoding in streaming requires causal convolution"
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
@ -299,18 +268,9 @@ def main():
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
if params.simulate_streaming:
|
||||
encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
chunk_size=params.decode_chunk_size,
|
||||
left_context=params.left_context,
|
||||
simulate_streaming=True,
|
||||
)
|
||||
else:
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features, x_lens=feature_lengths
|
||||
)
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features, x_lens=feature_lengths
|
||||
)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
|
@ -34,31 +34,6 @@ def test_model():
|
||||
params.context_size = 2
|
||||
params.unk_id = 2
|
||||
|
||||
params.dynamic_chunk_training = False
|
||||
params.short_chunk_size = 25
|
||||
params.num_left_chunks = 4
|
||||
params.causal_convolution = False
|
||||
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
torch.jit.script(model)
|
||||
|
||||
|
||||
def test_model_streaming():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.unk_id = 2
|
||||
|
||||
params.dynamic_chunk_training = True
|
||||
params.short_chunk_size = 25
|
||||
params.num_left_chunks = 4
|
||||
params.causal_convolution = True
|
||||
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
@ -69,7 +44,6 @@ def test_model_streaming():
|
||||
|
||||
def main():
|
||||
test_model()
|
||||
test_model_streaming()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -21,37 +21,24 @@ Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
./pruned_transducer_stateless2/train.py \
|
||||
./lstm_transducer_stateless/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--exp-dir lstm_transducer_stateless/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 300
|
||||
|
||||
# For mix precision training:
|
||||
|
||||
./pruned_transducer_stateless2/train.py \
|
||||
./lstm_transducer_stateless/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--exp-dir lstm_transducer_stateless/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
|
||||
# train a streaming model
|
||||
./pruned_transducer_stateless2/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir pruned_transducer_stateless/exp \
|
||||
--full-libri 1 \
|
||||
--dynamic-chunk-training 1 \
|
||||
--causal-convolution 1 \
|
||||
--short-chunk-size 25 \
|
||||
--num-left-chunks 4 \
|
||||
--max-duration 300
|
||||
"""
|
||||
|
||||
|
||||
@ -69,7 +56,7 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from lstm import RNN
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
@ -95,42 +82,6 @@ LRSchedulerType = Union[
|
||||
]
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--dynamic-chunk-training",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use dynamic_chunk_training, if you want a streaming
|
||||
model, this requires to be True.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--causal-convolution",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""Whether to use causal convolution, this requires to be True when
|
||||
using dynamic_chunk_training.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--short-chunk-size",
|
||||
type=int,
|
||||
default=25,
|
||||
help="""Chunk length of dynamic training, the chunk size would be either
|
||||
max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-left-chunks",
|
||||
type=int,
|
||||
default=4,
|
||||
help="How many left context can be seen in chunks when calculating attention.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -311,8 +262,6 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -374,7 +323,6 @@ def get_params() -> AttributeDict:
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"encoder_dim": 512,
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
# parameters for decoder
|
||||
@ -392,17 +340,12 @@ def get_params() -> AttributeDict:
|
||||
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
encoder = Conformer(
|
||||
encoder = RNN(
|
||||
num_features=params.feature_dim,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
d_model=params.encoder_dim,
|
||||
nhead=params.nhead,
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
dynamic_chunk_training=params.dynamic_chunk_training,
|
||||
short_chunk_size=params.short_chunk_size,
|
||||
num_left_chunks=params.num_left_chunks,
|
||||
causal=params.causal_convolution,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user