mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
pre commit hooks
This commit is contained in:
parent
e0536c9aee
commit
a2fb1859db
@ -22,15 +22,15 @@ Usage:
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
@ -39,7 +39,7 @@ Usage:
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--max-duration 100 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
@ -48,7 +48,7 @@ Usage:
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--max-duration 1500 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
@ -69,7 +69,7 @@ import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
fast_beam_search,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
@ -98,28 +98,27 @@ def get_parser():
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg-last-n",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch and --avg are ignored and it
|
||||
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
||||
where xxx is the number of processed batches while
|
||||
saving that checkpoint.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -152,7 +151,7 @@ def get_parser():
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
help="""An interger indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
@ -252,7 +251,7 @@ def decode_one_batch(
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
hyp_tokens = fast_beam_search(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
@ -270,7 +269,6 @@ def decode_one_batch(
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -278,7 +276,6 @@ def decode_one_batch(
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
@ -358,9 +355,9 @@ def decode_dataset(
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
log_interval = 100
|
||||
else:
|
||||
log_interval = 10
|
||||
log_interval = 2
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
@ -456,19 +453,13 @@ def main():
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -485,9 +476,8 @@ def main():
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
@ -495,20 +485,8 @@ def main():
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
if params.avg_last_n > 0:
|
||||
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
@ -1,34 +1,33 @@
|
||||
# SPGISpeech
|
||||
|
||||
SPGISpeech consists of 5,000 hours of recorded company earnings calls and their respective
|
||||
transcriptions. The original calls were split into slices ranging from 5 to 15 seconds in
|
||||
length to allow easy training for speech recognition systems. Calls represent a broad
|
||||
cross-section of international business English; SPGISpeech contains approximately 50,000
|
||||
speakers, one of the largest numbers of any speech corpus, and offers a variety of L1 and
|
||||
SPGISpeech consists of 5,000 hours of recorded company earnings calls and their respective
|
||||
transcriptions. The original calls were split into slices ranging from 5 to 15 seconds in
|
||||
length to allow easy training for speech recognition systems. Calls represent a broad
|
||||
cross-section of international business English; SPGISpeech contains approximately 50,000
|
||||
speakers, one of the largest numbers of any speech corpus, and offers a variety of L1 and
|
||||
L2 English accents. The format of each WAV file is single channel, 16kHz, 16 bit audio.
|
||||
|
||||
Transcription text represents the output of several stages of manual post-processing.
|
||||
As such, the text contains polished English orthography following a detailed style guide,
|
||||
including proper casing, punctuation, and denormalized non-standard words such as numbers
|
||||
Transcription text represents the output of several stages of manual post-processing.
|
||||
As such, the text contains polished English orthography following a detailed style guide,
|
||||
including proper casing, punctuation, and denormalized non-standard words such as numbers
|
||||
and acronyms, making SPGISpeech suited for training fully formatted end-to-end models.
|
||||
|
||||
Official reference:
|
||||
|
||||
O’Neill, P.K., Lavrukhin, V., Majumdar, S., Noroozi, V., Zhang, Y., Kuchaiev, O., Balam,
|
||||
J., Dovzhenko, Y., Freyberg, K., Shulman, M.D., Ginsburg, B., Watanabe, S., & Kucsko, G.
|
||||
(2021). SPGISpeech: 5, 000 hours of transcribed financial audio for fully formatted
|
||||
O’Neill, P.K., Lavrukhin, V., Majumdar, S., Noroozi, V., Zhang, Y., Kuchaiev, O., Balam,
|
||||
J., Dovzhenko, Y., Freyberg, K., Shulman, M.D., Ginsburg, B., Watanabe, S., & Kucsko, G.
|
||||
(2021). SPGISpeech: 5, 000 hours of transcribed financial audio for fully formatted
|
||||
end-to-end speech recognition. ArXiv, abs/2104.02014.
|
||||
|
||||
ArXiv link: https://arxiv.org/abs/2104.02014
|
||||
|
||||
## Performance Record
|
||||
|
||||
| Decoding method | val |
|
||||
| Decoding method | val |
|
||||
|---------------------------|------------|
|
||||
| greedy search | 2.40 |
|
||||
| beam search | 2.24 |
|
||||
| modified beam search | 2.30 |
|
||||
| modified beam search | 2.24 |
|
||||
| fast beam search | 2.35 |
|
||||
|
||||
See [RESULTS](/egs/spgispeech/ASR/RESULTS.md) for details.
|
||||
|
||||
|
@ -16,7 +16,7 @@ The WERs are
|
||||
|---------------------------|------------|------------|------------------------------------------|
|
||||
| greedy search | 2.46 | 2.40 | --avg-last-n 10 --max-duration 500 |
|
||||
| beam search | 2.27 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
|
||||
| modified beam search | 2.34 | 2.30 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
|
||||
| modified beam search | 2.28 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
|
||||
| fast beam search | 2.38 | 2.35 | --avg-last-n 10 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
|
||||
|
||||
**NOTE:** SPGISpeech transcripts can be prepared in `ortho` or `norm` ways, which refer to whether the
|
||||
@ -44,14 +44,14 @@ The decoding command is:
|
||||
```
|
||||
# greedy search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--avg-last-n 10 \
|
||||
--iter 696000 --avg 10 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
# beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--avg-last-n 10 \
|
||||
--iter 696000 --avg 10 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
@ -59,7 +59,7 @@ The decoding command is:
|
||||
|
||||
# modified beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--avg-last-n 10 \
|
||||
--iter 696000 --avg 10 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method modified_beam_search \
|
||||
@ -67,7 +67,7 @@ The decoding command is:
|
||||
|
||||
# fast beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--avg-last-n 10 \
|
||||
--iter 696000 --avg 10 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 1500 \
|
||||
--decoding-method fast_beam_search \
|
||||
|
@ -19,14 +19,16 @@
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--avg-last-n 10 \
|
||||
--iter 696000 \
|
||||
--avg 10 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--avg-last-n 10 \
|
||||
--iter 696000 \
|
||||
--avg 10 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
@ -34,7 +36,8 @@ Usage:
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--avg-last-n 10 \
|
||||
--iter 696000 \
|
||||
--avg 10 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method modified_beam_search \
|
||||
@ -42,7 +45,8 @@ Usage:
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--avg-last-n 10 \
|
||||
--iter 696000 \
|
||||
--avg 10 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 1500 \
|
||||
--decoding-method fast_beam_search \
|
||||
@ -93,30 +97,31 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=28,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
default=20,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 0.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg-last-n",
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch and --avg are ignored and it
|
||||
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
||||
where xxx is the number of processed batches while
|
||||
saving that checkpoint.
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -182,7 +187,8 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
@ -240,7 +246,9 @@ def decode_one_batch(
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
@ -255,10 +263,14 @@ def decode_one_batch(
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
elif (
|
||||
params.decoding_method == "greedy_search"
|
||||
and params.max_sym_per_frame == 1
|
||||
):
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
@ -266,6 +278,7 @@ def decode_one_batch(
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
@ -375,7 +388,9 @@ def decode_dataset(
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@ -407,7 +422,8 @@ def save_results(
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
params.res_dir
|
||||
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
@ -440,13 +456,19 @@ def main():
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
params.suffix += (
|
||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
)
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -472,8 +494,20 @@ def main():
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if params.avg_last_n > 0:
|
||||
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
|
@ -60,7 +60,6 @@ from asr_datamodule import SPGISpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
@ -78,7 +77,9 @@ from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
LRSchedulerType = Union[
|
||||
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||
]
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -154,7 +155,8 @@ def get_parser():
|
||||
"--initial-lr",
|
||||
type=float,
|
||||
default=0.003,
|
||||
help="The initial learning rate. This value should not need to be changed.",
|
||||
help="The initial learning rate. This value should not need to be "
|
||||
"changed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -177,7 +179,8 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -200,7 +203,8 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network)"
|
||||
"part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -550,16 +554,23 @@ def compute_loss(
|
||||
# overwhelming the simple_loss and causing it to diverge,
|
||||
# in case it had not fully learned the alignment yet.
|
||||
pruned_loss_scale = (
|
||||
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
0.0
|
||||
if warmup < 1.0
|
||||
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||
)
|
||||
loss = (
|
||||
params.simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
)
|
||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
||||
info["frames"] = (
|
||||
(feature_lens // params.subsampling_factor).sum().item()
|
||||
)
|
||||
|
||||
# Note: We use reduction=sum while computing the loss.
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
@ -722,7 +733,9 @@ def train_one_epoch(
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
logging.info("Computing validation loss")
|
||||
@ -820,7 +833,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
2 ** 22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
@ -828,9 +841,10 @@ def run(rank, world_size, args):
|
||||
|
||||
train_cuts = spgispeech.train_cuts()
|
||||
|
||||
# Ideally we should filter utterances that are too long or too short, but SPGISpeech
|
||||
# contains regular length utterances so we don't need to do that. Here are the
|
||||
# statistics of the training data (obtained by `train_cuts.describe()`):
|
||||
# Ideally we should filter utterances that are too long or too short,
|
||||
# but SPGISpeech contains regular length utterances so we don't need to
|
||||
# do that. Here are the statistics of the training data (obtained by
|
||||
# `train_cuts.describe()`):
|
||||
|
||||
# Cuts count: 5886320
|
||||
# Total duration (hours): 15070.1
|
||||
|
Loading…
x
Reference in New Issue
Block a user