mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-09 17:14:20 +00:00
pre commit hooks
This commit is contained in:
parent
e0536c9aee
commit
a2fb1859db
@ -22,15 +22,15 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 600 \
|
--max-duration 100 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search (not recommended)
|
(2) beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 600 \
|
--max-duration 100 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 600 \
|
--max-duration 100 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ Usage:
|
|||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 600 \
|
--max-duration 1500 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
--max-contexts 4 \
|
--max-contexts 4 \
|
||||||
@ -69,7 +69,7 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
fast_beam_search_one_best,
|
fast_beam_search,
|
||||||
greedy_search,
|
greedy_search,
|
||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
@ -98,28 +98,27 @@ def get_parser():
|
|||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=28,
|
default=28,
|
||||||
help="""It specifies the checkpoint to use for decoding.
|
help="It specifies the checkpoint to use for decoding."
|
||||||
Note: Epoch counts from 0.
|
"Note: Epoch counts from 0.",
|
||||||
You can specify --avg to use more checkpoints for model averaging.""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--iter",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="""If positive, --epoch is ignored and it
|
|
||||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
|
||||||
You can specify --avg to use more checkpoints for model averaging.
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=15,
|
default=15,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch' and '--iter'",
|
"'--epoch'. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg-last-n",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch and --avg are ignored and it
|
||||||
|
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
||||||
|
where xxx is the number of processed batches while
|
||||||
|
saving that checkpoint.
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -152,7 +151,7 @@ def get_parser():
|
|||||||
"--beam-size",
|
"--beam-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=4,
|
||||||
help="""An integer indicating how many candidates we will keep for each
|
help="""An interger indicating how many candidates we will keep for each
|
||||||
frame. Used only when --decoding-method is beam_search or
|
frame. Used only when --decoding-method is beam_search or
|
||||||
modified_beam_search.""",
|
modified_beam_search.""",
|
||||||
)
|
)
|
||||||
@ -252,7 +251,7 @@ def decode_one_batch(
|
|||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
decoding_graph=decoding_graph,
|
decoding_graph=decoding_graph,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
@ -270,7 +269,6 @@ def decode_one_batch(
|
|||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -278,7 +276,6 @@ def decode_one_batch(
|
|||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
@ -358,9 +355,9 @@ def decode_dataset(
|
|||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 50
|
log_interval = 100
|
||||||
else:
|
else:
|
||||||
log_interval = 10
|
log_interval = 2
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -456,19 +453,13 @@ def main():
|
|||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
if params.iter > 0:
|
|
||||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
|
||||||
else:
|
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += (
|
params.suffix += f"-beam-{params.beam_size}"
|
||||||
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -485,9 +476,8 @@ def main():
|
|||||||
sp = spm.SentencePieceProcessor()
|
sp = spm.SentencePieceProcessor()
|
||||||
sp.load(params.bpe_model)
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
@ -495,20 +485,8 @@ def main():
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.avg_last_n > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
||||||
: params.avg
|
|
||||||
]
|
|
||||||
if len(filenames) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"No checkpoints found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
elif len(filenames) < params.avg:
|
|
||||||
raise ValueError(
|
|
||||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
|
||||||
f" --iter {params.iter}, --avg {params.avg}"
|
|
||||||
)
|
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
@ -27,8 +27,7 @@ ArXiv link: https://arxiv.org/abs/2104.02014
|
|||||||
|---------------------------|------------|
|
|---------------------------|------------|
|
||||||
| greedy search | 2.40 |
|
| greedy search | 2.40 |
|
||||||
| beam search | 2.24 |
|
| beam search | 2.24 |
|
||||||
| modified beam search | 2.30 |
|
| modified beam search | 2.24 |
|
||||||
| fast beam search | 2.35 |
|
| fast beam search | 2.35 |
|
||||||
|
|
||||||
See [RESULTS](/egs/spgispeech/ASR/RESULTS.md) for details.
|
See [RESULTS](/egs/spgispeech/ASR/RESULTS.md) for details.
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ The WERs are
|
|||||||
|---------------------------|------------|------------|------------------------------------------|
|
|---------------------------|------------|------------|------------------------------------------|
|
||||||
| greedy search | 2.46 | 2.40 | --avg-last-n 10 --max-duration 500 |
|
| greedy search | 2.46 | 2.40 | --avg-last-n 10 --max-duration 500 |
|
||||||
| beam search | 2.27 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
|
| beam search | 2.27 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
|
||||||
| modified beam search | 2.34 | 2.30 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
|
| modified beam search | 2.28 | 2.24 | --avg-last-n 10 --max-duration 500 --beam-size 4 |
|
||||||
| fast beam search | 2.38 | 2.35 | --avg-last-n 10 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
|
| fast beam search | 2.38 | 2.35 | --avg-last-n 10 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 |
|
||||||
|
|
||||||
**NOTE:** SPGISpeech transcripts can be prepared in `ortho` or `norm` ways, which refer to whether the
|
**NOTE:** SPGISpeech transcripts can be prepared in `ortho` or `norm` ways, which refer to whether the
|
||||||
@ -44,14 +44,14 @@ The decoding command is:
|
|||||||
```
|
```
|
||||||
# greedy search
|
# greedy search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--avg-last-n 10 \
|
--iter 696000 --avg 10 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
# beam search
|
# beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--avg-last-n 10 \
|
--iter 696000 --avg 10 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
@ -59,7 +59,7 @@ The decoding command is:
|
|||||||
|
|
||||||
# modified beam search
|
# modified beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--avg-last-n 10 \
|
--iter 696000 --avg 10 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
@ -67,7 +67,7 @@ The decoding command is:
|
|||||||
|
|
||||||
# fast beam search
|
# fast beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--avg-last-n 10 \
|
--iter 696000 --avg 10 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 1500 \
|
--max-duration 1500 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
|
@ -19,14 +19,16 @@
|
|||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--avg-last-n 10 \
|
--iter 696000 \
|
||||||
|
--avg 10 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--avg-last-n 10 \
|
--iter 696000 \
|
||||||
|
--avg 10 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
@ -34,7 +36,8 @@ Usage:
|
|||||||
|
|
||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--avg-last-n 10 \
|
--iter 696000 \
|
||||||
|
--avg 10 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
@ -42,7 +45,8 @@ Usage:
|
|||||||
|
|
||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless2/decode.py \
|
||||||
--avg-last-n 10 \
|
--iter 696000 \
|
||||||
|
--avg 10 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||||
--max-duration 1500 \
|
--max-duration 1500 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
@ -93,30 +97,31 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--epoch",
|
"--epoch",
|
||||||
type=int,
|
type=int,
|
||||||
default=28,
|
default=20,
|
||||||
help="It specifies the checkpoint to use for decoding."
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
"Note: Epoch counts from 0.",
|
Note: Epoch counts from 0.
|
||||||
)
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
parser.add_argument(
|
|
||||||
"--avg",
|
|
||||||
type=int,
|
|
||||||
default=15,
|
|
||||||
help="Number of checkpoints to average. Automatically select "
|
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
|
||||||
"'--epoch'. ",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg-last-n",
|
"--iter",
|
||||||
type=int,
|
type=int,
|
||||||
default=0,
|
default=0,
|
||||||
help="""If positive, --epoch and --avg are ignored and it
|
help="""If positive, --epoch is ignored and it
|
||||||
will use the last n checkpoints exp_dir/checkpoint-xxx.pt
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
where xxx is the number of processed batches while
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
saving that checkpoint.
|
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -182,7 +187,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-sym-per-frame",
|
"--max-sym-per-frame",
|
||||||
@ -240,7 +246,9 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
|
x=feature, x_lens=feature_lens
|
||||||
|
)
|
||||||
hyps = []
|
hyps = []
|
||||||
|
|
||||||
if params.decoding_method == "fast_beam_search":
|
if params.decoding_method == "fast_beam_search":
|
||||||
@ -255,10 +263,14 @@ def decode_one_batch(
|
|||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif (
|
||||||
|
params.decoding_method == "greedy_search"
|
||||||
|
and params.max_sym_per_frame == 1
|
||||||
|
):
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
@ -266,6 +278,7 @@ def decode_one_batch(
|
|||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
@ -375,7 +388,9 @@ def decode_dataset(
|
|||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
batch_str = f"{batch_idx}/{num_batches}"
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
logging.info(
|
||||||
|
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
@ -407,7 +422,8 @@ def save_results(
|
|||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = (
|
errs_info = (
|
||||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
params.res_dir
|
||||||
|
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
)
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
@ -440,13 +456,19 @@ def main():
|
|||||||
)
|
)
|
||||||
params.res_dir = params.exp_dir / params.decoding_method
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
if "fast_beam_search" in params.decoding_method:
|
if "fast_beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam}"
|
params.suffix += f"-beam-{params.beam}"
|
||||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||||
params.suffix += f"-max-states-{params.max_states}"
|
params.suffix += f"-max-states-{params.max_states}"
|
||||||
elif "beam_search" in params.decoding_method:
|
elif "beam_search" in params.decoding_method:
|
||||||
params.suffix += f"-beam-{params.beam_size}"
|
params.suffix += (
|
||||||
|
f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
params.suffix += f"-context-{params.context_size}"
|
params.suffix += f"-context-{params.context_size}"
|
||||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||||
@ -472,8 +494,20 @@ def main():
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
if params.avg_last_n > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n]
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
@ -60,7 +60,6 @@ from asr_datamodule import SPGISpeechAsrDataModule
|
|||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
@ -78,7 +77,9 @@ from icefall.dist import cleanup_dist, setup_dist
|
|||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[
|
||||||
|
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -154,7 +155,8 @@ def get_parser():
|
|||||||
"--initial-lr",
|
"--initial-lr",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.003,
|
default=0.003,
|
||||||
help="The initial learning rate. This value should not need to be changed.",
|
help="The initial learning rate. This value should not need to be "
|
||||||
|
"changed.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -177,7 +179,8 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; "
|
||||||
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -200,7 +203,8 @@ def get_parser():
|
|||||||
"--am-scale",
|
"--am-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
help="The scale to smooth the loss with am (output of encoder network)"
|
||||||
|
"part.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -550,16 +554,23 @@ def compute_loss(
|
|||||||
# overwhelming the simple_loss and causing it to diverge,
|
# overwhelming the simple_loss and causing it to diverge,
|
||||||
# in case it had not fully learned the alignment yet.
|
# in case it had not fully learned the alignment yet.
|
||||||
pruned_loss_scale = (
|
pruned_loss_scale = (
|
||||||
0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
0.0
|
||||||
|
if warmup < 1.0
|
||||||
|
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
|
||||||
|
)
|
||||||
|
loss = (
|
||||||
|
params.simple_loss_scale * simple_loss
|
||||||
|
+ pruned_loss_scale * pruned_loss
|
||||||
)
|
)
|
||||||
loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore")
|
warnings.simplefilter("ignore")
|
||||||
info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
|
info["frames"] = (
|
||||||
|
(feature_lens // params.subsampling_factor).sum().item()
|
||||||
|
)
|
||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
@ -722,7 +733,9 @@ def train_one_epoch(
|
|||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
|
tot_loss.write_summary(
|
||||||
|
tb_writer, "train/tot_", params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||||
logging.info("Computing validation loss")
|
logging.info("Computing validation loss")
|
||||||
@ -820,7 +833,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
2 ** 22
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
@ -828,9 +841,10 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
train_cuts = spgispeech.train_cuts()
|
train_cuts = spgispeech.train_cuts()
|
||||||
|
|
||||||
# Ideally we should filter utterances that are too long or too short, but SPGISpeech
|
# Ideally we should filter utterances that are too long or too short,
|
||||||
# contains regular length utterances so we don't need to do that. Here are the
|
# but SPGISpeech contains regular length utterances so we don't need to
|
||||||
# statistics of the training data (obtained by `train_cuts.describe()`):
|
# do that. Here are the statistics of the training data (obtained by
|
||||||
|
# `train_cuts.describe()`):
|
||||||
|
|
||||||
# Cuts count: 5886320
|
# Cuts count: 5886320
|
||||||
# Total duration (hours): 15070.1
|
# Total duration (hours): 15070.1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user