Update docs, pretrained.py & results

This commit is contained in:
pkufool 2021-11-16 12:32:51 +08:00
parent 943244642f
commit cbc5557c87
8 changed files with 137 additions and 79 deletions

View File

@ -97,13 +97,17 @@ Configurable options
shows you the training options that can be passed from the commandline. shows you the training options that can be passed from the commandline.
The following options are used quite often: The following options are used quite often:
- ``--exp-dir``
The experiment folder to save logs and model checkpoints,
default ``./conformer_ctc/exp``.
- ``--num-epochs`` - ``--num-epochs``
It is the number of epochs to train. For instance, It is the number of epochs to train. For instance,
``./conformer_ctc/train.py --num-epochs 30`` trains for 30 epochs ``./conformer_ctc/train.py --num-epochs 30`` trains for 30 epochs
and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt`` and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt``
in the folder ``./conformer_ctc/exp``. in the folder set with ``--exp-dir``.
- ``--start-epoch`` - ``--start-epoch``
@ -174,7 +178,7 @@ Pre-configured options
~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~
There are some training options, e.g., weight decay, There are some training options, e.g., weight decay,
number of warmup steps, results dir, etc, number of warmup steps, etc,
that are not passed from the commandline. that are not passed from the commandline.
They are pre-configured by the function ``get_params()`` in They are pre-configured by the function ``get_params()`` in
`conformer_ctc/train.py <https://github.com/k2-fsa/icefall/blob/master/egs/aishell/ASR/conformer_ctc/train.py>`_ `conformer_ctc/train.py <https://github.com/k2-fsa/icefall/blob/master/egs/aishell/ASR/conformer_ctc/train.py>`_
@ -192,8 +196,8 @@ them, please modify ``./conformer_ctc/train.py`` directly.
Training logs Training logs
~~~~~~~~~~~~~ ~~~~~~~~~~~~~
Training logs and checkpoints are saved in ``conformer_ctc/exp``. Training logs and checkpoints are saved in the folder set by ``--exp-dir``
You will find the following files in that directory: (default ``conformer_ctc/exp``). You will find the following files in that directory:
- ``epoch-0.pt``, ``epoch-1.pt``, ... - ``epoch-0.pt``, ``epoch-1.pt``, ...
@ -223,10 +227,10 @@ You will find the following files in that directory:
To stop uploading, press Ctrl-C. To stop uploading, press Ctrl-C.
New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/qvNrx6JIQAaN5Ly3uQotrg/ New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/WE1DocDqRRCOSAgmGyClhg/
[2021-09-12T16:41:16] Started scanning logdir. [2021-11-16T10:51:46] Started scanning logdir.
[2021-09-12T16:42:17] Total uploaded: 125346 scalars, 0 tensors, 0 binary objects [2021-11-16T10:52:32] Total uploaded: 111606 scalars, 0 tensors, 0 binary objects
Listening for new data in logdir... Listening for new data in logdir...
Note there is a URL in the above output, click it and you will see Note there is a URL in the above output, click it and you will see
@ -236,7 +240,7 @@ You will find the following files in that directory:
:width: 600 :width: 600
:alt: TensorBoard screenshot :alt: TensorBoard screenshot
:align: center :align: center
:target: https://tensorboard.dev/experiment/qvNrx6JIQAaN5Ly3uQotrg/ :target: https://tensorboard.dev/experiment/WE1DocDqRRCOSAgmGyClhg/
TensorBoard screenshot. TensorBoard screenshot.
@ -307,9 +311,9 @@ The commonly used options are:
.. code-block:: .. code-block::
$ cd egs/aishell/ASR $ cd egs/aishell/ASR
$ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --lattice-score-scale 0.5 $ ./conformer_ctc/decode.py --method attention-decoder --max-duration 30 --nbest-scale 0.5
- ``--lattice-score-scale`` - ``--nbest-scale``
It is used to scale down lattice scores so that there are more unique It is used to scale down lattice scores so that there are more unique
paths for rescoring. paths for rescoring.
@ -403,7 +407,7 @@ After downloading, you will have the following files:
- ``exp/pretrained.pt`` - ``exp/pretrained.pt``
It contains pre-trained model parameters, obtained by averaging It contains pre-trained model parameters, obtained by averaging
checkpoints from ``epoch-18.pt`` to ``epoch-40.pt``. checkpoints from ``epoch-25.pt`` to ``epoch-84.pt``.
Note: We have removed optimizer ``state_dict`` to reduce file size. Note: We have removed optimizer ``state_dict`` to reduce file size.
- ``test_waves/*.wav`` - ``test_waves/*.wav``
@ -483,7 +487,7 @@ The command to run HLG decoding is:
--method 1best \ --method 1best \
./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0121.wav \ ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0122.wav \ ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav
The output is given below: The output is given below:
@ -527,7 +531,7 @@ The command to run HLG decoding + attention decoder rescoring is:
--method attention-decoder \ --method attention-decoder \
./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0121.wav \ ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0121.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0122.wav \ ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0122.wav \
./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav ./tmp/icefall_asr_aishell_conformer_ctc/test_wavs/BAC009S0764W0123.wav
The output is below: The output is below:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 544 KiB

After

Width:  |  Height:  |  Size: 308 KiB

View File

@ -1,16 +1,16 @@
## Results ## Results
### Aishell training results (Conformer-CTC) ### Aishell training results (Conformer-CTC)
#### 2021-09-13 #### 2021-11-16
(Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/30 (Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/30
Pretrained model is available at https://huggingface.co/pkufool/icefall_asr_aishell_conformer_ctc Pretrained model is available at https://huggingface.co/pkufool/icefall_asr_aishell_conformer_ctc
The best decoding results (CER) are listed below, we got this results by averaging models from epoch 23 to 40, and using `attention-decoder` decoder with num_paths equals to 100. The best decoding results (CER) are listed below, we got this results by averaging models from epoch 25 to 84, and using `attention-decoder` decoder with num_paths equals to 100.
||test| ||test|
|--|--| |--|--|
|CER| 4.74% | |CER| 4.26% |
To get more unique paths, we scaled the lattice.scores with 0.5 (see https://github.com/k2-fsa/icefall/pull/10#discussion_r690951662 for more details), we searched the lm_score_scale and attention_score_scale for best results, the scales that produced the CER above are also listed below. To get more unique paths, we scaled the lattice.scores with 0.5 (see https://github.com/k2-fsa/icefall/pull/10#discussion_r690951662 for more details), we searched the lm_score_scale and attention_score_scale for best results, the scales that produced the CER above are also listed below.
@ -27,17 +27,18 @@ cd icefall
cd egs/aishell/ASR cd egs/aishell/ASR
./prepare.sh ./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1" export CUDA_VISIBLE_DEVICES="0,1,2,3"
python conformer_ctc/train.py --bucketing-sampler False \ python conformer_ctc/train.py --bucketing-sampler True \
--concatenate-cuts False \
--max-duration 200 \ --max-duration 200 \
--world-size 2 --start-epoch 0 \
--num-epoch 90 \
--world-size 4
python conformer_ctc/decode.py --lattice-score-scale 0.5 \ python conformer_ctc/decode.py --nbest-scale 0.5 \
--epoch 40 \ --epoch 84 \
--avg 18 \ --avg 25 \
--method attention-decoder \ --method attention-decoder \
--max-duration 50 \ --max-duration 20 \
--num-paths 100 --num-paths 100
``` ```
@ -53,4 +54,3 @@ The best decoding results (CER) are listed below, we got this results by averagi
||test| ||test|
|--|--| |--|--|
|CER| 10.16% | |CER| 10.16% |

View File

@ -77,6 +77,8 @@ def get_parser():
default="attention-decoder", default="attention-decoder",
help="""Decoding method. help="""Decoding method.
Supported values are: Supported values are:
- (0) ctc-decoding. Use CTC decoding. It maps the tokens ids to
tokens using token symbol tabel directly.
- (1) 1best. Extract the best path from the decoding lattice as the - (1) 1best. Extract the best path from the decoding lattice as the
decoding result. decoding result.
- (2) nbest. Extract n paths from the decoding lattice; the path - (2) nbest. Extract n paths from the decoding lattice; the path

View File

@ -34,7 +34,7 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
rescore_with_attention_decoder, rescore_with_attention_decoder,
) )
from icefall.utils import AttributeDict, get_texts from icefall.utils import AttributeDict, get_env_info, get_texts
def get_parser(): def get_parser():
@ -52,14 +52,21 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--words-file", "--tokens-file",
type=str, type=str,
required=True, help="Path to tokens.txt" "Used only when method is ctc-decoding",
help="Path to words.txt",
) )
parser.add_argument( parser.add_argument(
"--HLG", type=str, required=True, help="Path to HLG.pt." "--words-file",
type=str,
help="Path to words.txt" "Used when method is NOT ctc-decoding",
)
parser.add_argument(
"--HLG",
type=str,
help="Path to HLG.pt." "Used when method is NOT ctc-decoding",
) )
parser.add_argument( parser.add_argument(
@ -68,6 +75,8 @@ def get_parser():
default="1best", default="1best",
help="""Decoding method. help="""Decoding method.
Possible values are: Possible values are:
(0) ctc-decoding - Use ctc decoding. It maps the tokens ids to tokens
using the token symbol table directly.
(1) 1best - Use the best path as decoding output. Only (1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding. the transformer encoder output is used for decoding.
We call it HLG decoding. We call it HLG decoding.
@ -111,7 +120,7 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lattice-score-scale", "--nbest-scale",
type=float, type=float,
default=0.5, default=0.5,
help=""" help="""
@ -125,7 +134,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--sos-id", "--sos-id",
type=float, type=int,
default=1, default=1,
help=""" help="""
Used only when method is attention-decoder. Used only when method is attention-decoder.
@ -135,7 +144,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--eos-id", "--eos-id",
type=float, type=int,
default=1, default=1,
help=""" help="""
Used only when method is attention-decoder. Used only when method is attention-decoder.
@ -143,6 +152,13 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--num_classes",
type=int,
default=4336,
help="The Vocab size.",
)
parser.add_argument( parser.add_argument(
"sound_files", "sound_files",
type=str, type=str,
@ -160,7 +176,6 @@ def get_params() -> AttributeDict:
params = AttributeDict( params = AttributeDict(
{ {
"sample_rate": 16000, "sample_rate": 16000,
"num_classes": 4336,
# parameters for conformer # parameters for conformer
"subsampling_factor": 4, "subsampling_factor": 4,
"feature_dim": 80, "feature_dim": 80,
@ -175,6 +190,7 @@ def get_params() -> AttributeDict:
"min_active_states": 30, "min_active_states": 30,
"max_active_states": 10000, "max_active_states": 10000,
"use_double_scores": True, "use_double_scores": True,
"env_info": get_env_info(),
} }
) )
return params return params
@ -212,6 +228,11 @@ def main():
params.update(vars(args)) params.update(vars(args))
logging.info(f"{params}") logging.info(f"{params}")
if args.method != "attention-decoder":
# to save memory as the attention decoder
# will not be used
params.num_decoder_layers = 0
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
@ -231,17 +252,10 @@ def main():
) )
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"], strict=False)
model.to(device) model.to(device)
model.eval() model.eval()
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
logging.info("Constructing Fbank computer") logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions() opts = kaldifeat.FbankOptions()
opts.device = device opts.device = device
@ -275,41 +289,79 @@ def main():
dtype=torch.int32, dtype=torch.int32,
) )
lattice = get_lattice( if params.method == "ctc-decoding":
nnet_output=nnet_output, logging.info("Use CTC decoding")
HLG=HLG, token_sym_table = k2.SymbolTable.from_file(params.tokens_file)
supervision_segments=supervision_segments, max_token_id = params.num_classes - 1
search_beam=params.search_beam,
output_beam=params.output_beam, H = k2.ctc_topo(
min_active_states=params.min_active_states, max_token=max_token_id,
max_active_states=params.max_active_states, modified=False,
subsampling_factor=params.subsampling_factor, device=device,
) )
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=H,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding( best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores
) )
elif params.method == "attention-decoder": token_ids = get_texts(best_path)
logging.info("Use HLG + attention decoder rescoring") hyps = [[token_sym_table[i] for i in ids] for ids in token_ids]
best_path_dict = rescore_with_attention_decoder( hyps = [s.split() for s in hyps]
lattice=lattice, elif params.method in ["1best", "attention-decoder"]:
num_paths=params.num_paths, logging.info(f"Loading HLG from {params.HLG}")
model=model, HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
memory=memory, HLG = HLG.to(device)
memory_key_padding_mask=memory_key_padding_mask, if not hasattr(HLG, "lm_scores"):
sos_id=params.sos_id, # For whole-lattice-rescoring and attention-decoder
eos_id=params.eos_id, HLG.lm_scores = HLG.scores.clone()
scale=params.lattice_score_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path) lattice = get_lattice(
word_sym_table = k2.SymbolTable.from_file(params.words_file) nnet_output=nnet_output,
hyps = [[word_sym_table[i] for i in ids] for ids in hyps] HLG=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
elif params.method == "attention-decoder":
logging.info("Use HLG + attention decoder rescoring")
best_path_dict = rescore_with_attention_decoder(
lattice=lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
scale=params.lattice_score_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):

View File

@ -23,6 +23,7 @@ It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank. The generated fbank features are saved in data/fbank.
""" """
import argparse
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
@ -43,7 +44,7 @@ torch.set_num_interop_threads(1)
def compute_fbank_aishell(num_mel_bins: int = 80): def compute_fbank_aishell(num_mel_bins: int = 80):
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path("data/fbank40") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(15, os.cpu_count())
dataset_parts = ( dataset_parts = (
@ -106,4 +107,3 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_aishell(num_mel_bins=args.num_mel_bins) compute_fbank_aishell(num_mel_bins=args.num_mel_bins)

View File

@ -23,6 +23,7 @@ It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank. The generated fbank features are saved in data/fbank.
""" """
import argparse
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
@ -43,7 +44,7 @@ torch.set_num_interop_threads(1)
def compute_fbank_musan(num_mel_bins: int = 80): def compute_fbank_musan(num_mel_bins: int = 80):
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path("data/fbank40") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(15, os.cpu_count())
dataset_parts = ( dataset_parts = (
@ -86,6 +87,7 @@ def compute_fbank_musan(num_mel_bins: int = 80):
) )
musan_cuts.to_json(musan_cuts_path) musan_cuts.to_json(musan_cuts_path)
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
@ -106,4 +108,3 @@ if __name__ == "__main__":
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args() args = get_args()
compute_fbank_musan(num_mel_bins=args.num_mel_bins) compute_fbank_musan(num_mel_bins=args.num_mel_bins)

View File

@ -69,7 +69,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
# |-- lexicon.txt # |-- lexicon.txt
# `-- speaker.info # `-- speaker.info
if [ ! -d $dl_dir/aishell/wav ]; then if [ ! -d $dl_dir/aishell/data_aishell/wav ]; then
lhotse download aishell $dl_dir lhotse download aishell $dl_dir
fi fi
@ -133,7 +133,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt | cat $dl_dir/aishell/data_aishell/transcript/aishell_transcript_v0.8.txt |
cut -d " " -f 2- | sed -e 's/[ \t\r\n]*//g' > data/lang_char/text cut -d " " -f 2- | sed -e 's/[ \t\r\n]*//g' > data/lang_char/text
if [ ! -f data/lang_char/L_disambig.pt ]; then if [ ! -f data/lang_char/L_disambig.pt ]; then
./local/prepare_char.py ./local/prepare_char.py
fi fi
@ -160,4 +160,3 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
./local/compile_hlg.py --lang-dir data/lang_phone ./local/compile_hlg.py --lang-dir data/lang_phone
./local/compile_hlg.py --lang-dir data/lang_char ./local/compile_hlg.py --lang-dir data/lang_char
fi fi