Add phone based train and decode for gigaspeech

This commit is contained in:
yfy62 2023-04-26 17:35:19 +08:00
parent 45c13e90e4
commit 9cbe54732a
3 changed files with 522 additions and 121 deletions

View File

@ -387,7 +387,7 @@ class GigaSpeechAsrDataModule:
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info(f"About to get train_{self.args.subset} cuts")
path = self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz"
path = self.args.manifest_dir / f"cuts_{self.args.subset}.jsonl.gz"
cuts_train = CutSet.from_jsonl_lazy(path)
return cuts_train
@ -395,7 +395,7 @@ class GigaSpeechAsrDataModule:
def dev_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz"
self.args.manifest_dir / "cuts_DEV.jsonl.gz"
)
if self.args.small_dev:
return cuts_valid.subset(first=1000)
@ -406,5 +406,5 @@ class GigaSpeechAsrDataModule:
def test_cuts(self) -> CutSet:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz"
self.args.manifest_dir / "cuts_TEST.jsonl.gz"
)

View File

@ -24,8 +24,7 @@ Usage:
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search
(2) beam search (not recommended)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
@ -33,7 +32,6 @@ Usage:
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
@ -42,17 +40,60 @@ Usage:
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
(4) fast beam search (one best)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
--max-contexts 8 \
--max-states 64 \
--num-paths 200 \
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(8) fast beam search (nbest with LG)
./pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
"""
@ -69,6 +110,9 @@ import torch.nn as nn
from asr_datamodule import GigaSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
@ -76,12 +120,14 @@ from beam_search import (
)
from gigaspeech_scoring import asr_text_post_processing
from train import get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import UniqLexicon
from icefall.utils import (
AttributeDict,
setup_logger,
@ -127,7 +173,7 @@ def get_parser():
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
@ -145,10 +191,17 @@ def get_parser():
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
default=None,
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_phone",
help="The lang dir contains word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
@ -158,6 +211,20 @@ def get_parser():
- beam_search
- modified_beam_search
- fast_beam_search
- fast_beam_search_LG
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
""",
)
parser.add_argument(
"--metrics",
type=str,
default="WER",
help="""Possible values are:
- WER
- PER
""",
)
@ -173,27 +240,45 @@ def get_parser():
parser.add_argument(
"--beam",
type=float,
default=4,
default=20.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_LG,
and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_LG or
fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_LG and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
default=64,
help="""Used only when --decoding-method is
fast_beam_search""",
fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
fast_beam_search_nbest_LG and fast_beam_search_nbest_oracle
""",
)
parser.add_argument(
@ -202,6 +287,7 @@ def get_parser():
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
@ -210,6 +296,24 @@ def get_parser():
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--num-paths",
type=int,
default=200,
help="""Number of paths for nbest decoding.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""Scale applied to lattice scores when computing nbest paths.
Used only when the decoding method is fast_beam_search_nbest,
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)
return parser
@ -228,7 +332,9 @@ def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
@ -248,18 +354,24 @@ def decode_one_batch(
The neural model.
sp:
The BPE model.
pl:
The phone lexicon.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_oracle
and fast_beam_search_nbest_LG.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = model.device
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
@ -272,7 +384,10 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
hyps = []
if params.decoding_method == "fast_beam_search":
if (
params.decoding_method == "fast_beam_search"
or params.decoding_method == "fast_beam_search_LG"
):
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
@ -282,6 +397,58 @@ def decode_one_batch(
max_contexts=params.max_contexts,
max_states=params.max_states,
)
if params.decoding_method == "fast_beam_search":
if sp is not None:
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([str(i) for i in hyp])
else:
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest_LG":
hyp_tokens = fast_beam_search_nbest_LG(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in hyp_tokens:
hyps.append([word_table[i] for i in hyp])
elif params.decoding_method == "fast_beam_search_nbest":
hyp_tokens = fast_beam_search_nbest(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "fast_beam_search_nbest_oracle":
hyp_tokens = fast_beam_search_nbest_oracle(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
num_paths=params.num_paths,
ref_texts=sp.encode(supervisions["text"]),
nbest_scale=params.nbest_scale,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
@ -290,8 +457,12 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
if sp is not None:
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([str(i) for i in hyp])
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -299,8 +470,12 @@ def decode_one_batch(
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
if sp is not None:
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
for hyp in hyp_tokens:
hyps.append([str(i) for i in hyp])
else:
batch_size = encoder_out.size(0)
@ -324,18 +499,24 @@ def decode_one_batch(
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp).split())
if sp is not None:
hyps.append(sp.decode(hyp).split())
else:
hyps.append([str(i) for i in hyp])
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
elif "fast_beam_search" in params.decoding_method:
key = f"beam_{params.beam}_"
key += f"max_contexts_{params.max_contexts}_"
key += f"max_states_{params.max_states}"
if "nbest" in params.decoding_method:
key += f"_num_paths_{params.num_paths}_"
key += f"nbest_scale_{params.nbest_scale}"
if "LG" in params.decoding_method:
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -345,6 +526,8 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
@ -358,9 +541,15 @@ def decode_dataset(
The neural model.
sp:
The BPE model.
pl:
The phone lexicon.
word_table:
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
fast_beam_search_nbest, fast_beam_search_nbest_oracle
and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -375,29 +564,82 @@ def decode_dataset(
except TypeError:
num_batches = "?"
log_interval = 20
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
if sp is not None:
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
)
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
pl=pl,
word_table=word_table,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
results[name].extend(this_batch)
else:
if params.metrics == "WER":
texts = batch["supervisions"]["text"]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
pl=pl,
word_table=word_table,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
ref_words = ref_text.split()
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
elif params.metrics == "PER":
texts = batch["supervisions"]["text"]
token_ids = pl.texts_to_token_ids(texts).tolist()
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
pl=pl,
word_table=word_table,
decoding_graph=decoding_graph,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(token_ids)
for cut_id, hyp_id, ref_token_id in zip(cut_ids, hyps, token_ids):
ref_token_id = [str(i) for i in ref_token_id]
this_batch.append((cut_id, ref_token_id, hyp_id))
results[name].extend(this_batch)
num_cuts += len(texts)
@ -413,38 +655,73 @@ def save_results(
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = post_processing(results)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
if params.metrics == "WER":
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = post_processing(results)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
elif params.metrics == "PER":
test_set_pers = dict()
for key, results in results_dict.items():
recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
results = post_processing(results)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out PERs, per-phone error statistics and aligned
# ref/hyp pairs.
errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
with open(errs_filename, "w") as f:
per = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_pers[key] = per
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_pers = sorted(test_set_pers.items(), key=lambda x: x[1])
errs_info = params.res_dir / f"per-summary-{test_set_name}-{params.suffix}.txt"
with open(errs_info, "w") as f:
print("settings\tPER", file=f)
for key, val in test_set_pers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, PER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_pers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
@ -457,12 +734,32 @@ def main():
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
if params.bpe_model is not None:
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"fast_beam_search_LG",
"fast_beam_search_nbest",
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
)
else:
if params.metrics == "PER":
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
"fast_beam_search",
), "Decoding method without L or LG must use PER"
elif params.metrics == "WER":
assert params.decoding_method in (
"fast_beam_search_LG",
"fast_beam_search_LG",
"fast_beam_search_nbest_LG",
), "Decoding method with L or LG must use WER"
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
@ -474,8 +771,13 @@ def main():
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if "nbest" in params.decoding_method:
params.suffix += f"-nbest-scale-{params.nbest_scale}"
params.suffix += f"-num-paths-{params.num_paths}"
if "LG" in params.decoding_method:
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -492,13 +794,19 @@ def main():
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
if params.bpe_model is not None:
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
else:
pl = UniqLexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(pl.tokens) + 1
sp = None
logging.info(params)
@ -585,10 +893,24 @@ def main():
model.to(device)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
if "fast_beam_search" in params.decoding_method:
if (
params.decoding_method == "fast_beam_search_LG"
or params.decoding_method == "fast_beam_search_nbest_LG"
):
word_table = pl.word_table
lg_filename = params.lang_dir / "LG.pt"
logging.info(f"Loading {lg_filename}")
decoding_graph = k2.Fsa.from_dict(
torch.load(lg_filename, map_location=device)
)
decoding_graph.scores *= params.ngram_lm_scale
else:
word_table = None
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
word_table = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -612,6 +934,8 @@ def main():
params=params,
model=model,
sp=sp,
pl=pl,
word_table=word_table,
decoding_graph=decoding_graph,
)

View File

@ -1,7 +1,8 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang
# Mingshuang Luo)
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,
# Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -19,13 +20,15 @@
"""
Usage:
(1) bpe
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless2/train.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 0 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless2/exp \
--subset XL \
--max-duration 120
# For mix precision training:
@ -33,11 +36,37 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless2/train.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 0 \
--use_fp16 1 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--subset XL \
--max-duration 200
(2) phone
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
./pruned_transducer_stateless2/train.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless2/exp \
--subset XL \
--lang-type phone \
--context-size 4 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless2/train.py \
--world-size 8 \
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--subset XL \
--lang-type phone \
--context-size 4 \
--max-duration 750
"""
@ -77,6 +106,7 @@ from icefall.checkpoint import (
)
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import UniqLexicon
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -119,8 +149,8 @@ def get_parser():
"--start-epoch",
type=int,
default=1,
help="""Resume training from this epoch.
If larger than 1, it will load checkpoint from
help="""Resume training from from this epoch.
If it is large than 1, it will load checkpoint from
exp-dir/epoch-{start_epoch-1}.pt
""",
)
@ -144,6 +174,13 @@ def get_parser():
""",
)
parser.add_argument(
"--lang-type",
type=str,
default="bpe",
help="Either bpe or phone",
)
parser.add_argument(
"--bpe-model",
type=str,
@ -151,6 +188,13 @@ def get_parser():
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_phone",
help="the lang dir contains lexicon",
)
parser.add_argument(
"--initial-lr",
type=float,
@ -231,7 +275,7 @@ def get_parser():
parser.add_argument(
"--save-every-n",
type=int,
default=8000,
default=20000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -244,7 +288,7 @@ def get_parser():
parser.add_argument(
"--keep-last-k",
type=int,
default=30,
default=20,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`.
@ -261,7 +305,7 @@ def get_parser():
in which each floating-point parameter is the average of all the
parameters from the start of training. Each time we take the average,
we do: `model_avg = model * (average_period / batch_idx_train) +
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
""",
)
@ -470,7 +514,7 @@ def load_checkpoint_if_available(
def save_checkpoint(
params: AttributeDict,
model: Union[nn.Module, DDP],
model: nn.Module,
model_avg: Optional[nn.Module] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
@ -520,14 +564,15 @@ def save_checkpoint(
def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
model: nn.Module,
sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
batch: dict,
is_training: bool,
warmup: float = 1.0,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute transducer loss given the model and its inputs.
Compute CTC loss given the model and its inputs.
Args:
params:
@ -554,8 +599,12 @@ def compute_loss(
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"]
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
if sp is not None:
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
else:
y = pl.texts_to_token_ids(texts).to(device)
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
@ -593,8 +642,9 @@ def compute_loss(
def compute_validation_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
model: nn.Module,
sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
@ -608,6 +658,7 @@ def compute_validation_loss(
params=params,
model=model,
sp=sp,
pl=pl,
batch=batch,
is_training=False,
)
@ -627,10 +678,11 @@ def compute_validation_loss(
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
@ -688,8 +740,8 @@ def train_one_epoch(
loss, loss_info = compute_loss(
params=params,
model=model,
model_avg=model_avg,
sp=sp,
pl=pl,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step),
@ -708,6 +760,17 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 30:
return
if (
rank == 0
and params.batch_idx_train > 0
and params.batch_idx_train % params.average_period == 0
):
update_averaged_model(
params=params,
model_cur=model,
model_avg=model_avg,
)
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
@ -757,6 +820,7 @@ def train_one_epoch(
params=params,
model=model,
sp=sp,
pl=pl,
valid_dl=valid_dl,
world_size=world_size,
)
@ -806,12 +870,21 @@ def run(rank, world_size, args):
device = torch.device("cuda", rank)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
if params.lang_type == "bpe":
logging.info(f"Using bpe model")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
pl = None
elif params.lang_type == "phone":
logging.info(f"Using phone lexion")
pl = UniqLexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(pl.tokens) + 1
sp = None
logging.info(params)
@ -875,12 +948,13 @@ def run(rank, world_size, args):
valid_cuts = gigaspeech.dev_cuts()
valid_dl = gigaspeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics:
if 0 and not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
pl=pl,
params=params,
)
@ -906,6 +980,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
scheduler=scheduler,
sp=sp,
pl=pl,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
@ -937,10 +1012,11 @@ def run(rank, world_size, args):
def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP],
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
pl: UniqLexicon,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
@ -960,6 +1036,7 @@ def scan_pessimistic_batches_for_oom(
params=params,
model=model,
sp=sp,
pl=pl,
batch=batch,
is_training=True,
warmup=0.0,