mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
Support decoding with LM rescoring and attention-decoder rescoring.
This commit is contained in:
parent
a73d3ed917
commit
eae1674ffa
@ -1,5 +1,5 @@
|
|||||||
|
|
||||||
# How to use a pre-trained model to transcript a sound file
|
# How to use a pre-trained model to transcribe a sound file or multiple sound files
|
||||||
|
|
||||||
You need to prepare 4 files:
|
You need to prepare 4 files:
|
||||||
|
|
||||||
@ -13,6 +13,14 @@ You need to prepare 4 files:
|
|||||||
Also, you need to install `kaldifeat`. Please refer to
|
Also, you need to install `kaldifeat`. Please refer to
|
||||||
<https://github.com/csukuangfj/kaldifeat> for installation.
|
<https://github.com/csukuangfj/kaldifeat> for installation.
|
||||||
|
|
||||||
|
```
|
||||||
|
./conformer_ctc/pretrained.py --help
|
||||||
|
```
|
||||||
|
|
||||||
|
displays the help information.
|
||||||
|
|
||||||
|
## HLG decoding
|
||||||
|
|
||||||
Once you have the above files ready and have `kaldifeat` installed,
|
Once you have the above files ready and have `kaldifeat` installed,
|
||||||
you can run:
|
you can run:
|
||||||
|
|
||||||
@ -20,7 +28,7 @@ you can run:
|
|||||||
./conformer_ctc/pretrained.py \
|
./conformer_ctc/pretrained.py \
|
||||||
--checkpoint /path/to/your/checkpoint.pt \
|
--checkpoint /path/to/your/checkpoint.pt \
|
||||||
--words-file /path/to/words.txt \
|
--words-file /path/to/words.txt \
|
||||||
--hlg /path/to/HLG.pt \
|
--HLG /path/to/HLG.pt \
|
||||||
/path/to/your/sound.wav
|
/path/to/your/sound.wav
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -32,7 +40,60 @@ If you want to transcribe multiple files at the same time, you can use:
|
|||||||
./conformer_ctc/pretrained.py \
|
./conformer_ctc/pretrained.py \
|
||||||
--checkpoint /path/to/your/checkpoint.pt \
|
--checkpoint /path/to/your/checkpoint.pt \
|
||||||
--words-file /path/to/words.txt \
|
--words-file /path/to/words.txt \
|
||||||
--hlg /path/to/HLG.pt \
|
--HLG /path/to/HLG.pt \
|
||||||
|
/path/to/your/sound1.wav \
|
||||||
|
/path/to/your/sound2.wav \
|
||||||
|
/path/to/your/sound3.wav \
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: This is the fastest decoding method.
|
||||||
|
|
||||||
|
## HLG decoding + LM rescoring
|
||||||
|
|
||||||
|
`./conformer_ctc/pretrained.py` also supports `whole lattice LM rescoring`
|
||||||
|
and `attention decoder rescoring`.
|
||||||
|
|
||||||
|
To use whole lattice LM rescoring, you also need the following files:
|
||||||
|
|
||||||
|
- G.pt, e.g., `data/lm/G_4_gram.pt` if you have run `./prepare.sh`
|
||||||
|
|
||||||
|
The command to run decoding with LM rescoring is:
|
||||||
|
|
||||||
|
```
|
||||||
|
./conformer_ctc/pretrained.py \
|
||||||
|
--checkpoint /path/to/your/checkpoint.pt \
|
||||||
|
--words-file /path/to/words.txt \
|
||||||
|
--HLG /path/to/HLG.pt \
|
||||||
|
--method whole-lattice-rescoring \
|
||||||
|
--G data/lm/G_4_gram.pt \
|
||||||
|
--ngram-lm-scale 0.8 \
|
||||||
|
/path/to/your/sound1.wav \
|
||||||
|
/path/to/your/sound2.wav \
|
||||||
|
/path/to/your/sound3.wav \
|
||||||
|
```
|
||||||
|
|
||||||
|
## HLG Decoding + LM rescoring + attention decoder rescoring
|
||||||
|
|
||||||
|
To use attention decoder for rescoring, you need the following extra information:
|
||||||
|
|
||||||
|
- sos token ID
|
||||||
|
- eos token ID
|
||||||
|
|
||||||
|
The command to run decoding with attention decoder rescoring is:
|
||||||
|
|
||||||
|
```
|
||||||
|
./conformer_ctc/pretrained.py \
|
||||||
|
--checkpoint /path/to/your/checkpoint.pt \
|
||||||
|
--words-file /path/to/words.txt \
|
||||||
|
--HLG /path/to/HLG.pt \
|
||||||
|
--method attention-decoder \
|
||||||
|
--G data/lm/G_4_gram.pt \
|
||||||
|
--ngram-lm-scale 1.3 \
|
||||||
|
--attention-decoder-scale 1.2 \
|
||||||
|
--lattice-score-scale 0.5 \
|
||||||
|
--num-paths 100 \
|
||||||
|
--sos-id 1 \
|
||||||
|
--eos-id 1 \
|
||||||
/path/to/your/sound1.wav \
|
/path/to/your/sound1.wav \
|
||||||
/path/to/your/sound2.wav \
|
/path/to/your/sound2.wav \
|
||||||
/path/to/your/sound3.wav \
|
/path/to/your/sound3.wav \
|
||||||
|
@ -12,7 +12,12 @@ import torchaudio
|
|||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall.decode import get_lattice, one_best_decoding
|
from icefall.decode import (
|
||||||
|
get_lattice,
|
||||||
|
one_best_decoding,
|
||||||
|
rescore_with_attention_decoder,
|
||||||
|
rescore_with_whole_lattice,
|
||||||
|
)
|
||||||
from icefall.utils import AttributeDict, get_texts
|
from icefall.utils import AttributeDict, get_texts
|
||||||
|
|
||||||
|
|
||||||
@ -26,7 +31,7 @@ def get_parser():
|
|||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="Path to the checkpoint. "
|
help="Path to the checkpoint. "
|
||||||
"The checkpoint is assume to be saved by "
|
"The checkpoint is assumed to be saved by "
|
||||||
"icefall.checkpoint.save_checkpoint().",
|
"icefall.checkpoint.save_checkpoint().",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -38,7 +43,102 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hlg", type=str, required=True, help="Path to HLG.pt."
|
"--HLG", type=str, required=True, help="Path to HLG.pt."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--method",
|
||||||
|
type=str,
|
||||||
|
default="1best",
|
||||||
|
help="""Decoding method.
|
||||||
|
Possible values are:
|
||||||
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
|
the transformer encoder output is used for decoding.
|
||||||
|
We call it HLG decoding.
|
||||||
|
(2) whole-lattice-rescoring - Use an LM to rescore the
|
||||||
|
decoding lattice and then use 1best to decode the
|
||||||
|
rescored lattice.
|
||||||
|
We call it HLG decoding + n-gram LM rescoring.
|
||||||
|
(3) attention-decoder - Extract n paths from he rescored
|
||||||
|
lattice and use the transformer attention decoder for
|
||||||
|
rescoring.
|
||||||
|
We call it HLG decoding + n-gram LM rescoring + attention
|
||||||
|
decoder rescoring.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--G",
|
||||||
|
type=str,
|
||||||
|
help="""An LM for rescoring.
|
||||||
|
Used only when method is
|
||||||
|
whole-lattice-rescoring or attention-decoder.
|
||||||
|
It's usually a 4-gram LM.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-paths",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="""
|
||||||
|
Used only when method is attention-decoder.
|
||||||
|
It specifies the size of n-best list.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ngram-lm-scale",
|
||||||
|
type=float,
|
||||||
|
default=1.3,
|
||||||
|
help="""
|
||||||
|
Used only when method is whole-lattice-rescoring and attention-decoder.
|
||||||
|
It specifies the scale for n-gram LM scores.
|
||||||
|
(Note: You need to tune it on a dataset.)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--attention-decoder-scale",
|
||||||
|
type=float,
|
||||||
|
default=1.2,
|
||||||
|
help="""
|
||||||
|
Used only when method is attention-decoder.
|
||||||
|
It specifies the scale for attention decoder scores.
|
||||||
|
(Note: You need to tune it on a dataset.)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lattice-score-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="""
|
||||||
|
Used only when method is attention-decoder.
|
||||||
|
It specifies the scale for lattice.scores when
|
||||||
|
extracting n-best lists. A smaller value results in
|
||||||
|
more unique number of paths with the risk of missing
|
||||||
|
the best path.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sos-id",
|
||||||
|
type=float,
|
||||||
|
default=1,
|
||||||
|
help="""
|
||||||
|
Used only when method is attention-decoder.
|
||||||
|
It specifies ID for the SOS token.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--eos-id",
|
||||||
|
type=float,
|
||||||
|
default=1,
|
||||||
|
help="""
|
||||||
|
Used only when method is attention-decoder.
|
||||||
|
It specifies ID for the EOS token.
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -46,8 +146,8 @@ def get_parser():
|
|||||||
type=str,
|
type=str,
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help="The input sound file(s) to transcribe. "
|
help="The input sound file(s) to transcribe. "
|
||||||
"Supported formats are those that supported by torchaudio.load(). "
|
"Supported formats are those supported by torchaudio.load(). "
|
||||||
"For example, wav, flac are supported. "
|
"For example, wav and flac are supported. "
|
||||||
"The sample rate has to be 16kHz.",
|
"The sample rate has to be 16kHz.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -108,6 +208,7 @@ def main():
|
|||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
logging.info(f"{params}")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -115,7 +216,7 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
logging.info("Create model")
|
logging.info("Creating model")
|
||||||
model = Conformer(
|
model = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
@ -134,9 +235,24 @@ def main():
|
|||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
HLG = k2.Fsa.from_dict(torch.load(params.hlg))
|
logging.info(f"Loading HLG from {params.HLG}")
|
||||||
|
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
|
if not hasattr(HLG, "lm_scores"):
|
||||||
|
# For whole-lattice-rescoring and attention-decoder
|
||||||
|
HLG.lm_scores = HLG.scores.clone()
|
||||||
|
|
||||||
|
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||||
|
logging.info(f"Loading G from {params.G}")
|
||||||
|
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||||
|
# Add epsilon self-loops to G as we will compose
|
||||||
|
# it with the whole lattice later
|
||||||
|
G = k2.add_epsilon_self_loops(G)
|
||||||
|
G = k2.arc_sort(G)
|
||||||
|
G = G.to(device)
|
||||||
|
G.lm_scores = G.scores.clone()
|
||||||
|
|
||||||
|
logging.info("Constructing Fbank computer")
|
||||||
opts = kaldifeat.FbankOptions()
|
opts = kaldifeat.FbankOptions()
|
||||||
opts.device = device
|
opts.device = device
|
||||||
opts.frame_opts.dither = 0
|
opts.frame_opts.dither = 0
|
||||||
@ -146,6 +262,7 @@ def main():
|
|||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
|
|
||||||
|
logging.info(f"Reading sound files: {params.sound_files}")
|
||||||
waves = read_sound_files(
|
waves = read_sound_files(
|
||||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||||
)
|
)
|
||||||
@ -158,8 +275,9 @@ def main():
|
|||||||
features, batch_first=True, padding_value=math.log(1e-10)
|
features, batch_first=True, padding_value=math.log(1e-10)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Note: We don't use key padding mask for attention during decoding
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
nnet_output, _, _ = model(features)
|
nnet_output, memory, memory_key_padding_mask = model(features)
|
||||||
|
|
||||||
batch_size = nnet_output.shape[0]
|
batch_size = nnet_output.shape[0]
|
||||||
supervision_segments = torch.tensor(
|
supervision_segments = torch.tensor(
|
||||||
@ -178,9 +296,37 @@ def main():
|
|||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if params.method == "1best":
|
||||||
|
logging.info("Use HLG decoding")
|
||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
|
elif params.method == "whole-lattice-rescoring":
|
||||||
|
logging.info("Use HLG decoding + LM rescoring")
|
||||||
|
best_path_dict = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice,
|
||||||
|
G_with_epsilon_loops=G,
|
||||||
|
lm_scale_list=[params.ngram_lm_scale],
|
||||||
|
)
|
||||||
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
elif params.method == "attention-decoder":
|
||||||
|
logging.info("Use HLG + LM rescoring + attention decoder rescoring")
|
||||||
|
rescored_lattice = rescore_with_whole_lattice(
|
||||||
|
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
|
||||||
|
)
|
||||||
|
best_path_dict = rescore_with_attention_decoder(
|
||||||
|
lattice=rescored_lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
model=model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=params.sos_id,
|
||||||
|
eos_id=params.eos_id,
|
||||||
|
scale=params.lattice_score_scale,
|
||||||
|
ngram_lm_scale=params.ngram_lm_scale,
|
||||||
|
attention_scale=params.attention_decoder_scale,
|
||||||
|
)
|
||||||
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
hyps = get_texts(best_path)
|
||||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
|
@ -546,6 +546,8 @@ def rescore_with_whole_lattice(
|
|||||||
del lattice.lm_scores
|
del lattice.lm_scores
|
||||||
assert hasattr(lattice, "lm_scores") is False
|
assert hasattr(lattice, "lm_scores") is False
|
||||||
|
|
||||||
|
assert hasattr(G_with_epsilon_loops, "lm_scores")
|
||||||
|
|
||||||
# Now, lattice.scores contains only am_scores
|
# Now, lattice.scores contains only am_scores
|
||||||
|
|
||||||
# inv_lattice has word IDs as labels.
|
# inv_lattice has word IDs as labels.
|
||||||
@ -677,10 +679,12 @@ def rescore_with_attention_decoder(
|
|||||||
num_paths: int,
|
num_paths: int,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
memory: torch.Tensor,
|
memory: torch.Tensor,
|
||||||
memory_key_padding_mask: torch.Tensor,
|
memory_key_padding_mask: Optional[torch.Tensor],
|
||||||
sos_id: int,
|
sos_id: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
scale: float = 1.0,
|
scale: float = 1.0,
|
||||||
|
ngram_lm_scale: Optional[float] = None,
|
||||||
|
attention_scale: Optional[float] = None,
|
||||||
) -> Dict[str, k2.Fsa]:
|
) -> Dict[str, k2.Fsa]:
|
||||||
"""This function extracts n paths from the given lattice and uses
|
"""This function extracts n paths from the given lattice and uses
|
||||||
an attention decoder to rescore them. The path with the highest
|
an attention decoder to rescore them. The path with the highest
|
||||||
@ -707,6 +711,10 @@ def rescore_with_attention_decoder(
|
|||||||
scale:
|
scale:
|
||||||
It's the scale applied to the lattice.scores. A smaller value
|
It's the scale applied to the lattice.scores. A smaller value
|
||||||
yields more unique paths.
|
yields more unique paths.
|
||||||
|
ngram_lm_scale:
|
||||||
|
Optional. It specifies the scale for n-gram LM scores.
|
||||||
|
attention_scale:
|
||||||
|
Optional. It specifies the scale for attention decoder scores.
|
||||||
Returns:
|
Returns:
|
||||||
A dict of FsaVec, whose key contains a string
|
A dict of FsaVec, whose key contains a string
|
||||||
ngram_lm_scale_attention_scale and the value is the
|
ngram_lm_scale_attention_scale and the value is the
|
||||||
@ -794,11 +802,13 @@ def rescore_with_attention_decoder(
|
|||||||
path_to_seq_map_long = path_to_seq_map.to(torch.long)
|
path_to_seq_map_long = path_to_seq_map.to(torch.long)
|
||||||
expanded_memory = memory.index_select(1, path_to_seq_map_long)
|
expanded_memory = memory.index_select(1, path_to_seq_map_long)
|
||||||
|
|
||||||
|
if memory_key_padding_mask is not None:
|
||||||
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||||
0, path_to_seq_map_long
|
0, path_to_seq_map_long
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
expanded_memory_key_padding_mask = None
|
||||||
|
|
||||||
# TODO: pass the sos_token_id and eos_token_id via function arguments
|
|
||||||
nll = model.decoder_nll(
|
nll = model.decoder_nll(
|
||||||
memory=expanded_memory,
|
memory=expanded_memory,
|
||||||
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||||
@ -813,11 +823,17 @@ def rescore_with_attention_decoder(
|
|||||||
assert attention_scores.ndim == 1
|
assert attention_scores.ndim == 1
|
||||||
assert attention_scores.numel() == num_word_seqs
|
assert attention_scores.numel() == num_word_seqs
|
||||||
|
|
||||||
|
if ngram_lm_scale is None:
|
||||||
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
else:
|
||||||
|
ngram_lm_scale_list = [ngram_lm_scale]
|
||||||
|
|
||||||
|
if attention_scale is None:
|
||||||
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
else:
|
||||||
|
attention_scale_list = [attention_scale]
|
||||||
|
|
||||||
path_2axes = k2.ragged.remove_axis(path, 0)
|
path_2axes = k2.ragged.remove_axis(path, 0)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user