Refactoring: Remove unused code.

This commit is contained in:
Fangjun Kuang 2021-08-22 22:13:13 +08:00
parent f246f0c24b
commit 09587d1108
3 changed files with 87 additions and 190 deletions

View File

@ -333,7 +333,7 @@ def main():
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(
torch.load("data/lang_phone/HLG.pt", map_location="cpu")
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False

View File

@ -61,20 +61,11 @@ def get_params() -> AttributeDict:
"lang_dir": Path("data/lang_phone"),
"lm_dir": Path("data/lm"),
"feature_dim": 23,
"subsampling_factor": 1,
"search_beam": 20,
"output_beam": 5,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
# Possible values for method:
# - 1best
# - nbest
# - nbest-rescoring
# - whole-lattice-rescoring
"method": "1best",
# num_paths is used when method is "nbest" and "nbest-rescoring"
"num_paths": 30,
}
)
return params
@ -85,29 +76,17 @@ def decode_one_batch(
model: nn.Module,
HLG: k2.Fsa,
batch: dict,
lexicon: Lexicon,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
word_table: k2.SymbolTable,
) -> List[List[int]]:
"""Decode one batch and return the result in a list-of-list.
Each sub list contains the word IDs for an utterance in the batch.
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
rescoring.
- params.method is "1best", it uses 1best decoding.
- params.method is "nbest", it uses nbest decoding.
model:
The neural model.
@ -117,15 +96,11 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains word symbol table.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
(https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py)
word_table:
It is the word symbol table.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
Return the decoding result. `len(ans)` == batch size.
"""
device = HLG.device
feature = batch["inputs"]
@ -141,8 +116,8 @@ def decode_one_batch(
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
supervisions["start_frame"],
supervisions["num_frames"],
),
1,
).to(torch.int32)
@ -157,46 +132,12 @@ def decode_one_batch(
max_active_states=params.max_active_states,
)
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = "no_rescore"
else:
best_path = nbest_decoding(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
)
key = f"no_rescore-{params.num_paths}"
hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
return {key: hyps}
assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"]
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
)
else:
best_path_dict = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
)
ans = dict()
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
return ans
hyps = [[word_table[i] for i in ids] for ids in hyps]
return hyps
def decode_dataset(
@ -204,9 +145,8 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
lexicon: Lexicon,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
word_table: k2.SymbolTable,
) -> List[Tuple[List[int], List[int]]]:
"""Decode dataset.
Args:
@ -218,16 +158,10 @@ def decode_dataset(
The neural model.
HLG:
The decoding graph.
lexicon:
It contains word symbol table.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
is a 3-gram LM, while this G is a 4-gram LM.
word_table:
It is word symbol table.
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
Return a tuple contains two elements (ref_text, hyp_text):
The first is the reference transcript, and the second is the
predicted result.
"""
@ -240,27 +174,25 @@ def decode_dataset(
except TypeError:
num_batches = "?"
results = defaultdict(list)
results = []
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
hyps = decode_one_batch(
params=params,
model=model,
HLG=HLG,
batch=batch,
lexicon=lexicon,
G=G,
word_table=word_table,
)
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[lm_scale].extend(this_batch)
results.extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
@ -274,39 +206,47 @@ def decode_dataset(
def save_results(
params: AttributeDict,
exp_dir: Path,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
results: List[Tuple[List[int], List[int]]],
) -> None:
"""Save results to `exp_dir`.
Args:
exp_dir:
The output directory. This function create the following files inside
this directory:
- recogs-{test_set_name}.text
It contains the reference and hypothesis results, like below::
ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
- errs-{test_set_name}.txt
It contains the detailed WER.
test_set_name:
The name of the test set, which will be part of the result filename.
results:
A list of tuples, each of which contains (ref_words, hyp_words).
Returns:
Return None.
"""
recog_path = exp_dir / f"recogs-{test_set_name}.txt"
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
errs_filename = exp_dir / f"errs-{test_set_name}.txt"
with open(errs_filename, "w") as f:
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
test_set_wers[key] = wer
wer = write_error_stats(f, f"{test_set_name}", results)
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
@ -322,7 +262,7 @@ def main():
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_phone_id = max(lexicon.tokens)
max_token_id = max(lexicon.tokens)
device = torch.device("cpu")
if torch.cuda.is_available():
@ -331,53 +271,14 @@ def main():
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(
torch.load("data/lang_phone/HLG.pt", map_location="cpu")
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]:
if not (params.lm_dir / "G_4_gram.pt").is_file():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.word_table["#0"]
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION: The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
G = k2.Fsa.from_fsas([G]).to(device)
G = k2.arc_sort(G)
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
else:
G = None
model = Tdnn(
num_features=params.feature_dim,
num_classes=max_phone_id + 1, # +1 for the blank symbol
num_classes=max_token_id + 1, # +1 for the blank symbol
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
@ -394,25 +295,17 @@ def main():
model.eval()
yes_no = YesNoAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
test_sets = ["test"]
for test_set, test_dl in zip(test_sets, [yes_no.test_dataloaders()]):
results_dict = decode_dataset(
test_dl = yes_no.test_dataloaders()
results = decode_dataset(
dl=test_dl,
params=params,
model=model,
HLG=HLG,
lexicon=lexicon,
G=G,
word_table=lexicon.word_table,
)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
exp_dir=params.exp_dir, test_set_name="test_set", results=results
)
logging.info("Done!")

View File

@ -496,6 +496,10 @@ def run(rank, world_size, args):
yes_no = YesNoAsrDataModule(args)
train_dl = yes_no.train_dataloaders()
# There are only 60 waves: 30 files are used for training
# and the remaining 30 files are used for testing.
# We use test data as validation.
valid_dl = yes_no.test_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs):