Support pure CTC decoding requiring neither a lexicion nor an n-gram LM.

This commit is contained in:
Fangjun Kuang 2021-09-26 12:55:39 +08:00
parent cd7a36b0a2
commit be34a1feed
7 changed files with 104 additions and 28 deletions

View File

@ -23,6 +23,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
@ -77,6 +78,9 @@ def get_parser():
default="attention-decoder", default="attention-decoder",
help="""Decoding method. help="""Decoding method.
Supported values are: Supported values are:
- (0) ctc. Use CTC decoding. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
- (1) 1best. Extract the best path from the decoding lattice as the - (1) 1best. Extract the best path from the decoding lattice as the
decoding result. decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path - (2) nbest. Extract n paths from the decoding lattice; the path
@ -128,14 +132,26 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe",
help="The lang dir",
)
return parser return parser
def get_params() -> AttributeDict: def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
"lm_dir": Path("data/lm"), "lm_dir": Path("data/lm"),
# parameters for conformer # parameters for conformer
"subsampling_factor": 4, "subsampling_factor": 4,
@ -159,13 +175,15 @@ def get_params() -> AttributeDict:
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
HLG: k2.Fsa, HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict, batch: dict,
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -190,7 +208,11 @@ def decode_one_batch(
model: model:
The neural model. The neural model.
HLG: HLG:
The decoding graph. The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
batch: batch:
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
@ -209,7 +231,10 @@ def decode_one_batch(
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
""" """
if HLG is not None:
device = HLG.device device = HLG.device
else:
device = H.device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
@ -229,9 +254,17 @@ def decode_one_batch(
1, 1,
).to(torch.int32) ).to(torch.int32)
if H is None:
assert HLG is not None
decoding_graph = HLG
else:
assert HLG is None
assert bpe_model is not None
decoding_graph = H
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=decoding_graph,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,
@ -240,6 +273,24 @@ def decode_one_batch(
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
) )
if params.method == "ctc-decoding":
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., [['xxx yyy zzz'], ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
key = f"ctc-decoding"
return {key: hyps}
if params.method == "nbest-oracle": if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it. # Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons # We choose the HLG decoded lattice for speed reasons
@ -340,12 +391,14 @@ def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
HLG: k2.Fsa, HLG: Optional[k2.Fsa],
H: Optional[k2.Fsa],
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
sos_id: int, sos_id: int,
eos_id: int, eos_id: int,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
Args: Args:
@ -356,7 +409,11 @@ def decode_dataset(
model: model:
The neural model. The neural model.
HLG: HLG:
The decoding graph. The decoding graph. Used only when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
word_table: word_table:
It is the word symbol table. It is the word symbol table.
sos_id: sos_id:
@ -391,6 +448,8 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
HLG=HLG, HLG=HLG,
H=H,
bpe_model=bpe_model,
batch=batch, batch=batch,
word_table=word_table, word_table=word_table,
G=G, G=G,
@ -469,6 +528,8 @@ def main():
parser = get_parser() parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -496,6 +557,18 @@ def main():
sos_id = graph_compiler.sos_id sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id eos_id = graph_compiler.eos_id
if params.method == "ctc-decoding":
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
else:
H = None
bpe_model = None
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
) )
@ -593,6 +666,8 @@ def main():
params=params, params=params,
model=model, model=model,
HLG=HLG, HLG=HLG,
H=H,
bpe_model=bpe_model,
word_table=lexicon.word_table, word_table=lexicon.word_table,
G=G, G=G,
sos_id=sos_id, sos_id=sos_id,

View File

@ -301,7 +301,7 @@ def main():
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=HLG,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,

View File

@ -146,7 +146,7 @@ def decode_one_batch(
batch: dict, batch: dict,
lexicon: Lexicon, lexicon: Lexicon,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[int]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -210,7 +210,7 @@ def decode_one_batch(
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=HLG,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,
@ -272,7 +272,7 @@ def decode_dataset(
HLG: k2.Fsa, HLG: k2.Fsa,
lexicon: Lexicon, lexicon: Lexicon,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[int], List[int]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
Args: Args:

View File

@ -232,7 +232,7 @@ def main():
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=HLG,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,

View File

@ -124,7 +124,7 @@ def decode_one_batch(
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=HLG,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,

View File

@ -175,7 +175,7 @@ def main():
lattice = get_lattice( lattice = get_lattice(
nnet_output=nnet_output, nnet_output=nnet_output,
HLG=HLG, decoding_graph=HLG,
supervision_segments=supervision_segments, supervision_segments=supervision_segments,
search_beam=params.search_beam, search_beam=params.search_beam,
output_beam=params.output_beam, output_beam=params.output_beam,

View File

@ -66,7 +66,7 @@ def _intersect_device(
def get_lattice( def get_lattice(
nnet_output: torch.Tensor, nnet_output: torch.Tensor,
HLG: k2.Fsa, decoding_graph: k2.Fsa,
supervision_segments: torch.Tensor, supervision_segments: torch.Tensor,
search_beam: float, search_beam: float,
output_beam: float, output_beam: float,
@ -79,8 +79,9 @@ def get_lattice(
Args: Args:
nnet_output: nnet_output:
It is the output of a neural model of shape `(N, T, C)`. It is the output of a neural model of shape `(N, T, C)`.
HLG: decoding_graph:
An Fsa, the decoding graph. See also `compile_HLG.py`. An Fsa, the decoding graph. It can be either an HLG
(see `compile_HLG.py`) or an H (see `k2.ctc_topo`).
supervision_segments: supervision_segments:
A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns. A 2-D **CPU** tensor of dtype `torch.int32` with 3 columns.
Each row contains information for a supervision segment. Column 0 Each row contains information for a supervision segment. Column 0
@ -117,7 +118,7 @@ def get_lattice(
) )
lattice = k2.intersect_dense_pruned( lattice = k2.intersect_dense_pruned(
HLG, decoding_graph,
dense_fsa_vec, dense_fsa_vec,
search_beam=search_beam, search_beam=search_beam,
output_beam=output_beam, output_beam=output_beam,