mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Refactoring: Remove unused code.
This commit is contained in:
parent
f246f0c24b
commit
09587d1108
@ -333,7 +333,7 @@ def main():
|
|||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
HLG = k2.Fsa.from_dict(
|
HLG = k2.Fsa.from_dict(
|
||||||
torch.load("data/lang_phone/HLG.pt", map_location="cpu")
|
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||||
)
|
)
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
@ -61,20 +61,11 @@ def get_params() -> AttributeDict:
|
|||||||
"lang_dir": Path("data/lang_phone"),
|
"lang_dir": Path("data/lang_phone"),
|
||||||
"lm_dir": Path("data/lm"),
|
"lm_dir": Path("data/lm"),
|
||||||
"feature_dim": 23,
|
"feature_dim": 23,
|
||||||
"subsampling_factor": 1,
|
|
||||||
"search_beam": 20,
|
"search_beam": 20,
|
||||||
"output_beam": 5,
|
"output_beam": 8,
|
||||||
"min_active_states": 30,
|
"min_active_states": 30,
|
||||||
"max_active_states": 10000,
|
"max_active_states": 10000,
|
||||||
"use_double_scores": True,
|
"use_double_scores": True,
|
||||||
# Possible values for method:
|
|
||||||
# - 1best
|
|
||||||
# - nbest
|
|
||||||
# - nbest-rescoring
|
|
||||||
# - whole-lattice-rescoring
|
|
||||||
"method": "1best",
|
|
||||||
# num_paths is used when method is "nbest" and "nbest-rescoring"
|
|
||||||
"num_paths": 30,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
@ -85,29 +76,17 @@ def decode_one_batch(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
HLG: k2.Fsa,
|
HLG: k2.Fsa,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
lexicon: Lexicon,
|
word_table: k2.SymbolTable,
|
||||||
G: Optional[k2.Fsa] = None,
|
) -> List[List[int]]:
|
||||||
) -> Dict[str, List[List[int]]]:
|
"""Decode one batch and return the result in a list-of-list.
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
Each sub list contains the word IDs for an utterance in the batch.
|
||||||
following format:
|
|
||||||
|
|
||||||
- key: It indicates the setting used for decoding. For example,
|
|
||||||
if no rescoring is used, the key is the string `no_rescore`.
|
|
||||||
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
|
||||||
where `xxx` is the value of `lm_scale`. An example key is
|
|
||||||
`lm_scale_0.7`
|
|
||||||
- value: It contains the decoding result. `len(value)` equals to
|
|
||||||
batch size. `value[i]` is the decoding result for the i-th
|
|
||||||
utterance in the given batch.
|
|
||||||
Args:
|
Args:
|
||||||
params:
|
params:
|
||||||
It's the return value of :func:`get_params`.
|
It's the return value of :func:`get_params`.
|
||||||
|
|
||||||
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
- params.method is "1best", it uses 1best decoding.
|
||||||
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
- params.method is "nbest", it uses nbest decoding.
|
||||||
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
|
|
||||||
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
|
|
||||||
rescoring.
|
|
||||||
|
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
@ -117,15 +96,11 @@ def decode_one_batch(
|
|||||||
It is the return value from iterating
|
It is the return value from iterating
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
lexicon:
|
(https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py)
|
||||||
It contains word symbol table.
|
word_table:
|
||||||
G:
|
It is the word symbol table.
|
||||||
An LM. It is not None when params.method is "nbest-rescoring"
|
|
||||||
or "whole-lattice-rescoring". In general, the G in HLG
|
|
||||||
is a 3-gram LM, while this G is a 4-gram LM.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. `len(ans)` == batch size.
|
||||||
the returned dict.
|
|
||||||
"""
|
"""
|
||||||
device = HLG.device
|
device = HLG.device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
@ -141,8 +116,8 @@ def decode_one_batch(
|
|||||||
supervision_segments = torch.stack(
|
supervision_segments = torch.stack(
|
||||||
(
|
(
|
||||||
supervisions["sequence_idx"],
|
supervisions["sequence_idx"],
|
||||||
supervisions["start_frame"] // params.subsampling_factor,
|
supervisions["start_frame"],
|
||||||
supervisions["num_frames"] // params.subsampling_factor,
|
supervisions["num_frames"],
|
||||||
),
|
),
|
||||||
1,
|
1,
|
||||||
).to(torch.int32)
|
).to(torch.int32)
|
||||||
@ -157,46 +132,12 @@ def decode_one_batch(
|
|||||||
max_active_states=params.max_active_states,
|
max_active_states=params.max_active_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.method in ["1best", "nbest"]:
|
|
||||||
if params.method == "1best":
|
|
||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
key = "no_rescore"
|
|
||||||
else:
|
|
||||||
best_path = nbest_decoding(
|
|
||||||
lattice=lattice,
|
|
||||||
num_paths=params.num_paths,
|
|
||||||
use_double_scores=params.use_double_scores,
|
|
||||||
)
|
|
||||||
key = f"no_rescore-{params.num_paths}"
|
|
||||||
hyps = get_texts(best_path)
|
hyps = get_texts(best_path)
|
||||||
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||||
return {key: hyps}
|
return hyps
|
||||||
|
|
||||||
assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"]
|
|
||||||
|
|
||||||
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
|
||||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
|
||||||
|
|
||||||
if params.method == "nbest-rescoring":
|
|
||||||
best_path_dict = rescore_with_n_best_list(
|
|
||||||
lattice=lattice,
|
|
||||||
G=G,
|
|
||||||
num_paths=params.num_paths,
|
|
||||||
lm_scale_list=lm_scale_list,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
best_path_dict = rescore_with_whole_lattice(
|
|
||||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
|
|
||||||
)
|
|
||||||
|
|
||||||
ans = dict()
|
|
||||||
for lm_scale_str, best_path in best_path_dict.items():
|
|
||||||
hyps = get_texts(best_path)
|
|
||||||
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
|
|
||||||
ans[lm_scale_str] = hyps
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -204,9 +145,8 @@ def decode_dataset(
|
|||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
HLG: k2.Fsa,
|
HLG: k2.Fsa,
|
||||||
lexicon: Lexicon,
|
word_table: k2.SymbolTable,
|
||||||
G: Optional[k2.Fsa] = None,
|
) -> List[Tuple[List[int], List[int]]]:
|
||||||
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
|
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -218,16 +158,10 @@ def decode_dataset(
|
|||||||
The neural model.
|
The neural model.
|
||||||
HLG:
|
HLG:
|
||||||
The decoding graph.
|
The decoding graph.
|
||||||
lexicon:
|
word_table:
|
||||||
It contains word symbol table.
|
It is word symbol table.
|
||||||
G:
|
|
||||||
An LM. It is not None when params.method is "nbest-rescoring"
|
|
||||||
or "whole-lattice-rescoring". In general, the G in HLG
|
|
||||||
is a 3-gram LM, while this G is a 4-gram LM.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
Return a tuple contains two elements (ref_text, hyp_text):
|
||||||
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
|
||||||
Its value is a list of tuples. Each tuple contains two elements:
|
|
||||||
The first is the reference transcript, and the second is the
|
The first is the reference transcript, and the second is the
|
||||||
predicted result.
|
predicted result.
|
||||||
"""
|
"""
|
||||||
@ -240,27 +174,25 @@ def decode_dataset(
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
num_batches = "?"
|
num_batches = "?"
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = []
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
|
|
||||||
hyps_dict = decode_one_batch(
|
hyps = decode_one_batch(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
HLG=HLG,
|
HLG=HLG,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
lexicon=lexicon,
|
word_table=word_table,
|
||||||
G=G,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for lm_scale, hyps in hyps_dict.items():
|
|
||||||
this_batch = []
|
this_batch = []
|
||||||
assert len(hyps) == len(texts)
|
assert len(hyps) == len(texts)
|
||||||
for hyp_words, ref_text in zip(hyps, texts):
|
for hyp_words, ref_text in zip(hyps, texts):
|
||||||
ref_words = ref_text.split()
|
ref_words = ref_text.split()
|
||||||
this_batch.append((ref_words, hyp_words))
|
this_batch.append((ref_words, hyp_words))
|
||||||
|
|
||||||
results[lm_scale].extend(this_batch)
|
results.extend(this_batch)
|
||||||
|
|
||||||
num_cuts += len(batch["supervisions"]["text"])
|
num_cuts += len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
@ -274,39 +206,47 @@ def decode_dataset(
|
|||||||
|
|
||||||
|
|
||||||
def save_results(
|
def save_results(
|
||||||
params: AttributeDict,
|
exp_dir: Path,
|
||||||
test_set_name: str,
|
test_set_name: str,
|
||||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
results: List[Tuple[List[int], List[int]]],
|
||||||
):
|
) -> None:
|
||||||
test_set_wers = dict()
|
"""Save results to `exp_dir`.
|
||||||
for key, results in results_dict.items():
|
Args:
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
exp_dir:
|
||||||
|
The output directory. This function create the following files inside
|
||||||
|
this directory:
|
||||||
|
|
||||||
|
- recogs-{test_set_name}.text
|
||||||
|
|
||||||
|
It contains the reference and hypothesis results, like below::
|
||||||
|
|
||||||
|
ref=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
|
||||||
|
hyp=['NO', 'NO', 'NO', 'YES', 'NO', 'NO', 'NO', 'YES']
|
||||||
|
ref=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
|
||||||
|
hyp=['NO', 'NO', 'YES', 'NO', 'YES', 'NO', 'NO', 'YES']
|
||||||
|
|
||||||
|
- errs-{test_set_name}.txt
|
||||||
|
|
||||||
|
It contains the detailed WER.
|
||||||
|
test_set_name:
|
||||||
|
The name of the test set, which will be part of the result filename.
|
||||||
|
results:
|
||||||
|
A list of tuples, each of which contains (ref_words, hyp_words).
|
||||||
|
Returns:
|
||||||
|
Return None.
|
||||||
|
"""
|
||||||
|
recog_path = exp_dir / f"recogs-{test_set_name}.txt"
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
errs_filename = exp_dir / f"errs-{test_set_name}.txt"
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
|
wer = write_error_stats(f, f"{test_set_name}", results)
|
||||||
test_set_wers[key] = wer
|
|
||||||
|
|
||||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
|
||||||
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
|
||||||
with open(errs_info, "w") as f:
|
|
||||||
print("settings\tWER", file=f)
|
|
||||||
for key, val in test_set_wers:
|
|
||||||
print("{}\t{}".format(key, val), file=f)
|
|
||||||
|
|
||||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
|
||||||
note = "\tbest for {}".format(test_set_name)
|
|
||||||
for key, val in test_set_wers:
|
|
||||||
s += "{}\t{}{}\n".format(key, val, note)
|
|
||||||
note = ""
|
|
||||||
logging.info(s)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
@ -322,7 +262,7 @@ def main():
|
|||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
max_phone_id = max(lexicon.tokens)
|
max_token_id = max(lexicon.tokens)
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -331,53 +271,14 @@ def main():
|
|||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
HLG = k2.Fsa.from_dict(
|
HLG = k2.Fsa.from_dict(
|
||||||
torch.load("data/lang_phone/HLG.pt", map_location="cpu")
|
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||||
)
|
)
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
if not hasattr(HLG, "lm_scores"):
|
|
||||||
HLG.lm_scores = HLG.scores.clone()
|
|
||||||
|
|
||||||
if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]:
|
|
||||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
|
||||||
logging.info("Loading G_4_gram.fst.txt")
|
|
||||||
logging.warning("It may take 8 minutes.")
|
|
||||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
|
||||||
first_word_disambig_id = lexicon.word_table["#0"]
|
|
||||||
|
|
||||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
|
||||||
# G.aux_labels is not needed in later computations, so
|
|
||||||
# remove it here.
|
|
||||||
del G.aux_labels
|
|
||||||
# CAUTION: The following line is crucial.
|
|
||||||
# Arcs entering the back-off state have label equal to #0.
|
|
||||||
# We have to change it to 0 here.
|
|
||||||
G.labels[G.labels >= first_word_disambig_id] = 0
|
|
||||||
G = k2.Fsa.from_fsas([G]).to(device)
|
|
||||||
G = k2.arc_sort(G)
|
|
||||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
|
||||||
else:
|
|
||||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
|
||||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
|
||||||
G = k2.Fsa.from_dict(d).to(device)
|
|
||||||
|
|
||||||
if params.method == "whole-lattice-rescoring":
|
|
||||||
# Add epsilon self-loops to G as we will compose
|
|
||||||
# it with the whole lattice later
|
|
||||||
G = k2.add_epsilon_self_loops(G)
|
|
||||||
G = k2.arc_sort(G)
|
|
||||||
G = G.to(device)
|
|
||||||
|
|
||||||
# G.lm_scores is used to replace HLG.lm_scores during
|
|
||||||
# LM rescoring.
|
|
||||||
G.lm_scores = G.scores.clone()
|
|
||||||
else:
|
|
||||||
G = None
|
|
||||||
|
|
||||||
model = Tdnn(
|
model = Tdnn(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
num_classes=max_token_id + 1, # +1 for the blank symbol
|
||||||
)
|
)
|
||||||
if params.avg == 1:
|
if params.avg == 1:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
@ -394,25 +295,17 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
yes_no = YesNoAsrDataModule(args)
|
yes_no = YesNoAsrDataModule(args)
|
||||||
# CAUTION: `test_sets` is for displaying only.
|
test_dl = yes_no.test_dataloaders()
|
||||||
# If you want to skip test-clean, you have to skip
|
results = decode_dataset(
|
||||||
# it inside the for loop. That is, use
|
|
||||||
#
|
|
||||||
# if test_set == 'test-clean': continue
|
|
||||||
#
|
|
||||||
test_sets = ["test"]
|
|
||||||
for test_set, test_dl in zip(test_sets, [yes_no.test_dataloaders()]):
|
|
||||||
results_dict = decode_dataset(
|
|
||||||
dl=test_dl,
|
dl=test_dl,
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
HLG=HLG,
|
HLG=HLG,
|
||||||
lexicon=lexicon,
|
word_table=lexicon.word_table,
|
||||||
G=G,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
params=params, test_set_name=test_set, results_dict=results_dict
|
exp_dir=params.exp_dir, test_set_name="test_set", results=results
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
@ -496,6 +496,10 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
yes_no = YesNoAsrDataModule(args)
|
yes_no = YesNoAsrDataModule(args)
|
||||||
train_dl = yes_no.train_dataloaders()
|
train_dl = yes_no.train_dataloaders()
|
||||||
|
|
||||||
|
# There are only 60 waves: 30 files are used for training
|
||||||
|
# and the remaining 30 files are used for testing.
|
||||||
|
# We use test data as validation.
|
||||||
valid_dl = yes_no.test_dataloaders()
|
valid_dl = yes_no.test_dataloaders()
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user