mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge 4858e2b0367a7ebaf20641743d91a745777aca63 into abd9437e6d5419a497707748eb935e50976c3b7b
This commit is contained in:
commit
2c7dcd65f2
@ -388,16 +388,14 @@ class GigaSpeechAsrDataModule:
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info(f"About to get train_{self.args.subset} cuts")
|
||||
path = self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz"
|
||||
path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz"
|
||||
cuts_train = CutSet.from_jsonl_lazy(path)
|
||||
return cuts_train
|
||||
|
||||
@lru_cache()
|
||||
def dev_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = load_manifest_lazy(
|
||||
self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
|
||||
)
|
||||
cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")
|
||||
if self.args.small_dev:
|
||||
return cuts_valid.subset(first=1000)
|
||||
else:
|
||||
@ -406,6 +404,4 @@ class GigaSpeechAsrDataModule:
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
|
||||
)
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
|
||||
|
@ -24,8 +24,7 @@ Usage:
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search
|
||||
(2) beam search (not recommended)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
@ -33,7 +32,6 @@ Usage:
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
@ -42,17 +40,60 @@ Usage:
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search
|
||||
(4) fast beam search (one best)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
--max-contexts 4 \
|
||||
--max-states 8
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
(5) fast beam search (nbest)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
(6) fast beam search (nbest oracle WER)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_oracle \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64 \
|
||||
--num-paths 200 \
|
||||
--nbest-scale 0.5
|
||||
(7) fast beam search (with LG)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
(8) fast beam search (nbest with LG)
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--max-duration 600 \
|
||||
--decoding-method fast_beam_search_nbest_LG \
|
||||
--beam 20.0 \
|
||||
--max-contexts 8 \
|
||||
--max-states 64
|
||||
"""
|
||||
|
||||
|
||||
@ -69,6 +110,9 @@ import torch.nn as nn
|
||||
from asr_datamodule import GigaSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
@ -83,6 +127,7 @@ from icefall.checkpoint import (
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import UniqLexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
@ -128,7 +173,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
@ -146,10 +191,17 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
default=None,
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_phone",
|
||||
help="The lang dir contains word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
@ -159,6 +211,20 @@ def get_parser():
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_LG
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--metrics",
|
||||
type=str,
|
||||
default="WER",
|
||||
help="""Possible values are:
|
||||
- WER
|
||||
- PER
|
||||
""",
|
||||
)
|
||||
|
||||
@ -174,27 +240,45 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search""",
|
||||
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.01,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_LG or
|
||||
fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search""",
|
||||
fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -203,6 +287,7 @@ def get_parser():
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
@ -211,6 +296,24 @@ def get_parser():
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=200,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -229,7 +332,9 @@ def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
pl: UniqLexicon,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
@ -249,13 +354,19 @@ def decode_one_batch(
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
pl:
|
||||
The phone lexicon.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_oracle
|
||||
and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
@ -273,7 +384,10 @@ def decode_one_batch(
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
hyps = []
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
if (
|
||||
params.decoding_method == "fast_beam_search"
|
||||
or params.decoding_method == "fast_beam_search_LG"
|
||||
):
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
@ -283,6 +397,58 @@ def decode_one_batch(
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
if sp is not None:
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([str(i) for i in hyp])
|
||||
else:
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
@ -291,8 +457,12 @@ def decode_one_batch(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
if sp is not None:
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([str(i) for i in hyp])
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
@ -300,8 +470,12 @@ def decode_one_batch(
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
if sp is not None:
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([str(i) for i in hyp])
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
@ -325,18 +499,24 @@ def decode_one_batch(
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
if sp is not None:
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
else:
|
||||
hyps.append([str(i) for i in hyp])
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif params.decoding_method == "fast_beam_search":
|
||||
return {
|
||||
(
|
||||
f"beam_{params.beam}_"
|
||||
f"max_contexts_{params.max_contexts}_"
|
||||
f"max_states_{params.max_states}"
|
||||
): hyps
|
||||
}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
@ -346,6 +526,8 @@ def decode_dataset(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
pl: UniqLexicon,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
@ -359,9 +541,15 @@ def decode_dataset(
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
pl:
|
||||
The phone lexicon.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search.
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_oracle
|
||||
and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
@ -376,10 +564,14 @@ def decode_dataset(
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
if sp is not None:
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
@ -387,6 +579,8 @@ def decode_dataset(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
pl=pl,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
)
|
||||
@ -399,6 +593,53 @@ def decode_dataset(
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
else:
|
||||
if params.metrics == "WER":
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
pl=pl,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
elif params.metrics == "PER":
|
||||
texts = batch["supervisions"]["text"]
|
||||
token_ids = pl.texts_to_token_ids(texts).tolist()
|
||||
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
pl=pl,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(token_ids)
|
||||
for cut_id, hyp_id, ref_token_id in zip(cut_ids, hyps, token_ids):
|
||||
ref_token_id = [str(i) for i in ref_token_id]
|
||||
this_batch.append((cut_id, ref_token_id, hyp_id))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
@ -414,6 +655,7 @@ def save_results(
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
if params.metrics == "WER":
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
@ -447,6 +689,40 @@ def save_results(
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
elif params.metrics == "PER":
|
||||
test_set_pers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
|
||||
results = post_processing(results)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out PERs, per-phone error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
per = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_pers[key] = per
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_pers = sorted(test_set_pers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"per-summary-{test_set_name}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tPER", file=f)
|
||||
for key, val in test_set_pers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, PER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_pers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
@ -458,12 +734,32 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
if params.bpe_model is not None:
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
else:
|
||||
if params.metrics == "PER":
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"modified_beam_search",
|
||||
"fast_beam_search",
|
||||
), "Decoding method without L or LG must use PER"
|
||||
elif params.metrics == "WER":
|
||||
assert params.decoding_method in (
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_nbest_LG",
|
||||
), "Decoding method with L or LG must use WER"
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
@ -475,8 +771,13 @@ def main():
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam_size}"
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
@ -493,13 +794,19 @@ def main():
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
if params.bpe_model is not None:
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
# <blk> and <unk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
else:
|
||||
pl = UniqLexicon(params.lang_dir)
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(pl.tokens) + 1
|
||||
sp = None
|
||||
|
||||
logging.info(params)
|
||||
|
||||
@ -586,10 +893,24 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if (
|
||||
params.decoding_method == "fast_beam_search_LG"
|
||||
or params.decoding_method == "fast_beam_search_nbest_LG"
|
||||
):
|
||||
word_table = pl.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
@ -613,6 +934,8 @@ def main():
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
pl=pl,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang
|
||||
# Mingshuang Luo)
|
||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Mingshuang Luo,
|
||||
# Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -19,13 +20,15 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(1) bpe
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
|
||||
./pruned_transducer_stateless2/train.py \
|
||||
--world-size 8 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--subset XL \
|
||||
--max-duration 120
|
||||
|
||||
# For mix precision training:
|
||||
@ -33,11 +36,37 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
./pruned_transducer_stateless2/train.py \
|
||||
--world-size 8 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--use_fp16 1 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--subset XL \
|
||||
--max-duration 200
|
||||
|
||||
(2) phone
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
|
||||
./pruned_transducer_stateless2/train.py \
|
||||
--world-size 8 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--subset XL \
|
||||
--lang-type phone \
|
||||
--context-size 4 \
|
||||
--max-duration 300
|
||||
|
||||
# For mix precision training:
|
||||
|
||||
./pruned_transducer_stateless2/train.py \
|
||||
--world-size 8 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--subset XL \
|
||||
--lang-type phone \
|
||||
--context-size 4 \
|
||||
--max-duration 750
|
||||
"""
|
||||
|
||||
|
||||
@ -77,6 +106,7 @@ from icefall.checkpoint import (
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import UniqLexicon
|
||||
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
@ -119,8 +149,8 @@ def get_parser():
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Resume training from this epoch.
|
||||
If larger than 1, it will load checkpoint from
|
||||
help="""Resume training from from this epoch.
|
||||
If it is large than 1, it will load checkpoint from
|
||||
exp-dir/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
@ -144,6 +174,13 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-type",
|
||||
type=str,
|
||||
default="bpe",
|
||||
help="Either bpe or phone",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
@ -151,6 +188,13 @@ def get_parser():
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_phone",
|
||||
help="the lang dir contains lexicon",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--initial-lr",
|
||||
type=float,
|
||||
@ -231,7 +275,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--save-every-n",
|
||||
type=int,
|
||||
default=8000,
|
||||
default=20000,
|
||||
help="""Save checkpoint after processing this number of batches"
|
||||
periodically. We save checkpoint to exp-dir/ whenever
|
||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||
@ -244,7 +288,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--keep-last-k",
|
||||
type=int,
|
||||
default=30,
|
||||
default=20,
|
||||
help="""Only keep this number of checkpoints on disk.
|
||||
For instance, if it is 3, there are only 3 checkpoints
|
||||
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
||||
@ -519,6 +563,7 @@ def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
sp: spm.SentencePieceProcessor,
|
||||
pl: UniqLexicon,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
warmup: float = 1.0,
|
||||
@ -551,8 +596,12 @@ def compute_loss(
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
if sp is not None:
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
else:
|
||||
y = pl.texts_to_token_ids(texts).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(
|
||||
@ -592,6 +641,7 @@ def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
sp: spm.SentencePieceProcessor,
|
||||
pl: UniqLexicon,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
@ -605,6 +655,7 @@ def compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
pl=pl,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
)
|
||||
@ -628,6 +679,7 @@ def train_one_epoch(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: LRSchedulerType,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
pl: UniqLexicon,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
@ -679,8 +731,8 @@ def train_one_epoch(
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
sp=sp,
|
||||
pl=pl,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=(params.batch_idx_train / params.model_warm_step),
|
||||
@ -699,6 +751,17 @@ def train_one_epoch(
|
||||
if params.print_diagnostics and batch_idx == 30:
|
||||
return
|
||||
|
||||
if (
|
||||
rank == 0
|
||||
and params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.average_period == 0
|
||||
):
|
||||
update_averaged_model(
|
||||
params=params,
|
||||
model_cur=model,
|
||||
model_avg=model_avg,
|
||||
)
|
||||
|
||||
if (
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
@ -746,6 +809,7 @@ def train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
pl=pl,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
@ -795,12 +859,21 @@ def run(rank, world_size, args):
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
if params.lang_type == "bpe":
|
||||
logging.info(f"Using bpe model")
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
pl = None
|
||||
elif params.lang_type == "phone":
|
||||
logging.info(f"Using phone lexion")
|
||||
pl = UniqLexicon(params.lang_dir)
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(pl.tokens) + 1
|
||||
sp = None
|
||||
|
||||
logging.info(params)
|
||||
|
||||
@ -870,6 +943,7 @@ def run(rank, world_size, args):
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
pl=pl,
|
||||
params=params,
|
||||
)
|
||||
|
||||
@ -895,6 +969,7 @@ def run(rank, world_size, args):
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sp=sp,
|
||||
pl=pl,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
@ -930,6 +1005,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
pl: UniqLexicon,
|
||||
params: AttributeDict,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
@ -949,6 +1025,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
pl=pl,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=0.0,
|
||||
|
Loading…
x
Reference in New Issue
Block a user