Fix conflicts

This commit is contained in:
pkufool 2021-08-19 17:38:19 +08:00
commit d2ae1ba060
12 changed files with 236 additions and 196 deletions

View File

@ -1 +1,60 @@
Working in progress.
# Table of Contents
- [Installation](#installation)
* [Install k2](#install-k2)
* [Install lhotse](#install-lhotse)
* [Install icefall](#install-icefall)
- [Run recipes](#run-recipes)
## Installation
`icefall` depends on [k2][k2] for FSA operations and [lhotse][lhotse] for
data preparations. To use `icefall`, you have to install its dependencies first.
The following subsections describe how to setup the environment.
CAUTION: There are various ways to setup the environment. What we describe
here is just one alternative.
### Install k2
Please refer to [k2's installation documentation][k2-install] to install k2.
If you have any issues about installing k2, please open an issue at
<https://github.com/k2-fsa/k2/issues>.
### Install lhotse
Please refer to [lhotse's installation documentation][lhotse-install] to install
lhotse.
### Install icefall
`icefall` is a set of Python scripts. What you need to do is just to set
the environment variable `PYTHONPATH`:
```bash
cd $HOME/open-source
git clone https://github.com/k2-fsa/icefall
cd icefall
pip install -r requirements.txt
export PYTHONPATH=$HOME/open-source/icefall:$PYTHONPATHON
```
To verify `icefall` was installed successfully, you can run:
```bash
python3 -c "import icefall; print(icefall.__file__)"
```
It should print the path to `icefall`.
## Run recipes
At present, only LibriSpeech recipe is provided. Please
follow [egs/librispeech/ASR/README.md][LibriSpeech] to run it.
[LibriSpeech]: egs/librispeech/ASR/README.md
[k2-install]: https://k2.readthedocs.io/en/latest/installation/index.html#
[k2]: https://github.com/k2-fsa/k2
[lhotse]: https://github.com/lhotse-speech/lhotse
[lhotse-install]: https://lhotse.readthedocs.io/en/latest/getting-started.html#installation

View File

@ -1,121 +1,64 @@
Run `./prepare.sh` to prepare the data. ## Data preparation
Run `./xxx_train.py` (to be added) to train a model. If you want to use `./prepare.sh` to download everything for you,
you can just run
## Conformer-CTC
Results of the pre-trained model from
`<https://huggingface.co/GuoLiyong/snowfall_bpe_model/tree/main/exp-duration-200-feat_batchnorm-bpe-lrfactor5.0-conformer-512-8-noam>`
are given below
### HLG - no LM rescoring
(output beam size is 8)
#### 1-best decoding
``` ```
[test-clean-no_rescore] %WER 3.15% [1656 / 52576, 127 ins, 377 del, 1152 sub ] ./prepare.sh
[test-other-no_rescore] %WER 7.03% [3682 / 52343, 220 ins, 1024 del, 2438 sub ]
``` ```
#### n-best decoding If you have pre-downloaded the LibriSpeech dataset, please
read `./prepare.sh` and modify it to point to the location
For n=100, of your dataset so that it won't re-download it. After modification,
please run
``` ```
[test-clean-no_rescore-100] %WER 3.15% [1656 / 52576, 127 ins, 377 del, 1152 sub ] ./prepare.sh
[test-other-no_rescore-100] %WER 7.14% [3737 / 52343, 275 ins, 1020 del, 2442 sub ]
``` ```
For n=200, The script `./prepare.sh` prepares features, lexicon, LMs, etc.
All generated files are saved in the folder `./data`.
**HINT:** `./prepare.sh` supports options `--stage` and `--stop-stage`.
## TDNN-LSTM CTC training
The folder `tdnn_lstm_ctc` contains scripts for CTC training
with TDNN-LSTM models.
Pre-configured parameters for training and decoding are set in the function
`get_params()` within `tdnn_lstm_ctc/train.py`
and `tdnn_lstm_ctc/decode.py`.
Parameters that can be passed from the command-line can be found by
``` ```
[test-clean-no_rescore-200] %WER 3.16% [1660 / 52576, 125 ins, 378 del, 1157 sub ] ./tdnn_lstm_ctc/train.py --help
[test-other-no_rescore-200] %WER 7.04% [3684 / 52343, 228 ins, 1012 del, 2444 sub ] ./tdnn_lstm_ctc/decode.py --help
``` ```
### HLG - with LM rescoring If you have 4 GPUs on a machine and want to use GPU 0, 2, 3 for
mutli-GPU training, you can run
#### Whole lattice rescoring
``` ```
[test-clean-lm_scale_0.8] %WER 2.77% [1456 / 52576, 150 ins, 210 del, 1096 sub ] export CUDA_VISIBLE_DEVICES="0,2,3"
[test-other-lm_scale_0.8] %WER 6.23% [3262 / 52343, 246 ins, 635 del, 2381 sub ] ./tdnn_lstm_ctc/train.py \
--master-port 12345 \
--world-size 3
``` ```
WERs of different LM scales are: If you want to decode by averaging checkpoints `epoch-8.pt`,
`epoch-9.pt` and `epoch-10.pt`, you can run
``` ```
For test-clean, WER of different settings are: ./tdnn_lstm_ctc/decode.py \
lm_scale_0.8 2.77 best for test-clean --epoch 10 \
lm_scale_0.9 2.87 --avg 3
lm_scale_1.0 3.06
lm_scale_1.1 3.34
lm_scale_1.2 3.71
lm_scale_1.3 4.18
lm_scale_1.4 4.8
lm_scale_1.5 5.48
lm_scale_1.6 6.08
lm_scale_1.7 6.79
lm_scale_1.8 7.49
lm_scale_1.9 8.14
lm_scale_2.0 8.82
For test-other, WER of different settings are:
lm_scale_0.8 6.23 best for test-other
lm_scale_0.9 6.37
lm_scale_1.0 6.62
lm_scale_1.1 6.99
lm_scale_1.2 7.46
lm_scale_1.3 8.13
lm_scale_1.4 8.84
lm_scale_1.5 9.61
lm_scale_1.6 10.32
lm_scale_1.7 11.17
lm_scale_1.8 12.12
lm_scale_1.9 12.93
lm_scale_2.0 13.77
``` ```
#### n-best LM rescoring ## Conformer CTC training
n = 100 The folder `conformer-ctc` contains scripts for CTC training
with conformer models. The steps of running the training and
``` decoding are similar to `tdnn_lstm_ctc`.
[test-clean-lm_scale_0.8] %WER 2.79% [1469 / 52576, 149 ins, 212 del, 1108 sub ]
[test-other-lm_scale_0.8] %WER 6.36% [3329 / 52343, 259 ins, 666 del, 2404 sub ]
```
WERs of different LM scales are:
```
For test-clean, WER of different settings are:
lm_scale_0.8 2.79 best for test-clean
lm_scale_0.9 2.89
lm_scale_1.0 3.03
lm_scale_1.1 3.28
lm_scale_1.2 3.52
lm_scale_1.3 3.78
lm_scale_1.4 4.04
lm_scale_1.5 4.24
lm_scale_1.6 4.45
lm_scale_1.7 4.58
lm_scale_1.8 4.7
lm_scale_1.9 4.8
lm_scale_2.0 4.92
For test-other, WER of different settings are:
lm_scale_0.8 6.36 best for test-other
lm_scale_0.9 6.45
lm_scale_1.0 6.64
lm_scale_1.1 6.92
lm_scale_1.2 7.25
lm_scale_1.3 7.59
lm_scale_1.4 7.88
lm_scale_1.5 8.13
lm_scale_1.6 8.36
lm_scale_1.7 8.54
lm_scale_1.8 8.71
lm_scale_1.9 8.88
lm_scale_2.0 9.02
```

View File

@ -316,6 +316,7 @@ def decode_dataset(
logging.info( logging.info(
f"batch {batch_idx}/{tot_num_batches}, cuts processed until now is " f"batch {batch_idx}/{tot_num_batches}, cuts processed until now is "
f"{num_cuts}" f"{num_cuts}"
f"batch {batch_idx}, cuts processed until now is {num_cuts}"
) )
return results return results
@ -398,7 +399,9 @@ def main():
sos_id = graph_compiler.sos_id sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id eos_id = graph_compiler.eos_id
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt")) HLG = k2.Fsa.from_dict(
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -429,7 +432,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt") d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device) G = k2.Fsa.from_dict(d).to(device)
if params.method in ["whole-lattice-rescoring", "attention-decoder"]: if params.method in ["whole-lattice-rescoring", "attention-decoder"]:

View File

@ -17,6 +17,7 @@ from conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_ from torch.nn.utils import clip_grad_value_
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam
@ -127,13 +128,13 @@ def get_params() -> AttributeDict:
""" """
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_ctc/exp"), "exp_dir": Path("conformer_ctc/exp_new"),
"lang_dir": Path("data/lang_bpe"), "lang_dir": Path("data/lang_bpe"),
"feature_dim": 80, "feature_dim": 80,
"weight_decay": 1e-6, "weight_decay": 1e-6,
"subsampling_factor": 4, "subsampling_factor": 4,
"start_epoch": 0, "start_epoch": 0,
"num_epochs": 50, "num_epochs": 20,
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
"best_train_epoch": -1, "best_train_epoch": -1,

View File

@ -4,12 +4,9 @@
import math import math
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from subsampling import Conv2dSubsampling, VggSubsampling from subsampling import Conv2dSubsampling, VggSubsampling
from icefall.utils import get_texts
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed. # Note: TorchScript requires Dict/List/etc. to be fully typed.

View File

@ -1,18 +1,18 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
This script compiles HLG from This script takes as input lang_dir and generates HLG from
- H, the ctc topology, built from tokens contained in lexicon.txt - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
- L, the lexicon, built from L_disambig.pt - L, the lexicon, built from lang_dir/L_disambig.pt
Caution: We use a lexicon that contains disambiguation symbols Caution: We use a lexicon that contains disambiguation symbols
- G, the LM, built from data/lm/G_3_gram.fst.txt - G, the LM, built from data/lm/G_3_gram.fst.txt
The generated HLG is saved in data/lm/HLG.pt (phone based) The generated HLG is saved in $lang_dir/HLG.pt
or data/lm/HLG_bpe.pt (BPE based)
""" """
import argparse
import logging import logging
from pathlib import Path from pathlib import Path
@ -22,11 +22,23 @@ import torch
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
""",
)
return parser.parse_args()
def compile_HLG(lang_dir: str) -> k2.Fsa: def compile_HLG(lang_dir: str) -> k2.Fsa:
""" """
Args: Args:
lang_dir: lang_dir:
The language directory, e.g., data/lang_phone or data/lang_bpe. The language directory, e.g., data/lang_phone or data/lang_bpe_5000.
Return: Return:
An FSA representing HLG. An FSA representing HLG.
@ -104,17 +116,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
def main(): def main():
for d in ["data/lang_phone", "data/lang_bpe"]: args = get_args()
d = Path(d) lang_dir = Path(args.lang_dir)
logging.info(f"Processing {d}")
if (d / "HLG.pt").is_file(): if (lang_dir / "HLG.pt").is_file():
logging.info(f"{d}/HLG.pt already exists - skipping") logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
continue return
HLG = compile_HLG(d) logging.info(f"Processing {lang_dir}")
logging.info(f"Saving HLG.pt to {d}")
torch.save(HLG.as_dict(), f"{d}/HLG.pt") HLG = compile_HLG(lang_dir)
logging.info(f"Saving HLG.pt to {lang_dir}")
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -3,12 +3,13 @@
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
""" """
This script takes as inputs the following two files:
- data/lang_bpe/bpe.model, This script takes as input `lang_dir`, which should contain::
- data/lang_bpe/words.txt
and generates the following files in the directory data/lang_bpe: - lang_dir/bpe.model,
- lang_dir/words.txt
and generates the following files in the directory `lang_dir`:
- lexicon.txt - lexicon.txt
- lexicon_disambig.txt - lexicon_disambig.txt
@ -17,6 +18,7 @@ and generates the following files in the directory data/lang_bpe:
- tokens.txt - tokens.txt
""" """
import argparse
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
@ -141,8 +143,22 @@ def generate_lexicon(
return lexicon, token2id return lexicon, token2id
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the bpe.model and words.txt
""",
)
return parser.parse_args()
def main(): def main():
lang_dir = Path("data/lang_bpe") args = get_args()
lang_dir = Path(args.lang_dir)
model_file = lang_dir / "bpe.model" model_file = lang_dir / "bpe.model"
word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt")
@ -189,15 +205,6 @@ def main():
torch.save(L.as_dict(), lang_dir / "L.pt") torch.save(L.as_dict(), lang_dir / "L.pt")
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
if False:
# Just for debugging, will remove it
L.labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
L_disambig.labels_sym = L.labels_sym
L_disambig.aux_labels_sym = L.aux_labels_sym
L.draw(lang_dir / "L.svg", title="L")
L_disambig.draw(lang_dir / "L_disambig.svg", title="L_disambig")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,10 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""
This script takes as input "data/lang/bpe/train.txt"
and generates "data/lang/bpe/bep.model".
"""
# You can install sentencepiece via: # You can install sentencepiece via:
# #
# pip install sentencepiece # pip install sentencepiece
@ -14,17 +9,41 @@ and generates "data/lang/bpe/bep.model".
# #
# Please install a version >=0.1.96 # Please install a version >=0.1.96
import argparse
import shutil import shutil
from pathlib import Path from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--lang-dir",
type=str,
help="""Input and output directory.
It should contain the training corpus: train.txt.
The generated bpe.model is saved to this directory.
""",
)
parser.add_argument(
"--vocab-size",
type=int,
help="Vocabulary size for BPE training",
)
return parser.parse_args()
def main(): def main():
args = get_args()
vocab_size = args.vocab_size
lang_dir = Path(args.lang_dir)
model_type = "unigram" model_type = "unigram"
vocab_size = 5000
model_prefix = f"data/lang_bpe/{model_type}_{vocab_size}" model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
train_text = "data/lang_bpe/train.txt" train_text = f"{lang_dir}/train.txt"
character_coverage = 1.0 character_coverage = 1.0
input_sentence_size = 100000000 input_sentence_size = 100000000
@ -49,10 +68,7 @@ def main():
eos_id=-1, eos_id=-1,
) )
sp = spm.SentencePieceProcessor(model_file=str(model_file)) shutil.copyfile(model_file, f"{lang_dir}/bpe.model")
vocab_size = sp.vocab_size()
shutil.copyfile(model_file, "data/lang_bpe/bpe.model")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -25,7 +25,7 @@ stop_stage=100
# - librispeech-vocab.txt # - librispeech-vocab.txt
# - librispeech-lexicon.txt # - librispeech-lexicon.txt
# #
# - $do_dir/musan # - $dl_dir/musan
# This directory contains the following directories downloaded from # This directory contains the following directories downloaded from
# http://www.openslr.org/17/ # http://www.openslr.org/17/
# #
@ -36,8 +36,15 @@ dl_dir=$PWD/download
. shared/parse_options.sh || exit 1 . shared/parse_options.sh || exit 1
# vocab size for sentence piece models.
# It will generate data/lang_bpe_xxx,
# data/lang_bpe_yyy if the array contains xxx, yyy
vocab_sizes=(
5000
)
# All generated files by this script are saved in "data" # All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data mkdir -p data
log() { log() {
@ -50,6 +57,7 @@ log "dl_dir: $dl_dir"
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "stage -1: Download LM" log "stage -1: Download LM"
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
./local/download_lm.py --out-dir=$dl_dir/lm ./local/download_lm.py --out-dir=$dl_dir/lm
fi fi
@ -118,28 +126,34 @@ fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "State 6: Prepare BPE based lang" log "State 6: Prepare BPE based lang"
mkdir -p data/lang_bpe
# We reuse words.txt from phone based lexicon
# so that the two can share G.pt later.
cp data/lang_phone/words.txt data/lang_bpe/
if [ ! -f data/lang_bpe/train.txt ]; then for vocab_size in ${vocab_sizes[@]}; do
log "Generate data for BPE training" lang_dir=data/lang_bpe_${vocab_size}
files=$( mkdir -p $lang_dir
find "data/LibriSpeech/train-clean-100" -name "*.trans.txt" # We reuse words.txt from phone based lexicon
find "data/LibriSpeech/train-clean-360" -name "*.trans.txt" # so that the two can share G.pt later.
find "data/LibriSpeech/train-other-500" -name "*.trans.txt" cp data/lang_phone/words.txt $lang_dir
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > data/lang_bpe/train.txt
fi
python3 ./local/train_bpe_model.py if [ ! -f $lang_dir/train.txt ]; then
log "Generate data for BPE training"
files=$(
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/train-clean-360" -name "*.trans.txt"
find "$dl_dir/LibriSpeech/train-other-500" -name "*.trans.txt"
)
for f in ${files[@]}; do
cat $f | cut -d " " -f 2-
done > $lang_dir/train.txt
fi
if [ ! -f data/lang_bpe/L_disambig.pt ]; then ./local/train_bpe_model.py \
./local/prepare_lang_bpe.py --lang-dir $lang_dir \
fi --vocab-size $vocab_size
if [ ! -f $lang_dir/L_disambig.pt ]; then
./local/prepare_lang_bpe.py --lang-dir $lang_dir
fi
done
fi fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
@ -169,5 +183,12 @@ fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Compile HLG" log "Stage 8: Compile HLG"
python3 ./local/compile_hlg.py ./local/compile_hlg.py --lang-dir data/lang_phone
for vocab_size in ${vocab_sizes[@]}; do
lang_dir=data/lang_bpe_${vocab_size}
./local/compile_hlg.py --lang-dir $lang_dir
done
fi fi
cd data && ln -sfv lang_bpe_5000 lang_bpe

View File

@ -1,22 +1,2 @@
## (To be filled in)
It will contain: Will add results later.
- How to run
- WERs
```bash
cd $PWD/..
./prepare.sh
./tdnn_lstm_ctc/train.py
```
If you have 4 GPUs and want to use GPU 1 and GPU 3 for DDP training,
you can do the following:
```
export CUDA_VISIBLE_DEVICES="1,3"
./tdnn_lstm_ctc/train.py --world-size=2
```

View File

@ -236,7 +236,6 @@ def decode_dataset(
results = [] results = []
num_cuts = 0 num_cuts = 0
tot_num_cuts = len(dl.dataset.cuts)
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
@ -264,9 +263,7 @@ def decode_dataset(
if batch_idx % 100 == 0: if batch_idx % 100 == 0:
logging.info( logging.info(
f"batch {batch_idx}, cuts processed until now is " f"batch {batch_idx}, cuts processed until now is {num_cuts}"
f"{num_cuts}/{tot_num_cuts} "
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
) )
return results return results
@ -328,7 +325,9 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt")) HLG = k2.Fsa.from_dict(
torch.load("data/lang_phone/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False
@ -355,7 +354,7 @@ def main():
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
else: else:
logging.info("Loading pre-compiled G_4_gram.pt") logging.info("Loading pre-compiled G_4_gram.pt")
d = torch.load(params.lm_dir / "G_4_gram.pt") d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
G = k2.Fsa.from_dict(d).to(device) G = k2.Fsa.from_dict(d).to(device)
if params.method == "whole-lattice-rescoring": if params.method == "whole-lattice-rescoring":

View File

@ -1,3 +1,4 @@
kaldilm kaldilm
kaldialign kaldialign
sentencepiece>=0.1.96 sentencepiece>=0.1.96
tensorboard