mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Fix decode.py
This commit is contained in:
parent
4d849cfd03
commit
ce9b23327f
@ -16,7 +16,6 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.decode import (
|
from icefall.decode import (
|
||||||
get_lattice,
|
get_lattice,
|
||||||
@ -28,6 +27,7 @@ from icefall.decode import (
|
|||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_texts,
|
get_texts,
|
||||||
@ -58,6 +58,11 @@ def get_parser():
|
|||||||
"'--epoch'. ",
|
"'--epoch'. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lattice-score-scale",
|
"--lattice-score-scale",
|
||||||
type=float,
|
type=float,
|
||||||
@ -82,7 +87,7 @@ def get_params() -> AttributeDict:
|
|||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"num_decoder_layers": 6,
|
"num_decoder_layers": 0,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
"is_espnet_structure": True,
|
"is_espnet_structure": True,
|
||||||
"mmi_loss": False,
|
"mmi_loss": False,
|
||||||
@ -102,7 +107,7 @@ def get_params() -> AttributeDict:
|
|||||||
# "method": "nbest",
|
# "method": "nbest",
|
||||||
# "method": "nbest-rescoring",
|
# "method": "nbest-rescoring",
|
||||||
# "method": "whole-lattice-rescoring",
|
# "method": "whole-lattice-rescoring",
|
||||||
"method": "attention-decoder",
|
# "method": "attention-decoder",
|
||||||
# "method": "nbest-oracle",
|
# "method": "nbest-oracle",
|
||||||
# num_paths is used when method is "nbest", "nbest-rescoring",
|
# num_paths is used when method is "nbest", "nbest-rescoring",
|
||||||
# attention-decoder, and nbest-oracle
|
# attention-decoder, and nbest-oracle
|
||||||
@ -118,8 +123,6 @@ def decode_one_batch(
|
|||||||
HLG: k2.Fsa,
|
HLG: k2.Fsa,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
sos_id: int,
|
|
||||||
eos_id: int,
|
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
@ -153,10 +156,6 @@ def decode_one_batch(
|
|||||||
for the format of the `batch`.
|
for the format of the `batch`.
|
||||||
lexicon:
|
lexicon:
|
||||||
It contains word symbol table.
|
It contains word symbol table.
|
||||||
sos_id:
|
|
||||||
The token ID of the SOS.
|
|
||||||
eos_id:
|
|
||||||
The token ID of the EOS.
|
|
||||||
G:
|
G:
|
||||||
An LM. It is not None when params.method is "nbest-rescoring"
|
An LM. It is not None when params.method is "nbest-rescoring"
|
||||||
or "whole-lattice-rescoring". In general, the G in HLG
|
or "whole-lattice-rescoring". In general, the G in HLG
|
||||||
@ -234,7 +233,8 @@ def decode_one_batch(
|
|||||||
"attention-decoder",
|
"attention-decoder",
|
||||||
]
|
]
|
||||||
|
|
||||||
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||||
|
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||||
|
|
||||||
if params.method == "nbest-rescoring":
|
if params.method == "nbest-rescoring":
|
||||||
@ -261,8 +261,6 @@ def decode_one_batch(
|
|||||||
model=model,
|
model=model,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
sos_id=sos_id,
|
|
||||||
eos_id=eos_id,
|
|
||||||
scale=params.lattice_score_scale,
|
scale=params.lattice_score_scale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -282,8 +280,6 @@ def decode_dataset(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
HLG: k2.Fsa,
|
HLG: k2.Fsa,
|
||||||
lexicon: Lexicon,
|
lexicon: Lexicon,
|
||||||
sos_id: int,
|
|
||||||
eos_id: int,
|
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
|
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
|
||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
@ -299,10 +295,6 @@ def decode_dataset(
|
|||||||
The decoding graph.
|
The decoding graph.
|
||||||
lexicon:
|
lexicon:
|
||||||
It contains word symbol table.
|
It contains word symbol table.
|
||||||
sos_id:
|
|
||||||
The token ID for SOS.
|
|
||||||
eos_id:
|
|
||||||
The token ID for EOS.
|
|
||||||
G:
|
G:
|
||||||
An LM. It is not None when params.method is "nbest-rescoring"
|
An LM. It is not None when params.method is "nbest-rescoring"
|
||||||
or "whole-lattice-rescoring". In general, the G in HLG
|
or "whole-lattice-rescoring". In general, the G in HLG
|
||||||
@ -334,8 +326,6 @@ def decode_dataset(
|
|||||||
batch=batch,
|
batch=batch,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
G=G,
|
G=G,
|
||||||
sos_id=sos_id,
|
|
||||||
eos_id=eos_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for lm_scale, hyps in hyps_dict.items():
|
for lm_scale, hyps in hyps_dict.items():
|
||||||
@ -427,14 +417,10 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
graph_compiler = MmiTrainingGraphCompiler(
|
||||||
params.lang_dir,
|
params.lang_dir,
|
||||||
device=device,
|
device=device,
|
||||||
sos_token="<sos/eos>",
|
|
||||||
eos_token="<sos/eos>",
|
|
||||||
)
|
)
|
||||||
sos_id = graph_compiler.sos_id
|
|
||||||
eos_id = graph_compiler.eos_id
|
|
||||||
|
|
||||||
HLG = k2.Fsa.from_dict(
|
HLG = k2.Fsa.from_dict(
|
||||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||||
@ -530,8 +516,6 @@ def main():
|
|||||||
HLG=HLG,
|
HLG=HLG,
|
||||||
lexicon=lexicon,
|
lexicon=lexicon,
|
||||||
G=G,
|
G=G,
|
||||||
sos_id=sos_id,
|
|
||||||
eos_id=eos_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
save_results(
|
save_results(
|
||||||
|
@ -4,8 +4,11 @@ profile = "black"
|
|||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 80
|
line-length = 80
|
||||||
exclude = '''
|
exclude = '''
|
||||||
/(
|
(
|
||||||
|
/(
|
||||||
\.git
|
\.git
|
||||||
| \.github
|
| \.github
|
||||||
)/
|
| icefall/shared/*
|
||||||
|
)/
|
||||||
|
)
|
||||||
'''
|
'''
|
||||||
|
Loading…
x
Reference in New Issue
Block a user