From 8469f9ae0a2d7686f04e558fba8ddfb5505109fc Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 21 Aug 2021 09:53:46 +0800 Subject: [PATCH] Refactor asr_datamodule. (#15) * WIP: Refactor asr_datamodule. * Fixes after review. * Minor fixes. --- .../ASR/conformer_ctc/asr_datamodule.py | 1 + egs/librispeech/ASR/conformer_ctc/decode.py | 16 ++- egs/librispeech/ASR/conformer_ctc/train.py | 8 +- .../ASR/tdnn_lstm_ctc}/asr_datamodule.py | 117 +++++++++++++++--- egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 11 +- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 11 +- icefall/checkpoint.py | 2 +- icefall/dataset/librispeech.py | 68 ---------- 8 files changed, 128 insertions(+), 106 deletions(-) create mode 120000 egs/librispeech/ASR/conformer_ctc/asr_datamodule.py rename {icefall/dataset => egs/librispeech/ASR/tdnn_lstm_ctc}/asr_datamodule.py (71%) delete mode 100644 icefall/dataset/librispeech.py diff --git a/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py b/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py new file mode 120000 index 000000000..fa1b8cca3 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/asr_datamodule.py @@ -0,0 +1 @@ +../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index c17a8b284..c540b1ea1 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -13,11 +13,11 @@ from typing import Dict, List, Optional, Tuple import k2 import torch import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.decode import ( get_lattice, nbest_decoding, @@ -222,7 +222,7 @@ def decode_one_batch( use_double_scores=params.use_double_scores, scale=params.lattice_score_scale, ) - key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" + key = f"no_rescore-scale-{params.lattice_score_scale}-{params.num_paths}" # noqa hyps = get_texts(best_path) hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] @@ -317,7 +317,11 @@ def decode_dataset( results = [] num_cuts = 0 - tot_num_batches = len(dl) + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -346,10 +350,10 @@ def decode_dataset( num_cuts += len(batch["supervisions"]["text"]) if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info( - f"batch {batch_idx}/{tot_num_batches}, cuts processed until now is " - f"{num_cuts}" - f"batch {batch_idx}, cuts processed until now is {num_cuts}" + f"batch {batch_str}, cuts processed until now is {num_cuts}" ) return results diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index d3ea8efb0..d17ee6164 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -13,10 +13,10 @@ import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_value_ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam @@ -24,7 +24,6 @@ from transformer import Noam from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.utils import ( @@ -61,9 +60,6 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) - # TODO: add extra arguments and support DDP training. - # Currently, only single GPU training is implemented. Will add - # DDP training once single GPU training is finished. return parser @@ -463,7 +459,7 @@ def train_one_epoch( optimizer.zero_grad() loss.backward() - clip_grad_value_(model.parameters(), 5.0) + clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() loss_cpu = loss.detach().cpu().item() diff --git a/icefall/dataset/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py similarity index 71% rename from icefall/dataset/asr_datamodule.py rename to egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 73eef9c31..8d8c7a366 100644 --- a/icefall/dataset/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -1,14 +1,16 @@ import argparse import logging +from functools import lru_cache from pathlib import Path from typing import List, Union -from lhotse import Fbank, FbankConfig, load_manifest +from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( BucketingSampler, CutConcatenate, CutMix, K2SpeechRecognitionDataset, + PrecomputedFeatures, SingleCutSampler, SpecAugment, ) @@ -19,7 +21,7 @@ from icefall.dataset.datamodule import DataModule from icefall.utils import str2bool -class AsrDataModule(DataModule): +class LibriSpeechAsrDataModule(DataModule): """ DataModule for K2 ASR experiments. It assumes there is always one train and valid dataloader, @@ -47,6 +49,13 @@ class AsrDataModule(DataModule): "effective batch sizes, sampling strategies, applied data " "augmentations, etc.", ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", + ) group.add_argument( "--feature-dir", type=Path, @@ -77,7 +86,7 @@ class AsrDataModule(DataModule): group.add_argument( "--concatenate-cuts", type=str2bool, - default=True, + default=False, help="When enabled, utterances (cuts) will be concatenated " "to minimize the amount of padding.", ) @@ -104,6 +113,29 @@ class AsrDataModule(DataModule): "extraction. Will drop existing precomputed feature manifests " "if available.", ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) def train_dataloaders(self) -> DataLoader: logging.info("About to get train cuts") @@ -138,9 +170,9 @@ class AsrDataModule(DataModule): ] train = K2SpeechRecognitionDataset( - cuts_train, cut_transforms=transforms, input_transforms=input_transforms, + return_cuts=self.args.return_cuts, ) if self.args.on_the_fly_feats: @@ -154,14 +186,13 @@ class AsrDataModule(DataModule): # to be strict (e.g. could be randomized) # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa # Drop feats to be on the safe side. - cuts_train = cuts_train.drop_features() train = K2SpeechRecognitionDataset( - cuts=cuts_train, cut_transforms=transforms, input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=80)) ), input_transforms=input_transforms, + return_cuts=self.args.return_cuts, ) if self.args.bucketing_sampler: @@ -169,9 +200,9 @@ class AsrDataModule(DataModule): train_sampler = BucketingSampler( cuts_train, max_duration=self.args.max_duration, - shuffle=True, + shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, - bucket_method='equal_duration', + bucket_method="equal_duration", drop_last=True, ) else: @@ -179,36 +210,50 @@ class AsrDataModule(DataModule): train_sampler = SingleCutSampler( cuts_train, max_duration=self.args.max_duration, - shuffle=True, + shuffle=self.args.shuffle, ) logging.info("About to create train dataloader") + train_dl = DataLoader( train, sampler=train_sampler, batch_size=None, - num_workers=2, + num_workers=self.args.num_workers, persistent_workers=False, ) + return train_dl def valid_dataloaders(self) -> DataLoader: logging.info("About to get dev cuts") cuts_valid = self.valid_cuts() + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + logging.info("About to create dev dataset") if self.args.on_the_fly_feats: - cuts_valid = cuts_valid.drop_features() validate = K2SpeechRecognitionDataset( - cuts_valid.drop_features(), + cut_transforms=transforms, input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=80)) ), + return_cuts=self.args.return_cuts, ) else: - validate = K2SpeechRecognitionDataset(cuts_valid) + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) valid_sampler = SingleCutSampler( cuts_valid, max_duration=self.args.max_duration, + shuffle=False, ) logging.info("About to create dev dataloader") valid_dl = DataLoader( @@ -218,6 +263,7 @@ class AsrDataModule(DataModule): num_workers=2, persistent_workers=False, ) + return valid_dl def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: @@ -230,10 +276,12 @@ class AsrDataModule(DataModule): for cuts_test in cuts: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - cuts_test, input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=80)) - ), + ) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, ) sampler = SingleCutSampler( cuts_test, max_duration=self.args.max_duration @@ -248,3 +296,42 @@ class AsrDataModule(DataModule): return test_loaders else: return test_loaders[0] + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + cuts_train = load_manifest( + self.args.feature_dir / "cuts_train-clean-100.json.gz" + ) + if self.args.full_libri: + cuts_train = ( + cuts_train + + load_manifest( + self.args.feature_dir / "cuts_train-clean-360.json.gz" + ) + + load_manifest( + self.args.feature_dir / "cuts_train-other-500.json.gz" + ) + ) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest( + self.args.feature_dir / "cuts_dev-clean.json.gz" + ) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz") + return cuts_valid + + @lru_cache() + def test_cuts(self) -> List[CutSet]: + test_sets = ["test-clean", "test-other"] + cuts = [] + for test_set in test_sets: + logging.debug("About to get test cuts") + cuts.append( + load_manifest( + self.args.feature_dir / f"cuts_{test_set}.json.gz" + ) + ) + return cuts diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 9a1aad579..72f39ef40 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -10,10 +10,10 @@ from typing import Dict, List, Optional, Tuple import k2 import torch import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule from model import TdnnLstm from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.decode import ( get_lattice, nbest_decoding, @@ -237,6 +237,11 @@ def decode_dataset( num_cuts = 0 + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + results = defaultdict(list) for batch_idx, batch in enumerate(dl): texts = batch["supervisions"]["text"] @@ -262,8 +267,10 @@ def decode_dataset( num_cuts += len(batch["supervisions"]["text"]) if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info( - f"batch {batch_idx}, cuts processed until now is {num_cuts}" + f"batch {batch_str}, cuts processed until now is {num_cuts}" ) return results diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index dbb9f64ec..4adb988a0 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -1,7 +1,5 @@ #!/usr/bin/env python3 -# This is just at the very beginning ... - import argparse import logging from pathlib import Path @@ -14,16 +12,16 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn import torch.optim as optim +from asr_datamodule import LibriSpeechAsrDataModule from lhotse.utils import fix_random_seed from model import TdnnLstm from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_value_ +from torch.nn.utils import clip_grad_norm_ from torch.optim.lr_scheduler import StepLR from torch.utils.tensorboard import SummaryWriter from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.dist import cleanup_dist, setup_dist from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon @@ -61,9 +59,6 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) - # TODO: add extra arguments and support DDP training. - # Currently, only single GPU training is implemented. Will add - # DDP training once single GPU training is finished. return parser @@ -406,7 +401,7 @@ def train_one_epoch( optimizer.zero_grad() loss.backward() - clip_grad_value_(model.parameters(), 5.0) + clip_grad_norm_(model.parameters(), 5.0, 2.0) optimizer.step() loss_cpu = loss.detach().cpu().item() diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index e45df4fe4..a64ecfcf6 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -91,7 +91,7 @@ def load_checkpoint( checkpoint.pop("model") def load(name, obj): - s = checkpoint[name] + s = checkpoint.get(name, None) if obj and s: obj.load_state_dict(s) checkpoint.pop(name) diff --git a/icefall/dataset/librispeech.py b/icefall/dataset/librispeech.py deleted file mode 100644 index 5c18041ed..000000000 --- a/icefall/dataset/librispeech.py +++ /dev/null @@ -1,68 +0,0 @@ -import argparse -import logging -from functools import lru_cache -from typing import List - -from lhotse import CutSet, load_manifest - -from icefall.dataset.asr_datamodule import AsrDataModule -from icefall.utils import str2bool - - -class LibriSpeechAsrDataModule(AsrDataModule): - """ - LibriSpeech ASR data module. Can be used for 100h subset - (``--full-libri false``) or full 960h set. - The train and valid cuts for standard Libri splits are - concatenated into a single CutSet/DataLoader. - """ - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - super().add_arguments(parser) - group = parser.add_argument_group(title="LibriSpeech specific options") - group.add_argument( - "--full-libri", - type=str2bool, - default=True, - help="When enabled, use 960h LibriSpeech.", - ) - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest( - self.args.feature_dir / "cuts_train-clean-100.json.gz" - ) - if self.args.full_libri: - cuts_train = ( - cuts_train - + load_manifest( - self.args.feature_dir / "cuts_train-clean-360.json.gz" - ) - + load_manifest( - self.args.feature_dir / "cuts_train-other-500.json.gz" - ) - ) - return cuts_train - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - cuts_valid = load_manifest( - self.args.feature_dir / "cuts_dev-clean.json.gz" - ) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz") - return cuts_valid - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - test_sets = ["test-clean", "test-other"] - cuts = [] - for test_set in test_sets: - logging.debug("About to get test cuts") - cuts.append( - load_manifest( - self.args.feature_dir / f"cuts_{test_set}.json.gz" - ) - ) - return cuts