mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Fix decode.py
This commit is contained in:
parent
4d849cfd03
commit
ce9b23327f
@ -16,7 +16,6 @@ import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
@ -28,6 +27,7 @@ from icefall.decode import (
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
@ -58,6 +58,11 @@ def get_parser():
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lattice-score-scale",
|
||||
type=float,
|
||||
@ -82,7 +87,7 @@ def get_params() -> AttributeDict:
|
||||
"nhead": 8,
|
||||
"attention_dim": 512,
|
||||
"subsampling_factor": 4,
|
||||
"num_decoder_layers": 6,
|
||||
"num_decoder_layers": 0,
|
||||
"vgg_frontend": False,
|
||||
"is_espnet_structure": True,
|
||||
"mmi_loss": False,
|
||||
@ -102,7 +107,7 @@ def get_params() -> AttributeDict:
|
||||
# "method": "nbest",
|
||||
# "method": "nbest-rescoring",
|
||||
# "method": "whole-lattice-rescoring",
|
||||
"method": "attention-decoder",
|
||||
# "method": "attention-decoder",
|
||||
# "method": "nbest-oracle",
|
||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||
# attention-decoder, and nbest-oracle
|
||||
@ -118,8 +123,6 @@ def decode_one_batch(
|
||||
HLG: k2.Fsa,
|
||||
batch: dict,
|
||||
lexicon: Lexicon,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
@ -153,10 +156,6 @@ def decode_one_batch(
|
||||
for the format of the `batch`.
|
||||
lexicon:
|
||||
It contains word symbol table.
|
||||
sos_id:
|
||||
The token ID of the SOS.
|
||||
eos_id:
|
||||
The token ID of the EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
@ -234,7 +233,8 @@ def decode_one_batch(
|
||||
"attention-decoder",
|
||||
]
|
||||
|
||||
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||
|
||||
if params.method == "nbest-rescoring":
|
||||
@ -261,8 +261,6 @@ def decode_one_batch(
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
scale=params.lattice_score_scale,
|
||||
)
|
||||
else:
|
||||
@ -282,8 +280,6 @@ def decode_dataset(
|
||||
model: nn.Module,
|
||||
HLG: k2.Fsa,
|
||||
lexicon: Lexicon,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
|
||||
"""Decode dataset.
|
||||
@ -299,10 +295,6 @@ def decode_dataset(
|
||||
The decoding graph.
|
||||
lexicon:
|
||||
It contains word symbol table.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
@ -334,8 +326,6 @@ def decode_dataset(
|
||||
batch=batch,
|
||||
lexicon=lexicon,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
@ -427,14 +417,10 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
graph_compiler = MmiTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
eos_token="<sos/eos>",
|
||||
)
|
||||
sos_id = graph_compiler.sos_id
|
||||
eos_id = graph_compiler.eos_id
|
||||
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||
@ -530,8 +516,6 @@ def main():
|
||||
HLG=HLG,
|
||||
lexicon=lexicon,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
save_results(
|
||||
|
@ -4,8 +4,11 @@ profile = "black"
|
||||
[tool.black]
|
||||
line-length = 80
|
||||
exclude = '''
|
||||
/(
|
||||
\.git
|
||||
| \.github
|
||||
)/
|
||||
(
|
||||
/(
|
||||
\.git
|
||||
| \.github
|
||||
| icefall/shared/*
|
||||
)/
|
||||
)
|
||||
'''
|
||||
|
Loading…
x
Reference in New Issue
Block a user