Support computing nbest oracle WER. (#10)

* Support computing nbest oracle WER.

* Add scale to all nbest based decoding/rescoring methods.

* Add script to run pretrained models.

* Use torchaudio to extract features.

* Support decoding multiple files at the same time.

Also, use kaldifeat for feature extraction.

* Support decoding with LM rescoring and attention-decoder rescoring.

* Minor fixes.

* Replace scale with lattice-score-scale.

* Add usage example with a provided pretrained model.
This commit is contained in:
Fangjun Kuang 2021-08-20 11:53:37 +08:00 committed by GitHub
parent ef233486ae
commit 9d0cc9d829
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 893 additions and 23 deletions

View File

@ -0,0 +1,339 @@
# How to use a pre-trained model to transcribe a sound file or multiple sound files
You need to prepare 4 files:
- a model checkpoint file, e.g., epoch-20.pt
- HLG.pt, the decoding graph
- words.txt, the word symbol table
- a sound file, whose sampling rate has to be 16 kHz.
Supported formats are those supported by `torchaudio.load()`,
e.g., wav and flac.
Also, you need to install `kaldifeat`. Please refer to
<https://github.com/csukuangfj/kaldifeat> for installation.
```bash
./conformer_ctc/pretrained.py --help
```
displays the help information.
## HLG decoding
Once you have the above files ready and have `kaldifeat` installed,
you can run:
```bash
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
/path/to/your/sound.wav
```
and you will see the transcribed result.
If you want to transcribe multiple files at the same time, you can use:
```bash
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav
```
**Note**: This is the fastest decoding method.
## HLG decoding + LM rescoring
`./conformer_ctc/pretrained.py` also supports `whole lattice LM rescoring`
and `attention decoder rescoring`.
To use whole lattice LM rescoring, you also need the following files:
- G.pt, e.g., `data/lm/G_4_gram.pt` if you have run `./prepare.sh`
The command to run decoding with LM rescoring is:
```bash
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
--method whole-lattice-rescoring \
--G data/lm/G_4_gram.pt \
--ngram-lm-scale 0.8 \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav
```
## HLG Decoding + LM rescoring + attention decoder rescoring
To use attention decoder for rescoring, you need the following extra information:
- sos token ID
- eos token ID
The command to run decoding with attention decoder rescoring is:
```bash
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
--method attention-decoder \
--G data/lm/G_4_gram.pt \
--ngram-lm-scale 1.3 \
--attention-decoder-scale 1.2 \
--lattice-score-scale 0.5 \
--num-paths 100 \
--sos-id 1 \
--eos-id 1 \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav
```
# Decoding with a pretrained model in action
We have uploaded a pretrained model to <https://huggingface.co/pkufool/conformer_ctc>
The following shows the steps about the usage of the provided pretrained model.
### (1) Download the pretrained model
```bash
cd /path/to/icefall/egs/librispeech/ASR
mkdir tmp
cd tmp
git clone https://huggingface.co/pkufool/conformer_ctc
```
You will find the following files:
```
tmp
`-- conformer_ctc
|-- README.md
|-- data
| |-- lang_bpe
| | |-- HLG.pt
| | |-- bpe.model
| | |-- tokens.txt
| | `-- words.txt
| `-- lm
| `-- G_4_gram.pt
|-- exp
| `-- pretraind.pt
`-- test_wavs
|-- 1089-134686-0001.flac
|-- 1221-135766-0001.flac
|-- 1221-135766-0002.flac
`-- trans.txt
6 directories, 11 files
```
**File descriptions**:
- `data/lang_bpe/HLG.pt`
It is the decoding graph.
- `data/lang_bpe/bpe.model`
It is a sentencepiece model. You can use it to reproduce our results.
- `data/lang_bpe/tokens.txt`
It contains tokens and their IDs, generated from `bpe.model`.
Provided only for convienice so that you can look up the SOS/EOS ID easily.
- `data/lang_bpe/words.txt`
It contains words and their IDs.
- `data/lm/G_4_gram.pt`
It is a 4-gram LM, useful for LM rescoring.
- `exp/pretrained.pt`
It contains pretrained model parameters, obtained by averaging
checkpoints from `epoch-15.pt` to `epoch-34.pt`.
Note: We have removed optimizer `state_dict` to reduce file size.
- `test_waves/*.flac`
It contains some test sound files from LibriSpeech `test-clean` dataset.
- `test_waves/trans.txt`
It contains the reference transcripts for the sound files in `test_waves/`.
The information of the test sound files is listed below:
```
$ soxi tmp/conformer_ctc/test_wavs/*.flac
Input File : 'tmp/conformer_ctc/test_wavs/1089-134686-0001.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors
File Size : 116k
Bit Rate : 140k
Sample Encoding: 16-bit FLAC
Input File : 'tmp/conformer_ctc/test_wavs/1221-135766-0001.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors
File Size : 343k
Bit Rate : 164k
Sample Encoding: 16-bit FLAC
Input File : 'tmp/conformer_ctc/test_wavs/1221-135766-0002.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors
File Size : 105k
Bit Rate : 174k
Sample Encoding: 16-bit FLAC
Total Duration of 3 files: 00:00:28.16
```
### (2) Use HLG decoding
```bash
cd /path/to/icefall/egs/librispeech/ASR
./conformer_ctc/pretrained.py \
--checkpoint ./tmp/conformer_ctc/exp/pretraind.pt \
--words-file ./tmp/conformer_ctc/data/lang_bpe/words.txt \
--HLG ./tmp/conformer_ctc/data/lang_bpe/HLG.pt \
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac
```
The output is given below:
```
2021-08-20 11:03:05,712 INFO [pretrained.py:217] device: cuda:0
2021-08-20 11:03:05,712 INFO [pretrained.py:219] Creating model
2021-08-20 11:03:11,345 INFO [pretrained.py:238] Loading HLG from ./tmp/conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:03:18,442 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:03:18,444 INFO [pretrained.py:265] Reading sound files: ['./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:03:18,507 INFO [pretrained.py:271] Decoding started
2021-08-20 11:03:18,795 INFO [pretrained.py:300] Use HLG decoding
2021-08-20 11:03:19,149 INFO [pretrained.py:339]
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:03:19,149 INFO [pretrained.py:341] Decoding Done
```
### (3) Use HLG decoding + LM rescoring
```bash
./conformer_ctc/pretrained.py \
--checkpoint ./tmp/conformer_ctc/exp/pretraind.pt \
--words-file ./tmp/conformer_ctc/data/lang_bpe/words.txt \
--HLG ./tmp/conformer_ctc/data/lang_bpe/HLG.pt \
--method whole-lattice-rescoring \
--G ./tmp/conformer_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 0.8 \
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac
```
The output is:
```
2021-08-20 11:12:17,565 INFO [pretrained.py:217] device: cuda:0
2021-08-20 11:12:17,565 INFO [pretrained.py:219] Creating model
2021-08-20 11:12:23,728 INFO [pretrained.py:238] Loading HLG from ./tmp/conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:12:30,035 INFO [pretrained.py:246] Loading G from ./tmp/conformer_ctc/data/lm/G_4_gram.pt
2021-08-20 11:13:10,779 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:13:10,787 INFO [pretrained.py:265] Reading sound files: ['./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:13:10,798 INFO [pretrained.py:271] Decoding started
2021-08-20 11:13:11,085 INFO [pretrained.py:305] Use HLG decoding + LM rescoring
2021-08-20 11:13:11,736 INFO [pretrained.py:339]
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:13:11,737 INFO [pretrained.py:341] Decoding Done
```
### (4) Use HLG decoding + LM rescoring + attention decoder rescoring
```bash
./conformer_ctc/pretrained.py \
--checkpoint ./tmp/conformer_ctc/exp/pretraind.pt \
--words-file ./tmp/conformer_ctc/data/lang_bpe/words.txt \
--HLG ./tmp/conformer_ctc/data/lang_bpe/HLG.pt \
--method attention-decoder \
--G ./tmp/conformer_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 1.3 \
--attention-decoder-scale 1.2 \
--lattice-score-scale 0.5 \
--num-paths 100 \
--sos-id 1 \
--eos-id 1 \
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac
```
The output is:
```
2021-08-20 11:19:11,397 INFO [pretrained.py:217] device: cuda:0
2021-08-20 11:19:11,397 INFO [pretrained.py:219] Creating model
2021-08-20 11:19:17,354 INFO [pretrained.py:238] Loading HLG from ./tmp/conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:19:24,615 INFO [pretrained.py:246] Loading G from ./tmp/conformer_ctc/data/lm/G_4_gram.pt
2021-08-20 11:20:04,576 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:20:04,584 INFO [pretrained.py:265] Reading sound files: ['./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:20:04,595 INFO [pretrained.py:271] Decoding started
2021-08-20 11:20:04,854 INFO [pretrained.py:313] Use HLG + LM rescoring + attention decoder rescoring
2021-08-20 11:20:05,805 INFO [pretrained.py:339]
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:20:05,805 INFO [pretrained.py:341] Decoding Done
```

View File

@ -21,6 +21,7 @@ from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
@ -56,6 +57,18 @@ def get_parser():
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--lattice-score-scale",
type=float,
default=1.0,
help="The scale to be applied to `lattice.scores`."
"It's needed if you use any kinds of n-best based rescoring. "
"Currently, it is used when the decoding method is: nbest, "
"nbest-rescoring, attention-decoder, and nbest-oracle. "
"A smaller value results in more unique paths.",
)
return parser
@ -85,10 +98,14 @@ def get_params() -> AttributeDict:
# - nbest-rescoring
# - whole-lattice-rescoring
# - attention-decoder
# - nbest-oracle
# "method": "nbest",
# "method": "nbest-rescoring",
# "method": "whole-lattice-rescoring",
"method": "attention-decoder",
# "method": "nbest-oracle",
# num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder
# attention-decoder, and nbest-oracle
"num_paths": 100,
}
)
@ -179,6 +196,19 @@ def decode_one_batch(
subsampling_factor=params.subsampling_factor,
)
if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is slightly worse than that of rescored lattices.
return nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
lexicon=lexicon,
scale=params.lattice_score_scale,
)
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
@ -190,8 +220,9 @@ def decode_one_batch(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
scale=params.lattice_score_scale,
)
key = f"no_rescore-{params.num_paths}"
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}"
hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
@ -212,6 +243,7 @@ def decode_one_batch(
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
scale=params.lattice_score_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
@ -231,6 +263,7 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
scale=params.lattice_score_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"

View File

@ -0,0 +1,350 @@
#!/usr/bin/env python3
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from conformer import Conformer
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_whole_lattice,
)
from icefall.utils import AttributeDict, get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument(
"--HLG", type=str, required=True, help="Path to HLG.pt."
)
parser.add_argument(
"--method",
type=str,
default="1best",
help="""Decoding method.
Possible values are:
(1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding.
We call it HLG decoding.
(2) whole-lattice-rescoring - Use an LM to rescore the
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + n-gram LM rescoring.
(3) attention-decoder - Extract n paths from he rescored
lattice and use the transformer attention decoder for
rescoring.
We call it HLG decoding + n-gram LM rescoring + attention
decoder rescoring.
""",
)
parser.add_argument(
"--G",
type=str,
help="""An LM for rescoring.
Used only when method is
whole-lattice-rescoring or attention-decoder.
It's usually a 4-gram LM.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""
Used only when method is attention-decoder.
It specifies the size of n-best list.""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=1.3,
help="""
Used only when method is whole-lattice-rescoring and attention-decoder.
It specifies the scale for n-gram LM scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"--attention-decoder-scale",
type=float,
default=1.2,
help="""
Used only when method is attention-decoder.
It specifies the scale for attention decoder scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"--lattice-score-scale",
type=float,
default=0.5,
help="""
Used only when method is attention-decoder.
It specifies the scale for lattice.scores when
extracting n-best lists. A smaller value results in
more unique number of paths with the risk of missing
the best path.
""",
)
parser.add_argument(
"--sos-id",
type=float,
default=1,
help="""
Used only when method is attention-decoder.
It specifies ID for the SOS token.
""",
)
parser.add_argument(
"--eos-id",
type=float,
default=1,
help="""
Used only when method is attention-decoder.
It specifies ID for the EOS token.
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"nhead": 8,
"num_classes": 5000,
"sample_rate": 16000,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"is_espnet_structure": True,
"mmi_loss": False,
"use_feat_batchnorm": True,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=params.num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
G.lm_scores = G.scores.clone()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info(f"Decoding started")
features = fbank(waves)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
# Note: We don't use key padding mask for attention during decoding
with torch.no_grad():
nnet_output, memory, memory_key_padding_mask = model(features)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
elif params.method == "whole-lattice-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=[params.ngram_lm_scale],
)
best_path = next(iter(best_path_dict.values()))
elif params.method == "attention-decoder":
logging.info("Use HLG + LM rescoring + attention decoder rescoring")
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
scale=params.lattice_score_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info(f"Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -128,7 +128,7 @@ def get_params() -> AttributeDict:
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp_new"),
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
"feature_dim": 80,
"weight_decay": 1e-6,

View File

@ -2,9 +2,42 @@ import logging
from typing import Dict, List, Optional, Tuple, Union
import k2
import kaldialign
import torch
import torch.nn as nn
from icefall.lexicon import Lexicon
def _get_random_paths(
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
scale: float = 1.0,
):
"""
Args:
lattice:
The decoding lattice, returned by :func:`get_lattice`.
num_paths:
It specifies the size `n` in n-best. Note: Paths are selected randomly
and those containing identical word sequences are remove dand only one
of them is kept.
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
Return a k2.RaggedInt with 3 axes [seq][path][arc_pos]
"""
saved_scores = lattice.scores.clone()
lattice.scores *= scale
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
lattice.scores = saved_scores
return path
def _intersect_device(
a_fsas: k2.Fsa,
@ -129,7 +162,10 @@ def one_best_decoding(
def nbest_decoding(
lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
scale: float = 1.0,
) -> k2.Fsa:
"""It implements something like CTC prefix beam search using n-best lists.
@ -152,12 +188,18 @@ def nbest_decoding(
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
An FsaVec containing linear FSAs.
"""
# First, extract `num_paths` paths for each sequence.
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
path = _get_random_paths(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
scale=scale,
)
# word_seq is a k2.RaggedInt sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
@ -320,7 +362,11 @@ def compute_am_and_lm_scores(
def rescore_with_n_best_list(
lattice: k2.Fsa, G: k2.Fsa, num_paths: int, lm_scale_list: List[float]
lattice: k2.Fsa,
G: k2.Fsa,
num_paths: int,
lm_scale_list: List[float],
scale: float = 1.0,
) -> Dict[str, k2.Fsa]:
"""Decode using n-best list with LM rescoring.
@ -342,6 +388,9 @@ def rescore_with_n_best_list(
It is the size `n` in `n-best` list.
lm_scale_list:
A list containing lm_scale values.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
A dict of FsaVec, whose key is an lm_scale and the value is the
best decoding path for each sequence in the lattice.
@ -356,9 +405,12 @@ def rescore_with_n_best_list(
assert G.device == device
assert hasattr(G, "aux_labels") is False
# First, extract `num_paths` paths for each sequence.
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
path = _get_random_paths(
lattice=lattice,
num_paths=num_paths,
use_double_scores=True,
scale=scale,
)
# word_seq is a k2.RaggedInt sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
@ -376,7 +428,7 @@ def rescore_with_n_best_list(
#
# num_repeats is also a k2.RaggedInt with 2 axes containing the
# multiplicities of each path.
# num_repeats.num_elements() == unique_word_seqs.num_elements()
# num_repeats.num_elements() == unique_word_seqs.tot_size(1)
#
# Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index
@ -494,6 +546,8 @@ def rescore_with_whole_lattice(
del lattice.lm_scores
assert hasattr(lattice, "lm_scores") is False
assert hasattr(G_with_epsilon_loops, "lm_scores")
# Now, lattice.scores contains only am_scores
# inv_lattice has word IDs as labels.
@ -549,14 +603,88 @@ def rescore_with_whole_lattice(
return ans
def nbest_oracle(
lattice: k2.Fsa,
num_paths: int,
ref_texts: List[str],
lexicon: Lexicon,
scale: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""Select the best hypothesis given a lattice and a reference transcript.
The basic idea is to extract n paths from the given lattice, unique them,
and select the one that has the minimum edit distance with the corresponding
reference transcript as the decoding output.
The decoding result returned from this function is the best result that
we can obtain using n-best decoding with all kinds of rescoring techniques.
Args:
lattice:
An FsaVec. It can be the return value of :func:`get_lattice`.
Note: We assume its aux_labels contain word IDs.
num_paths:
The size of `n` in n-best.
ref_texts:
A list of reference transcript. Each entry contains space(s)
separated words
lexicon:
It is used to convert word IDs to word symbols.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Return:
Return a dict. Its key contains the information about the parameters
when calling this function, while its value contains the decoding output.
`len(ans_dict) == len(ref_texts)`
"""
path = _get_random_paths(
lattice=lattice,
num_paths=num_paths,
use_double_scores=True,
scale=scale,
)
word_seq = k2.index(lattice.aux_labels, path)
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
unique_word_seq, _, _ = k2.ragged.unique_sequences(
word_seq, need_num_repeats=False, need_new2old_indexes=False
)
unique_word_ids = k2.ragged.to_list(unique_word_seq)
assert len(unique_word_ids) == len(ref_texts)
# unique_word_ids[i] contains all hypotheses of the i-th utterance
results = []
for hyps, ref in zip(unique_word_ids, ref_texts):
# Note hyps is a list-of-list ints
# Each sublist contains a hypothesis
ref_words = ref.strip().split()
# CAUTION: We don't convert ref_words to ref_words_ids
# since there may exist OOV words in ref_words
best_hyp_words = None
min_error = float("inf")
for hyp_words in hyps:
hyp_words = [lexicon.word_table[i] for i in hyp_words]
this_error = kaldialign.edit_distance(ref_words, hyp_words)["total"]
if this_error < min_error:
min_error = this_error
best_hyp_words = hyp_words
results.append(best_hyp_words)
return {f"nbest_{num_paths}_scale_{scale}_oracle": results}
def rescore_with_attention_decoder(
lattice: k2.Fsa,
num_paths: int,
model: nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int,
eos_id: int,
scale: float = 1.0,
ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None,
) -> Dict[str, k2.Fsa]:
"""This function extracts n paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest
@ -580,6 +708,13 @@ def rescore_with_attention_decoder(
The token ID for SOS.
eos_id:
The token ID for EOS.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
ngram_lm_scale:
Optional. It specifies the scale for n-gram LM scores.
attention_scale:
Optional. It specifies the scale for attention decoder scores.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
@ -587,7 +722,12 @@ def rescore_with_attention_decoder(
"""
# First, extract `num_paths` paths for each sequence.
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
path = _get_random_paths(
lattice=lattice,
num_paths=num_paths,
use_double_scores=True,
scale=scale,
)
# word_seq is a k2.RaggedInt sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
@ -605,7 +745,7 @@ def rescore_with_attention_decoder(
#
# num_repeats is also a k2.RaggedInt with 2 axes containing the
# multiplicities of each path.
# num_repeats.num_elements() == unique_word_seqs.num_elements()
# num_repeats.num_elements() == unique_word_seqs.tot_size(1)
#
# Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index
@ -662,11 +802,13 @@ def rescore_with_attention_decoder(
path_to_seq_map_long = path_to_seq_map.to(torch.long)
expanded_memory = memory.index_select(1, path_to_seq_map_long)
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_seq_map_long
)
if memory_key_padding_mask is not None:
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_seq_map_long
)
else:
expanded_memory_key_padding_mask = None
# TODO: pass the sos_token_id and eos_token_id via function arguments
nll = model.decoder_nll(
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
@ -681,11 +823,17 @@ def rescore_with_attention_decoder(
assert attention_scores.ndim == 1
assert attention_scores.numel() == num_word_seqs
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
if ngram_lm_scale is None:
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
ngram_lm_scale_list = [ngram_lm_scale]
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
if attention_scale is None:
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
attention_scale_list = [attention_scale]
path_2axes = k2.ragged.remove_axis(path, 0)