Support decoding with LM rescoring and attention-decoder rescoring.

This commit is contained in:
Fangjun Kuang 2021-08-19 16:10:38 +08:00
parent a73d3ed917
commit eae1674ffa
3 changed files with 247 additions and 24 deletions

View File

@ -1,5 +1,5 @@
# How to use a pre-trained model to transcript a sound file
# How to use a pre-trained model to transcribe a sound file or multiple sound files
You need to prepare 4 files:
@ -13,6 +13,14 @@ You need to prepare 4 files:
Also, you need to install `kaldifeat`. Please refer to
<https://github.com/csukuangfj/kaldifeat> for installation.
```
./conformer_ctc/pretrained.py --help
```
displays the help information.
## HLG decoding
Once you have the above files ready and have `kaldifeat` installed,
you can run:
@ -20,7 +28,7 @@ you can run:
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--hlg /path/to/HLG.pt \
--HLG /path/to/HLG.pt \
/path/to/your/sound.wav
```
@ -32,7 +40,60 @@ If you want to transcribe multiple files at the same time, you can use:
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--hlg /path/to/HLG.pt \
--HLG /path/to/HLG.pt \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav \
```
**Note**: This is the fastest decoding method.
## HLG decoding + LM rescoring
`./conformer_ctc/pretrained.py` also supports `whole lattice LM rescoring`
and `attention decoder rescoring`.
To use whole lattice LM rescoring, you also need the following files:
- G.pt, e.g., `data/lm/G_4_gram.pt` if you have run `./prepare.sh`
The command to run decoding with LM rescoring is:
```
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
--method whole-lattice-rescoring \
--G data/lm/G_4_gram.pt \
--ngram-lm-scale 0.8 \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav \
```
## HLG Decoding + LM rescoring + attention decoder rescoring
To use attention decoder for rescoring, you need the following extra information:
- sos token ID
- eos token ID
The command to run decoding with attention decoder rescoring is:
```
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
--method attention-decoder \
--G data/lm/G_4_gram.pt \
--ngram-lm-scale 1.3 \
--attention-decoder-scale 1.2 \
--lattice-score-scale 0.5 \
--num-paths 100 \
--sos-id 1 \
--eos-id 1 \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav \

View File

@ -12,7 +12,12 @@ import torchaudio
from conformer import Conformer
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_whole_lattice,
)
from icefall.utils import AttributeDict, get_texts
@ -25,8 +30,8 @@ def get_parser():
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint."
"The checkpoint is assume to be saved by "
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
@ -38,7 +43,102 @@ def get_parser():
)
parser.add_argument(
"--hlg", type=str, required=True, help="Path to HLG.pt."
"--HLG", type=str, required=True, help="Path to HLG.pt."
)
parser.add_argument(
"--method",
type=str,
default="1best",
help="""Decoding method.
Possible values are:
(1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding.
We call it HLG decoding.
(2) whole-lattice-rescoring - Use an LM to rescore the
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + n-gram LM rescoring.
(3) attention-decoder - Extract n paths from he rescored
lattice and use the transformer attention decoder for
rescoring.
We call it HLG decoding + n-gram LM rescoring + attention
decoder rescoring.
""",
)
parser.add_argument(
"--G",
type=str,
help="""An LM for rescoring.
Used only when method is
whole-lattice-rescoring or attention-decoder.
It's usually a 4-gram LM.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""
Used only when method is attention-decoder.
It specifies the size of n-best list.""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=1.3,
help="""
Used only when method is whole-lattice-rescoring and attention-decoder.
It specifies the scale for n-gram LM scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"--attention-decoder-scale",
type=float,
default=1.2,
help="""
Used only when method is attention-decoder.
It specifies the scale for attention decoder scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"--lattice-score-scale",
type=float,
default=0.5,
help="""
Used only when method is attention-decoder.
It specifies the scale for lattice.scores when
extracting n-best lists. A smaller value results in
more unique number of paths with the risk of missing
the best path.
""",
)
parser.add_argument(
"--sos-id",
type=float,
default=1,
help="""
Used only when method is attention-decoder.
It specifies ID for the SOS token.
""",
)
parser.add_argument(
"--eos-id",
type=float,
default=1,
help="""
Used only when method is attention-decoder.
It specifies ID for the EOS token.
""",
)
parser.add_argument(
@ -46,8 +146,8 @@ def get_parser():
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those that supported by torchaudio.load(). "
"For example, wav, flac are supported. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
@ -108,6 +208,7 @@ def main():
params = get_params()
params.update(vars(args))
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
@ -115,7 +216,7 @@ def main():
logging.info(f"device: {device}")
logging.info("Create model")
logging.info("Creating model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
@ -134,9 +235,24 @@ def main():
model.to(device)
model.eval()
HLG = k2.Fsa.from_dict(torch.load(params.hlg))
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
G.lm_scores = G.scores.clone()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
@ -146,6 +262,7 @@ def main():
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
@ -158,8 +275,9 @@ def main():
features, batch_first=True, padding_value=math.log(1e-10)
)
# Note: We don't use key padding mask for attention during decoding
with torch.no_grad():
nnet_output, _, _ = model(features)
nnet_output, memory, memory_key_padding_mask = model(features)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
@ -178,9 +296,37 @@ def main():
subsampling_factor=params.subsampling_factor,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
elif params.method == "whole-lattice-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=[params.ngram_lm_scale],
)
best_path = next(iter(best_path_dict.values()))
elif params.method == "attention-decoder":
logging.info("Use HLG + LM rescoring + attention decoder rescoring")
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
scale=params.lattice_score_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)

View File

@ -546,6 +546,8 @@ def rescore_with_whole_lattice(
del lattice.lm_scores
assert hasattr(lattice, "lm_scores") is False
assert hasattr(G_with_epsilon_loops, "lm_scores")
# Now, lattice.scores contains only am_scores
# inv_lattice has word IDs as labels.
@ -677,10 +679,12 @@ def rescore_with_attention_decoder(
num_paths: int,
model: nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int,
eos_id: int,
scale: float = 1.0,
ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None,
) -> Dict[str, k2.Fsa]:
"""This function extracts n paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest
@ -707,6 +711,10 @@ def rescore_with_attention_decoder(
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
ngram_lm_scale:
Optional. It specifies the scale for n-gram LM scores.
attention_scale:
Optional. It specifies the scale for attention decoder scores.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
@ -794,11 +802,13 @@ def rescore_with_attention_decoder(
path_to_seq_map_long = path_to_seq_map.to(torch.long)
expanded_memory = memory.index_select(1, path_to_seq_map_long)
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_seq_map_long
)
if memory_key_padding_mask is not None:
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_seq_map_long
)
else:
expanded_memory_key_padding_mask = None
# TODO: pass the sos_token_id and eos_token_id via function arguments
nll = model.decoder_nll(
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
@ -813,11 +823,17 @@ def rescore_with_attention_decoder(
assert attention_scores.ndim == 1
assert attention_scores.numel() == num_word_seqs
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
if ngram_lm_scale is None:
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
ngram_lm_scale_list = [ngram_lm_scale]
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
if attention_scale is None:
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
attention_scale_list = [attention_scale]
path_2axes = k2.ragged.remove_axis(path, 0)