From f344b07e7c0feb40f0cf53857d2019066deac9a7 Mon Sep 17 00:00:00 2001 From: PingFeng Luo Date: Wed, 17 Nov 2021 15:50:18 +0800 Subject: [PATCH] update aishell recipe with master branch and fix some bugs --- docs/source/recipes/index.rst | 2 - egs/aishell/ASR/RESULTS.md | 42 +- .../ASR/conformer_ctc/asr_datamodule.py | 336 +++++++++- egs/aishell/ASR/conformer_ctc/conformer.py | 5 +- egs/aishell/ASR/conformer_ctc/decode.py | 318 +++++++-- egs/aishell/ASR/conformer_ctc/pretrained.py | 208 ++++-- egs/aishell/ASR/conformer_ctc/subsampling.py | 32 +- .../ASR/conformer_ctc/test_subsampling.py | 49 -- .../ASR/conformer_ctc/test_transformer.py | 105 --- egs/aishell/ASR/conformer_ctc/train.py | 367 +++++------ egs/aishell/ASR/conformer_ctc/transformer.py | 90 +-- egs/aishell/ASR/local/compile_hlg.py | 2 +- .../ASR/local/compute_fbank_aishell.py | 3 +- egs/aishell/ASR/local/compute_fbank_musan.py | 3 +- egs/aishell/ASR/prepare.sh | 3 +- egs/aishell/ASR/tdnn_lstm_ctc/README.md | 4 - egs/aishell/ASR/tdnn_lstm_ctc/__init__.py | 0 .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 335 ---------- egs/aishell/ASR/tdnn_lstm_ctc/decode.py | 399 ------------ egs/aishell/ASR/tdnn_lstm_ctc/model.py | 103 --- egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py | 231 ------- egs/aishell/ASR/tdnn_lstm_ctc/train.py | 616 ------------------ 22 files changed, 1015 insertions(+), 2238 deletions(-) mode change 120000 => 100644 egs/aishell/ASR/conformer_ctc/asr_datamodule.py delete mode 100755 egs/aishell/ASR/conformer_ctc/test_subsampling.py delete mode 100644 egs/aishell/ASR/conformer_ctc/test_transformer.py delete mode 100644 egs/aishell/ASR/tdnn_lstm_ctc/README.md delete mode 100644 egs/aishell/ASR/tdnn_lstm_ctc/__init__.py delete mode 100644 egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py delete mode 100755 egs/aishell/ASR/tdnn_lstm_ctc/decode.py delete mode 100644 egs/aishell/ASR/tdnn_lstm_ctc/model.py delete mode 100644 egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py delete mode 100755 egs/aishell/ASR/tdnn_lstm_ctc/train.py diff --git a/docs/source/recipes/index.rst b/docs/source/recipes/index.rst index ab81e5875..36f8dfc39 100644 --- a/docs/source/recipes/index.rst +++ b/docs/source/recipes/index.rst @@ -15,5 +15,3 @@ We may add recipes for other tasks as well in the future. yesno librispeech - - aishell diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 5cbe5d213..7d821a672 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -1,22 +1,22 @@ ## Results -### Aishell training results (Conformer-CTC) -#### 2021-09-13 +### AIShell training results (Conformer-CTC) +#### 2021-11-17 (Wei Kang): Result of https://github.com/k2-fsa/icefall/pull/30 +(Pinfeng Luo): Result of https://github.com/k2-fsa/icefall/pull/30 -Pretrained model is available at https://huggingface.co/pkufool/icefall_asr_aishell_conformer_ctc +Pretrained model is available at https://huggingface.co/pfluo/icefall_aishell_model +The tensorboard log for training is available at https://tensorboard.dev/experiment/zsw6Hn6EQlG8I7HqEkiQpw -The best decoding results (CER) are listed below, we got this results by averaging models from epoch 23 to 40, and using `attention-decoder` decoder with num_paths equals to 100. +The best decoding results (CER) are listed below, we got this results by averaging models from epoch 30 to 49, and using `attention-decoder` decoder with num_paths equals to 100. ||test| |--|--| -|CER| 4.74% | - -To get more unique paths, we scaled the lattice.scores with 0.5 (see https://github.com/k2-fsa/icefall/pull/10#discussion_r690951662 for more details), we searched the lm_score_scale and attention_score_scale for best results, the scales that produced the CER above are also listed below. +|CER| 4.38% | ||lm_scale|attention_scale| |--|--|--| -|test|0.3|0.9| +|test|0.6|1.2| You can use the following commands to reproduce our results: @@ -27,30 +27,16 @@ cd icefall cd egs/aishell/ASR ./prepare.sh -export CUDA_VISIBLE_DEVICES="0,1" -python conformer_ctc/train.py --bucketing-sampler False \ +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" +python3 conformer_ctc/train.py --bucketing-sampler False \ --concatenate-cuts False \ --max-duration 200 \ - --world-size 2 + --world-size 8 -python conformer_ctc/decode.py --lattice-score-scale 0.5 \ - --epoch 40 \ - --avg 18 \ +python3 conformer_ctc/decode.py --lattice-score-scale 0.5 \ + --epoch 49 \ + --avg 20 \ --method attention-decoder \ --max-duration 50 \ --num-paths 100 ``` - -### Aishell training results (Tdnn-Lstm) -#### 2021-09-13 - -(Wei Kang): Result of phone based Tdnn-Lstm model, https://github.com/k2-fsa/icefall/pull/30 - -Pretrained model is available at https://huggingface.co/pkufool/icefall_asr_aishell_conformer_ctc_lstm_ctc - -The best decoding results (CER) are listed below, we got this results by averaging models from epoch 19 to 8, and using `1best` decoding method. - -||test| -|--|--| -|CER| 10.16% | - diff --git a/egs/aishell/ASR/conformer_ctc/asr_datamodule.py b/egs/aishell/ASR/conformer_ctc/asr_datamodule.py deleted file mode 120000 index fa1b8cca3..000000000 --- a/egs/aishell/ASR/conformer_ctc/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_ctc/asr_datamodule.py b/egs/aishell/ASR/conformer_ctc/asr_datamodule.py new file mode 100644 index 000000000..9dede6288 --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/asr_datamodule.py @@ -0,0 +1,335 @@ +# Copyright 2021 Piotr Żelasko +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import List, Union + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( + BucketingSampler, + CutConcatenate, + CutMix, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool + + +class AishellAsrDataModule(DataModule): + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + """ + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + super().add_arguments(parser) + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--feature-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + def train_dataloaders(self) -> DataLoader: + logging.info("About to get train cuts") + cuts_train = self.train_cuts() + + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") + + logging.info("About to create train dataset") + transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [ + SpecAugment( + num_frame_masks=2, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ] + + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using BucketingSampler.") + train_sampler = BucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + bucket_method="equal_duration", + drop_last=True, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return train_dl + + def valid_dataloaders(self) -> DataLoader: + logging.info("About to get dev cuts") + cuts_valid = self.valid_cuts() + + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = SingleCutSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: + cuts = self.test_cuts() + is_list = isinstance(cuts, list) + test_loaders = [] + if not is_list: + cuts = [cuts] + + for cuts_test in cuts: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = SingleCutSampler( + cuts_test, max_duration=self.args.max_duration + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, batch_size=None, sampler=sampler, num_workers=1 + ) + test_loaders.append(test_dl) + + if is_list: + return test_loaders + else: + return test_loaders[0] + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + cuts_train = load_manifest( + self.args.feature_dir / "cuts_train.json.gz" + ) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest( + self.args.feature_dir / "cuts_dev.json.gz" + ) + return cuts_valid + + @lru_cache() + def test_cuts(self) -> List[CutSet]: + test_sets = ["test"] + cuts = [] + for test_set in test_sets: + logging.debug("About to get test cuts") + cuts.append( + load_manifest( + self.args.feature_dir / f"cuts_{test_set}.json.gz" + ) + ) + return cuts diff --git a/egs/aishell/ASR/conformer_ctc/conformer.py b/egs/aishell/ASR/conformer_ctc/conformer.py index 5b136f40e..b19b94db1 100644 --- a/egs/aishell/ASR/conformer_ctc/conformer.py +++ b/egs/aishell/ASR/conformer_ctc/conformer.py @@ -40,7 +40,6 @@ class Conformer(Transformer): cnn_module_kernel (int): Kernel size of convolution module normalize_before (bool): whether to use layer_norm before the first block. vgg_frontend (bool): whether to use vgg frontend. - use_feat_batchnorm(bool): whether to use batch-normalize the input. """ def __init__( @@ -99,7 +98,7 @@ class Conformer(Transformer): """ Args: x: - The model input. Its shape is [N, T, C]. + The model input. Its shape is (N, T, C). supervisions: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa @@ -453,7 +452,6 @@ class RelPositionMultiheadAttention(nn.Module): self._reset_parameters() - def _reset_parameters(self) -> None: nn.init.xavier_uniform_(self.in_proj.weight) nn.init.constant_(self.in_proj.bias, 0.0) @@ -683,7 +681,6 @@ class RelPositionMultiheadAttention(nn.Module): _b = _b[_start:] v = nn.functional.linear(value, _w, _b) - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 diff --git a/egs/aishell/ASR/conformer_ctc/decode.py b/egs/aishell/ASR/conformer_ctc/decode.py index 20a8f7b3a..e288967f8 100755 --- a/egs/aishell/ASR/conformer_ctc/decode.py +++ b/egs/aishell/ASR/conformer_ctc/decode.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, -# Fangjun Kuang, -# Wei Kang) +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# Copyright 2021 (Author: Pingfeng Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -25,6 +24,7 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 +import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import AishellAsrDataModule @@ -38,10 +38,13 @@ from icefall.decode import ( nbest_oracle, one_best_decoding, rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_whole_lattice, ) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + get_env_info, get_texts, setup_logger, store_transcripts, @@ -77,13 +80,22 @@ def get_parser(): default="attention-decoder", help="""Decoding method. Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. - (1) 1best. Extract the best path from the decoding lattice as the decoding result. - (2) nbest. Extract n paths from the decoding lattice; the path with the highest score is the decoding result. - - (3) attention-decoder. Extract n paths from the lattice, - the path with the highest score is the decoding result. - - (4) nbest-oracle. Its WER is the lower bound of any n-best + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (5) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (6) nbest-oracle. Its WER is the lower bound of any n-best rescoring method can achieve. Useful for debugging n-best rescoring method. """, @@ -95,18 +107,18 @@ def get_parser(): default=100, help="""Number of paths for n-best based decoding method. Used only when "method" is one of the following values: - nbest, attention-decoder, and nbest-oracle + nbest, nbest-rescoring, attention-decoder, and nbest-oracle """, ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help="""The scale to be applied to `lattice.scores`. It's needed if you use any kinds of n-best based rescoring. Used only when "method" is one of the following values: - nbest, attention-decoder, and nbest-oracle + nbest, nbest-rescoring, attention-decoder, and nbest-oracle A smaller value results in more unique paths. """, ) @@ -122,17 +134,39 @@ def get_parser(): """, ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The LM dir. + It should contain either G_3_gram.pt or G_3_gram.fst.txt + """, + ) + return parser def get_params() -> AttributeDict: params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_char"), - "lm_dir": Path("data/lm"), # parameters for conformer "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, "feature_dim": 80, "nhead": 4, "attention_dim": 512, @@ -146,6 +180,7 @@ def get_params() -> AttributeDict: "min_active_states": 30, "max_active_states": 10000, "use_double_scores": True, + "env_info": get_env_info(), } ) return params @@ -154,21 +189,23 @@ def get_params() -> AttributeDict: def decode_one_batch( params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], batch: dict, word_table: k2.SymbolTable, sos_id: int, eos_id: int, -) -> Dict[str, List[List[int]]]: + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: - key: It indicates the setting used for decoding. For example, - if decoding method is 1best, the key is the string `no_rescore`. - If attention rescoring is used, the key is the string - `ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the - value of `lm_scale` and `attention_scale`. An example key is - `ngram_lm_scale_0.7_attention_scale_0.5` + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` - value: It contains the decoding result. `len(value)` equals to batch size. `value[i]` is the decoding result for the i-th utterance in the given batch. @@ -178,12 +215,18 @@ def decode_one_batch( - params.method is "1best", it uses 1best decoding without LM rescoring. - params.method is "nbest", it uses nbest decoding without LM rescoring. - - params.method is "attention-decoder", it uses attention rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. model: The neural model. HLG: - The decoding graph. + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. batch: It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation @@ -194,20 +237,27 @@ def decode_one_batch( The token ID of the SOS. eos_id: The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. Returns: Return the decoding result. See above description for the format of the returned dict. """ - device = HLG.device + if HLG is not None: + device = HLG.device + else: + device = H.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) - # at entry, feature is [N, T, C] + # at entry, feature is (N, T, C) supervisions = batch["supervisions"] nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) - # nnet_output is [N, T, C] + # nnet_output is (N, T, C) supervision_segments = torch.stack( ( @@ -218,9 +268,17 @@ def decode_one_batch( 1, ).to(torch.int32) + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + lattice = get_lattice( nnet_output=nnet_output, - HLG=HLG, + decoding_graph=decoding_graph, supervision_segments=supervision_segments, search_beam=params.search_beam, output_beam=params.output_beam, @@ -229,18 +287,41 @@ def decode_one_batch( subsampling_factor=params.subsampling_factor, ) + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + if params.method == "nbest-oracle": # Note: You can also pass rescored lattices to it. # We choose the HLG decoded lattice for speed reasons # as HLG decoding is faster and the oracle WER - # is slightly worse than that of rescored lattices. - return nbest_oracle( + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( lattice=lattice, num_paths=params.num_paths, ref_texts=supervisions["text"], word_table=word_table, - scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, + oov="", ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} if params.method in ["1best", "nbest"]: if params.method == "1best": @@ -253,31 +334,70 @@ def decode_one_batch( lattice=lattice, num_paths=params.num_paths, use_double_scores=params.use_double_scores, - scale=params.lattice_score_scale, + nbest_scale=params.nbest_scale, ) - key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[word_table[i] for i in ids] for ids in hyps] return {key: hyps} - assert params.method == "attention-decoder" + assert params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.method == "attention-decoder": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + # TODO: pass `lattice` instead of `rescored_lattice` to + # `rescore_with_attention_decoder` + + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.method}" - best_path_dict = rescore_with_attention_decoder( - lattice=lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=sos_id, - eos_id=eos_id, - scale=params.lattice_score_scale, - ) ans = dict() - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + for lm_scale in lm_scale_list: + ans[f"{lm_scale}"] = [[] * lattice.shape[0]] return ans @@ -285,11 +405,14 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - HLG: k2.Fsa, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], word_table: k2.SymbolTable, sos_id: int, eos_id: int, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. Args: @@ -300,18 +423,26 @@ def decode_dataset( model: The neural model. HLG: - The decoding graph. + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. word_table: It is the word symbol table. sos_id: The token ID for SOS. eos_id: The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. Returns: - Return a dict, whose key may be "no-rescore" if the decoding method is - 1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention - rescoring is used. Its value is a list of tuples. Each tuple contains two - elements: The first is the reference transcript, and the second is the + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the predicted result. """ results = [] @@ -331,14 +462,18 @@ def decode_dataset( params=params, model=model, HLG=HLG, + H=H, + bpe_model=bpe_model, batch=batch, word_table=word_table, + G=G, sos_id=sos_id, eos_id=eos_id, ) for lm_scale, hyps in hyps_dict.items(): this_batch = [] + assert len(hyps) == len(texts) for hyp_words, ref_text in zip(hyps, texts): ref_words = ref_text.split() this_batch.append((ref_words, hyp_words)) @@ -411,6 +546,9 @@ def main(): parser = get_parser() AishellAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) params = get_params() params.update(vars(args)) @@ -438,14 +576,70 @@ def main(): sos_id = graph_compiler.sos_id eos_id = graph_compiler.eos_id - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) - HLG = HLG.to(device) - assert HLG.requires_grad is False + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + ): + if not (params.lm_dir / "G_3_gram.pt").is_file(): + logging.info("Loading G_3_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_3_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G["dummy"] = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_3_gram.pt") + else: + logging.info("Loading pre-compiled G_3_gram.pt") + d = torch.load(params.lm_dir / "G_3_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.method in ["whole-lattice-rescoring", "attention-decoder"]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None model = Conformer( num_features=params.feature_dim, @@ -468,7 +662,8 @@ def main(): if start >= 0: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") @@ -496,7 +691,10 @@ def main(): params=params, model=model, HLG=HLG, + H=H, + bpe_model=bpe_model, word_table=lexicon.word_table, + G=G, sos_id=sos_id, eos_id=eos_id, ) diff --git a/egs/aishell/ASR/conformer_ctc/pretrained.py b/egs/aishell/ASR/conformer_ctc/pretrained.py index 846681f00..e7b8c2cc8 100755 --- a/egs/aishell/ASR/conformer_ctc/pretrained.py +++ b/egs/aishell/ASR/conformer_ctc/pretrained.py @@ -24,6 +24,7 @@ from typing import List import k2 import kaldifeat +import sentencepiece as spm import torch import torchaudio from conformer import Conformer @@ -33,8 +34,9 @@ from icefall.decode import ( get_lattice, one_best_decoding, rescore_with_attention_decoder, + rescore_with_whole_lattice, ) -from icefall.utils import AttributeDict, get_texts +from icefall.utils import AttributeDict, get_env_info, get_texts def get_parser(): @@ -54,12 +56,25 @@ def get_parser(): parser.add_argument( "--words-file", type=str, - required=True, - help="Path to words.txt", + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, ) parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, ) parser.add_argument( @@ -68,10 +83,18 @@ def get_parser(): default="1best", help="""Decoding method. Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. (1) 1best - Use the best path as decoding output. Only the transformer encoder output is used for decoding. We call it HLG decoding. - (2) attention-decoder - Extract n paths from the rescored + (2) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + (3) attention-decoder - Extract n paths from the rescored lattice and use the transformer attention decoder for rescoring. We call it HLG decoding + n-gram LM rescoring + attention @@ -79,6 +102,16 @@ def get_parser(): """, ) + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or attention-decoder. + It's usually a 4-gram LM. + """, + ) + parser.add_argument( "--num-paths", type=int, @@ -93,7 +126,7 @@ def get_parser(): type=float, default=0.3, help=""" - Used only when method is attention-decoder. + Used only when method is whole-lattice-rescoring and attention-decoder. It specifies the scale for n-gram LM scores. (Note: You need to tune it on a dataset.) """, @@ -111,7 +144,7 @@ def get_parser(): ) parser.add_argument( - "--lattice-score-scale", + "--nbest-scale", type=float, default=0.5, help=""" @@ -125,7 +158,7 @@ def get_parser(): parser.add_argument( "--sos-id", - type=float, + type=int, default=1, help=""" Used only when method is attention-decoder. @@ -133,9 +166,18 @@ def get_parser(): """, ) + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + parser.add_argument( "--eos-id", - type=float, + type=int, default=1, help=""" Used only when method is attention-decoder. @@ -163,13 +205,13 @@ def get_params() -> AttributeDict: "num_classes": 4336, # parameters for conformer "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, "feature_dim": 80, "nhead": 4, "attention_dim": 512, "num_decoder_layers": 6, - "vgg_frontend": False, - "use_feat_batchnorm": True, - # parameters for deocding + # parameters for decoding "search_beam": 20, "output_beam": 8, "min_active_states": 30, @@ -209,7 +251,13 @@ def main(): args = parser.parse_args() params = get_params() + if args.method != "attention-decoder": + # to save memory as the attention decoder + # will not be used + params.num_decoder_layers = 0 + params.update(vars(args)) + params["env_info"] = get_env_info() logging.info(f"{params}") device = torch.device("cpu") @@ -231,17 +279,10 @@ def main(): ) checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"]) + model.load_state_dict(checkpoint["model"], strict=False) model.to(device) model.eval() - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - logging.info("Constructing Fbank computer") opts = kaldifeat.FbankOptions() opts.device = device @@ -275,41 +316,108 @@ def main(): dtype=torch.int32, ) - lattice = get_lattice( - nnet_output=nnet_output, - HLG=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) - if params.method == "1best": - logging.info("Use HLG decoding") best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - elif params.method == "attention-decoder": - logging.info("Use HLG + attention decoder rescoring") - best_path_dict = rescore_with_attention_decoder( - lattice=lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=params.sos_id, - eos_id=params.eos_id, - scale=params.lattice_score_scale, - ngram_lm_scale=params.ngram_lm_scale, - attention_scale=params.attention_decoder_scale, - ) - best_path = next(iter(best_path_dict.values())) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "attention-decoder": + logging.info("Use HLG + LM rescoring + attention decoder rescoring") + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + ) + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=params.sos_id, + eos_id=params.eos_id, + nbest_scale=params.nbest_scale, + ngram_lm_scale=params.ngram_lm_scale, + attention_scale=params.attention_decoder_scale, + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") s = "\n" for filename, hyp in zip(params.sound_files, hyps): diff --git a/egs/aishell/ASR/conformer_ctc/subsampling.py b/egs/aishell/ASR/conformer_ctc/subsampling.py index 720ed6c22..542fb0364 100644 --- a/egs/aishell/ASR/conformer_ctc/subsampling.py +++ b/egs/aishell/ASR/conformer_ctc/subsampling.py @@ -22,8 +22,8 @@ import torch.nn as nn class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/4 length). - Convert an input of shape [N, T, idim] to an output - with shape [N, T', odim], where + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 It is based on @@ -34,10 +34,10 @@ class Conv2dSubsampling(nn.Module): """ Args: idim: - Input dim. The input shape is [N, T, idim]. + Input dim. The input shape is (N, T, idim). Caution: It requires: T >=7, idim >=7 odim: - Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) """ assert idim >= 7 super().__init__() @@ -58,18 +58,18 @@ class Conv2dSubsampling(nn.Module): Args: x: - Its shape is [N, T, idim]. + Its shape is (N, T, idim). Returns: - Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) """ - # On entry, x is [N, T, idim] - x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W] + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = self.conv(x) - # Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2] + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - # Now x is of shape [N, ((T-1)//2 - 1))//2, odim] + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) return x @@ -80,8 +80,8 @@ class VggSubsampling(nn.Module): This paper is not 100% explicit so I am guessing to some extent, and trying to compare with other VGG implementations. - Convert an input of shape [N, T, idim] to an output - with shape [N, T', odim], where + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 """ @@ -93,10 +93,10 @@ class VggSubsampling(nn.Module): Args: idim: - Input dim. The input shape is [N, T, idim]. + Input dim. The input shape is (N, T, idim). Caution: It requires: T >=7, idim >=7 odim: - Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) """ super().__init__() @@ -149,10 +149,10 @@ class VggSubsampling(nn.Module): Args: x: - Its shape is [N, T, idim]. + Its shape is (N, T, idim). Returns: - Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) """ x = x.unsqueeze(1) x = self.layers(x) diff --git a/egs/aishell/ASR/conformer_ctc/test_subsampling.py b/egs/aishell/ASR/conformer_ctc/test_subsampling.py deleted file mode 100755 index e3361d0c9..000000000 --- a/egs/aishell/ASR/conformer_ctc/test_subsampling.py +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from subsampling import Conv2dSubsampling -from subsampling import VggSubsampling -import torch - - -def test_conv2d_subsampling(): - N = 3 - odim = 2 - - for T in range(7, 19): - for idim in range(7, 20): - model = Conv2dSubsampling(idim=idim, odim=odim) - x = torch.empty(N, T, idim) - y = model(x) - assert y.shape[0] == N - assert y.shape[1] == ((T - 1) // 2 - 1) // 2 - assert y.shape[2] == odim - - -def test_vgg_subsampling(): - N = 3 - odim = 2 - - for T in range(7, 19): - for idim in range(7, 20): - model = VggSubsampling(idim=idim, odim=odim) - x = torch.empty(N, T, idim) - y = model(x) - assert y.shape[0] == N - assert y.shape[1] == ((T - 1) // 2 - 1) // 2 - assert y.shape[2] == odim diff --git a/egs/aishell/ASR/conformer_ctc/test_transformer.py b/egs/aishell/ASR/conformer_ctc/test_transformer.py deleted file mode 100644 index b90215274..000000000 --- a/egs/aishell/ASR/conformer_ctc/test_transformer.py +++ /dev/null @@ -1,105 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -from transformer import ( - Transformer, - encoder_padding_mask, - generate_square_subsequent_mask, - decoder_padding_mask, - add_sos, - add_eos, -) - -from torch.nn.utils.rnn import pad_sequence - - -def test_encoder_padding_mask(): - supervisions = { - "sequence_idx": torch.tensor([0, 1, 2]), - "start_frame": torch.tensor([0, 0, 0]), - "num_frames": torch.tensor([18, 7, 13]), - } - - max_len = ((18 - 1) // 2 - 1) // 2 - mask = encoder_padding_mask(max_len, supervisions) - expected_mask = torch.tensor( - [ - [False, False, False], # ((18 - 1)//2 - 1)//2 = 3, - [False, True, True], # ((7 - 1)//2 - 1)//2 = 1, - [False, False, True], # ((13 - 1)//2 - 1)//2 = 2, - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_transformer(): - num_features = 40 - num_classes = 87 - model = Transformer(num_features=num_features, num_classes=num_classes) - - N = 31 - - for T in range(7, 30): - x = torch.rand(N, T, num_features) - y, _, _ = model(x) - assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) - - -def test_generate_square_subsequent_mask(): - s = 5 - mask = generate_square_subsequent_mask(s) - inf = float("inf") - expected_mask = torch.tensor( - [ - [0.0, -inf, -inf, -inf, -inf], - [0.0, 0.0, -inf, -inf, -inf], - [0.0, 0.0, 0.0, -inf, -inf], - [0.0, 0.0, 0.0, 0.0, -inf], - [0.0, 0.0, 0.0, 0.0, 0.0], - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_decoder_padding_mask(): - x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])] - y = pad_sequence(x, batch_first=True, padding_value=-1) - mask = decoder_padding_mask(y, ignore_id=-1) - expected_mask = torch.tensor( - [ - [False, False, True], - [False, True, True], - [False, False, False], - ] - ) - assert torch.all(torch.eq(mask, expected_mask)) - - -def test_add_sos(): - x = [[1, 2], [3], [2, 5, 8]] - y = add_sos(x, sos_id=0) - expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]] - assert y == expected_y - - -def test_add_eos(): - x = [[1, 2], [3], [2, 5, 8]] - y = add_eos(x, eos_id=0) - expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]] - assert y == expected_y diff --git a/egs/aishell/ASR/conformer_ctc/train.py b/egs/aishell/ASR/conformer_ctc/train.py index 3c54fc42a..9ea28cf6c 100755 --- a/egs/aishell/ASR/conformer_ctc/train.py +++ b/egs/aishell/ASR/conformer_ctc/train.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang # Wei Kang) +# Copyright 2021 (authors: Pinfeng Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -21,16 +22,16 @@ import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Optional, Tuple import k2 import torch -import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import AishellAsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed +from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter @@ -43,7 +44,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, + MetricsTracker, encode_supervisions, + get_env_info, setup_logger, str2bool, ) @@ -92,6 +95,35 @@ def get_parser(): """, ) + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.8, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + return parser @@ -99,18 +131,16 @@ def get_params() -> AttributeDict: """Return a dict containing training parameters. All training related parameters that are not passed from the commandline - is saved in the variable `params`. + are saved in the variable `params`. Commandline options are merged into `params` after they are parsed, so you can also access them via `params`. Explanation of options saved in `params`: - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. - best_valid_loss: Best validation loss so far. It is used to select the model that has the lowest validation loss. It is @@ -130,66 +160,60 @@ def get_params() -> AttributeDict: - valid_interval: Run validation if batch_idx % valid_interval is 0 + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Whether to do batch normalization for the + input features. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + - beam_size: It is used in k2.ctc_loss - reduction: It is used in k2.ctc_loss - use_double_scores: It is used in k2.ctc_loss - - att_rate: The proportion of label smoothing loss, final loss will be - (1 - att_rate) * ctc_loss + att_rate * label_smoothing_loss - - - subsampling_factor: The subsampling factor for the model. - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - attention_dim: Attention dimension. - - - nhead: Number of heads in multi-head attention. - Must satisfy attention_dim // nhead == 0. - - - num_encoder_layers: Number of attention encoder layers. - - - num_decoder_layers: Number of attention decoder layers. - - - use_feat_batchnorm: Whether to do normalization in the input layer. - - weight_decay: The weight_decay for the optimizer. - - lr_factor: The lr_factor for the optimizer. + - lr_factor: The lr_factor for Noam optimizer. - - warm_step: The warm_step for the optimizer. + - warm_step: The warm_step for Noam optimizer. """ params = AttributeDict( { - "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_char"), "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 10, + "log_interval": 50, "reset_interval": 200, "valid_interval": 3000, - # parameters for k2.ctc_loss - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - "att_rate": 0.7, # parameters for conformer - "subsampling_factor": 4, "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, "attention_dim": 512, "nhead": 4, "num_encoder_layers": 12, "num_decoder_layers": 6, - "use_feat_batchnorm": True, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + "att_rate": 0.7, # parameters for Noam "weight_decay": 1e-5, "lr_factor": 5.0, "warm_step": 36000, + "env_info": get_env_info(), } ) @@ -289,7 +313,7 @@ def compute_loss( batch: dict, graph_compiler: CharCtcTrainingGraphCompiler, is_training: bool, -): +) -> Tuple[Tensor, MetricsTracker]: """ Compute CTC loss given the model and its inputs. @@ -312,14 +336,14 @@ def compute_loss( """ device = graph_compiler.device feature = batch["inputs"] - # at entry, feature is [N, T, C] + # at entry, feature is (N, T, C) assert feature.ndim == 3 feature = feature.to(device) supervisions = batch["supervisions"] with torch.set_grad_enabled(is_training): nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is [N, T, C] + # nnet_output is (N, T, C) # NOTE: We need `encode_supervisions` to sort sequences with # different duration in decreasing order, required by @@ -348,36 +372,41 @@ def compute_loss( if params.att_rate != 0.0: with torch.set_grad_enabled(is_training): - if hasattr(model, "module"): - att_loss = model.module.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - else: - att_loss = model.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss else: loss = ctc_loss att_loss = torch.tensor([0]) - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() - assert loss.requires_grad == is_training - return loss, ctc_loss.detach(), att_loss.detach() + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + return loss, info def compute_validation_loss( @@ -386,18 +415,14 @@ def compute_validation_loss( graph_compiler: CharCtcTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, -) -> None: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ +) -> MetricsTracker: + """Run the validation process.""" model.eval() - tot_loss = 0.0 - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 - tot_frames = 0.0 + tot_loss = MetricsTracker() + for batch_idx, batch in enumerate(valid_dl): - loss, ctc_loss, att_loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, @@ -405,36 +430,17 @@ def compute_validation_loss( is_training=False, ) assert loss.requires_grad is False - assert ctc_loss.requires_grad is False - assert att_loss.requires_grad is False - - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - - tot_ctc_loss += ctc_loss.detach().cpu().item() - tot_att_loss += att_loss.detach().cpu().item() - - tot_frames += params.valid_frames + tot_loss = tot_loss + loss_info if world_size > 1: - s = torch.tensor( - [tot_loss, tot_ctc_loss, tot_att_loss, tot_frames], - device=loss.device, - ) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_ctc_loss = s[1] - tot_att_loss = s[2] - tot_frames = s[3] + tot_loss.reduce(loss.device) - params.valid_loss = tot_loss / tot_frames - params.valid_ctc_loss = tot_ctc_loss / tot_frames - params.valid_att_loss = tot_att_loss / tot_frames - - if params.valid_loss < params.best_valid_loss: + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss + params.best_valid_loss = loss_value + + return tot_loss def train_one_epoch( @@ -473,24 +479,21 @@ def train_one_epoch( """ model.train() - tot_loss = 0.0 # sum of losses over all batches - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 + tot_loss = MetricsTracker() - tot_frames = 0.0 # sum of frames over all batches - params.tot_loss = 0.0 - params.tot_frames = 0.0 for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss, ctc_loss, att_loss = compute_loss( + loss, loss_info = compute_loss( params=params, model=model, batch=batch, graph_compiler=graph_compiler, is_training=True, ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. @@ -500,75 +503,26 @@ def train_one_epoch( clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() - loss_cpu = loss.detach().cpu().item() - ctc_loss_cpu = ctc_loss.detach().cpu().item() - att_loss_cpu = att_loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_ctc_loss += ctc_loss_cpu - tot_att_loss += att_loss_cpu - - params.tot_frames += params.train_frames - params.tot_loss += loss_cpu - - tot_avg_loss = tot_loss / tot_frames - tot_avg_ctc_loss = tot_ctc_loss / tot_frames - tot_avg_att_loss = tot_att_loss / tot_frames - if batch_idx % params.log_interval == 0: logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, " - f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, " - f"total avg att loss: {tot_avg_att_loss:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" ) + if batch_idx % params.log_interval == 0: + if tb_writer is not None: - tb_writer.add_scalar( - "train/current_ctc_loss", - ctc_loss_cpu / params.train_frames, - params.batch_idx_train, + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train ) - tb_writer.add_scalar( - "train/current_att_loss", - att_loss_cpu / params.train_frames, - params.batch_idx_train, + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train ) - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_ctc_loss", - tot_avg_ctc_loss, - params.batch_idx_train, - ) - - tb_writer.add_scalar( - "train/tot_avg_att_loss", - tot_avg_att_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, - ) - if batch_idx > 0 and batch_idx % params.reset_interval == 0: - tot_loss = 0.0 # sum of losses over all batches - tot_ctc_loss = 0.0 - tot_att_loss = 0.0 - - tot_frames = 0.0 # sum of frames over all batches if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( + logging.info("Computing validation loss") + valid_info = compute_validation_loss( params=params, model=model, graph_compiler=graph_compiler, @@ -576,33 +530,14 @@ def train_one_epoch( world_size=world_size, ) model.train() - logging.info( - f"Epoch {params.cur_epoch}, " - f"valid ctc loss {params.valid_ctc_loss:.4f}," - f"valid att loss {params.valid_att_loss:.4f}," - f"valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") if tb_writer is not None: - tb_writer.add_scalar( - "train/valid_ctc_loss", - params.valid_ctc_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/valid_att_loss", - params.valid_att_loss, - params.batch_idx_train, - ) - tb_writer.add_scalar( - "train/valid_loss", - params.valid_loss, - params.batch_idx_train, + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train ) - params.train_loss = params.tot_loss / params.tot_frames - + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value if params.train_loss < params.best_train_loss: params.best_train_epoch = params.cur_epoch params.best_train_loss = params.train_loss @@ -685,6 +620,14 @@ def run(rank, world_size, args): train_dl = aishell.train_dataloaders() valid_dl = aishell.valid_dataloaders() + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) @@ -725,11 +668,51 @@ def run(rank, world_size, args): cleanup_dist() +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + def main(): parser = get_parser() AishellAsrDataModule.add_arguments(parser) args = parser.parse_args() - + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) world_size = args.world_size assert world_size >= 1 diff --git a/egs/aishell/ASR/conformer_ctc/transformer.py b/egs/aishell/ASR/conformer_ctc/transformer.py index 88b10b23d..c9666362f 100644 --- a/egs/aishell/ASR/conformer_ctc/transformer.py +++ b/egs/aishell/ASR/conformer_ctc/transformer.py @@ -83,8 +83,8 @@ class Transformer(nn.Module): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - # self.encoder_embed converts the input of shape [N, T, num_classes] - # to the shape [N, T//subsampling_factor, d_model]. + # self.encoder_embed converts the input of shape (N, T, num_classes) + # to the shape (N, T//subsampling_factor, d_model). # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_classes -> d_model @@ -162,7 +162,7 @@ class Transformer(nn.Module): """ Args: x: - The input tensor. Its shape is [N, T, C]. + The input tensor. Its shape is (N, T, C). supervision: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa @@ -171,17 +171,17 @@ class Transformer(nn.Module): Returns: Return a tuple containing 3 tensors: - - CTC output for ctc decoding. Its shape is [N, T, C] - - Encoder output with shape [T, N, C]. It can be used as key and + - CTC output for ctc decoding. Its shape is (N, T, C) + - Encoder output with shape (T, N, C). It can be used as key and value for the decoder. - Encoder output padding mask. It can be used as - memory_key_padding_mask for the decoder. Its shape is [N, T]. + memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] + x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) encoder_memory, memory_key_padding_mask = self.run_encoder( x, supervision ) @@ -195,7 +195,7 @@ class Transformer(nn.Module): Args: x: - The model input. Its shape is [N, T, C]. + The model input. Its shape is (N, T, C). supervisions: Supervision in lhotse format. See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa @@ -206,8 +206,8 @@ class Transformer(nn.Module): padding mask for the decoder. Returns: Return a tuple with two tensors: - - The encoder output, with shape [T, N, C] - - encoder padding mask, with shape [N, T]. + - The encoder output, with shape (T, N, C) + - encoder padding mask, with shape (N, T). The mask is None if `supervisions` is None. It is used as memory key padding mask in the decoder. """ @@ -225,17 +225,18 @@ class Transformer(nn.Module): Args: x: The output tensor from the transformer encoder. - Its shape is [T, N, C] + Its shape is (T, N, C) Returns: Return a tensor that can be used for CTC decoding. - Its shape is [N, T, C] + Its shape is (N, T, C) """ x = self.encoder_output_layer(x) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) return x + @torch.jit.export def decoder_forward( self, memory: torch.Tensor, @@ -247,7 +248,7 @@ class Transformer(nn.Module): """ Args: memory: - It's the output of the encoder with shape [T, N, C] + It's the output of the encoder with shape (T, N, C) memory_key_padding_mask: The padding mask from the encoder. token_ids: @@ -264,11 +265,15 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) @@ -301,18 +306,19 @@ class Transformer(nn.Module): return decoder_loss + @torch.jit.export def decoder_nll( self, memory: torch.Tensor, memory_key_padding_mask: torch.Tensor, - token_ids: List[List[int]], + token_ids: List[torch.Tensor], sos_id: int, eos_id: int, ) -> torch.Tensor: """ Args: memory: - It's the output of the encoder with shape [T, N, C] + It's the output of the encoder with shape (T, N, C) memory_key_padding_mask: The padding mask from the encoder. token_ids: @@ -328,14 +334,23 @@ class Transformer(nn.Module): """ # The common part between this function and decoder_forward could be # extracted as a separate function. + if isinstance(token_ids[0], torch.Tensor): + # This branch is executed by torchscript in C++. + # See https://github.com/k2-fsa/k2/pull/870 + # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 + token_ids = [tolist(t) for t in token_ids] ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) @@ -649,25 +664,25 @@ class PositionalEncoding(nn.Module): self.d_model = d_model self.xscale = math.sqrt(self.d_model) self.dropout = nn.Dropout(p=dropout) - self.pe = None + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) def extend_pe(self, x: torch.Tensor) -> None: """Extend the time t in the positional encoding if required. - The shape of `self.pe` is [1, T1, d_model]. The shape of the input x - is [N, T, d_model]. If T > T1, then we change the shape of self.pe - to [N, T, d_model]. Otherwise, nothing is done. + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. Args: x: - It is a tensor of shape [N, T, C]. + It is a tensor of shape (N, T, C). Returns: Return None. """ if self.pe is not None: if self.pe.size(1) >= x.size(1): - if self.pe.dtype != x.dtype or self.pe.device != x.device: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) + self.pe = self.pe.to(dtype=x.dtype, device=x.device) return pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) @@ -678,7 +693,7 @@ class PositionalEncoding(nn.Module): pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) - # Now pe is of shape [1, T, d_model], where T is x.size(1) + # Now pe is of shape (1, T, d_model), where T is x.size(1) self.pe = pe.to(device=x.device, dtype=x.dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -687,10 +702,10 @@ class PositionalEncoding(nn.Module): Args: x: - Its shape is [N, T, C] + Its shape is (N, T, C) Returns: - Return a tensor of shape [N, T, C] + Return a tensor of shape (N, T, C) """ self.extend_pe(x) x = x * self.xscale + self.pe[:, : x.size(1), :] @@ -972,10 +987,7 @@ def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: Return a new list-of-list, where each sublist starts with SOS ID. """ - ans = [] - for utt in token_ids: - ans.append([sos_id] + utt) - return ans + return [[sos_id] + utt for utt in token_ids] def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: @@ -992,7 +1004,9 @@ def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: Return a new list-of-list, where each sublist ends with EOS ID. """ - ans = [] - for utt in token_ids: - ans.append(utt + [eos_id]) - return ans + return [utt + [eos_id] for utt in token_ids] + + +def tolist(t: torch.Tensor) -> List[int]: + """Used by jit""" + return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/aishell/ASR/local/compile_hlg.py b/egs/aishell/ASR/local/compile_hlg.py index 407fb7d88..098d5d6a3 100755 --- a/egs/aishell/ASR/local/compile_hlg.py +++ b/egs/aishell/ASR/local/compile_hlg.py @@ -103,7 +103,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: LG.labels[LG.labels >= first_token_disambig_id] = 0 assert isinstance(LG.aux_labels, k2.RaggedTensor) - LG.aux_labels.data[LG.aux_labels.data >= first_word_disambig_id] = 0 + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 LG = k2.remove_epsilon(LG) logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index 77293f772..d170931db 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -25,6 +25,7 @@ The generated fbank features are saved in data/fbank. import logging import os +import argparse from pathlib import Path import torch @@ -43,7 +44,7 @@ torch.set_num_interop_threads(1) def compute_fbank_aishell(num_mel_bins: int = 80): src_dir = Path("data/manifests") - output_dir = Path("data/fbank40") + output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) dataset_parts = ( diff --git a/egs/aishell/ASR/local/compute_fbank_musan.py b/egs/aishell/ASR/local/compute_fbank_musan.py index 0b97fb8c5..4b0120283 100755 --- a/egs/aishell/ASR/local/compute_fbank_musan.py +++ b/egs/aishell/ASR/local/compute_fbank_musan.py @@ -25,6 +25,7 @@ The generated fbank features are saved in data/fbank. import logging import os +import argparse from pathlib import Path import torch @@ -43,7 +44,7 @@ torch.set_num_interop_threads(1) def compute_fbank_musan(num_mel_bins: int = 80): src_dir = Path("data/manifests") - output_dir = Path("data/fbank40") + output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) dataset_parts = ( diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index f70e89c65..31e41a5bc 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -27,7 +27,6 @@ stop_stage=10 # - music # - noise # - speech - dl_dir=$PWD/download . shared/parse_options.sh || exit 1 @@ -88,7 +87,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # We assume that you have downloaded the aishell corpus # to $dl_dir/aishell mkdir -p data/manifests - lhotse prepare aishell -j $nj $dl_dir/aishell data/manifests + lhotse prepare aishell $dl_dir/aishell data/manifests fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/README.md b/egs/aishell/ASR/tdnn_lstm_ctc/README.md deleted file mode 100644 index a2d80a785..000000000 --- a/egs/aishell/ASR/tdnn_lstm_ctc/README.md +++ /dev/null @@ -1,4 +0,0 @@ - -Please visit - -for how to run this recipe. diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/__init__.py b/egs/aishell/ASR/tdnn_lstm_ctc/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py deleted file mode 100644 index 9dede6288..000000000 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ /dev/null @@ -1,335 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import List, Union - -from lhotse import CutSet, Fbank, FbankConfig, load_manifest -from lhotse.dataset import ( - BucketingSampler, - CutConcatenate, - CutMix, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SingleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool - - -class AishellAsrDataModule(DataModule): - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - """ - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - super().add_arguments(parser) - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--feature-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the BucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - def train_dataloaders(self) -> DataLoader: - logging.info("About to get train cuts") - cuts_train = self.train_cuts() - - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") - - logging.info("About to create train dataset") - transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [ - SpecAugment( - num_frame_masks=2, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ] - - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using BucketingSampler.") - train_sampler = BucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - bucket_method="equal_duration", - drop_last=True, - ) - else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return train_dl - - def valid_dataloaders(self) -> DataLoader: - logging.info("About to get dev cuts") - cuts_valid = self.valid_cuts() - - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = SingleCutSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: - cuts = self.test_cuts() - is_list = isinstance(cuts, list) - test_loaders = [] - if not is_list: - cuts = [cuts] - - for cuts_test in cuts: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = SingleCutSampler( - cuts_test, max_duration=self.args.max_duration - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, batch_size=None, sampler=sampler, num_workers=1 - ) - test_loaders.append(test_dl) - - if is_list: - return test_loaders - else: - return test_loaders[0] - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest( - self.args.feature_dir / "cuts_train.json.gz" - ) - return cuts_train - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - cuts_valid = load_manifest( - self.args.feature_dir / "cuts_dev.json.gz" - ) - return cuts_valid - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - test_sets = ["test"] - cuts = [] - for test_set in test_sets: - logging.debug("About to get test cuts") - cuts.append( - load_manifest( - self.args.feature_dir / f"cuts_{test_set}.json.gz" - ) - ) - return cuts diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py b/egs/aishell/ASR/tdnn_lstm_ctc/decode.py deleted file mode 100755 index 568da3811..000000000 --- a/egs/aishell/ASR/tdnn_lstm_ctc/decode.py +++ /dev/null @@ -1,399 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import AishellAsrDataModule -from model import TdnnLstm - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import ( - get_lattice, - nbest_decoding, - one_best_decoding, - rescore_with_attention_decoder, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=19, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Supported values are: - - (1) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (2) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - """, - ) - parser.add_argument( - "--num-paths", - type=int, - default=30, - help="""Number of paths for n-best based decoding method. - Used only when "method" is nbest. - """, - ) - parser.add_argument( - "--export", - type=str2bool, - default=False, - help="""When enabled, the averaged model is saved to - tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved. - pretrained.pt contains a dict {"model": model.state_dict()}, - which can be loaded by `icefall.checkpoint.load_checkpoint()`. - """, - ) - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "exp_dir": Path("tdnn_lstm_ctc/exp/"), - "lang_dir": Path("data/lang_phone"), - "lm_dir": Path("data/lm"), - # parameters for tdnn_lstm_ctc - "subsampling_factor": 3, - "feature_dim": 80, - # parameters for decoding - "search_beam": 20, - "output_beam": 7, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - HLG: k2.Fsa, - batch: dict, - lexicon: Lexicon, -) -> Dict[str, List[List[int]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if the decoding method is 1best, the key is the string - `no_rescore`. If the decoding method is nbest, the key is the - string `no_rescore-xxx`, xxx is the num_paths. - - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding without LM rescoring. - - params.method is "nbest", it uses nbest decoding without LM rescoring. - - model: - The neural model. - HLG: - The decoding graph. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - lexicon: - It contains word symbol table. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = HLG.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is [N, T, C] - - feature = feature.permute(0, 2, 1) # now feature is [N, C, T] - - nnet_output = model(feature) - # nnet_output is [N, T, C] - - supervisions = batch["supervisions"] - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"] // params.subsampling_factor, - supervisions["num_frames"] // params.subsampling_factor, - ), - 1, - ).to(torch.int32) - - lattice = get_lattice( - nnet_output=nnet_output, - HLG=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - ) - - assert params.method in ["1best", "nbest"] - if params.method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - ) - key = f"no_rescore-{params.num_paths}" - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - return {key: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: k2.Fsa, - lexicon: Lexicon, -) -> Dict[str, List[Tuple[List[int], List[int]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. - lexicon: - It contains word symbol table. - Returns: - Return a dict, whose key may be "no-rescore" if decoding method is 1best, - or it may be "no-rescoer-100" if decoding method is nbest. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - - hyps_dict = decode_one_batch( - params=params, - model=model, - HLG=HLG, - batch=batch, - lexicon=lexicon, - ) - - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): - ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" - # We compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((list("".join(res[0])), list("".join(res[1])))) - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results_char) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"cer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("settings\tCER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, CER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log/log-decode") - logging.info("Decoding started") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_phone_id = max(lexicon.tokens) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) - HLG = HLG.to(device) - assert HLG.requires_grad is False - - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - - model = TdnnLstm( - num_features=params.feature_dim, - num_classes=max_phone_id + 1, # +1 for the blank symbol - subsampling_factor=params.subsampling_factor, - ) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) - - if params.export: - logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) - - model.to(device) - model.eval() - - aishell = AishellAsrDataModule(args) - # CAUTION: `test_sets` is for displaying only. - # If you want to skip test-clean, you have to skip - # it inside the for loop. That is, use - # - # if test_set == 'test-clean': continue - # - test_sets = ["test"] - for test_set, test_dl in zip(test_sets, aishell.test_dataloaders()): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - lexicon=lexicon, - ) - - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/model.py b/egs/aishell/ASR/tdnn_lstm_ctc/model.py deleted file mode 100644 index 5e04c11b4..000000000 --- a/egs/aishell/ASR/tdnn_lstm_ctc/model.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn - - -class TdnnLstm(nn.Module): - def __init__( - self, num_features: int, num_classes: int, subsampling_factor: int = 3 - ) -> None: - """ - Args: - num_features: - The input dimension of the model. - num_classes: - The output dimension of the model. - subsampling_factor: - It reduces the number of output frames by this factor. - """ - super().__init__() - self.num_features = num_features - self.num_classes = num_classes - self.subsampling_factor = subsampling_factor - self.tdnn = nn.Sequential( - nn.Conv1d( - in_channels=num_features, - out_channels=500, - kernel_size=3, - stride=1, - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=500, affine=False), - nn.Conv1d( - in_channels=500, - out_channels=500, - kernel_size=3, - stride=1, - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=500, affine=False), - nn.Conv1d( - in_channels=500, - out_channels=500, - kernel_size=3, - stride=self.subsampling_factor, # stride: subsampling_factor! - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=500, affine=False), - ) - self.lstms = nn.ModuleList( - [ - nn.LSTM(input_size=500, hidden_size=500, num_layers=1) - for _ in range(5) - ] - ) - self.lstm_bnorms = nn.ModuleList( - [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] - ) - self.dropout = nn.Dropout(0.2) - self.linear = nn.Linear(in_features=500, out_features=self.num_classes) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - Its shape is [N, C, T] - - Returns: - The output tensor has shape [N, T, C] - """ - x = self.tdnn(x) - x = x.permute(2, 0, 1) # (N, C, T) -> (T, N, C) -> how LSTM expects it - for lstm, bnorm in zip(self.lstms, self.lstm_bnorms): - x_new, _ = lstm(x) - x_new = bnorm(x_new.permute(1, 2, 0)).permute( - 2, 0, 1 - ) # (T, N, C) -> (N, C, T) -> (T, N, C) - x_new = self.dropout(x_new) - x = x_new + x # skip connections - x = x.transpose( - 1, 0 - ) # (T, N, C) -> (N, T, C) -> linear expects "features" in the last dim - x = self.linear(x) - x = nn.functional.log_softmax(x, dim=-1) - return x diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py b/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py deleted file mode 100644 index 8421dd3ea..000000000 --- a/egs/aishell/ASR/tdnn_lstm_ctc/pretrained.py +++ /dev/null @@ -1,231 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from model import TdnnLstm -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import ( - get_lattice, - one_best_decoding, -) -from icefall.utils import AttributeDict, get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--words-file", - type=str, - required=True, - help="Path to words.txt", - ) - - parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." - ) - - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Use the best path as decoding output. Only the transformer encoder - output is used for decoding. We call it HLG decoding. - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 80, - "subsampling_factor": 3, - "num_classes": 220, - "sample_rate": 16000, - "search_beam": 20, - "output_beam": 7, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) - # We use only the first channel - ans.append(wave[0]) - return ans - - -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = TdnnLstm( - num_features=params.feature_dim, - num_classes=params.num_classes, - subsampling_factor=params.subsampling_factor, - ) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"]) - model.to(device) - model.eval() - - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) - features = features.permute(0, 2, 1) # now features is [N, C, T] - - with torch.no_grad(): - nnet_output = model(features) - # nnet_output is [N, T, C] - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - HLG=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - assert(params.method == "1best") - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/train.py b/egs/aishell/ASR/tdnn_lstm_ctc/train.py deleted file mode 100755 index 410f07c53..000000000 --- a/egs/aishell/ASR/tdnn_lstm_ctc/train.py +++ /dev/null @@ -1,616 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional - -import k2 -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -from asr_datamodule import AishellAsrDataModule -from lhotse.utils import fix_random_seed -from model import TdnnLstm -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.optim.lr_scheduler import StepLR -from torch.utils.tensorboard import SummaryWriter - -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.graph_compiler import CtcTrainingGraphCompiler -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - encode_supervisions, - setup_logger, - str2bool, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=20, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - tdnn_lstm_ctc/exp/epoch-{start_epoch-1}.pt - """, - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - is saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" - - - lr: It specifies the initial learning rate - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - weight_decay: The weight_decay for the optimizer. - - - subsampling_factor: The subsampling factor for the model. - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval` is 0 - - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - """ - params = AttributeDict( - { - "exp_dir": Path("tdnn_lstm_ctc/exp_lr1e-4"), - "lang_dir": Path("data/lang_phone"), - "lr": 1e-4, - "feature_dim": 80, - "weight_decay": 5e-4, - "subsampling_factor": 3, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 10, - "reset_interval": 200, - "valid_interval": 1000, - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - graph_compiler: CtcTrainingGraphCompiler, - is_training: bool, -): - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of TdnnLstm in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = graph_compiler.device - feature = batch["inputs"] - # at entry, feature is [N, T, C] - feature = feature.permute(0, 2, 1) # now feature is [N, C, T] - assert feature.ndim == 3 - feature = feature.to(device) - - with torch.set_grad_enabled(is_training): - nnet_output = model(feature) - # nnet_output is [N, T, C] - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervisions = batch["supervisions"] - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - decoding_graph = graph_compiler.compile(texts) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - assert loss.requires_grad == is_training - - # train_frames and valid_frames are used for printing. - if is_training: - params.train_frames = supervision_segments[:, 2].sum().item() - else: - params.valid_frames = supervision_segments[:, 2].sum().item() - - return loss - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> None: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ - model.eval() - - tot_loss = 0.0 - tot_frames = 0.0 - for batch_idx, batch in enumerate(valid_dl): - loss = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - - loss_cpu = loss.detach().cpu().item() - tot_loss += loss_cpu - tot_frames += params.valid_frames - - if world_size > 1: - s = torch.tensor([tot_loss, tot_frames], device=loss.device) - dist.all_reduce(s, op=dist.ReduceOp.SUM) - s = s.cpu().tolist() - tot_loss = s[0] - tot_frames = s[1] - - params.valid_loss = tot_loss / tot_frames - - if params.valid_loss < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = params.valid_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: CtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = 0.0 # reset after params.reset_interval of batches - tot_frames = 0.0 # reset after params.reset_interval of batches - - params.tot_loss = 0.0 - params.tot_frames = 0.0 - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - loss_cpu = loss.detach().cpu().item() - - tot_frames += params.train_frames - tot_loss += loss_cpu - tot_avg_loss = tot_loss / tot_frames - - params.tot_frames += params.train_frames - params.tot_loss += loss_cpu - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, batch {batch_idx}, " - f"batch avg loss {loss_cpu/params.train_frames:.4f}, " - f"total avg loss: {tot_avg_loss:.4f}, " - f"batch size: {batch_size}" - ) - if tb_writer is not None: - tb_writer.add_scalar( - "train/current_loss", - loss_cpu / params.train_frames, - params.batch_idx_train, - ) - - tb_writer.add_scalar( - "train/tot_avg_loss", - tot_avg_loss, - params.batch_idx_train, - ) - - if batch_idx > 0 and batch_idx % params.reset_interval == 0: - tot_loss = 0 - tot_frames = 0 - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info( - f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," - f" best valid loss: {params.best_valid_loss:.4f} " - f"best valid epoch: {params.best_valid_epoch}" - ) - - params.train_loss = params.tot_loss / params.tot_frames - - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(42) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_phone_id = max(lexicon.tokens) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device) - - model = TdnnLstm( - num_features=params.feature_dim, - num_classes=max_phone_id + 1, # +1 for the blank symbol - subsampling_factor=params.subsampling_factor, - ) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[rank]) - - optimizer = optim.AdamW( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, - ) - scheduler = StepLR(optimizer, step_size=8, gamma=0.1) - - if checkpoints: - optimizer.load_state_dict(checkpoints["optimizer"]) - scheduler.load_state_dict(checkpoints["scheduler"]) - - aishell = AishellAsrDataModule(args) - train_dl = aishell.train_dataloaders() - valid_dl = aishell.valid_dataloaders() - - for epoch in range(params.start_epoch, params.num_epochs): - train_dl.sampler.set_epoch(epoch) - - if epoch > params.start_epoch: - logging.info(f"epoch {epoch}, lr: {scheduler.get_last_lr()[0]}") - - if tb_writer is not None: - tb_writer.add_scalar( - "train/lr", - scheduler.get_last_lr()[0], - params.batch_idx_train, - ) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - scheduler.step() - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - logging.info("Done!") - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - main()