Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Daniel Povey 2021-08-22 11:56:45 +08:00
commit 24d3a98378
13 changed files with 1074 additions and 137 deletions

View File

@ -53,6 +53,13 @@ It should print the path to `icefall`.
At present, only LibriSpeech recipe is provided. Please
follow [egs/librispeech/ASR/README.md][LibriSpeech] to run it.
## Use Pre-trained models
See [egs/librispeech/ASR/conformer_ctc/README.md](egs/librispeech/ASR/conformer_ctc/README.md)
for how to use pre-trained models.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing)
[LibriSpeech]: egs/librispeech/ASR/README.md
[k2-install]: https://k2.readthedocs.io/en/latest/installation/index.html#
[k2]: https://github.com/k2-fsa/k2

View File

@ -0,0 +1,23 @@
## Results
### LibriSpeech BPE training results (Conformer-CTC)
#### 2021-08-19
(Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/13
TensorBoard log is available at https://tensorboard.dev/experiment/GnRzq8WWQW62dK4bklXBTg/#scalars
Pretrained model is available at https://huggingface.co/pkufool/conformer_ctc
The best decoding results (WER) are listed below, we got this results by averaging models from epoch 15 to 34, and using `attention-decoder` decoder with num_paths equals to 100.
||test-clean|test-other|
|--|--|--|
|WER| 2.57% | 5.94% |
To get more unique paths, we scaled the lattice.scores with 0.5 (see https://github.com/k2-fsa/icefall/pull/10#discussion_r690951662 for more details), we searched the lm_score_scale and attention_score_scale for best results, the scales that produced the WER above are also listed below.
||lm_scale|attention_scale|
|--|--|--|
|test-clean|1.3|1.2|
|test-other|1.2|1.1|

View File

@ -0,0 +1,351 @@
# How to use a pre-trained model to transcribe a sound file or multiple sound files
(See the bottom of this document for the link to a colab notebook.)
You need to prepare 4 files:
- a model checkpoint file, e.g., epoch-20.pt
- HLG.pt, the decoding graph
- words.txt, the word symbol table
- a sound file, whose sampling rate has to be 16 kHz.
Supported formats are those supported by `torchaudio.load()`,
e.g., wav and flac.
Also, you need to install `kaldifeat`. Please refer to
<https://github.com/csukuangfj/kaldifeat> for installation.
```bash
./conformer_ctc/pretrained.py --help
```
displays the help information.
## HLG decoding
Once you have the above files ready and have `kaldifeat` installed,
you can run:
```bash
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
/path/to/your/sound.wav
```
and you will see the transcribed result.
If you want to transcribe multiple files at the same time, you can use:
```bash
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav
```
**Note**: This is the fastest decoding method.
## HLG decoding + LM rescoring
`./conformer_ctc/pretrained.py` also supports `whole lattice LM rescoring`
and `attention decoder rescoring`.
To use whole lattice LM rescoring, you also need the following files:
- G.pt, e.g., `data/lm/G_4_gram.pt` if you have run `./prepare.sh`
The command to run decoding with LM rescoring is:
```bash
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
--method whole-lattice-rescoring \
--G data/lm/G_4_gram.pt \
--ngram-lm-scale 0.8 \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav
```
## HLG Decoding + LM rescoring + attention decoder rescoring
To use attention decoder for rescoring, you need the following extra information:
- sos token ID
- eos token ID
The command to run decoding with attention decoder rescoring is:
```bash
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--HLG /path/to/HLG.pt \
--method attention-decoder \
--G data/lm/G_4_gram.pt \
--ngram-lm-scale 1.3 \
--attention-decoder-scale 1.2 \
--lattice-score-scale 0.5 \
--num-paths 100 \
--sos-id 1 \
--eos-id 1 \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav
```
# Decoding with a pre-trained model in action
We have uploaded a pre-trained model to <https://huggingface.co/pkufool/conformer_ctc>
The following shows the steps about the usage of the provided pre-trained model.
### (1) Download the pre-trained model
```bash
sudo apt-get install git-lfs
cd /path/to/icefall/egs/librispeech/ASR
git lfs install
mkdir tmp
cd tmp
git clone https://huggingface.co/pkufool/conformer_ctc
```
**CAUTION**: You have to install `git-lfst` to download the pre-trained model.
You will find the following files:
```
tmp
`-- conformer_ctc
|-- README.md
|-- data
| |-- lang_bpe
| | |-- HLG.pt
| | |-- bpe.model
| | |-- tokens.txt
| | `-- words.txt
| `-- lm
| `-- G_4_gram.pt
|-- exp
| `-- pretraind.pt
`-- test_wavs
|-- 1089-134686-0001.flac
|-- 1221-135766-0001.flac
|-- 1221-135766-0002.flac
`-- trans.txt
6 directories, 11 files
```
**File descriptions**:
- `data/lang_bpe/HLG.pt`
It is the decoding graph.
- `data/lang_bpe/bpe.model`
It is a sentencepiece model. You can use it to reproduce our results.
- `data/lang_bpe/tokens.txt`
It contains tokens and their IDs, generated from `bpe.model`.
Provided only for convienice so that you can look up the SOS/EOS ID easily.
- `data/lang_bpe/words.txt`
It contains words and their IDs.
- `data/lm/G_4_gram.pt`
It is a 4-gram LM, useful for LM rescoring.
- `exp/pretrained.pt`
It contains pre-trained model parameters, obtained by averaging
checkpoints from `epoch-15.pt` to `epoch-34.pt`.
Note: We have removed optimizer `state_dict` to reduce file size.
- `test_waves/*.flac`
It contains some test sound files from LibriSpeech `test-clean` dataset.
- `test_waves/trans.txt`
It contains the reference transcripts for the sound files in `test_waves/`.
The information of the test sound files is listed below:
```
$ soxi tmp/conformer_ctc/test_wavs/*.flac
Input File : 'tmp/conformer_ctc/test_wavs/1089-134686-0001.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:06.62 = 106000 samples ~ 496.875 CDDA sectors
File Size : 116k
Bit Rate : 140k
Sample Encoding: 16-bit FLAC
Input File : 'tmp/conformer_ctc/test_wavs/1221-135766-0001.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:16.71 = 267440 samples ~ 1253.62 CDDA sectors
File Size : 343k
Bit Rate : 164k
Sample Encoding: 16-bit FLAC
Input File : 'tmp/conformer_ctc/test_wavs/1221-135766-0002.flac'
Channels : 1
Sample Rate : 16000
Precision : 16-bit
Duration : 00:00:04.83 = 77200 samples ~ 361.875 CDDA sectors
File Size : 105k
Bit Rate : 174k
Sample Encoding: 16-bit FLAC
Total Duration of 3 files: 00:00:28.16
```
### (2) Use HLG decoding
```bash
cd /path/to/icefall/egs/librispeech/ASR
./conformer_ctc/pretrained.py \
--checkpoint ./tmp/conformer_ctc/exp/pretraind.pt \
--words-file ./tmp/conformer_ctc/data/lang_bpe/words.txt \
--HLG ./tmp/conformer_ctc/data/lang_bpe/HLG.pt \
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac
```
The output is given below:
```
2021-08-20 11:03:05,712 INFO [pretrained.py:217] device: cuda:0
2021-08-20 11:03:05,712 INFO [pretrained.py:219] Creating model
2021-08-20 11:03:11,345 INFO [pretrained.py:238] Loading HLG from ./tmp/conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:03:18,442 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:03:18,444 INFO [pretrained.py:265] Reading sound files: ['./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:03:18,507 INFO [pretrained.py:271] Decoding started
2021-08-20 11:03:18,795 INFO [pretrained.py:300] Use HLG decoding
2021-08-20 11:03:19,149 INFO [pretrained.py:339]
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:03:19,149 INFO [pretrained.py:341] Decoding Done
```
### (3) Use HLG decoding + LM rescoring
```bash
./conformer_ctc/pretrained.py \
--checkpoint ./tmp/conformer_ctc/exp/pretraind.pt \
--words-file ./tmp/conformer_ctc/data/lang_bpe/words.txt \
--HLG ./tmp/conformer_ctc/data/lang_bpe/HLG.pt \
--method whole-lattice-rescoring \
--G ./tmp/conformer_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 0.8 \
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac
```
The output is:
```
2021-08-20 11:12:17,565 INFO [pretrained.py:217] device: cuda:0
2021-08-20 11:12:17,565 INFO [pretrained.py:219] Creating model
2021-08-20 11:12:23,728 INFO [pretrained.py:238] Loading HLG from ./tmp/conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:12:30,035 INFO [pretrained.py:246] Loading G from ./tmp/conformer_ctc/data/lm/G_4_gram.pt
2021-08-20 11:13:10,779 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:13:10,787 INFO [pretrained.py:265] Reading sound files: ['./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:13:10,798 INFO [pretrained.py:271] Decoding started
2021-08-20 11:13:11,085 INFO [pretrained.py:305] Use HLG decoding + LM rescoring
2021-08-20 11:13:11,736 INFO [pretrained.py:339]
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:13:11,737 INFO [pretrained.py:341] Decoding Done
```
### (4) Use HLG decoding + LM rescoring + attention decoder rescoring
```bash
./conformer_ctc/pretrained.py \
--checkpoint ./tmp/conformer_ctc/exp/pretraind.pt \
--words-file ./tmp/conformer_ctc/data/lang_bpe/words.txt \
--HLG ./tmp/conformer_ctc/data/lang_bpe/HLG.pt \
--method attention-decoder \
--G ./tmp/conformer_ctc/data/lm/G_4_gram.pt \
--ngram-lm-scale 1.3 \
--attention-decoder-scale 1.2 \
--lattice-score-scale 0.5 \
--num-paths 100 \
--sos-id 1 \
--eos-id 1 \
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac \
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac
```
The output is:
```
2021-08-20 11:19:11,397 INFO [pretrained.py:217] device: cuda:0
2021-08-20 11:19:11,397 INFO [pretrained.py:219] Creating model
2021-08-20 11:19:17,354 INFO [pretrained.py:238] Loading HLG from ./tmp/conformer_ctc/data/lang_bpe/HLG.pt
2021-08-20 11:19:24,615 INFO [pretrained.py:246] Loading G from ./tmp/conformer_ctc/data/lm/G_4_gram.pt
2021-08-20 11:20:04,576 INFO [pretrained.py:255] Constructing Fbank computer
2021-08-20 11:20:04,584 INFO [pretrained.py:265] Reading sound files: ['./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0001.flac', './tmp/conformer_ctc/test_wavs/1221-135766-0002.flac']
2021-08-20 11:20:04,595 INFO [pretrained.py:271] Decoding started
2021-08-20 11:20:04,854 INFO [pretrained.py:313] Use HLG + LM rescoring + attention decoder rescoring
2021-08-20 11:20:05,805 INFO [pretrained.py:339]
./tmp/conformer_ctc/test_wavs/1089-134686-0001.flac:
AFTER EARLY NIGHTFALL THE YELLOW LAMPS WOULD LIGHT UP HERE AND THERE THE SQUALID QUARTER OF THE BROTHELS
./tmp/conformer_ctc/test_wavs/1221-135766-0001.flac:
GOD AS A DIRECT CONSEQUENCE OF THE SIN WHICH MAN THUS PUNISHED HAD GIVEN HER A LOVELY CHILD WHOSE PLACE WAS ON THAT SAME DISHONOURED
BOSOM TO CONNECT HER PARENT FOR EVER WITH THE RACE AND DESCENT OF MORTALS AND TO BE FINALLY A BLESSED SOUL IN HEAVEN
./tmp/conformer_ctc/test_wavs/1221-135766-0002.flac:
YET THESE THOUGHTS AFFECTED HESTER PRYNNE LESS WITH HOPE THAN APPREHENSION
2021-08-20 11:20:05,805 INFO [pretrained.py:341] Decoding Done
```
**NOTE**: We provide a colab notebook for demonstration.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1huyupXAcHsUrKaWfI83iMEJ6J0Nh0213?usp=sharing)
Due to limited memory provided by Colab, you have to upgrade to Colab Pro to
run `HLG decoding + LM rescoring` and `HLG decoding + LM rescoring + attention decoder rescoring`.
Otherwise, you can only run `HLG decoding` with Colab.

View File

@ -0,0 +1 @@
../tdnn_lstm_ctc/asr_datamodule.py

View File

@ -13,14 +13,15 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_n_best_list,
@ -56,6 +57,18 @@ def get_parser():
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--lattice-score-scale",
type=float,
default=1.0,
help="The scale to be applied to `lattice.scores`."
"It's needed if you use any kinds of n-best based rescoring. "
"Currently, it is used when the decoding method is: nbest, "
"nbest-rescoring, attention-decoder, and nbest-oracle. "
"A smaller value results in more unique paths.",
)
return parser
@ -85,10 +98,14 @@ def get_params() -> AttributeDict:
# - nbest-rescoring
# - whole-lattice-rescoring
# - attention-decoder
# - nbest-oracle
# "method": "nbest",
# "method": "nbest-rescoring",
# "method": "whole-lattice-rescoring",
"method": "attention-decoder",
# "method": "nbest-oracle",
# num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder
# attention-decoder, and nbest-oracle
"num_paths": 100,
}
)
@ -179,6 +196,19 @@ def decode_one_batch(
subsampling_factor=params.subsampling_factor,
)
if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
# as HLG decoding is faster and the oracle WER
# is slightly worse than that of rescored lattices.
return nbest_oracle(
lattice=lattice,
num_paths=params.num_paths,
ref_texts=supervisions["text"],
lexicon=lexicon,
scale=params.lattice_score_scale,
)
if params.method in ["1best", "nbest"]:
if params.method == "1best":
best_path = one_best_decoding(
@ -190,8 +220,9 @@ def decode_one_batch(
lattice=lattice,
num_paths=params.num_paths,
use_double_scores=params.use_double_scores,
scale=params.lattice_score_scale,
)
key = f"no_rescore-{params.num_paths}"
key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa
hyps = get_texts(best_path)
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
@ -212,6 +243,7 @@ def decode_one_batch(
G=G,
num_paths=params.num_paths,
lm_scale_list=lm_scale_list,
scale=params.lattice_score_scale,
)
elif params.method == "whole-lattice-rescoring":
best_path_dict = rescore_with_whole_lattice(
@ -231,6 +263,7 @@ def decode_one_batch(
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
scale=params.lattice_score_scale,
)
else:
assert False, f"Unsupported decoding method: {params.method}"
@ -284,7 +317,11 @@ def decode_dataset(
results = []
num_cuts = 0
tot_num_cuts = len(dl.dataset.cuts)
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -313,10 +350,10 @@ def decode_dataset(
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_idx}, cuts processed until now is "
f"{num_cuts}/{tot_num_cuts} "
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -376,7 +413,7 @@ def main():
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/log-decode")
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started")
logging.info(params)
@ -399,7 +436,9 @@ def main():
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
@ -430,7 +469,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:

View File

@ -0,0 +1,350 @@
#!/usr/bin/env python3
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from conformer import Conformer
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_whole_lattice,
)
from icefall.utils import AttributeDict, get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument(
"--HLG", type=str, required=True, help="Path to HLG.pt."
)
parser.add_argument(
"--method",
type=str,
default="1best",
help="""Decoding method.
Possible values are:
(1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding.
We call it HLG decoding.
(2) whole-lattice-rescoring - Use an LM to rescore the
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + n-gram LM rescoring.
(3) attention-decoder - Extract n paths from he rescored
lattice and use the transformer attention decoder for
rescoring.
We call it HLG decoding + n-gram LM rescoring + attention
decoder rescoring.
""",
)
parser.add_argument(
"--G",
type=str,
help="""An LM for rescoring.
Used only when method is
whole-lattice-rescoring or attention-decoder.
It's usually a 4-gram LM.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""
Used only when method is attention-decoder.
It specifies the size of n-best list.""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=1.3,
help="""
Used only when method is whole-lattice-rescoring and attention-decoder.
It specifies the scale for n-gram LM scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"--attention-decoder-scale",
type=float,
default=1.2,
help="""
Used only when method is attention-decoder.
It specifies the scale for attention decoder scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"--lattice-score-scale",
type=float,
default=0.5,
help="""
Used only when method is attention-decoder.
It specifies the scale for lattice.scores when
extracting n-best lists. A smaller value results in
more unique number of paths with the risk of missing
the best path.
""",
)
parser.add_argument(
"--sos-id",
type=float,
default=1,
help="""
Used only when method is attention-decoder.
It specifies ID for the SOS token.
""",
)
parser.add_argument(
"--eos-id",
type=float,
default=1,
help="""
Used only when method is attention-decoder.
It specifies ID for the EOS token.
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"nhead": 8,
"num_classes": 5000,
"sample_rate": 16000,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"is_espnet_structure": True,
"mmi_loss": False,
"use_feat_batchnorm": True,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=params.num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = G.to(device)
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G.lm_scores = G.scores.clone()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info(f"Decoding started")
features = fbank(waves)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
# Note: We don't use key padding mask for attention during decoding
with torch.no_grad():
nnet_output, memory, memory_key_padding_mask = model(features)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
elif params.method == "whole-lattice-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=[params.ngram_lm_scale],
)
best_path = next(iter(best_path_dict.values()))
elif params.method == "attention-decoder":
logging.info("Use HLG + LM rescoring + attention decoder rescoring")
rescored_lattice = rescore_with_whole_lattice(
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None
)
best_path_dict = rescore_with_attention_decoder(
lattice=rescored_lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
scale=params.lattice_score_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info(f"Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -13,6 +13,7 @@ import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
@ -23,7 +24,6 @@ from transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.dist import cleanup_dist, setup_dist
from icefall.lexicon import Lexicon
from icefall.utils import (
@ -60,9 +60,6 @@ def get_parser():
help="Should various information be logged in tensorboard.",
)
# TODO: add extra arguments and support DDP training.
# Currently, only single GPU training is implemented. Will add
# DDP training once single GPU training is finished.
return parser
@ -127,7 +124,7 @@ def get_params() -> AttributeDict:
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp_new"),
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
"feature_dim": 80,
"weight_decay": 1e-6,
@ -145,7 +142,6 @@ def get_params() -> AttributeDict:
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
#
"accum_grad": 1,
"att_rate": 0.7,
"attention_dim": 512,

View File

@ -1,14 +1,16 @@
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import List, Union
from lhotse import Fbank, FbankConfig, load_manifest
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
from lhotse.dataset import (
BucketingSampler,
CutConcatenate,
CutMix,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SingleCutSampler,
SpecAugment,
)
@ -19,7 +21,7 @@ from icefall.dataset.datamodule import DataModule
from icefall.utils import str2bool
class AsrDataModule(DataModule):
class LibriSpeechAsrDataModule(DataModule):
"""
DataModule for K2 ASR experiments.
It assumes there is always one train and valid dataloader,
@ -47,6 +49,13 @@ class AsrDataModule(DataModule):
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled, use 960h LibriSpeech. "
"Otherwise, use 100h subset.",
)
group.add_argument(
"--feature-dir",
type=Path,
@ -77,7 +86,7 @@ class AsrDataModule(DataModule):
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=True,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
@ -104,6 +113,29 @@ class AsrDataModule(DataModule):
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
def train_dataloaders(self) -> DataLoader:
logging.info("About to get train cuts")
@ -138,9 +170,9 @@ class AsrDataModule(DataModule):
]
train = K2SpeechRecognitionDataset(
cuts_train,
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
@ -154,14 +186,13 @@ class AsrDataModule(DataModule):
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
cuts_train = cuts_train.drop_features()
train = K2SpeechRecognitionDataset(
cuts=cuts_train,
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
@ -169,44 +200,60 @@ class AsrDataModule(DataModule):
train_sampler = BucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=True,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
bucket_method="equal_duration",
drop_last=True,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SingleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=True,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=4,
persistent_workers=True,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return train_dl
def valid_dataloaders(self) -> DataLoader:
logging.info("About to get dev cuts")
cuts_valid = self.valid_cuts()
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
cuts_valid = cuts_valid.drop_features()
validate = K2SpeechRecognitionDataset(
cuts_valid.drop_features(),
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(cuts_valid)
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = SingleCutSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
@ -214,8 +261,9 @@ class AsrDataModule(DataModule):
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=True,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
@ -228,10 +276,12 @@ class AsrDataModule(DataModule):
for cuts_test in cuts:
logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(
cuts_test,
input_strategy=OnTheFlyFeatures(
Fbank(FbankConfig(num_mel_bins=80))
),
)
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = SingleCutSampler(
cuts_test, max_duration=self.args.max_duration
@ -246,3 +296,42 @@ class AsrDataModule(DataModule):
return test_loaders
else:
return test_loaders[0]
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(
self.args.feature_dir / "cuts_train-clean-100.json.gz"
)
if self.args.full_libri:
cuts_train = (
cuts_train
+ load_manifest(
self.args.feature_dir / "cuts_train-clean-360.json.gz"
)
+ load_manifest(
self.args.feature_dir / "cuts_train-other-500.json.gz"
)
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(
self.args.feature_dir / "cuts_dev-clean.json.gz"
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
return cuts_valid
@lru_cache()
def test_cuts(self) -> List[CutSet]:
test_sets = ["test-clean", "test-other"]
cuts = []
for test_set in test_sets:
logging.debug("About to get test cuts")
cuts.append(
load_manifest(
self.args.feature_dir / f"cuts_{test_set}.json.gz"
)
)
return cuts

View File

@ -10,10 +10,10 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from model import TdnnLstm
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.decode import (
get_lattice,
nbest_decoding,
@ -236,7 +236,11 @@ def decode_dataset(
results = []
num_cuts = 0
tot_num_cuts = len(dl.dataset.cuts)
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
@ -263,10 +267,10 @@ def decode_dataset(
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_idx}, cuts processed until now is "
f"{num_cuts}/{tot_num_cuts} "
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
@ -328,7 +332,9 @@ def main():
logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt"))
HLG = k2.Fsa.from_dict(
torch.load("data/lang_phone/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device)
assert HLG.requires_grad is False
@ -355,7 +361,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else:
logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device)
if params.method == "whole-lattice-rescoring":

View File

@ -1,7 +1,5 @@
#!/usr/bin/env python3
# This is just at the very beginning ...
import argparse
import logging
from pathlib import Path
@ -14,16 +12,16 @@ import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from asr_datamodule import LibriSpeechAsrDataModule
from lhotse.utils import fix_random_seed
from model import TdnnLstm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.dist import cleanup_dist, setup_dist
from icefall.graph_compiler import CtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
@ -61,9 +59,6 @@ def get_parser():
help="Should various information be logged in tensorboard.",
)
# TODO: add extra arguments and support DDP training.
# Currently, only single GPU training is implemented. Will add
# DDP training once single GPU training is finished.
return parser
@ -406,7 +401,7 @@ def train_one_epoch(
optimizer.zero_grad()
loss.backward()
clip_grad_value_(model.parameters(), 5.0)
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
loss_cpu = loss.detach().cpu().item()

View File

@ -91,7 +91,7 @@ def load_checkpoint(
checkpoint.pop("model")
def load(name, obj):
s = checkpoint[name]
s = checkpoint.get(name, None)
if obj and s:
obj.load_state_dict(s)
checkpoint.pop(name)

View File

@ -1,68 +0,0 @@
import argparse
import logging
from functools import lru_cache
from typing import List
from lhotse import CutSet, load_manifest
from icefall.dataset.asr_datamodule import AsrDataModule
from icefall.utils import str2bool
class LibriSpeechAsrDataModule(AsrDataModule):
"""
LibriSpeech ASR data module. Can be used for 100h subset
(``--full-libri false``) or full 960h set.
The train and valid cuts for standard Libri splits are
concatenated into a single CutSet/DataLoader.
"""
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
super().add_arguments(parser)
group = parser.add_argument_group(title="LibriSpeech specific options")
group.add_argument(
"--full-libri",
type=str2bool,
default=True,
help="When enabled, use 960h LibriSpeech.",
)
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
cuts_train = load_manifest(
self.args.feature_dir / "cuts_train-clean-100.json.gz"
)
if self.args.full_libri:
cuts_train = (
cuts_train
+ load_manifest(
self.args.feature_dir / "cuts_train-clean-360.json.gz"
)
+ load_manifest(
self.args.feature_dir / "cuts_train-other-500.json.gz"
)
)
return cuts_train
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
cuts_valid = load_manifest(
self.args.feature_dir / "cuts_dev-clean.json.gz"
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
return cuts_valid
@lru_cache()
def test_cuts(self) -> List[CutSet]:
test_sets = ["test-clean", "test-other"]
cuts = []
for test_set in test_sets:
logging.debug("About to get test cuts")
cuts.append(
load_manifest(
self.args.feature_dir / f"cuts_{test_set}.json.gz"
)
)
return cuts

View File

@ -2,9 +2,42 @@ import logging
from typing import Dict, List, Optional, Tuple, Union
import k2
import kaldialign
import torch
import torch.nn as nn
from icefall.lexicon import Lexicon
def _get_random_paths(
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
scale: float = 1.0,
):
"""
Args:
lattice:
The decoding lattice, returned by :func:`get_lattice`.
num_paths:
It specifies the size `n` in n-best. Note: Paths are selected randomly
and those containing identical word sequences are remove dand only one
of them is kept.
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
Return a k2.RaggedInt with 3 axes [seq][path][arc_pos]
"""
saved_scores = lattice.scores.clone()
lattice.scores *= scale
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
lattice.scores = saved_scores
return path
def _intersect_device(
a_fsas: k2.Fsa,
@ -129,7 +162,10 @@ def one_best_decoding(
def nbest_decoding(
lattice: k2.Fsa, num_paths: int, use_double_scores: bool = True
lattice: k2.Fsa,
num_paths: int,
use_double_scores: bool = True,
scale: float = 1.0,
) -> k2.Fsa:
"""It implements something like CTC prefix beam search using n-best lists.
@ -152,12 +188,18 @@ def nbest_decoding(
use_double_scores:
True to use double precision floating point in the computation.
False to use single precision.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
An FsaVec containing linear FSAs.
"""
# First, extract `num_paths` paths for each sequence.
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
path = _get_random_paths(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
scale=scale,
)
# word_seq is a k2.RaggedInt sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
@ -320,7 +362,11 @@ def compute_am_and_lm_scores(
def rescore_with_n_best_list(
lattice: k2.Fsa, G: k2.Fsa, num_paths: int, lm_scale_list: List[float]
lattice: k2.Fsa,
G: k2.Fsa,
num_paths: int,
lm_scale_list: List[float],
scale: float = 1.0,
) -> Dict[str, k2.Fsa]:
"""Decode using n-best list with LM rescoring.
@ -342,6 +388,9 @@ def rescore_with_n_best_list(
It is the size `n` in `n-best` list.
lm_scale_list:
A list containing lm_scale values.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
A dict of FsaVec, whose key is an lm_scale and the value is the
best decoding path for each sequence in the lattice.
@ -356,9 +405,12 @@ def rescore_with_n_best_list(
assert G.device == device
assert hasattr(G, "aux_labels") is False
# First, extract `num_paths` paths for each sequence.
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
path = _get_random_paths(
lattice=lattice,
num_paths=num_paths,
use_double_scores=True,
scale=scale,
)
# word_seq is a k2.RaggedInt sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
@ -376,7 +428,7 @@ def rescore_with_n_best_list(
#
# num_repeats is also a k2.RaggedInt with 2 axes containing the
# multiplicities of each path.
# num_repeats.num_elements() == unique_word_seqs.num_elements()
# num_repeats.num_elements() == unique_word_seqs.tot_size(1)
#
# Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index
@ -494,6 +546,8 @@ def rescore_with_whole_lattice(
del lattice.lm_scores
assert hasattr(lattice, "lm_scores") is False
assert hasattr(G_with_epsilon_loops, "lm_scores")
# Now, lattice.scores contains only am_scores
# inv_lattice has word IDs as labels.
@ -549,14 +603,88 @@ def rescore_with_whole_lattice(
return ans
def nbest_oracle(
lattice: k2.Fsa,
num_paths: int,
ref_texts: List[str],
lexicon: Lexicon,
scale: float = 1.0,
) -> Dict[str, List[List[int]]]:
"""Select the best hypothesis given a lattice and a reference transcript.
The basic idea is to extract n paths from the given lattice, unique them,
and select the one that has the minimum edit distance with the corresponding
reference transcript as the decoding output.
The decoding result returned from this function is the best result that
we can obtain using n-best decoding with all kinds of rescoring techniques.
Args:
lattice:
An FsaVec. It can be the return value of :func:`get_lattice`.
Note: We assume its aux_labels contain word IDs.
num_paths:
The size of `n` in n-best.
ref_texts:
A list of reference transcript. Each entry contains space(s)
separated words
lexicon:
It is used to convert word IDs to word symbols.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Return:
Return a dict. Its key contains the information about the parameters
when calling this function, while its value contains the decoding output.
`len(ans_dict) == len(ref_texts)`
"""
path = _get_random_paths(
lattice=lattice,
num_paths=num_paths,
use_double_scores=True,
scale=scale,
)
word_seq = k2.index(lattice.aux_labels, path)
word_seq = k2.ragged.remove_values_leq(word_seq, 0)
unique_word_seq, _, _ = k2.ragged.unique_sequences(
word_seq, need_num_repeats=False, need_new2old_indexes=False
)
unique_word_ids = k2.ragged.to_list(unique_word_seq)
assert len(unique_word_ids) == len(ref_texts)
# unique_word_ids[i] contains all hypotheses of the i-th utterance
results = []
for hyps, ref in zip(unique_word_ids, ref_texts):
# Note hyps is a list-of-list ints
# Each sublist contains a hypothesis
ref_words = ref.strip().split()
# CAUTION: We don't convert ref_words to ref_words_ids
# since there may exist OOV words in ref_words
best_hyp_words = None
min_error = float("inf")
for hyp_words in hyps:
hyp_words = [lexicon.word_table[i] for i in hyp_words]
this_error = kaldialign.edit_distance(ref_words, hyp_words)["total"]
if this_error < min_error:
min_error = this_error
best_hyp_words = hyp_words
results.append(best_hyp_words)
return {f"nbest_{num_paths}_scale_{scale}_oracle": results}
def rescore_with_attention_decoder(
lattice: k2.Fsa,
num_paths: int,
model: nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int,
eos_id: int,
scale: float = 1.0,
ngram_lm_scale: Optional[float] = None,
attention_scale: Optional[float] = None,
) -> Dict[str, k2.Fsa]:
"""This function extracts n paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest
@ -580,6 +708,13 @@ def rescore_with_attention_decoder(
The token ID for SOS.
eos_id:
The token ID for EOS.
scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
ngram_lm_scale:
Optional. It specifies the scale for n-gram LM scores.
attention_scale:
Optional. It specifies the scale for attention decoder scores.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
@ -587,7 +722,12 @@ def rescore_with_attention_decoder(
"""
# First, extract `num_paths` paths for each sequence.
# path is a k2.RaggedInt with axes [seq][path][arc_pos]
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
path = _get_random_paths(
lattice=lattice,
num_paths=num_paths,
use_double_scores=True,
scale=scale,
)
# word_seq is a k2.RaggedInt sharing the same shape as `path`
# but it contains word IDs. Note that it also contains 0s and -1s.
@ -605,12 +745,12 @@ def rescore_with_attention_decoder(
#
# num_repeats is also a k2.RaggedInt with 2 axes containing the
# multiplicities of each path.
# num_repeats.num_elements() == unique_word_seqs.num_elements()
# num_repeats.num_elements() == unique_word_seqs.tot_size(1)
#
# Since k2.ragged.unique_sequences will reorder paths within a seq,
# `new2old` is a 1-D torch.Tensor mapping from the output path index
# to the input path index.
# new2old.numel() == unique_word_seqs.tot_size(1)
# new2old.numel() == unique_word_seq.tot_size(1)
unique_word_seq, num_repeats, new2old = k2.ragged.unique_sequences(
word_seq, need_num_repeats=True, need_new2old_indexes=True
)
@ -662,11 +802,13 @@ def rescore_with_attention_decoder(
path_to_seq_map_long = path_to_seq_map.to(torch.long)
expanded_memory = memory.index_select(1, path_to_seq_map_long)
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_seq_map_long
)
if memory_key_padding_mask is not None:
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_seq_map_long
)
else:
expanded_memory_key_padding_mask = None
# TODO: pass the sos_token_id and eos_token_id via function arguments
nll = model.decoder_nll(
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
@ -681,11 +823,17 @@ def rescore_with_attention_decoder(
assert attention_scores.ndim == 1
assert attention_scores.numel() == num_word_seqs
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
if ngram_lm_scale is None:
ngram_lm_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
ngram_lm_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
ngram_lm_scale_list = [ngram_lm_scale]
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
if attention_scale is None:
attention_scale_list = [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
else:
attention_scale_list = [attention_scale]
path_2axes = k2.ragged.remove_axis(path, 0)