Fix decode.py

This commit is contained in:
Fangjun Kuang 2021-09-09 15:15:35 +08:00
parent 4d849cfd03
commit ce9b23327f
2 changed files with 18 additions and 31 deletions

View File

@ -16,7 +16,6 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.decode import (
get_lattice,
@ -28,6 +27,7 @@ from icefall.decode import (
rescore_with_whole_lattice,
)
from icefall.lexicon import Lexicon
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import (
AttributeDict,
get_texts,
@ -58,6 +58,11 @@ def get_parser():
"'--epoch'. ",
)
parser.add_argument(
"--method",
type=str,
)
parser.add_argument(
"--lattice-score-scale",
type=float,
@ -82,7 +87,7 @@ def get_params() -> AttributeDict:
"nhead": 8,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"num_decoder_layers": 0,
"vgg_frontend": False,
"is_espnet_structure": True,
"mmi_loss": False,
@ -102,7 +107,7 @@ def get_params() -> AttributeDict:
# "method": "nbest",
# "method": "nbest-rescoring",
# "method": "whole-lattice-rescoring",
"method": "attention-decoder",
# "method": "attention-decoder",
# "method": "nbest-oracle",
# num_paths is used when method is "nbest", "nbest-rescoring",
# attention-decoder, and nbest-oracle
@ -118,8 +123,6 @@ def decode_one_batch(
HLG: k2.Fsa,
batch: dict,
lexicon: Lexicon,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]:
"""Decode one batch and return the result in a dict. The dict has the
@ -153,10 +156,6 @@ def decode_one_batch(
for the format of the `batch`.
lexicon:
It contains word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
@ -234,7 +233,8 @@ def decode_one_batch(
"attention-decoder",
]
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
if params.method == "nbest-rescoring":
@ -261,8 +261,6 @@ def decode_one_batch(
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
scale=params.lattice_score_scale,
)
else:
@ -282,8 +280,6 @@ def decode_dataset(
model: nn.Module,
HLG: k2.Fsa,
lexicon: Lexicon,
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
"""Decode dataset.
@ -299,10 +295,6 @@ def decode_dataset(
The decoding graph.
lexicon:
It contains word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
G:
An LM. It is not None when params.method is "nbest-rescoring"
or "whole-lattice-rescoring". In general, the G in HLG
@ -334,8 +326,6 @@ def decode_dataset(
batch=batch,
lexicon=lexicon,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
for lm_scale, hyps in hyps_dict.items():
@ -427,14 +417,10 @@ def main():
logging.info(f"device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler(
graph_compiler = MmiTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
@ -530,8 +516,6 @@ def main():
HLG=HLG,
lexicon=lexicon,
G=G,
sos_id=sos_id,
eos_id=eos_id,
)
save_results(

View File

@ -4,8 +4,11 @@ profile = "black"
[tool.black]
line-length = 80
exclude = '''
/(
\.git
| \.github
)/
(
/(
\.git
| \.github
| icefall/shared/*
)/
)
'''