pre commit hooks

This commit is contained in:
Desh Raj 2022-05-13 13:58:20 -04:00
parent e0536c9aee
commit a2fb1859db
5 changed files with 137 additions and 112 deletions

View File

@ -22,15 +22,15 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 100 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (not recommended) (2) beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
@ -39,7 +39,7 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 100 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
@ -48,7 +48,7 @@ Usage:
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \ --max-duration 1500 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
--max-contexts 4 \ --max-contexts 4 \
@ -69,7 +69,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search_one_best, fast_beam_search,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -98,28 +98,27 @@ def get_parser():
"--epoch", "--epoch",
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for decoding. help="It specifies the checkpoint to use for decoding."
Note: Epoch counts from 0. "Note: Epoch counts from 0.",
You can specify --avg to use more checkpoints for model averaging.""",
) )
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=15, default=15,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'", "'--epoch'. ",
)
parser.add_argument(
"--avg-last-n",
type=int,
default=0,
help="""If positive, --epoch and --avg are ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
where xxx is the number of processed batches while
saving that checkpoint.
""",
) )
parser.add_argument( parser.add_argument(
@ -152,7 +151,7 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="""An integer indicating how many candidates we will keep for each help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or frame. Used only when --decoding-method is beam_search or
modified_beam_search.""", modified_beam_search.""",
) )
@ -252,7 +251,7 @@ def decode_one_batch(
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -270,7 +269,6 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -278,7 +276,6 @@ def decode_one_batch(
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
@ -358,9 +355,9 @@ def decode_dataset(
num_batches = "?" num_batches = "?"
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 50 log_interval = 100
else: else:
log_interval = 10 log_interval = 2
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -456,19 +453,13 @@ def main():
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += ( params.suffix += f"-beam-{params.beam_size}"
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -485,9 +476,8 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)
@ -495,20 +485,8 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
if params.iter > 0: if params.avg_last_n > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))

View File

@ -1,34 +1,33 @@
# SPGISpeech # SPGISpeech
SPGISpeech consists of 5,000 hours of recorded company earnings calls and their respective SPGISpeech consists of 5,000 hours of recorded company earnings calls and their respective
transcriptions. The original calls were split into slices ranging from 5 to 15 seconds in transcriptions. The original calls were split into slices ranging from 5 to 15 seconds in
length to allow easy training for speech recognition systems. Calls represent a broad length to allow easy training for speech recognition systems. Calls represent a broad
cross-section of international business English; SPGISpeech contains approximately 50,000 cross-section of international business English; SPGISpeech contains approximately 50,000
speakers, one of the largest numbers of any speech corpus, and offers a variety of L1 and speakers, one of the largest numbers of any speech corpus, and offers a variety of L1 and
L2 English accents. The format of each WAV file is single channel, 16kHz, 16 bit audio. L2 English accents. The format of each WAV file is single channel, 16kHz, 16 bit audio.
Transcription text represents the output of several stages of manual post-processing. Transcription text represents the output of several stages of manual post-processing.
As such, the text contains polished English orthography following a detailed style guide, As such, the text contains polished English orthography following a detailed style guide,
including proper casing, punctuation, and denormalized non-standard words such as numbers including proper casing, punctuation, and denormalized non-standard words such as numbers
and acronyms, making SPGISpeech suited for training fully formatted end-to-end models. and acronyms, making SPGISpeech suited for training fully formatted end-to-end models.
Official reference: Official reference:
ONeill, P.K., Lavrukhin, V., Majumdar, S., Noroozi, V., Zhang, Y., Kuchaiev, O., Balam, ONeill, P.K., Lavrukhin, V., Majumdar, S., Noroozi, V., Zhang, Y., Kuchaiev, O., Balam,
J., Dovzhenko, Y., Freyberg, K., Shulman, M.D., Ginsburg, B., Watanabe, S., & Kucsko, G. J., Dovzhenko, Y., Freyberg, K., Shulman, M.D., Ginsburg, B., Watanabe, S., & Kucsko, G.
(2021). SPGISpeech: 5, 000 hours of transcribed financial audio for fully formatted (2021). SPGISpeech: 5, 000 hours of transcribed financial audio for fully formatted
end-to-end speech recognition. ArXiv, abs/2104.02014. end-to-end speech recognition. ArXiv, abs/2104.02014.
ArXiv link: https://arxiv.org/abs/2104.02014 ArXiv link: https://arxiv.org/abs/2104.02014
## Performance Record ## Performance Record
| Decoding method | val | | Decoding method | val |
|---------------------------|------------| |---------------------------|------------|
| greedy search | 2.40 | | greedy search | 2.40 |
| beam search | 2.24 | | beam search | 2.24 |
| modified beam search | 2.30 | | modified beam search | 2.24 |
| fast beam search | 2.35 | | fast beam search | 2.35 |
See [RESULTS](/egs/spgispeech/ASR/RESULTS.md) for details. See [RESULTS](/egs/spgispeech/ASR/RESULTS.md) for details.

View File

@ -16,7 +16,7 @@ The WERs are
|---------------------------|------------|------------|------------------------------------------| |---------------------------|------------|------------|------------------------------------------|
| greedy search | 2.46 | 2.40 | --avg-last-n 10 --max-duration 500 | | greedy search | 2.46 | 2.40 | --avg-last-n 10 --max-duration 500 |
| beam search | 2.27 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 | | beam search | 2.27 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
| modified beam search | 2.34 | 2.30 | --avg-last-n 10 --max-duration 500 --beam-size 4 | | modified beam search | 2.28 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
| fast beam search | 2.38 | 2.35 | --avg-last-n 10 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | | fast beam search | 2.38 | 2.35 | --avg-last-n 10 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
**NOTE:** SPGISpeech transcripts can be prepared in `ortho` or `norm` ways, which refer to whether the **NOTE:** SPGISpeech transcripts can be prepared in `ortho` or `norm` ways, which refer to whether the
@ -44,14 +44,14 @@ The decoding command is:
``` ```
# greedy search # greedy search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \ --iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method greedy_search --decoding-method greedy_search
# beam search # beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \ --iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
@ -59,7 +59,7 @@ The decoding command is:
# modified beam search # modified beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \ --iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
@ -67,7 +67,7 @@ The decoding command is:
# fast beam search # fast beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \ --iter 696000 --avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \ --max-duration 1500 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \

View File

@ -19,14 +19,16 @@
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \ --iter 696000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \ --iter 696000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
@ -34,7 +36,8 @@ Usage:
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \ --iter 696000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
@ -42,7 +45,8 @@ Usage:
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--avg-last-n 10 \ --iter 696000 \
--avg 10 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 1500 \ --max-duration 1500 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
@ -93,30 +97,31 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=28, default=20,
help="It specifies the checkpoint to use for decoding." help="""It specifies the checkpoint to use for decoding.
"Note: Epoch counts from 0.", Note: Epoch counts from 0.
) You can specify --avg to use more checkpoints for model averaging.""",
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
) )
parser.add_argument( parser.add_argument(
"--avg-last-n", "--iter",
type=int, type=int,
default=0, default=0,
help="""If positive, --epoch and --avg are ignored and it help="""If positive, --epoch is ignored and it
will use the last n checkpoints exp_dir/checkpoint-xxx.pt will use the checkpoint exp_dir/checkpoint-iter.pt.
where xxx is the number of processed batches while You can specify --avg to use more checkpoints for model averaging.
saving that checkpoint.
""", """,
) )
parser.add_argument(
"--avg",
type=int,
default=10,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
@ -182,7 +187,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
@ -240,7 +246,9 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
@ -255,10 +263,14 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
@ -266,6 +278,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
@ -375,7 +388,9 @@ def decode_dataset(
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}" batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results return results
@ -407,7 +422,8 @@ def save_results(
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = ( errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
) )
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
@ -440,13 +456,19 @@ def main():
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method: if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}" params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}" params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method: elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}" params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -472,8 +494,20 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
if params.avg_last_n > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))

View File

@ -60,7 +60,6 @@ from asr_datamodule import SPGISpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
@ -78,7 +77,9 @@ from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def get_parser(): def get_parser():
@ -154,7 +155,8 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="The initial learning rate. This value should not need to be changed.", help="The initial learning rate. This value should not need to be "
"changed.",
) )
parser.add_argument( parser.add_argument(
@ -177,7 +179,8 @@ def get_parser():
"--context-size", "--context-size",
type=int, type=int,
default=2, default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
) )
parser.add_argument( parser.add_argument(
@ -200,7 +203,8 @@ def get_parser():
"--am-scale", "--am-scale",
type=float, type=float,
default=0.0, default=0.0,
help="The scale to smooth the loss with am (output of encoder network)" "part.", help="The scale to smooth the loss with am (output of encoder network)"
"part.",
) )
parser.add_argument( parser.add_argument(
@ -550,16 +554,23 @@ def compute_loss(
# overwhelming the simple_loss and causing it to diverge, # overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet. # in case it had not fully learned the alignment yet.
pruned_loss_scale = ( pruned_loss_scale = (
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) 0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
) )
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = (feature_lens // params.subsampling_factor).sum().item() info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -722,7 +733,9 @@ def train_one_epoch(
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0: if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss") logging.info("Computing validation loss")
@ -820,7 +833,7 @@ def run(rank, world_size, args):
if params.print_diagnostics: if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions( opts = diagnostics.TensorDiagnosticOptions(
2**22 2 ** 22
) # allow 4 megabytes per sub-module ) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts) diagnostic = diagnostics.attach_diagnostics(model, opts)
@ -828,9 +841,10 @@ def run(rank, world_size, args):
train_cuts = spgispeech.train_cuts() train_cuts = spgispeech.train_cuts()
# Ideally we should filter utterances that are too long or too short, but SPGISpeech # Ideally we should filter utterances that are too long or too short,
# contains regular length utterances so we don't need to do that. Here are the # but SPGISpeech contains regular length utterances so we don't need to
# statistics of the training data (obtained by `train_cuts.describe()`): # do that. Here are the statistics of the training data (obtained by
# `train_cuts.describe()`):
# Cuts count: 5886320 # Cuts count: 5886320
# Total duration (hours): 15070.1 # Total duration (hours): 15070.1