mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix conflicts
This commit is contained in:
commit
d2ae1ba060
61
README.md
61
README.md
@ -1 +1,60 @@
|
|||||||
Working in progress.
|
|
||||||
|
# Table of Contents
|
||||||
|
|
||||||
|
- [Installation](#installation)
|
||||||
|
* [Install k2](#install-k2)
|
||||||
|
* [Install lhotse](#install-lhotse)
|
||||||
|
* [Install icefall](#install-icefall)
|
||||||
|
- [Run recipes](#run-recipes)
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
`icefall` depends on [k2][k2] for FSA operations and [lhotse][lhotse] for
|
||||||
|
data preparations. To use `icefall`, you have to install its dependencies first.
|
||||||
|
The following subsections describe how to setup the environment.
|
||||||
|
|
||||||
|
CAUTION: There are various ways to setup the environment. What we describe
|
||||||
|
here is just one alternative.
|
||||||
|
|
||||||
|
### Install k2
|
||||||
|
|
||||||
|
Please refer to [k2's installation documentation][k2-install] to install k2.
|
||||||
|
If you have any issues about installing k2, please open an issue at
|
||||||
|
<https://github.com/k2-fsa/k2/issues>.
|
||||||
|
|
||||||
|
### Install lhotse
|
||||||
|
|
||||||
|
Please refer to [lhotse's installation documentation][lhotse-install] to install
|
||||||
|
lhotse.
|
||||||
|
|
||||||
|
### Install icefall
|
||||||
|
|
||||||
|
`icefall` is a set of Python scripts. What you need to do is just to set
|
||||||
|
the environment variable `PYTHONPATH`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd $HOME/open-source
|
||||||
|
git clone https://github.com/k2-fsa/icefall
|
||||||
|
cd icefall
|
||||||
|
pip install -r requirements.txt
|
||||||
|
export PYTHONPATH=$HOME/open-source/icefall:$PYTHONPATHON
|
||||||
|
```
|
||||||
|
|
||||||
|
To verify `icefall` was installed successfully, you can run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -c "import icefall; print(icefall.__file__)"
|
||||||
|
```
|
||||||
|
|
||||||
|
It should print the path to `icefall`.
|
||||||
|
|
||||||
|
## Run recipes
|
||||||
|
|
||||||
|
At present, only LibriSpeech recipe is provided. Please
|
||||||
|
follow [egs/librispeech/ASR/README.md][LibriSpeech] to run it.
|
||||||
|
|
||||||
|
[LibriSpeech]: egs/librispeech/ASR/README.md
|
||||||
|
[k2-install]: https://k2.readthedocs.io/en/latest/installation/index.html#
|
||||||
|
[k2]: https://github.com/k2-fsa/k2
|
||||||
|
[lhotse]: https://github.com/lhotse-speech/lhotse
|
||||||
|
[lhotse-install]: https://lhotse.readthedocs.io/en/latest/getting-started.html#installation
|
||||||
|
@ -1,121 +1,64 @@
|
|||||||
|
|
||||||
Run `./prepare.sh` to prepare the data.
|
## Data preparation
|
||||||
|
|
||||||
Run `./xxx_train.py` (to be added) to train a model.
|
If you want to use `./prepare.sh` to download everything for you,
|
||||||
|
you can just run
|
||||||
## Conformer-CTC
|
|
||||||
Results of the pre-trained model from
|
|
||||||
`<https://huggingface.co/GuoLiyong/snowfall_bpe_model/tree/main/exp-duration-200-feat_batchnorm-bpe-lrfactor5.0-conformer-512-8-noam>`
|
|
||||||
are given below
|
|
||||||
|
|
||||||
### HLG - no LM rescoring
|
|
||||||
|
|
||||||
(output beam size is 8)
|
|
||||||
|
|
||||||
#### 1-best decoding
|
|
||||||
|
|
||||||
```
|
```
|
||||||
[test-clean-no_rescore] %WER 3.15% [1656 / 52576, 127 ins, 377 del, 1152 sub ]
|
./prepare.sh
|
||||||
[test-other-no_rescore] %WER 7.03% [3682 / 52343, 220 ins, 1024 del, 2438 sub ]
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### n-best decoding
|
If you have pre-downloaded the LibriSpeech dataset, please
|
||||||
|
read `./prepare.sh` and modify it to point to the location
|
||||||
For n=100,
|
of your dataset so that it won't re-download it. After modification,
|
||||||
|
please run
|
||||||
|
|
||||||
```
|
```
|
||||||
[test-clean-no_rescore-100] %WER 3.15% [1656 / 52576, 127 ins, 377 del, 1152 sub ]
|
./prepare.sh
|
||||||
[test-other-no_rescore-100] %WER 7.14% [3737 / 52343, 275 ins, 1020 del, 2442 sub ]
|
|
||||||
```
|
```
|
||||||
|
|
||||||
For n=200,
|
The script `./prepare.sh` prepares features, lexicon, LMs, etc.
|
||||||
|
All generated files are saved in the folder `./data`.
|
||||||
|
|
||||||
|
**HINT:** `./prepare.sh` supports options `--stage` and `--stop-stage`.
|
||||||
|
|
||||||
|
## TDNN-LSTM CTC training
|
||||||
|
|
||||||
|
The folder `tdnn_lstm_ctc` contains scripts for CTC training
|
||||||
|
with TDNN-LSTM models.
|
||||||
|
|
||||||
|
Pre-configured parameters for training and decoding are set in the function
|
||||||
|
`get_params()` within `tdnn_lstm_ctc/train.py`
|
||||||
|
and `tdnn_lstm_ctc/decode.py`.
|
||||||
|
|
||||||
|
Parameters that can be passed from the command-line can be found by
|
||||||
|
|
||||||
```
|
```
|
||||||
[test-clean-no_rescore-200] %WER 3.16% [1660 / 52576, 125 ins, 378 del, 1157 sub ]
|
./tdnn_lstm_ctc/train.py --help
|
||||||
[test-other-no_rescore-200] %WER 7.04% [3684 / 52343, 228 ins, 1012 del, 2444 sub ]
|
./tdnn_lstm_ctc/decode.py --help
|
||||||
```
|
```
|
||||||
|
|
||||||
### HLG - with LM rescoring
|
If you have 4 GPUs on a machine and want to use GPU 0, 2, 3 for
|
||||||
|
mutli-GPU training, you can run
|
||||||
#### Whole lattice rescoring
|
|
||||||
|
|
||||||
```
|
```
|
||||||
[test-clean-lm_scale_0.8] %WER 2.77% [1456 / 52576, 150 ins, 210 del, 1096 sub ]
|
export CUDA_VISIBLE_DEVICES="0,2,3"
|
||||||
[test-other-lm_scale_0.8] %WER 6.23% [3262 / 52343, 246 ins, 635 del, 2381 sub ]
|
./tdnn_lstm_ctc/train.py \
|
||||||
|
--master-port 12345 \
|
||||||
|
--world-size 3
|
||||||
```
|
```
|
||||||
|
|
||||||
WERs of different LM scales are:
|
If you want to decode by averaging checkpoints `epoch-8.pt`,
|
||||||
|
`epoch-9.pt` and `epoch-10.pt`, you can run
|
||||||
|
|
||||||
```
|
```
|
||||||
For test-clean, WER of different settings are:
|
./tdnn_lstm_ctc/decode.py \
|
||||||
lm_scale_0.8 2.77 best for test-clean
|
--epoch 10 \
|
||||||
lm_scale_0.9 2.87
|
--avg 3
|
||||||
lm_scale_1.0 3.06
|
|
||||||
lm_scale_1.1 3.34
|
|
||||||
lm_scale_1.2 3.71
|
|
||||||
lm_scale_1.3 4.18
|
|
||||||
lm_scale_1.4 4.8
|
|
||||||
lm_scale_1.5 5.48
|
|
||||||
lm_scale_1.6 6.08
|
|
||||||
lm_scale_1.7 6.79
|
|
||||||
lm_scale_1.8 7.49
|
|
||||||
lm_scale_1.9 8.14
|
|
||||||
lm_scale_2.0 8.82
|
|
||||||
|
|
||||||
For test-other, WER of different settings are:
|
|
||||||
lm_scale_0.8 6.23 best for test-other
|
|
||||||
lm_scale_0.9 6.37
|
|
||||||
lm_scale_1.0 6.62
|
|
||||||
lm_scale_1.1 6.99
|
|
||||||
lm_scale_1.2 7.46
|
|
||||||
lm_scale_1.3 8.13
|
|
||||||
lm_scale_1.4 8.84
|
|
||||||
lm_scale_1.5 9.61
|
|
||||||
lm_scale_1.6 10.32
|
|
||||||
lm_scale_1.7 11.17
|
|
||||||
lm_scale_1.8 12.12
|
|
||||||
lm_scale_1.9 12.93
|
|
||||||
lm_scale_2.0 13.77
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### n-best LM rescoring
|
## Conformer CTC training
|
||||||
|
|
||||||
n = 100
|
The folder `conformer-ctc` contains scripts for CTC training
|
||||||
|
with conformer models. The steps of running the training and
|
||||||
```
|
decoding are similar to `tdnn_lstm_ctc`.
|
||||||
[test-clean-lm_scale_0.8] %WER 2.79% [1469 / 52576, 149 ins, 212 del, 1108 sub ]
|
|
||||||
[test-other-lm_scale_0.8] %WER 6.36% [3329 / 52343, 259 ins, 666 del, 2404 sub ]
|
|
||||||
```
|
|
||||||
|
|
||||||
WERs of different LM scales are:
|
|
||||||
|
|
||||||
```
|
|
||||||
For test-clean, WER of different settings are:
|
|
||||||
lm_scale_0.8 2.79 best for test-clean
|
|
||||||
lm_scale_0.9 2.89
|
|
||||||
lm_scale_1.0 3.03
|
|
||||||
lm_scale_1.1 3.28
|
|
||||||
lm_scale_1.2 3.52
|
|
||||||
lm_scale_1.3 3.78
|
|
||||||
lm_scale_1.4 4.04
|
|
||||||
lm_scale_1.5 4.24
|
|
||||||
lm_scale_1.6 4.45
|
|
||||||
lm_scale_1.7 4.58
|
|
||||||
lm_scale_1.8 4.7
|
|
||||||
lm_scale_1.9 4.8
|
|
||||||
lm_scale_2.0 4.92
|
|
||||||
For test-other, WER of different settings are:
|
|
||||||
lm_scale_0.8 6.36 best for test-other
|
|
||||||
lm_scale_0.9 6.45
|
|
||||||
lm_scale_1.0 6.64
|
|
||||||
lm_scale_1.1 6.92
|
|
||||||
lm_scale_1.2 7.25
|
|
||||||
lm_scale_1.3 7.59
|
|
||||||
lm_scale_1.4 7.88
|
|
||||||
lm_scale_1.5 8.13
|
|
||||||
lm_scale_1.6 8.36
|
|
||||||
lm_scale_1.7 8.54
|
|
||||||
lm_scale_1.8 8.71
|
|
||||||
lm_scale_1.9 8.88
|
|
||||||
lm_scale_2.0 9.02
|
|
||||||
```
|
|
||||||
|
@ -316,6 +316,7 @@ def decode_dataset(
|
|||||||
logging.info(
|
logging.info(
|
||||||
f"batch {batch_idx}/{tot_num_batches}, cuts processed until now is "
|
f"batch {batch_idx}/{tot_num_batches}, cuts processed until now is "
|
||||||
f"{num_cuts}"
|
f"{num_cuts}"
|
||||||
|
f"batch {batch_idx}, cuts processed until now is {num_cuts}"
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -398,7 +399,9 @@ def main():
|
|||||||
sos_id = graph_compiler.sos_id
|
sos_id = graph_compiler.sos_id
|
||||||
eos_id = graph_compiler.eos_id
|
eos_id = graph_compiler.eos_id
|
||||||
|
|
||||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt"))
|
HLG = k2.Fsa.from_dict(
|
||||||
|
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||||
|
)
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
@ -429,7 +432,7 @@ def main():
|
|||||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||||
else:
|
else:
|
||||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||||
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||||
G = k2.Fsa.from_dict(d).to(device)
|
G = k2.Fsa.from_dict(d).to(device)
|
||||||
|
|
||||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||||
|
@ -17,6 +17,7 @@ from conformer import Conformer
|
|||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_value_
|
from torch.nn.utils import clip_grad_value_
|
||||||
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformer import Noam
|
from transformer import Noam
|
||||||
|
|
||||||
@ -127,13 +128,13 @@ def get_params() -> AttributeDict:
|
|||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"exp_dir": Path("conformer_ctc/exp"),
|
"exp_dir": Path("conformer_ctc/exp_new"),
|
||||||
"lang_dir": Path("data/lang_bpe"),
|
"lang_dir": Path("data/lang_bpe"),
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"weight_decay": 1e-6,
|
"weight_decay": 1e-6,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"start_epoch": 0,
|
"start_epoch": 0,
|
||||||
"num_epochs": 50,
|
"num_epochs": 20,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
"best_train_epoch": -1,
|
"best_train_epoch": -1,
|
||||||
|
@ -4,12 +4,9 @@
|
|||||||
import math
|
import math
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||||
|
|
||||||
from icefall.utils import get_texts
|
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||||
|
@ -1,18 +1,18 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This script compiles HLG from
|
This script takes as input lang_dir and generates HLG from
|
||||||
|
|
||||||
- H, the ctc topology, built from tokens contained in lexicon.txt
|
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
|
||||||
- L, the lexicon, built from L_disambig.pt
|
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||||
|
|
||||||
Caution: We use a lexicon that contains disambiguation symbols
|
Caution: We use a lexicon that contains disambiguation symbols
|
||||||
|
|
||||||
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
||||||
|
|
||||||
The generated HLG is saved in data/lm/HLG.pt (phone based)
|
The generated HLG is saved in $lang_dir/HLG.pt
|
||||||
or data/lm/HLG_bpe.pt (BPE based)
|
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -22,11 +22,23 @@ import torch
|
|||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def compile_HLG(lang_dir: str) -> k2.Fsa:
|
def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
lang_dir:
|
lang_dir:
|
||||||
The language directory, e.g., data/lang_phone or data/lang_bpe.
|
The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
An FSA representing HLG.
|
An FSA representing HLG.
|
||||||
@ -104,17 +116,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
for d in ["data/lang_phone", "data/lang_bpe"]:
|
args = get_args()
|
||||||
d = Path(d)
|
lang_dir = Path(args.lang_dir)
|
||||||
logging.info(f"Processing {d}")
|
|
||||||
|
|
||||||
if (d / "HLG.pt").is_file():
|
if (lang_dir / "HLG.pt").is_file():
|
||||||
logging.info(f"{d}/HLG.pt already exists - skipping")
|
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
|
||||||
continue
|
return
|
||||||
|
|
||||||
HLG = compile_HLG(d)
|
logging.info(f"Processing {lang_dir}")
|
||||||
logging.info(f"Saving HLG.pt to {d}")
|
|
||||||
torch.save(HLG.as_dict(), f"{d}/HLG.pt")
|
HLG = compile_HLG(lang_dir)
|
||||||
|
logging.info(f"Saving HLG.pt to {lang_dir}")
|
||||||
|
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -3,12 +3,13 @@
|
|||||||
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This script takes as inputs the following two files:
|
|
||||||
|
|
||||||
- data/lang_bpe/bpe.model,
|
This script takes as input `lang_dir`, which should contain::
|
||||||
- data/lang_bpe/words.txt
|
|
||||||
|
|
||||||
and generates the following files in the directory data/lang_bpe:
|
- lang_dir/bpe.model,
|
||||||
|
- lang_dir/words.txt
|
||||||
|
|
||||||
|
and generates the following files in the directory `lang_dir`:
|
||||||
|
|
||||||
- lexicon.txt
|
- lexicon.txt
|
||||||
- lexicon_disambig.txt
|
- lexicon_disambig.txt
|
||||||
@ -17,6 +18,7 @@ and generates the following files in the directory data/lang_bpe:
|
|||||||
- tokens.txt
|
- tokens.txt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
@ -141,8 +143,22 @@ def generate_lexicon(
|
|||||||
return lexicon, token2id
|
return lexicon, token2id
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
It should contain the bpe.model and words.txt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
lang_dir = Path("data/lang_bpe")
|
args = get_args()
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
model_file = lang_dir / "bpe.model"
|
model_file = lang_dir / "bpe.model"
|
||||||
|
|
||||||
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||||
@ -189,15 +205,6 @@ def main():
|
|||||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||||
|
|
||||||
if False:
|
|
||||||
# Just for debugging, will remove it
|
|
||||||
L.labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
|
||||||
L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
|
||||||
L_disambig.labels_sym = L.labels_sym
|
|
||||||
L_disambig.aux_labels_sym = L.aux_labels_sym
|
|
||||||
L.draw(lang_dir / "L.svg", title="L")
|
|
||||||
L_disambig.draw(lang_dir / "L_disambig.svg", title="L_disambig")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -1,10 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
"""
|
|
||||||
This script takes as input "data/lang/bpe/train.txt"
|
|
||||||
and generates "data/lang/bpe/bep.model".
|
|
||||||
"""
|
|
||||||
|
|
||||||
# You can install sentencepiece via:
|
# You can install sentencepiece via:
|
||||||
#
|
#
|
||||||
# pip install sentencepiece
|
# pip install sentencepiece
|
||||||
@ -14,17 +9,41 @@ and generates "data/lang/bpe/bep.model".
|
|||||||
#
|
#
|
||||||
# Please install a version >=0.1.96
|
# Please install a version >=0.1.96
|
||||||
|
|
||||||
|
import argparse
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=str,
|
||||||
|
help="""Input and output directory.
|
||||||
|
It should contain the training corpus: train.txt.
|
||||||
|
The generated bpe.model is saved to this directory.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vocab-size",
|
||||||
|
type=int,
|
||||||
|
help="Vocabulary size for BPE training",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
vocab_size = args.vocab_size
|
||||||
|
lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
model_type = "unigram"
|
model_type = "unigram"
|
||||||
vocab_size = 5000
|
|
||||||
model_prefix = f"data/lang_bpe/{model_type}_{vocab_size}"
|
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
|
||||||
train_text = "data/lang_bpe/train.txt"
|
train_text = f"{lang_dir}/train.txt"
|
||||||
character_coverage = 1.0
|
character_coverage = 1.0
|
||||||
input_sentence_size = 100000000
|
input_sentence_size = 100000000
|
||||||
|
|
||||||
@ -49,10 +68,7 @@ def main():
|
|||||||
eos_id=-1,
|
eos_id=-1,
|
||||||
)
|
)
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor(model_file=str(model_file))
|
shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
|
||||||
vocab_size = sp.vocab_size()
|
|
||||||
|
|
||||||
shutil.copyfile(model_file, "data/lang_bpe/bpe.model")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -25,7 +25,7 @@ stop_stage=100
|
|||||||
# - librispeech-vocab.txt
|
# - librispeech-vocab.txt
|
||||||
# - librispeech-lexicon.txt
|
# - librispeech-lexicon.txt
|
||||||
#
|
#
|
||||||
# - $do_dir/musan
|
# - $dl_dir/musan
|
||||||
# This directory contains the following directories downloaded from
|
# This directory contains the following directories downloaded from
|
||||||
# http://www.openslr.org/17/
|
# http://www.openslr.org/17/
|
||||||
#
|
#
|
||||||
@ -36,8 +36,15 @@ dl_dir=$PWD/download
|
|||||||
|
|
||||||
. shared/parse_options.sh || exit 1
|
. shared/parse_options.sh || exit 1
|
||||||
|
|
||||||
|
# vocab size for sentence piece models.
|
||||||
|
# It will generate data/lang_bpe_xxx,
|
||||||
|
# data/lang_bpe_yyy if the array contains xxx, yyy
|
||||||
|
vocab_sizes=(
|
||||||
|
5000
|
||||||
|
)
|
||||||
|
|
||||||
# All generated files by this script are saved in "data"
|
# All files generated by this script are saved in "data".
|
||||||
|
# You can safely remove "data" and rerun this script to regenerate it.
|
||||||
mkdir -p data
|
mkdir -p data
|
||||||
|
|
||||||
log() {
|
log() {
|
||||||
@ -50,6 +57,7 @@ log "dl_dir: $dl_dir"
|
|||||||
|
|
||||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||||
log "stage -1: Download LM"
|
log "stage -1: Download LM"
|
||||||
|
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
|
||||||
./local/download_lm.py --out-dir=$dl_dir/lm
|
./local/download_lm.py --out-dir=$dl_dir/lm
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -118,28 +126,34 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||||
log "State 6: Prepare BPE based lang"
|
log "State 6: Prepare BPE based lang"
|
||||||
mkdir -p data/lang_bpe
|
|
||||||
# We reuse words.txt from phone based lexicon
|
|
||||||
# so that the two can share G.pt later.
|
|
||||||
cp data/lang_phone/words.txt data/lang_bpe/
|
|
||||||
|
|
||||||
if [ ! -f data/lang_bpe/train.txt ]; then
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
log "Generate data for BPE training"
|
lang_dir=data/lang_bpe_${vocab_size}
|
||||||
files=$(
|
mkdir -p $lang_dir
|
||||||
find "data/LibriSpeech/train-clean-100" -name "*.trans.txt"
|
# We reuse words.txt from phone based lexicon
|
||||||
find "data/LibriSpeech/train-clean-360" -name "*.trans.txt"
|
# so that the two can share G.pt later.
|
||||||
find "data/LibriSpeech/train-other-500" -name "*.trans.txt"
|
cp data/lang_phone/words.txt $lang_dir
|
||||||
)
|
|
||||||
for f in ${files[@]}; do
|
|
||||||
cat $f | cut -d " " -f 2-
|
|
||||||
done > data/lang_bpe/train.txt
|
|
||||||
fi
|
|
||||||
|
|
||||||
python3 ./local/train_bpe_model.py
|
if [ ! -f $lang_dir/train.txt ]; then
|
||||||
|
log "Generate data for BPE training"
|
||||||
|
files=$(
|
||||||
|
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
|
||||||
|
find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
|
||||||
|
find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
|
||||||
|
)
|
||||||
|
for f in ${files[@]}; do
|
||||||
|
cat $f | cut -d " " -f 2-
|
||||||
|
done > $lang_dir/train.txt
|
||||||
|
fi
|
||||||
|
|
||||||
if [ ! -f data/lang_bpe/L_disambig.pt ]; then
|
./local/train_bpe_model.py \
|
||||||
./local/prepare_lang_bpe.py
|
--lang-dir $lang_dir \
|
||||||
fi
|
--vocab-size $vocab_size
|
||||||
|
|
||||||
|
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||||
|
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||||
|
fi
|
||||||
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||||
@ -169,5 +183,12 @@ fi
|
|||||||
|
|
||||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||||
log "Stage 8: Compile HLG"
|
log "Stage 8: Compile HLG"
|
||||||
python3 ./local/compile_hlg.py
|
./local/compile_hlg.py --lang-dir data/lang_phone
|
||||||
|
|
||||||
|
for vocab_size in ${vocab_sizes[@]}; do
|
||||||
|
lang_dir=data/lang_bpe_${vocab_size}
|
||||||
|
./local/compile_hlg.py --lang-dir $lang_dir
|
||||||
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
cd data && ln -sfv lang_bpe_5000 lang_bpe
|
||||||
|
@ -1,22 +1,2 @@
|
|||||||
## (To be filled in)
|
|
||||||
|
|
||||||
It will contain:
|
Will add results later.
|
||||||
|
|
||||||
- How to run
|
|
||||||
- WERs
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd $PWD/..
|
|
||||||
|
|
||||||
./prepare.sh
|
|
||||||
|
|
||||||
./tdnn_lstm_ctc/train.py
|
|
||||||
```
|
|
||||||
|
|
||||||
If you have 4 GPUs and want to use GPU 1 and GPU 3 for DDP training,
|
|
||||||
you can do the following:
|
|
||||||
|
|
||||||
```
|
|
||||||
export CUDA_VISIBLE_DEVICES="1,3"
|
|
||||||
./tdnn_lstm_ctc/train.py --world-size=2
|
|
||||||
```
|
|
||||||
|
@ -236,7 +236,6 @@ def decode_dataset(
|
|||||||
results = []
|
results = []
|
||||||
|
|
||||||
num_cuts = 0
|
num_cuts = 0
|
||||||
tot_num_cuts = len(dl.dataset.cuts)
|
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
@ -264,9 +263,7 @@ def decode_dataset(
|
|||||||
|
|
||||||
if batch_idx % 100 == 0:
|
if batch_idx % 100 == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"batch {batch_idx}, cuts processed until now is "
|
f"batch {batch_idx}, cuts processed until now is {num_cuts}"
|
||||||
f"{num_cuts}/{tot_num_cuts} "
|
|
||||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@ -328,7 +325,9 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt"))
|
HLG = k2.Fsa.from_dict(
|
||||||
|
torch.load("data/lang_phone/HLG.pt", map_location="cpu")
|
||||||
|
)
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
@ -355,7 +354,7 @@ def main():
|
|||||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||||
else:
|
else:
|
||||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||||
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||||
G = k2.Fsa.from_dict(d).to(device)
|
G = k2.Fsa.from_dict(d).to(device)
|
||||||
|
|
||||||
if params.method == "whole-lattice-rescoring":
|
if params.method == "whole-lattice-rescoring":
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
kaldilm
|
kaldilm
|
||||||
kaldialign
|
kaldialign
|
||||||
sentencepiece>=0.1.96
|
sentencepiece>=0.1.96
|
||||||
|
tensorboard
|
||||||
|
Loading…
x
Reference in New Issue
Block a user