Update decoding script.

This commit is contained in:
Fangjun Kuang 2021-10-18 14:38:07 +08:00
parent 28f1aabf99
commit b8dbad5156
2 changed files with 97 additions and 30 deletions

View File

@ -23,6 +23,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
@ -77,6 +78,9 @@ def get_parser():
default="attention-decoder", default="attention-decoder",
help="""Decoding method. help="""Decoding method.
Supported values are: Supported values are:
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (1) 1best. Extract the best path from the decoding lattice as the - (1) 1best. Extract the best path from the decoding lattice as the
decoding result. decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path - (2) nbest. Extract n paths from the decoding lattice; the path
@ -106,7 +110,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lattice-score-scale", "--nbest-scale",
type=float, type=float,
default=0.5, default=0.5,
help="""The scale to be applied to `lattice.scores`. help="""The scale to be applied to `lattice.scores`.
@ -122,7 +126,7 @@ def get_parser():
type=str2bool, type=str2bool,
default=False, default=False,
help="""When enabled, the averaged model is saved to help="""When enabled, the averaged model is saved to
conformer_mmi/exp/pretrained.pt. Note: only model.state_dict() is saved. conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()}, pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`. which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""", """,
@ -131,17 +135,24 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="conformer_mmi/exp", default="conformer_mmi/exp_500",
help="The experiment dir", help="The experiment dir",
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--lang-dir",
type=str, type=str,
default="data/lang_bpe", default="data/lang_bpe_500",
help="The lang dir", help="The lang dir",
) )
parser.add_argument(
"--num-decoder-layers",
type=int,
default=6,
help="Number of attention decoder layers",
)
return parser return parser
@ -156,7 +167,6 @@ def get_params() -> AttributeDict:
"feature_dim": 80, "feature_dim": 80,
"nhead": 8, "nhead": 8,
"attention_dim": 512, "attention_dim": 512,
"num_decoder_layers": 6,
# parameters for decoding # parameters for decoding
"search_beam": 20, "search_beam": 20,
"output_beam": 8, "output_beam": 8,
@ -171,13 +181,15 @@ def get_params() -> AttributeDict:
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
HLG: k2.Fsa, HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict, batch: dict,
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -202,7 +214,11 @@ def decode_one_batch(
model: model:
The neural model. The neural model.
HLG: HLG:
The decoding graph. The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
batch: batch:
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
@ -221,7 +237,10 @@ def decode_one_batch(
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
""" """
if HLG is not None:
device = HLG.device device = HLG.device
else:
device = H.device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
@ -241,9 +260,17 @@ def decode_one_batch(
1, 1,
).to(torch.int32) ).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=decoding_graph,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,
@ -252,6 +279,24 @@ def decode_one_batch(
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
) )
if params.method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = "ctc-decoding"
return {key: hyps}
if params.method == "nbest-oracle": if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it. # Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons # We choose the HLG decoded lattice for speed reasons
@ -262,12 +307,12 @@ def decode_one_batch(
num_paths=params.num_paths, num_paths=params.num_paths,
ref_texts=supervisions["text"], ref_texts=supervisions["text"],
word_table=word_table, word_table=word_table,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
oov="<UNK>", oov="<UNK>",
) )
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
key = f"oracle_{params.num_paths}_lattice_score_scale_{params.lattice_score_scale}" # noqa key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
return {key: hyps} return {key: hyps}
if params.method in ["1best", "nbest"]: if params.method in ["1best", "nbest"]:
@ -281,9 +326,9 @@ def decode_one_batch(
lattice=lattice, lattice=lattice,
num_paths=params.num_paths, num_paths=params.num_paths,
use_double_scores=params.use_double_scores, use_double_scores=params.use_double_scores,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path) hyps = get_texts(best_path)
hyps = [[word_table[i] for i in ids] for ids in hyps] hyps = [[word_table[i] for i in ids] for ids in hyps]
@ -305,7 +350,7 @@ def decode_one_batch(
G=G, G=G,
num_paths=params.num_paths, num_paths=params.num_paths,
lm_scale_list=lm_scale_list, lm_scale_list=lm_scale_list,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
elif params.method == "whole-lattice-rescoring": elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice( best_path_dict = rescore_with_whole_lattice(
@ -331,7 +376,7 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id, sos_id=sos_id,
eos_id=eos_id, eos_id=eos_id,
lattice_score_scale=params.lattice_score_scale, nbest_scale=params.nbest_scale,
) )
else: else:
assert False, f"Unsupported decoding method: {params.method}" assert False, f"Unsupported decoding method: {params.method}"
@ -344,7 +389,7 @@ def decode_one_batch(
ans[lm_scale_str] = hyps ans[lm_scale_str] = hyps
else: else:
for lm_scale in lm_scale_list: for lm_scale in lm_scale_list:
ans[lm_scale_str] = [[] * lattice.shape[0]] ans["empty"] = [[] * lattice.shape[0]]
return ans return ans
@ -352,12 +397,14 @@ def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
HLG: k2.Fsa, HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
Args: Args:
@ -368,7 +415,11 @@ def decode_dataset(
model: model:
The neural model. The neural model.
HLG: HLG:
The decoding graph. The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
word_table: word_table:
It is the word symbol table. It is the word symbol table.
sos_id: sos_id:
@ -403,6 +454,8 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
HLG=HLG, HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch, batch=batch,
word_table=word_table, word_table=word_table,
G=G, G=G,
@ -481,11 +534,11 @@ def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params.exp_dir = Path(params.exp_dir)
params.lang_dir = Path(params.lang_dir)
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started") logging.info("Decoding started")
@ -510,6 +563,18 @@ def main():
sos_id = graph_compiler.sos_id sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id eos_id = graph_compiler.eos_id
if params.method == "ctc-decoding":
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
H = None
bpe_model = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
) )
@ -607,6 +672,8 @@ def main():
params=params, params=params,
model=model, model=model,
HLG=HLG, HLG=HLG,
H=H,
bpe_model=bpe_model,
word_table=lexicon.word_table, word_table=lexicon.word_table,
G=G, G=G,
sos_id=sos_id, sos_id=sos_id,

View File

@ -373,7 +373,7 @@ def compute_loss(
params.batch_idx_train > params.use_ali_until params.batch_idx_train > params.use_ali_until
and params.beam_size < 8 and params.beam_size < 8
): ):
logging.info("Change beam size to 8") # logging.info("Change beam size to 8")
params.beam_size = 8 params.beam_size = 8
else: else:
params.beam_size = 6 params.beam_size = 6