mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
update the timit recipe and check style
This commit is contained in:
parent
e87bfacb91
commit
1193228a14
@ -65,7 +65,7 @@ The command to run the training part is:
|
||||
$ export CUDA_VISIBLE_DEVICES="0"
|
||||
$ ./tdnn_ligru_ctc/train.py
|
||||
|
||||
By default, it will run ``30`` epochs. Training logs and checkpoints are saved
|
||||
By default, it will run ``25`` epochs. Training logs and checkpoints are saved
|
||||
in ``tdnn_ligru_ctc/exp``.
|
||||
|
||||
In ``tdnn_ligru_ctc/exp``, you will find the following files:
|
||||
@ -221,7 +221,7 @@ After downloading, you will have the following files:
|
||||
| `-- lm
|
||||
| `-- G_4_gram.pt
|
||||
|-- exp
|
||||
| `-- pretrained.pt
|
||||
| `-- pretrained_average_9_25.pt
|
||||
`-- test_wavs
|
||||
|-- FDHC0_SI1559.WAV
|
||||
|-- FELC0_SI756.WAV
|
||||
@ -319,7 +319,7 @@ To decode with ``1best`` method, we can use:
|
||||
|
||||
./tdnn_ligru_ctc/pretrained.py
|
||||
--method 1best
|
||||
--checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_16_25.pt
|
||||
--checkpoint ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/exp/pretrained_average_9_25.pt
|
||||
--words-file ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/words.txt
|
||||
--HLG ./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/data/lang_phone/HLG.pt
|
||||
./tmp-ligru/icefall_asr_timit_tdnn_ligru_ctc/test_waves/FDHC0_SI1559.WAV
|
||||
@ -357,7 +357,7 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
|
||||
|
||||
./tdnn_ligru_ctc/pretrained.py \
|
||||
--method whole-lattice-rescoring \
|
||||
--checkpoint ./tmp-ligru/icefall_asr_timit_tdnn-ligru_ctc/exp/pretraind.pt \
|
||||
--checkpoint ./tmp-ligru/icefall_asr_timit_tdnn-ligru_ctc/exp/pretrained_average_9_25.pt \
|
||||
--words-file ./tmp-ligru/icefall_asr_timit_tdnn-ligru_ctc/data/lang_phone/words.txt \
|
||||
--HLG ./tmp-ligru/icefall_asr_timit_tdnn-ligru_ctc/data/lang_phone/HLG.pt \
|
||||
--G ./tmp-ligru/icefall_asr_timit_tdnn-ligru_ctc/data/lm/G_4_gram.pt \
|
||||
|
@ -65,7 +65,7 @@ The command to run the training part is:
|
||||
$ export CUDA_VISIBLE_DEVICES="0"
|
||||
$ ./tdnn_lstm_ctc/train.py
|
||||
|
||||
By default, it will run ``30`` epochs. Training logs and checkpoints are saved
|
||||
By default, it will run ``25`` epochs. Training logs and checkpoints are saved
|
||||
in ``tdnn_lstm_ctc/exp``.
|
||||
|
||||
In ``tdnn_lstm_ctc/exp``, you will find the following files:
|
||||
@ -219,7 +219,7 @@ After downloading, you will have the following files:
|
||||
| `-- lm
|
||||
| `-- G_4_gram.pt
|
||||
|-- exp
|
||||
| `-- pretrained.pt
|
||||
| `-- pretrained_average_16_25.pt
|
||||
`-- test_wavs
|
||||
|-- FDHC0_SI1559.WAV
|
||||
|-- FELC0_SI756.WAV
|
||||
@ -264,9 +264,9 @@ The information of the test sound files is listed below:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
$ ffprobe -show_format tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
|
||||
$ ffprobe -show_format tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
|
||||
|
||||
Input #0, nistsphere, from 'tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV':
|
||||
Input #0, nistsphere, from 'tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV':
|
||||
Metadata:
|
||||
database_id : TIMIT
|
||||
database_version: 1.0
|
||||
@ -276,9 +276,9 @@ The information of the test sound files is listed below:
|
||||
Duration: 00:00:03.40, bitrate: 258 kb/s
|
||||
Stream #0:0: Audio: pcm_s16le, 16000 Hz, 1 channels, s16, 256 kb/s
|
||||
|
||||
$ ffprobe -show_format tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
|
||||
$ ffprobe -show_format tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
|
||||
|
||||
Input #0, nistsphere, from 'tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV':
|
||||
Input #0, nistsphere, from 'tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV':
|
||||
Metadata:
|
||||
database_id : TIMIT
|
||||
database_version: 1.0
|
||||
@ -288,9 +288,9 @@ The information of the test sound files is listed below:
|
||||
Duration: 00:00:04.19, bitrate: 257 kb/s
|
||||
Stream #0:0: Audio: pcm_s16le, 16000 Hz, 1 channels, s16, 256 kb/s
|
||||
|
||||
$ ffprobe -show_format tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
|
||||
$ ffprobe -show_format tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
|
||||
|
||||
Input #0, nistsphere, from 'tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV':
|
||||
Input #0, nistsphere, from 'tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV':
|
||||
Metadata:
|
||||
database_id : TIMIT
|
||||
database_version: 1.0
|
||||
@ -317,12 +317,12 @@ To decode with ``1best`` method, we can use:
|
||||
|
||||
./tdnn_lstm_ctc/pretrained.py
|
||||
--method 1best
|
||||
--checkpoint ./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt
|
||||
--words-file ./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt
|
||||
--HLG ./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
|
||||
./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
|
||||
./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
|
||||
./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
|
||||
--checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt
|
||||
--words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt
|
||||
--HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
|
||||
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
|
||||
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
|
||||
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
|
||||
|
||||
The output is:
|
||||
|
||||
@ -355,14 +355,14 @@ To decode with ``whole-lattice-rescoring`` methond, you can use
|
||||
|
||||
./tdnn_lstm_ctc/pretrained.py \
|
||||
--method whole-lattice-rescoring \
|
||||
--checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretraind.pt \
|
||||
--checkpoint ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/exp/pretrained_average_16_25.pt \
|
||||
--words-file ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/words.txt \
|
||||
--HLG ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt \
|
||||
--G ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt \
|
||||
--ngram-lm-scale 0.8 \
|
||||
./tmp-lstm/icefall_asr_timit_tdnn-lstm_ctc/test_wavs/1089-134686-0001.flac \
|
||||
./tmp-lstm/icefall_asr_timit_tdnn-lstm_ctc/test_wavs/1221-135766-0001.flac \
|
||||
./tmp-lstm/icefall_asr_timit_tdnn-lstm_ctc/test_wavs/1221-135766-0002.flac
|
||||
--ngram-lm-scale 0.08 \
|
||||
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV
|
||||
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV
|
||||
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV
|
||||
|
||||
The decoding output is:
|
||||
|
||||
@ -370,20 +370,20 @@ The decoding output is:
|
||||
|
||||
2021-11-08 20:05:22,739 INFO [pretrained.py:169] device: cuda:0
|
||||
2021-11-08 20:05:22,739 INFO [pretrained.py:171] Creating model
|
||||
2021-11-08 20:05:26,959 INFO [pretrained.py:183] Loading HLG from ./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
|
||||
2021-11-08 20:05:26,971 INFO [pretrained.py:191] Loading G from ./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt
|
||||
2021-11-08 20:05:26,959 INFO [pretrained.py:183] Loading HLG from ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lang_phone/HLG.pt
|
||||
2021-11-08 20:05:26,971 INFO [pretrained.py:191] Loading G from ./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/data/lm/G_4_gram.pt
|
||||
2021-11-08 20:05:26,977 INFO [pretrained.py:200] Constructing Fbank computer
|
||||
2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
|
||||
2021-11-08 20:05:26,978 INFO [pretrained.py:210] Reading sound files: ['./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV', './tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV']
|
||||
2021-11-08 20:05:26,981 INFO [pretrained.py:216] Decoding started
|
||||
2021-11-08 20:05:27,519 INFO [pretrained.py:251] Use HLG decoding + LM rescoring
|
||||
2021-11-08 20:05:27,878 INFO [pretrained.py:267]
|
||||
./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
|
||||
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FDHC0_SI1559.WAV:
|
||||
sil dh ih sh uw l iy v iy z ih sil p r aa sil k s ah m ey dx ih sil w uh dx iy w ih s f iy l ih ng w ih th ih n ih m s eh l f sil jh
|
||||
|
||||
./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV:
|
||||
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FELC0_SI756.WAV:
|
||||
sil dh ih sil t ih r iy ih s sil s er r eh m ih sil n ah l ih ng sil k l ey sil r eh sil d w ay sil d aa r sil b ow f sil jh
|
||||
|
||||
./tmp-lstm-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV:
|
||||
./tmp-lstm/icefall_asr_timit_tdnn_lstm_ctc/test_waves/FMGD0_SI1564.WAV:
|
||||
sil hh ah z sil b ih n iy w ah z sil b ae n ih sil b ay s sil n ey sil k ih l f eh n s ih z eh n dh eh r w er sil g r ey z ih n sil k ae dx l sil
|
||||
|
||||
|
||||
|
0
egs/timit/ASR/local/__init__.py
Normal file
0
egs/timit/ASR/local/__init__.py
Normal file
155
egs/timit/ASR/local/compile_hlg.py
Normal file
155
egs/timit/ASR/local/compile_hlg.py
Normal file
@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input lang_dir and generates HLG from
|
||||
|
||||
- H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt
|
||||
- L, the lexicon, built from lang_dir/L_disambig.pt
|
||||
|
||||
Caution: We use a lexicon that contains disambiguation symbols
|
||||
|
||||
- G, the LM, built from data/lm/G_3_gram.fst.txt
|
||||
|
||||
The generated HLG is saved in $lang_dir/HLG.pt
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import Lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone.
|
||||
|
||||
Return:
|
||||
An FSA representing HLG.
|
||||
"""
|
||||
lexicon = Lexicon(lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
logging.info(f"Building ctc_topo. max_token_id: {max_token_id}")
|
||||
H = k2.ctc_topo(max_token_id)
|
||||
L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt"))
|
||||
|
||||
if Path("data/lm/G.pt").is_file():
|
||||
logging.info("Loading pre-compiled G")
|
||||
d = torch.load("data/lm/G.pt")
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
logging.info("Loading G_3_gram.fst.txt")
|
||||
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
torch.save(G.as_dict(), "data/lm/G.pt")
|
||||
|
||||
first_token_disambig_id = lexicon.token_table["#0"]
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
L = k2.arc_sort(L)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
logging.info("Intersecting L and G")
|
||||
LG = k2.compose(L, G)
|
||||
logging.info(f"LG shape: {LG.shape}")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
LG = k2.connect(LG)
|
||||
logging.info(f"LG shape after k2.connect: {LG.shape}")
|
||||
|
||||
logging.info(type(LG.aux_labels))
|
||||
logging.info("Determinizing LG")
|
||||
|
||||
LG = k2.determinize(LG)
|
||||
logging.info(type(LG.aux_labels))
|
||||
|
||||
logging.info("Connecting LG after k2.determinize")
|
||||
LG = k2.connect(LG)
|
||||
|
||||
logging.info("Removing disambiguation symbols on LG")
|
||||
|
||||
LG.labels[LG.labels >= first_token_disambig_id] = 0
|
||||
|
||||
LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0
|
||||
|
||||
LG = k2.remove_epsilon(LG)
|
||||
logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}")
|
||||
|
||||
LG = k2.connect(LG)
|
||||
LG.aux_labels = LG.aux_labels.remove_values_eq(0)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
LG = k2.arc_sort(LG)
|
||||
|
||||
logging.info("Composing H and LG")
|
||||
# CAUTION: The name of the inner_labels is fixed
|
||||
# to `tokens`. If you want to change it, please
|
||||
# also change other places in icefall that are using
|
||||
# it.
|
||||
HLG = k2.compose(H, LG, inner_labels="tokens")
|
||||
|
||||
logging.info("Connecting LG")
|
||||
HLG = k2.connect(HLG)
|
||||
|
||||
logging.info("Arc sorting LG")
|
||||
HLG = k2.arc_sort(HLG)
|
||||
logging.info(f"HLG.shape: {HLG.shape}")
|
||||
|
||||
return HLG
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
if (lang_dir / "HLG.pt").is_file():
|
||||
logging.info(f"{lang_dir}/HLG.pt already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info(f"Processing {lang_dir}")
|
||||
|
||||
HLG = compile_HLG(lang_dir)
|
||||
logging.info(f"Saving HLG.pt to {lang_dir}")
|
||||
torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
97
egs/timit/ASR/local/compute_fbank_musan.py
Normal file
97
egs/timit/ASR/local/compute_fbank_musan.py
Normal file
@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the musan dataset.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_musan():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
num_mel_bins = 80
|
||||
|
||||
dataset_parts = (
|
||||
"music",
|
||||
"speech",
|
||||
"noise",
|
||||
)
|
||||
manifests = read_manifests_if_cached(
|
||||
dataset_parts=dataset_parts, output_dir=src_dir
|
||||
)
|
||||
assert manifests is not None
|
||||
|
||||
musan_cuts_path = output_dir / "cuts_musan.json.gz"
|
||||
|
||||
if musan_cuts_path.is_file():
|
||||
logging.info(f"{musan_cuts_path} already exists - skipping")
|
||||
return
|
||||
|
||||
logging.info("Extracting features for Musan")
|
||||
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
with get_executor() as ex: # Initialize the executor only once.
|
||||
# create chunks of Musan with duration 5 - 10 seconds
|
||||
musan_cuts = (
|
||||
CutSet.from_manifests(
|
||||
recordings=combine(
|
||||
part["recordings"] for part in manifests.values()
|
||||
)
|
||||
)
|
||||
.cut_into_windows(10.0)
|
||||
.filter(lambda c: c.duration > 5)
|
||||
.compute_and_store_features(
|
||||
extractor=extractor,
|
||||
storage_path=f"{output_dir}/feats_musan",
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomHdf5Writer,
|
||||
)
|
||||
)
|
||||
musan_cuts.to_json(musan_cuts_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
compute_fbank_musan()
|
386
egs/timit/ASR/local/prepare_lang.py
Normal file
386
egs/timit/ASR/local/prepare_lang.py
Normal file
@ -0,0 +1,386 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input a lexicon file "data/lang_phone/lexicon.txt"
|
||||
consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt
|
||||
|
||||
2. Generate tokens.txt, the token table mapping a token to a unique integer.
|
||||
|
||||
3. Generate words.txt, the word table mapping a word to a unique integer.
|
||||
|
||||
4. Generate L.pt, in k2 format. It can be loaded by
|
||||
|
||||
d = torch.load("L.pt")
|
||||
lexicon = k2.Fsa.from_dict(d)
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
"""
|
||||
import argparse
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import read_lexicon, write_lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
Lexicon = List[Tuple[str, List[str]]]
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain a file lexicon.txt.
|
||||
Generated files by this script are saved into this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True for debugging, which will generate
|
||||
a visualization of the lexicon FST.
|
||||
|
||||
Caution: If your lexicon contains hundreds of thousands
|
||||
of lines, please set it to False!
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||
"""Write a symbol to ID mapping to a file.
|
||||
|
||||
Note:
|
||||
No need to implement `read_mapping` as it can be done
|
||||
through :func:`k2.SymbolTable.from_file`.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename to save the mapping.
|
||||
sym2id:
|
||||
A dict mapping symbols to IDs.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
for sym, i in sym2id.items():
|
||||
f.write(f"{sym} {i}\n")
|
||||
|
||||
|
||||
def get_tokens(lexicon: Lexicon) -> List[str]:
|
||||
"""Get tokens from a lexicon.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is the return value of :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a list of unique tokens.
|
||||
"""
|
||||
ans = set()
|
||||
for _, tokens in lexicon:
|
||||
ans.update(tokens)
|
||||
|
||||
sorted_ans = list(ans)
|
||||
return sorted_ans
|
||||
|
||||
|
||||
def get_words(lexicon: Lexicon) -> List[str]:
|
||||
"""Get words from a lexicon.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is the return value of :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a list of unique words.
|
||||
"""
|
||||
ans = set()
|
||||
for word, _ in lexicon:
|
||||
ans.add(word)
|
||||
sorted_ans = sorted(list(ans))
|
||||
return sorted_ans
|
||||
|
||||
|
||||
def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]:
|
||||
"""It adds pseudo-token disambiguation symbols #1, #2 and so on
|
||||
at the ends of tokens to ensure that all pronunciations are different,
|
||||
and that none is a prefix of another.
|
||||
|
||||
See also add_lex_disambig.pl from kaldi.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
It is returned by :func:`read_lexicon`.
|
||||
Returns:
|
||||
Return a tuple with two elements:
|
||||
|
||||
- The output lexicon with disambiguation symbols
|
||||
- The ID of the max disambiguation symbol that appears
|
||||
in the lexicon
|
||||
"""
|
||||
|
||||
# (1) Work out the count of each token-sequence in the
|
||||
# lexicon.
|
||||
count = defaultdict(int)
|
||||
for _, tokens in lexicon:
|
||||
count[" ".join(tokens)] += 1
|
||||
|
||||
# (2) For each left sub-sequence of each token-sequence, note down
|
||||
# that it exists (for identifying prefixes of longer strings).
|
||||
issubseq = defaultdict(int)
|
||||
for _, tokens in lexicon:
|
||||
tokens = tokens.copy()
|
||||
tokens.pop()
|
||||
while tokens:
|
||||
issubseq[" ".join(tokens)] = 1
|
||||
tokens.pop()
|
||||
|
||||
# (3) For each entry in the lexicon:
|
||||
# if the token sequence is unique and is not a
|
||||
# prefix of another word, no disambig symbol.
|
||||
# Else output #1, or #2, #3, ... if the same token-seq
|
||||
# has already been assigned a disambig symbol.
|
||||
ans = []
|
||||
|
||||
# We start with #1 since #0 has its own purpose
|
||||
first_allowed_disambig = 1
|
||||
max_disambig = first_allowed_disambig - 1
|
||||
last_used_disambig_symbol_of = defaultdict(int)
|
||||
|
||||
for word, tokens in lexicon:
|
||||
tokenseq = " ".join(tokens)
|
||||
assert tokenseq != ""
|
||||
if issubseq[tokenseq] == 0 and count[tokenseq] == 1:
|
||||
ans.append((word, tokens))
|
||||
continue
|
||||
|
||||
cur_disambig = last_used_disambig_symbol_of[tokenseq]
|
||||
if cur_disambig == 0:
|
||||
cur_disambig = first_allowed_disambig
|
||||
else:
|
||||
cur_disambig += 1
|
||||
|
||||
if cur_disambig > max_disambig:
|
||||
max_disambig = cur_disambig
|
||||
last_used_disambig_symbol_of[tokenseq] = cur_disambig
|
||||
tokenseq += f" #{cur_disambig}"
|
||||
ans.append((word, tokenseq.split()))
|
||||
return ans, max_disambig
|
||||
|
||||
|
||||
def generate_id_map(symbols: List[str]) -> Dict[str, int]:
|
||||
"""Generate ID maps, i.e., map a symbol to a unique ID.
|
||||
|
||||
Args:
|
||||
symbols:
|
||||
A list of unique symbols.
|
||||
Returns:
|
||||
A dict containing the mapping between symbols and IDs.
|
||||
"""
|
||||
return {sym: i for i, sym in enumerate(symbols)}
|
||||
|
||||
|
||||
def add_self_loops(
|
||||
arcs: List[List[Any]], disambig_token: int, disambig_word: int
|
||||
) -> List[List[Any]]:
|
||||
"""Adds self-loops to states of an FST to propagate disambiguation symbols
|
||||
through it. They are added on each state with non-epsilon output symbols
|
||||
on at least one arc out of the state.
|
||||
|
||||
See also fstaddselfloops.pl from Kaldi. One difference is that
|
||||
Kaldi uses OpenFst style FSTs and it has multiple final states.
|
||||
This function uses k2 style FSTs and it does not need to add self-loops
|
||||
to the final state.
|
||||
|
||||
The input label of a self-loop is `disambig_token`, while the output
|
||||
label is `disambig_word`.
|
||||
|
||||
Args:
|
||||
arcs:
|
||||
A list-of-list. The sublist contains
|
||||
`[src_state, dest_state, label, aux_label, score]`
|
||||
disambig_token:
|
||||
It is the token ID of the symbol `#0`.
|
||||
disambig_word:
|
||||
It is the word ID of the symbol `#0`.
|
||||
|
||||
Return:
|
||||
Return new `arcs` containing self-loops.
|
||||
"""
|
||||
states_needs_self_loops = set()
|
||||
for arc in arcs:
|
||||
src, dst, ilabel, olabel, score = arc
|
||||
if olabel != 0:
|
||||
states_needs_self_loops.add(src)
|
||||
|
||||
ans = []
|
||||
for s in states_needs_self_loops:
|
||||
ans.append([s, s, disambig_token, disambig_word, 0])
|
||||
|
||||
return arcs + ans
|
||||
|
||||
|
||||
def lexicon_to_fst(
|
||||
lexicon: Lexicon,
|
||||
token2id: Dict[str, int],
|
||||
word2id: Dict[str, int],
|
||||
need_self_loops: bool = False,
|
||||
) -> k2.Fsa:
|
||||
"""Convert a lexicon to an FST (in k2 format) with optional silence at
|
||||
the beginning and end of each word.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon. See also :func:`read_lexicon`
|
||||
token2id:
|
||||
A dict mapping tokens to IDs.
|
||||
word2id:
|
||||
A dict mapping words to IDs.
|
||||
need_self_loops:
|
||||
If True, add self-loop to states with non-epsilon output symbols
|
||||
on at least one arc out of the state. The input label for this
|
||||
self loop is `token2id["#0"]` and the output label is `word2id["#0"]`.
|
||||
Returns:
|
||||
Return an instance of `k2.Fsa` representing the given lexicon.
|
||||
"""
|
||||
pronprob = 1.0
|
||||
score = -math.log(pronprob)
|
||||
|
||||
loop_state = 0 # words enter and leave from here
|
||||
next_state = 1 # the next un-allocated state, will be incremented as we go.
|
||||
arcs = []
|
||||
|
||||
assert token2id["<eps>"] == 0
|
||||
assert word2id["<eps>"] == 0
|
||||
|
||||
eps = 0
|
||||
for word, tokens in lexicon:
|
||||
assert len(tokens) > 0, f"{word} has no pronunciations"
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
tokens = [token2id[i] for i in tokens]
|
||||
|
||||
for i in range(len(tokens) - 1):
|
||||
w = word if i == 0 else eps
|
||||
arcs.append([cur_state, next_state, tokens[i], w, score])
|
||||
|
||||
cur_state = next_state
|
||||
next_state += 1
|
||||
|
||||
# now for the last token of this word
|
||||
# It has two out-going arcs, one to the loop state,
|
||||
# the other one to the sil_state.
|
||||
i = len(tokens) - 1
|
||||
w = word if i == 0 else eps
|
||||
tokens[i] = tokens[i] if i >= 0 else eps
|
||||
arcs.append([cur_state, loop_state, tokens[i], w, score])
|
||||
|
||||
if need_self_loops:
|
||||
disambig_token = token2id["#0"]
|
||||
disambig_word = word2id["#0"]
|
||||
arcs = add_self_loops(
|
||||
arcs,
|
||||
disambig_token=disambig_token,
|
||||
disambig_word=disambig_word,
|
||||
)
|
||||
|
||||
final_state = next_state
|
||||
arcs.append([loop_state, final_state, -1, -1, 0])
|
||||
arcs.append([final_state])
|
||||
|
||||
arcs = sorted(arcs, key=lambda arc: arc[0])
|
||||
arcs = [[str(i) for i in arc] for arc in arcs]
|
||||
arcs = [" ".join(arc) for arc in arcs]
|
||||
arcs = "\n".join(arcs)
|
||||
fsa = k2.Fsa.from_str(arcs, acceptor=False)
|
||||
return fsa
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
lexicon_filename = lang_dir / "lexicon.txt"
|
||||
|
||||
lexicon = read_lexicon(lexicon_filename)
|
||||
tokens = get_tokens(lexicon)
|
||||
|
||||
words = get_words(lexicon)
|
||||
lexicon_disambig, max_disambig = add_disambig_symbols(lexicon)
|
||||
|
||||
for i in range(max_disambig + 1):
|
||||
disambig = f"#{i}"
|
||||
assert disambig not in tokens
|
||||
tokens.append(f"#{i}")
|
||||
|
||||
assert "<eps>" not in tokens
|
||||
tokens = ["<eps>"] + tokens
|
||||
|
||||
assert "<eps>" not in words
|
||||
assert "#0" not in words
|
||||
assert "<s>" not in words
|
||||
assert "</s>" not in words
|
||||
|
||||
words = ["<eps>"] + words + ["#0", "<s>", "</s>"]
|
||||
|
||||
token2id = generate_id_map(tokens)
|
||||
word2id = generate_id_map(words)
|
||||
|
||||
write_mapping(lang_dir / "tokens.txt", token2id)
|
||||
write_mapping(lang_dir / "words.txt", word2id)
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst(
|
||||
lexicon,
|
||||
token2id=token2id,
|
||||
word2id=word2id,
|
||||
)
|
||||
|
||||
L_disambig = lexicon_to_fst(
|
||||
lexicon_disambig,
|
||||
token2id=token2id,
|
||||
word2id=word2id,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
if False:
|
||||
# Just for debugging, will remove it
|
||||
L.labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||
L.aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
L_disambig.labels_sym = L.labels_sym
|
||||
L_disambig.aux_labels_sym = L.aux_labels_sym
|
||||
L.draw(lang_dir / "L.png", title="L")
|
||||
L_disambig.draw(lang_dir / "L_disambig.png", title="L_disambig")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
102
egs/timit/ASR/local/prepare_lexicon.py
Normal file
102
egs/timit/ASR/local/prepare_lexicon.py
Normal file
@ -0,0 +1,102 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script takes as input supervisions json dir "data/manifests"
|
||||
consisting of supervisions_TRAIN.json and does the following:
|
||||
|
||||
1. Generate lexicon.txt.
|
||||
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--manifests-dir",
|
||||
type=str,
|
||||
help="""Input directory.
|
||||
""",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Output directory.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def prepare_lexicon(manifests_dir: str, lang_dir: str):
|
||||
"""
|
||||
Args:
|
||||
manifests_dir:
|
||||
The manifests directory, e.g., data/manifests.
|
||||
lang_dir:
|
||||
The language directory, e.g., data/lang_phone.
|
||||
|
||||
Return:
|
||||
The lexicon.txt file and the train.text in lang_dir.
|
||||
"""
|
||||
phones = set([])
|
||||
|
||||
supervisions_train = Path(manifests_dir) / "supervisions_TRAIN.json"
|
||||
lexicon = Path(lang_dir) / "lexicon.txt"
|
||||
|
||||
logging.info(f"Loading {supervisions_train}!")
|
||||
with open(supervisions_train, "r") as load_f:
|
||||
load_dicts = json.load(load_f)
|
||||
for load_dict in load_dicts:
|
||||
text = load_dict["text"]
|
||||
# list the phone units and filter the empty item
|
||||
phones_list = list(filter(None, text.split()))
|
||||
|
||||
for phone in phones_list:
|
||||
if phone not in phones:
|
||||
phones.add(phone)
|
||||
|
||||
with open(lexicon, "w") as f:
|
||||
for phone in sorted(phones):
|
||||
f.write(phone + " " + phone)
|
||||
f.write("\n")
|
||||
f.write("<UNK> <UNK>")
|
||||
f.write("\n")
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
manifests_dir = Path(args.manifests_dir)
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
logging.info("Generating lexicon.txt")
|
||||
prepare_lexicon(manifests_dir, lang_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -56,6 +56,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
# using: `sudo apt-get install git-lfs && git-lfs install`
|
||||
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
|
||||
git clone https://huggingface.co/luomingshuang/timit_lm $dl_dir/lm
|
||||
cd $dl_dir/lm && git lfs pull
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
@ -124,7 +125,7 @@ fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare G"
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# We assume you have installed kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
|
||||
mkdir -p data/lm
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# 2021 Xiaomi Corp. (authors: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -310,26 +311,20 @@ class TimitAsrDataModule(DataModule):
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
cuts_train = load_manifest(
|
||||
self.args.feature_dir / "cuts_TRAIN.json.gz"
|
||||
)
|
||||
cuts_train = load_manifest(self.args.feature_dir / "cuts_TRAIN.json.gz")
|
||||
|
||||
return cuts_train
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = load_manifest(
|
||||
self.args.feature_dir / "cuts_DEV.json.gz"
|
||||
)
|
||||
cuts_valid = load_manifest(self.args.feature_dir / "cuts_DEV.json.gz")
|
||||
|
||||
return cuts_valid
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.debug("About to get test cuts")
|
||||
cuts_test = load_manifest(
|
||||
self.args.feature_dir / "cuts_TEST.json.gz"
|
||||
)
|
||||
cuts_test = load_manifest(self.args.feature_dir / "cuts_TEST.json.gz")
|
||||
|
||||
return cuts_test
|
||||
|
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -26,7 +27,7 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import TimitAsrDataModule
|
||||
from model import TdnnLstm
|
||||
from model import TdnnLiGRU
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import (
|
||||
@ -442,14 +443,13 @@ def main():
|
||||
else:
|
||||
G = None
|
||||
|
||||
model = TdnnLstm(
|
||||
model = TdnnLiGRU(
|
||||
num_features=params.feature_dim,
|
||||
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
#load_checkpoint(f"tmp/icefall_asr_librispeech_tdnn-lstm_ctc/exp/pretrained.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
@ -470,22 +470,9 @@ def main():
|
||||
model.eval()
|
||||
|
||||
timit = TimitAsrDataModule(args)
|
||||
# CAUTION: `test_sets` is for displaying only.
|
||||
# If you want to skip test-clean, you have to skip
|
||||
# it inside the for loop. That is, use
|
||||
#
|
||||
# if test_set == 'test-clean': continue
|
||||
#
|
||||
#test_sets = ["test-clean", "test-other"]
|
||||
#test_sets = ["test-other"]
|
||||
#for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||
#if test_set == "test-clean": continue
|
||||
#if test_set == "test-other": break
|
||||
test_set = "TEST"
|
||||
test_dl = timit.test_dataloaders()
|
||||
|
||||
#test_set = "TRAIN"
|
||||
#test_dl = timit.train_dataloaders()
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
|
@ -1,4 +1,5 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -21,12 +22,10 @@ import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from typing import Optional
|
||||
|
||||
class TdnnLstm(nn.Module):
|
||||
|
||||
class TdnnLiGRU(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
num_classes: int,
|
||||
subsampling_factor: int = 3
|
||||
self, num_features: int, num_classes: int, subsampling_factor: int = 3
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -65,7 +64,6 @@ class TdnnLstm(nn.Module):
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
#stride=self.subsampling_factor, # stride: subsampling_factor!
|
||||
stride=1,
|
||||
padding=1,
|
||||
),
|
||||
@ -75,7 +73,7 @@ class TdnnLstm(nn.Module):
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
stride=self.subsampling_factor,
|
||||
stride=self.subsampling_factor, # stride: subsampling_factor!
|
||||
padding=1,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
@ -83,21 +81,20 @@ class TdnnLstm(nn.Module):
|
||||
)
|
||||
self.ligrus = nn.ModuleList(
|
||||
[
|
||||
LiGRU(input_shape=[None, None, 512], hidden_size=512, num_layers=1, bidirectional=True, re_init=False)
|
||||
LiGRU(
|
||||
input_shape=[None, None, 512],
|
||||
hidden_size=512,
|
||||
num_layers=1,
|
||||
bidirectional=True,
|
||||
)
|
||||
for _ in range(4)
|
||||
]
|
||||
)
|
||||
self.linears = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(in_features=1024, out_features=512)
|
||||
for _ in range(4)
|
||||
]
|
||||
[nn.Linear(in_features=1024, out_features=512) for _ in range(4)]
|
||||
)
|
||||
self.bnorms = nn.ModuleList(
|
||||
[
|
||||
nn.BatchNorm1d(num_features=512, affine=False)
|
||||
for _ in range(4)
|
||||
]
|
||||
[nn.BatchNorm1d(num_features=512, affine=False) for _ in range(4)]
|
||||
)
|
||||
self.dropout = nn.Dropout(0.2)
|
||||
self.linear = nn.Linear(in_features=512, out_features=self.num_classes)
|
||||
@ -115,23 +112,22 @@ class TdnnLstm(nn.Module):
|
||||
x = x.permute(0, 2, 1)
|
||||
for ligru, linear, bnorm in zip(self.ligrus, self.linears, self.bnorms):
|
||||
x_new, _ = ligru(x)
|
||||
#print('ligru output shape: ', x_new.shape)
|
||||
x_new = linear(x_new)
|
||||
#print('linear output shape: ', x_new.shape)
|
||||
x_new = bnorm(x_new.permute(0, 2, 1)).permute(0, 2, 1)
|
||||
# 2, 0, 1
|
||||
#) # (T, N, C) -> (N, C, T) -> (T, N, C)
|
||||
# (N, T, C) -> (N, C, T) -> (N, T, C)
|
||||
x_new = self.dropout(x_new)
|
||||
x = x_new + x # skip connections
|
||||
#x = x.transpose(
|
||||
# 1, 0
|
||||
#) # (T, N, C) -> (N, T, C) -> linear expects "features" in the last dim
|
||||
|
||||
x = self.linear(x)
|
||||
x = nn.functional.log_softmax(x, dim=-1)
|
||||
return x
|
||||
|
||||
|
||||
class LiGRU(torch.nn.Module):
|
||||
""" This function implements a Light GRU (liGRU).
|
||||
"""This function implements a Light GRU (liGRU).
|
||||
This LiGRU model is from speechbrain, please see
|
||||
https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/RNN.py
|
||||
|
||||
|
||||
LiGRU is single-gate GRU model based on batch-norm + relu
|
||||
activations + recurrent dropout. For more info see:
|
||||
@ -169,9 +165,6 @@ class LiGRU(torch.nn.Module):
|
||||
If True, the additive bias b is adopted.
|
||||
dropout : float
|
||||
It is the dropout factor (must be between 0 and 1).
|
||||
re_init : bool
|
||||
If True, orthogonal initialization is used for the recurrent weights.
|
||||
Xavier initialization is used for the input connection weights.
|
||||
bidirectional : bool
|
||||
If True, a bidirectional model that scans the sequence both
|
||||
right-to-left and left-to-right is used.
|
||||
@ -194,7 +187,6 @@ class LiGRU(torch.nn.Module):
|
||||
num_layers=1,
|
||||
bias=True,
|
||||
dropout=0.0,
|
||||
re_init=True,
|
||||
bidirectional=False,
|
||||
):
|
||||
super().__init__()
|
||||
@ -204,7 +196,6 @@ class LiGRU(torch.nn.Module):
|
||||
self.normalization = normalization
|
||||
self.bias = bias
|
||||
self.dropout = dropout
|
||||
self.re_init = re_init
|
||||
self.bidirectional = bidirectional
|
||||
self.reshape = False
|
||||
|
||||
@ -215,13 +206,9 @@ class LiGRU(torch.nn.Module):
|
||||
self.batch_size = input_shape[0]
|
||||
self.rnn = self._init_layers()
|
||||
|
||||
if self.re_init:
|
||||
rnn_init(self.rnn)
|
||||
|
||||
def _init_layers(self):
|
||||
"""Initializes the layers of the liGRU."""
|
||||
rnn = torch.nn.ModuleList([])
|
||||
#print('fea_dim: ', self.fea_dim)
|
||||
current_dim = self.fea_dim
|
||||
|
||||
for i in range(self.num_layers):
|
||||
@ -296,7 +283,7 @@ class LiGRU(torch.nn.Module):
|
||||
|
||||
|
||||
class LiGRU_Layer(torch.nn.Module):
|
||||
""" This function implements Light-Gated Recurrent Units (ligru) layer.
|
||||
"""This function implements Light-Gated Recurrent Units (ligru) layer.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
@ -344,11 +331,7 @@ class LiGRU_Layer(torch.nn.Module):
|
||||
self.drop_mask_cnt = 0
|
||||
self.drop_mask_te = torch.tensor([1.0]).float()
|
||||
self.w = nn.Linear(self.input_size, 2 * self.hidden_size, bias=False)
|
||||
|
||||
self.u = nn.Linear(self.hidden_size, 2 * self.hidden_size, bias=False)
|
||||
print(self.batch_size)
|
||||
#if self.bidirectional:
|
||||
# self.batch_size = self.batch_size * 2
|
||||
|
||||
# Initializing batch norm
|
||||
self.normalize = False
|
||||
@ -369,9 +352,6 @@ class LiGRU_Layer(torch.nn.Module):
|
||||
# Initial state
|
||||
self.register_buffer("h_init", torch.zeros(1, self.hidden_size))
|
||||
|
||||
# Preloading dropout masks (gives some speed improvement)
|
||||
#self._init_drop(self.batch_size)
|
||||
|
||||
# Setting the activation function
|
||||
if nonlinearity == "tanh":
|
||||
self.act = torch.nn.Tanh()
|
||||
@ -399,7 +379,6 @@ class LiGRU_Layer(torch.nn.Module):
|
||||
self._change_batch_size(x)
|
||||
|
||||
# Feed-forward affine transformations (all steps in parallel)
|
||||
#print(x.shape)
|
||||
w = self.w(x)
|
||||
|
||||
# Apply batch normalization
|
||||
@ -450,7 +429,6 @@ class LiGRU_Layer(torch.nn.Module):
|
||||
"""Initializes the recurrent dropout operation. To speed it up,
|
||||
the dropout masks are sampled in advance.
|
||||
"""
|
||||
#self.drop = torch.nn.Dropout(p=self.dropout, inplace=False)
|
||||
self.N_drop_masks = 16000
|
||||
self.drop_mask_cnt = 0
|
||||
|
||||
@ -497,71 +475,8 @@ class LiGRU_Layer(torch.nn.Module):
|
||||
if self.training:
|
||||
self.drop_masks = self.drop(
|
||||
torch.ones(
|
||||
self.N_drop_masks, self.hidden_size, device=x.device,
|
||||
self.N_drop_masks,
|
||||
self.hidden_size,
|
||||
device=x.device,
|
||||
)
|
||||
).data
|
||||
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
"""Computes a linear transformation y = wx + b.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
n_neurons : int
|
||||
It is the number of output neurons (i.e, the dimensionality of the
|
||||
output).
|
||||
input_shape: tuple
|
||||
It is the shape of the input tensor.
|
||||
input_size: int
|
||||
Size of the input tensor.
|
||||
bias : bool
|
||||
If True, the additive bias b is adopted.
|
||||
combine_dims : bool
|
||||
If True and the input is 4D, combine 3rd and 4th dimensions of input.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> inputs = torch.rand(10, 50, 40)
|
||||
>>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100)
|
||||
>>> output = lin_t(inputs)
|
||||
>>> output.shape
|
||||
torch.Size([10, 50, 100])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_neurons,
|
||||
input_shape=None,
|
||||
input_size=None,
|
||||
bias=True,
|
||||
combine_dims=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.combine_dims = combine_dims
|
||||
|
||||
if input_shape is None and input_size is None:
|
||||
raise ValueError("Expected one of input_shape or input_size")
|
||||
|
||||
if input_size is None:
|
||||
input_size = input_shape[-1]
|
||||
if len(input_shape) == 4 and self.combine_dims:
|
||||
input_size = input_shape[2] * input_shape[3]
|
||||
|
||||
# Weights are initialized following pytorch approach
|
||||
self.w = nn.Linear(input_size, n_neurons, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
"""Returns the linear transformation of input tensor.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
x : torch.Tensor
|
||||
Input to transform linearly.
|
||||
"""
|
||||
if x.ndim == 4 and self.combine_dims:
|
||||
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
|
||||
#print(x.shape)
|
||||
#print(self.w)
|
||||
wx = self.w(x)
|
||||
|
||||
return wx
|
||||
|
@ -26,7 +26,7 @@ import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from model import TdnnLstm
|
||||
from model import TdnnLiGRU
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.decode import (
|
||||
@ -91,7 +91,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.8,
|
||||
default=0.1,
|
||||
help="""
|
||||
Used only when method is whole-lattice-rescoring.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
@ -169,7 +169,7 @@ def main():
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = TdnnLstm(
|
||||
model = TdnnLiGRU(
|
||||
num_features=params.feature_dim,
|
||||
num_classes=params.num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
|
@ -30,7 +30,7 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from asr_datamodule import TimitAsrDataModule
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import TdnnLstm
|
||||
from model import TdnnLiGRU
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
@ -81,7 +81,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=60,
|
||||
default=25,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
@ -508,7 +508,7 @@ def run(rank, world_size, args):
|
||||
|
||||
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
|
||||
|
||||
model = TdnnLstm(
|
||||
model = TdnnLiGRU(
|
||||
num_features=params.feature_dim,
|
||||
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
|
0
egs/timit/ASR/tdnn_lstm_ctc/__init__.py
Normal file
0
egs/timit/ASR/tdnn_lstm_ctc/__init__.py
Normal file
330
egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
Normal file
330
egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py
Normal file
@ -0,0 +1,330 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# 2021 Xiaomi Corp. (authors: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||
from lhotse.dataset import (
|
||||
BucketingSampler,
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.dataset.datamodule import DataModule
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class TimitAsrDataModule(DataModule):
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
super().add_arguments(parser)
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--feature-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the BucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
def train_dataloaders(self) -> DataLoader:
|
||||
logging.info("About to get train cuts")
|
||||
cuts_train = self.train_cuts()
|
||||
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = [
|
||||
SpecAugment(
|
||||
num_frame_masks=2,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
]
|
||||
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using BucketingSampler.")
|
||||
train_sampler = BucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
bucket_method="equal_duration",
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self) -> DataLoader:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = self.valid_cuts()
|
||||
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = SingleCutSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
||||
cuts = self.test_cuts()
|
||||
is_list = isinstance(cuts, list)
|
||||
test_loaders = []
|
||||
if not is_list:
|
||||
cuts = [cuts]
|
||||
|
||||
for cuts_test in cuts:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = SingleCutSampler(
|
||||
cuts_test, max_duration=self.args.max_duration
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test, batch_size=None, sampler=sampler, num_workers=1
|
||||
)
|
||||
test_loaders.append(test_dl)
|
||||
|
||||
if is_list:
|
||||
return test_loaders
|
||||
else:
|
||||
return test_loaders[0]
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
cuts_train = load_manifest(self.args.feature_dir / "cuts_TRAIN.json.gz")
|
||||
|
||||
return cuts_train
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = load_manifest(self.args.feature_dir / "cuts_DEV.json.gz")
|
||||
|
||||
return cuts_valid
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> CutSet:
|
||||
logging.debug("About to get test cuts")
|
||||
cuts_test = load_manifest(self.args.feature_dir / "cuts_TEST.json.gz")
|
||||
|
||||
return cuts_test
|
@ -1,6 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Mingshuang Luo)
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -239,7 +238,8 @@ def decode_one_batch(
|
||||
|
||||
assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"]
|
||||
|
||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
lm_scale_list = [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09]
|
||||
lm_scale_list += [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||
|
||||
@ -409,7 +409,7 @@ def main():
|
||||
if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]:
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
logging.warning("It may take 20 seconds.")
|
||||
logging.warning("It may take 8 minutes.")
|
||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
@ -469,6 +469,7 @@ def main():
|
||||
model.eval()
|
||||
|
||||
timit = TimitAsrDataModule(args)
|
||||
test_set = "TEST"
|
||||
test_dl = timit.test_dataloaders()
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
@ -478,7 +479,7 @@ def main():
|
||||
lexicon=lexicon,
|
||||
G=G,
|
||||
)
|
||||
test_set = "TEST"
|
||||
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=results_dict
|
||||
)
|
||||
|
110
egs/timit/ASR/tdnn_lstm_ctc/model.py
Normal file
110
egs/timit/ASR/tdnn_lstm_ctc/model.py
Normal file
@ -0,0 +1,110 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class TdnnLstm(nn.Module):
|
||||
def __init__(
|
||||
self, num_features: int, num_classes: int, subsampling_factor: int = 3
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
num_features:
|
||||
The input dimension of the model.
|
||||
num_classes:
|
||||
The output dimension of the model.
|
||||
subsampling_factor:
|
||||
It reduces the number of output frames by this factor.
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.num_classes = num_classes
|
||||
self.subsampling_factor = subsampling_factor
|
||||
self.tdnn = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
in_channels=num_features,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm1d(num_features=512, affine=False),
|
||||
nn.Conv1d(
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm1d(num_features=512, affine=False),
|
||||
nn.Conv1d(
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm1d(num_features=512, affine=False),
|
||||
nn.Conv1d(
|
||||
in_channels=512,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
stride=self.subsampling_factor, # stride: subsampling_factor!
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm1d(num_features=512, affine=False),
|
||||
)
|
||||
self.lstms = nn.ModuleList(
|
||||
[
|
||||
nn.LSTM(input_size=512, hidden_size=512, num_layers=1)
|
||||
for _ in range(4)
|
||||
]
|
||||
)
|
||||
self.lstm_bnorms = nn.ModuleList(
|
||||
[nn.BatchNorm1d(num_features=512, affine=False) for _ in range(5)]
|
||||
)
|
||||
self.dropout = nn.Dropout(0.2)
|
||||
self.linear = nn.Linear(in_features=512, out_features=self.num_classes)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
Its shape is [N, C, T]
|
||||
Returns:
|
||||
The output tensor has shape [N, T, C]
|
||||
"""
|
||||
x = self.tdnn(x)
|
||||
x = x.permute(2, 0, 1) # (N, C, T) -> (T, N, C) -> how LSTM expects it
|
||||
for lstm, bnorm in zip(self.lstms, self.lstm_bnorms):
|
||||
x_new, _ = lstm(x)
|
||||
x_new = bnorm(x_new.permute(1, 2, 0)).permute(
|
||||
2, 0, 1
|
||||
) # (T, N, C) -> (N, C, T) -> (T, N, C)
|
||||
x_new = self.dropout(x_new)
|
||||
x = x_new + x # skip connections
|
||||
x = x.transpose(
|
||||
1, 0
|
||||
) # (T, N, C) -> (N, T, C) -> linear expects "features" in the last dim
|
||||
x = self.linear(x)
|
||||
x = nn.functional.log_softmax(x, dim=-1)
|
||||
return x
|
595
egs/timit/ASR/tdnn_lstm_ctc/train.py
Normal file
595
egs/timit/ASR/tdnn_lstm_ctc/train.py
Normal file
@ -0,0 +1,595 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from asr_datamodule import TimitAsrDataModule
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import TdnnLstm
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.optim.lr_scheduler import StepLR
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.graph_compiler import CtcTrainingGraphCompiler
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
encode_supervisions,
|
||||
get_env_info,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--master-port",
|
||||
type=int,
|
||||
default=12354,
|
||||
help="Master port to use for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tensorboard",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Resume training from from this epoch.
|
||||
If it is positive, it will load checkpoint from
|
||||
tdnn_lstm_ctc/exp/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
"""Return a dict containing training parameters.
|
||||
|
||||
All training related parameters that are not passed from the commandline
|
||||
is saved in the variable `params`.
|
||||
|
||||
Commandline options are merged into `params` after they are parsed, so
|
||||
you can also access them via `params`.
|
||||
|
||||
Explanation of options saved in `params`:
|
||||
|
||||
- exp_dir: It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
||||
- lang_dir: It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
|
||||
- lr: It specifies the initial learning rate
|
||||
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
in computing features.
|
||||
|
||||
- weight_decay: The weight_decay for the optimizer.
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_valid_loss: Best validation loss so far. It is used to select
|
||||
the model that has the lowest validation loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_train_epoch: It is the epoch that has the best training loss.
|
||||
|
||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||
|
||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||
contains number of batches trained so far across
|
||||
epochs.
|
||||
|
||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||
|
||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||
|
||||
- valid_interval: Run validation if batch_idx % valid_interval` is 0
|
||||
|
||||
- beam_size: It is used in k2.ctc_loss
|
||||
|
||||
- reduction: It is used in k2.ctc_loss
|
||||
|
||||
- use_double_scores: It is used in k2.ctc_loss
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("tdnn_lstm_ctc/exp"),
|
||||
"lang_dir": Path("data/lang_phone"),
|
||||
"lr": 1e-3,
|
||||
"feature_dim": 80,
|
||||
"weight_decay": 5e-4,
|
||||
"subsampling_factor": 3,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 10,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 1000,
|
||||
"beam_size": 10,
|
||||
"reduction": "sum",
|
||||
"use_double_scores": True,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
) -> None:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_epoch is positive, it will load the checkpoint from
|
||||
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
||||
|
||||
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
The learning rate scheduler we are using.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
if params.start_epoch <= 0:
|
||||
return
|
||||
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: torch.optim.lr_scheduler._LRScheduler,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
batch: dict,
|
||||
graph_compiler: CtcTrainingGraphCompiler,
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute CTC loss given the model and its inputs.
|
||||
|
||||
Args:
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
model:
|
||||
The model for training. It is an instance of TdnnLstm in our case.
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
graph_compiler:
|
||||
It is used to build a decoding graph from a ctc topo and training
|
||||
transcript. The training transcript is contained in the given `batch`,
|
||||
while the ctc topo is built when this compiler is instantiated.
|
||||
is_training:
|
||||
True for training. False for validation. When it is True, this
|
||||
function enables autograd during computation; when it is False, it
|
||||
disables autograd.
|
||||
"""
|
||||
device = graph_compiler.device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
feature = feature.permute(0, 2, 1) # now feature is (N, C, T)
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
nnet_output = model(feature)
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
# different duration in decreasing order, required by
|
||||
# `k2.intersect_dense` called in `k2.ctc_loss`
|
||||
supervisions = batch["supervisions"]
|
||||
supervision_segments, texts = encode_supervisions(
|
||||
supervisions, subsampling_factor=params.subsampling_factor
|
||||
)
|
||||
decoding_graph = graph_compiler.compile(texts)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output,
|
||||
supervision_segments,
|
||||
allow_truncate=params.subsampling_factor - 1,
|
||||
)
|
||||
|
||||
loss = k2.ctc_loss(
|
||||
decoding_graph=decoding_graph,
|
||||
dense_fsa_vec=dense_fsa_vec,
|
||||
output_beam=params.beam_size,
|
||||
reduction=params.reduction,
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
info = MetricsTracker()
|
||||
info["frames"] = supervision_segments[:, 2].sum().item()
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
graph_compiler: CtcTrainingGraphCompiler,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process. The validation loss
|
||||
is saved in `params.valid_loss`.
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=False,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
graph_compiler: CtcTrainingGraphCompiler,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
optimizer:
|
||||
The optimizer we are using.
|
||||
graph_compiler:
|
||||
It is used to convert transcripts to FSAs.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=True,
|
||||
)
|
||||
# summary stats.
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}"
|
||||
)
|
||||
if batch_idx % params.log_interval == 0:
|
||||
|
||||
if tb_writer is not None:
|
||||
loss_info.write_summary(
|
||||
tb_writer, "train/current_", params.batch_idx_train
|
||||
)
|
||||
tot_loss.write_summary(
|
||||
tb_writer, "train/tot_", params.batch_idx_train
|
||||
)
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}")
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer,
|
||||
"train/valid_",
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
fix_random_seed(42)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
logging.info(params)
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_phone_id = max(lexicon.tokens)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
|
||||
graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device)
|
||||
|
||||
model = TdnnLstm(
|
||||
num_features=params.feature_dim,
|
||||
num_classes=max_phone_id + 1, # +1 for the blank symbol
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
|
||||
optimizer = optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=params.lr,
|
||||
weight_decay=params.weight_decay,
|
||||
)
|
||||
scheduler = StepLR(optimizer, step_size=8, gamma=0.8)
|
||||
|
||||
if checkpoints:
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
timit = TimitAsrDataModule(args)
|
||||
train_dl = timit.train_dataloaders()
|
||||
valid_dl = timit.valid_dataloaders()
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
|
||||
if epoch > params.start_epoch:
|
||||
logging.info(f"epoch {epoch}, lr: {scheduler.get_last_lr()[0]}")
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/lr",
|
||||
scheduler.get_last_lr()[0],
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
graph_compiler=graph_compiler,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
scheduler.step()
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
TimitAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user