pre commit hooks

This commit is contained in:
Desh Raj 2022-05-13 13:58:20 -04:00
parent e0536c9aee
commit a2fb1859db
5 changed files with 137 additions and 112 deletions

View File

@ -22,15 +22,15 @@ Usage:
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search (not recommended)
(2) beam search
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
@ -39,7 +39,7 @@ Usage:
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
@ -48,7 +48,7 @@ Usage:
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
@ -69,7 +69,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_one_best,
fast_beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -98,28 +98,27 @@ def get_parser():
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
"'--epoch'. ",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
)
parser.add_argument(
@ -152,7 +151,7 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
@ -252,7 +251,7 @@ def decode_one_batch(
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
hyp_tokens = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
@ -270,7 +269,6 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -278,7 +276,6 @@ def decode_one_batch(
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
@ -358,9 +355,9 @@ def decode_dataset(
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
log_interval = 100
else:
log_interval = 10
log_interval = 2
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -456,19 +453,13 @@ def main():
)
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
params.suffix += f"-beam-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -485,9 +476,8 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
@ -495,20 +485,8 @@ def main():
logging.info("About to create model")
model = get_transducer_model(params)
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))

View File

@ -1,34 +1,33 @@
# SPGISpeech
SPGISpeech consists of 5,000 hours of recorded company earnings calls and their respective
transcriptions. The original calls were split into slices ranging from 5 to 15 seconds in
length to allow easy training for speech recognition systems. Calls represent a broad
cross-section of international business English; SPGISpeech contains approximately 50,000
speakers, one of the largest numbers of any speech corpus, and offers a variety of L1 and
SPGISpeech consists of 5,000 hours of recorded company earnings calls and their respective
transcriptions. The original calls were split into slices ranging from 5 to 15 seconds in
length to allow easy training for speech recognition systems. Calls represent a broad
cross-section of international business English; SPGISpeech contains approximately 50,000
speakers, one of the largest numbers of any speech corpus, and offers a variety of L1 and
L2 English accents. The format of each WAV file is single channel, 16kHz, 16 bit audio.
Transcription text represents the output of several stages of manual post-processing.
As such, the text contains polished English orthography following a detailed style guide,
including proper casing, punctuation, and denormalized non-standard words such as numbers
Transcription text represents the output of several stages of manual post-processing.
As such, the text contains polished English orthography following a detailed style guide,
including proper casing, punctuation, and denormalized non-standard words such as numbers
and acronyms, making SPGISpeech suited for training fully formatted end-to-end models.
Official reference:
ONeill, P.K., Lavrukhin, V., Majumdar, S., Noroozi, V., Zhang, Y., Kuchaiev, O., Balam,
J., Dovzhenko, Y., Freyberg, K., Shulman, M.D., Ginsburg, B., Watanabe, S., & Kucsko, G.
(2021). SPGISpeech: 5, 000 hours of transcribed financial audio for fully formatted
ONeill, P.K., Lavrukhin, V., Majumdar, S., Noroozi, V., Zhang, Y., Kuchaiev, O., Balam,
J., Dovzhenko, Y., Freyberg, K., Shulman, M.D., Ginsburg, B., Watanabe, S., & Kucsko, G.
(2021). SPGISpeech: 5, 000 hours of transcribed financial audio for fully formatted
end-to-end speech recognition. ArXiv, abs/2104.02014.
ArXiv link: https://arxiv.org/abs/2104.02014
## Performance Record
| Decoding method | val |
| Decoding method | val |
|---------------------------|------------|
| greedy search | 2.40 |
| beam search | 2.24 |
| modified beam search | 2.30 |
| modified beam search | 2.24 |
| fast beam search | 2.35 |
See [RESULTS](/egs/spgispeech/ASR/RESULTS.md) for details.

View File

@ -16,7 +16,7 @@ The WERs are
|---------------------------|------------|------------|------------------------------------------|
| greedy search | 2.46 | 2.40 | --avg-last-n 10 --max-duration 500 |
| beam search | 2.27 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
| modified beam search | 2.34 | 2.30 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
| modified beam search | 2.28 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
| fast beam search | 2.38 | 2.35 | --avg-last-n 10 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
**NOTE:** SPGISpeech transcripts can be prepared in `ortho` or `norm` ways, which refer to whether the
@ -44,14 +44,14 @@ The decoding command is:
```
# greedy search
./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method greedy_search
# beam search
./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method beam_search \
@ -59,7 +59,7 @@ The decoding command is:
# modified beam search
./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
@ -67,7 +67,7 @@ The decoding command is:
# fast beam search
./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \
--iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \

View File

@ -19,14 +19,16 @@
Usage:
(1) greedy search
./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \
--iter 696000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \
--iter 696000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method beam_search \
@ -34,7 +36,8 @@ Usage:
(3) modified beam search
./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \
--iter 696000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
@ -42,7 +45,8 @@ Usage:
(4) fast beam search
./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \
--iter 696000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
@ -93,30 +97,31 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=28,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
default=20,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--avg-last-n",
"--iter",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=10,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--exp-dir",
type=str,
@ -182,7 +187,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
@ -240,7 +246,9 @@ def decode_one_batch(
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = []
if params.decoding_method == "fast_beam_search":
@ -255,10 +263,14 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -266,6 +278,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
@ -375,7 +388,9 @@ def decode_dataset(
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -407,7 +422,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
@ -440,13 +456,19 @@ def main():
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -472,8 +494,20 @@ def main():
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))

View File

@ -60,7 +60,6 @@ from asr_datamodule import SPGISpeechAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
@ -78,7 +77,9 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def get_parser():
@ -154,7 +155,8 @@ def get_parser():
"--initial-lr",
type=float,
default=0.003,
help="The initial learning rate. This value should not need to be changed.",
help="The initial learning rate. This value should not need to be "
"changed.",
)
parser.add_argument(
@ -177,7 +179,8 @@ def get_parser():
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
@ -200,7 +203,8 @@ def get_parser():
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" "part.",
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
)
parser.add_argument(
@ -550,16 +554,23 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
@ -722,7 +733,9 @@ def train_one_epoch(
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
@ -820,7 +833,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2**22
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -828,9 +841,10 @@ def run(rank, world_size, args):
train_cuts = spgispeech.train_cuts()
# Ideally we should filter utterances that are too long or too short, but SPGISpeech
# contains regular length utterances so we don't need to do that. Here are the
# statistics of the training data (obtained by `train_cuts.describe()`):
# Ideally we should filter utterances that are too long or too short,
# but SPGISpeech contains regular length utterances so we don't need to
# do that. Here are the statistics of the training data (obtained by
# `train_cuts.describe()`):
# Cuts count: 5886320
# Total duration (hours): 15070.1