mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Support decoding with LM rescoring and attention-decoder rescoring.
This commit is contained in:
parent
a73d3ed917
commit
eae1674ffa
@ -1,5 +1,5 @@
|
||||
|
||||
# How to use a pre-trained model to transcript a sound file
|
||||
# How to use a pre-trained model to transcribe a sound file or multiple sound files
|
||||
|
||||
You need to prepare 4 files:
|
||||
|
||||
@ -13,6 +13,14 @@ You need to prepare 4 files:
|
||||
Also, you need to install `kaldifeat`. Please refer to
|
||||
<https://github.com/csukuangfj/kaldifeat> for installation.
|
||||
|
||||
```
|
||||
./conformer_ctc/pretrained.py --help
|
||||
```
|
||||
|
||||
displays the help information.
|
||||
|
||||
## HLG decoding
|
||||
|
||||
Once you have the above files ready and have `kaldifeat` installed,
|
||||
you can run:
|
||||
|
||||
@ -20,7 +28,7 @@ you can run:
|
||||
./conformer_ctc/pretrained.py \
|
||||
--checkpoint /path/to/your/checkpoint.pt \
|
||||
--words-file /path/to/words.txt \
|
||||
--hlg /path/to/HLG.pt \
|
||||
--HLG /path/to/HLG.pt \
|
||||
/path/to/your/sound.wav
|
||||
```
|
||||
|
||||
@ -32,7 +40,60 @@ If you want to transcribe multiple files at the same time, you can use:
|
||||
./conformer_ctc/pretrained.py \
|
||||
--checkpoint /path/to/your/checkpoint.pt \
|
||||
--words-file /path/to/words.txt \
|
||||
--hlg /path/to/HLG.pt \
|
||||
--HLG /path/to/HLG.pt \
|
||||
/path/to/your/sound1.wav \
|
||||
/path/to/your/sound2.wav \
|
||||
/path/to/your/sound3.wav \
|
||||
```
|
||||
|
||||
**Note**: This is the fastest decoding method.
|
||||
|
||||
## HLG decoding + LM rescoring
|
||||
|
||||
`./conformer_ctc/pretrained.py` also supports `whole lattice LM rescoring`
|
||||
and `attention decoder rescoring`.
|
||||
|
||||
To use whole lattice LM rescoring, you also need the following files:
|
||||
|
||||
- G.pt, e.g., `data/lm/G_4_gram.pt` if you have run `./prepare.sh`
|
||||
|
||||
The command to run decoding with LM rescoring is:
|
||||
|
||||
```
|
||||
./conformer_ctc/pretrained.py \
|
||||
--checkpoint /path/to/your/checkpoint.pt \
|
||||
--words-file /path/to/words.txt \
|
||||
--HLG /path/to/HLG.pt \
|
||||
--method whole-lattice-rescoring \
|
||||
--G data/lm/G_4_gram.pt \
|
||||
--ngram-lm-scale 0.8 \
|
||||
/path/to/your/sound1.wav \
|
||||
/path/to/your/sound2.wav \
|
||||
/path/to/your/sound3.wav \
|
||||
```
|
||||
|
||||
## HLG Decoding + LM rescoring + attention decoder rescoring
|
||||
|
||||
To use attention decoder for rescoring, you need the following extra information:
|
||||
|
||||
- sos token ID
|
||||
- eos token ID
|
||||
|
||||
The command to run decoding with attention decoder rescoring is:
|
||||
|
||||
```
|
||||
./conformer_ctc/pretrained.py \
|
||||
--checkpoint /path/to/your/checkpoint.pt \
|
||||
--words-file /path/to/words.txt \
|
||||
--HLG /path/to/HLG.pt \
|
||||
--method attention-decoder \
|
||||
--G data/lm/G_4_gram.pt \
|
||||
--ngram-lm-scale 1.3 \
|
||||
--attention-decoder-scale 1.2 \
|
||||
--lattice-score-scale 0.5 \
|
||||
--num-paths 100 \
|
||||
--sos-id 1 \
|
||||
--eos-id 1 \
|
||||
/path/to/your/sound1.wav \
|
||||
/path/to/your/sound2.wav \
|
||||
/path/to/your/sound3.wav \
|
||||
|
@ -12,7 +12,12 @@ import torchaudio
|
||||
from conformer import Conformer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
|
||||
|
||||
@ -25,8 +30,8 @@ def get_parser():
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint."
|
||||
"The checkpoint is assume to be saved by "
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
@ -38,7 +43,102 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hlg", type=str, required=True, help="Path to HLG.pt."
|
||||
"--HLG", type=str, required=True, help="Path to HLG.pt."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="1best",
|
||||
help="""Decoding method.
|
||||
Possible values are:
|
||||
(1) 1best - Use the best path as decoding output. Only
|
||||
the transformer encoder output is used for decoding.
|
||||
We call it HLG decoding.
|
||||
(2) whole-lattice-rescoring - Use an LM to rescore the
|
||||
decoding lattice and then use 1best to decode the
|
||||
rescored lattice.
|
||||
We call it HLG decoding + n-gram LM rescoring.
|
||||
(3) attention-decoder - Extract n paths from he rescored
|
||||
lattice and use the transformer attention decoder for
|
||||
rescoring.
|
||||
We call it HLG decoding + n-gram LM rescoring + attention
|
||||
decoder rescoring.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--G",
|
||||
type=str,
|
||||
help="""An LM for rescoring.
|
||||
Used only when method is
|
||||
whole-lattice-rescoring or attention-decoder.
|
||||
It's usually a 4-gram LM.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies the size of n-best list.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=1.3,
|
||||
help="""
|
||||
Used only when method is whole-lattice-rescoring and attention-decoder.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
(Note: You need to tune it on a dataset.)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-decoder-scale",
|
||||
type=float,
|
||||
default=1.2,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies the scale for attention decoder scores.
|
||||
(Note: You need to tune it on a dataset.)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lattice-score-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies the scale for lattice.scores when
|
||||
extracting n-best lists. A smaller value results in
|
||||
more unique number of paths with the risk of missing
|
||||
the best path.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sos-id",
|
||||
type=float,
|
||||
default=1,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies ID for the SOS token.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--eos-id",
|
||||
type=float,
|
||||
default=1,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies ID for the EOS token.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -46,8 +146,8 @@ def get_parser():
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those that supported by torchaudio.load(). "
|
||||
"For example, wav, flac are supported. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
@ -108,6 +208,7 @@ def main():
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
@ -115,7 +216,7 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Create model")
|
||||
logging.info("Creating model")
|
||||
model = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
@ -134,9 +235,24 @@ def main():
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.hlg))
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
G = G.to(device)
|
||||
G.lm_scores = G.scores.clone()
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
@ -146,6 +262,7 @@ def main():
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
@ -158,8 +275,9 @@ def main():
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
|
||||
# Note: We don't use key padding mask for attention during decoding
|
||||
with torch.no_grad():
|
||||
nnet_output, _, _ = model(features)
|
||||
nnet_output, memory, memory_key_padding_mask = model(features)
|
||||
|
||||
batch_size = nnet_output.shape[0]
|
||||
supervision_segments = torch.tensor(
|
||||
@ -178,9 +296,37 @@ def main():
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
if params.method == "1best":
|
||||
logging.info("Use HLG decoding")
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
elif params.method == "attention-decoder":
|
||||
logging.info("Use HLG + LM rescoring + attention decoder rescoring")
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
||||
)
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=params.sos_id,
|
||||
eos_id=params.eos_id,
|
||||
scale=params.lattice_score_scale,
|
||||
ngram_lm_scale=params.ngram_lm_scale,
|
||||
attention_scale=params.attention_decoder_scale,
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||
|
@ -546,6 +546,8 @@ def rescore_with_whole_lattice(
|
||||
del lattice.lm_scores
|
||||
assert hasattr(lattice, "lm_scores") is False
|
||||
|
||||
assert hasattr(G_with_epsilon_loops, "lm_scores")
|
||||
|
||||
# Now, lattice.scores contains only am_scores
|
||||
|
||||
# inv_lattice has word IDs as labels.
|
||||
@ -677,10 +679,12 @@ def rescore_with_attention_decoder(
|
||||
num_paths: int,
|
||||
model: nn.Module,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
memory_key_padding_mask: Optional[torch.Tensor],
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
scale: float = 1.0,
|
||||
ngram_lm_scale: Optional[float] = None,
|
||||
attention_scale: Optional[float] = None,
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
"""This function extracts n paths from the given lattice and uses
|
||||
an attention decoder to rescore them. The path with the highest
|
||||
@ -707,6 +711,10 @@ def rescore_with_attention_decoder(
|
||||
scale:
|
||||
It's the scale applied to the lattice.scores. A smaller value
|
||||
yields more unique paths.
|
||||
ngram_lm_scale:
|
||||
Optional. It specifies the scale for n-gram LM scores.
|
||||
attention_scale:
|
||||
Optional. It specifies the scale for attention decoder scores.
|
||||
Returns:
|
||||
A dict of FsaVec, whose key contains a string
|
||||
ngram_lm_scale_attention_scale and the value is the
|
||||
@ -794,11 +802,13 @@ def rescore_with_attention_decoder(
|
||||
path_to_seq_map_long = path_to_seq_map.to(torch.long)
|
||||
expanded_memory = memory.index_select(1, path_to_seq_map_long)
|
||||
|
||||
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||
0, path_to_seq_map_long
|
||||
)
|
||||
if memory_key_padding_mask is not None:
|
||||
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||
0, path_to_seq_map_long
|
||||
)
|
||||
else:
|
||||
expanded_memory_key_padding_mask = None
|
||||
|
||||
# TODO: pass the sos_token_id and eos_token_id via function arguments
|
||||
nll = model.decoder_nll(
|
||||
memory=expanded_memory,
|
||||
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||
@ -813,11 +823,17 @@ def rescore_with_attention_decoder(
|
||||
assert attention_scores.ndim == 1
|
||||
assert attention_scores.numel() == num_word_seqs
|
||||
|
||||
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
if ngram_lm_scale is None:
|
||||
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
else:
|
||||
ngram_lm_scale_list = [ngram_lm_scale]
|
||||
|
||||
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
if attention_scale is None:
|
||||
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
else:
|
||||
attention_scale_list = [attention_scale]
|
||||
|
||||
path_2axes = k2.ragged.remove_axis(path, 0)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user