mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-02 21:54:18 +00:00
add decoding
This commit is contained in:
parent
0c27ba45e7
commit
b2a2b1d387
@ -104,7 +104,7 @@ class SPGISpeechAsrDataModule:
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=140.0,
|
||||
default=100.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
@ -296,7 +296,7 @@ class SPGISpeechAsrDataModule:
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz")
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
def val_cuts(self) -> CutSet:
|
||||
logging.info("About to get SPGISpeech val cuts")
|
||||
return load_manifest_lazy(self.args.manifest_dir / "cuts_val.jsonl.gz")
|
||||
|
||||
|
@ -23,10 +23,11 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
from numpy import True_
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from asr_datamodule import SPGISpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
@ -131,7 +132,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_500",
|
||||
default="data/lang_bpe_5000",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
@ -140,7 +141,7 @@ def get_parser():
|
||||
type=str,
|
||||
default="data/lm",
|
||||
help="""The LM dir.
|
||||
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||
It should contain either G_3_gram.pt or G_3_gram.fst.txt
|
||||
""",
|
||||
)
|
||||
|
||||
@ -420,7 +421,7 @@ def decode_dataset(
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
is a 3-gram LM.
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||
@ -462,9 +463,7 @@ def decode_dataset(
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
else:
|
||||
assert (
|
||||
len(results) > 0
|
||||
), "It should not decode to empty in the first batch!"
|
||||
assert len(results) > 0, "It should not decode to empty in the first batch!"
|
||||
this_batch = []
|
||||
hyp_words = []
|
||||
for ref_text in texts:
|
||||
@ -479,9 +478,7 @@ def decode_dataset(
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
@ -512,9 +509,7 @@ def save_results(
|
||||
test_set_wers[key] = wer
|
||||
|
||||
if enable_log:
|
||||
logging.info(
|
||||
"Wrote detailed error stats to {}".format(errs_filename)
|
||||
)
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
||||
@ -534,7 +529,7 @@ def save_results(
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
SPGISpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
@ -570,7 +565,7 @@ def main():
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
modified=True, # Use modified topology since vocab size is large
|
||||
device=device,
|
||||
)
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
@ -591,10 +586,10 @@ def main():
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
):
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
if not (params.lm_dir / "G_3_gram.pt").is_file():
|
||||
logging.info("Loading G_3_gram.fst.txt")
|
||||
logging.warning("It may take 8 minutes.")
|
||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||
with open(params.lm_dir / "G_3_gram.fst.txt") as f:
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
@ -615,10 +610,10 @@ def main():
|
||||
# for why we need to do this.
|
||||
G.dummy = 1
|
||||
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_3_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
logging.info("Loading pre-compiled G_3_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_3_gram.pt", map_location=device)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
@ -662,16 +657,16 @@ def main():
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
spgispeech = SPGISpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
dev_cuts = spgispeech.dev_cuts()
|
||||
val_cuts = spgispeech.val_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
dev_dl = spgispeech.test_dataloaders(dev_cuts)
|
||||
val_dl = spgispeech.test_dataloaders(val_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
test_sets = ["dev", "val"]
|
||||
test_dl = [dev_dl, val_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
@ -687,9 +682,7 @@ def main():
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=results_dict
|
||||
)
|
||||
save_results(params=params, test_set_name=test_set, results_dict=results_dict)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
10
egs/spgispeech/ASR/decode.sh
Executable file
10
egs/spgispeech/ASR/decode.sh
Executable file
@ -0,0 +1,10 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
. ./path.sh
|
||||
. parse_options.sh || exit 1
|
||||
|
||||
# Train Conformer CTC model
|
||||
utils/queue-freegpu.pl --gpu 1 --mem 10G -l "hostname=c*" -q g.q conformer_ctc/exp/decode.log \
|
||||
python conformer_ctc/decode.py --epoch 12 --avg 3 --method ctc-decoding --max-duration 50 --num-paths 20
|
Loading…
x
Reference in New Issue
Block a user