diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 2b2967506..42fa2308e 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -15,15 +15,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Usage: + ./conformer_ctc/ali.py \ + --exp-dir ./conformer_ctc/exp \ + --lang-dir ./data/lang_bpe_500 \ + --epoch 20 \ + --avg 10 \ + --max-duration 300 \ + --dataset train-clean-100 \ + --out-dir data/ali +""" + import argparse import logging from pathlib import Path -from typing import List, Tuple import k2 +import numpy as np import torch from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer +from lhotse import CutSet +from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint @@ -34,7 +48,6 @@ from icefall.utils import ( AttributeDict, encode_supervisions, get_alignments, - save_alignments, setup_logger, ) @@ -75,10 +88,42 @@ def get_parser(): ) parser.add_argument( - "--ali-dir", + "--out-dir", type=str, - default="data/ali_500", - help="The experiment dir", + required=True, + help="""Output directory. + It contains 3 generated files: + + - labels_xxx.h5 + - aux_labels_xxx.h5 + - cuts_xxx.json.gz + + where xxx is the value of `--dataset`. For instance, if + `--dataset` is `train-clean-100`, it will contain 3 files: + + - `labels_train-clean-100.h5` + - `aux_labels_train-clean-100.h5` + - `cuts_train-clean-100.json.gz` + + Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise + alignment. The difference is that labels_xxx.h5 contains repeats. + """, + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="""The name of the dataset to compute alignments for. + Possible values are: + - test-clean. + - test-other + - train-clean-100 + - train-clean-360 + - train-other-500 + - dev-clean + - dev-other + """, ) return parser @@ -91,7 +136,9 @@ def get_params() -> AttributeDict: "nhead": 8, "attention_dim": 512, "subsampling_factor": 4, - "num_decoder_layers": 6, + # Set it to 0 since attention decoder + # is not used for computing alignments + "num_decoder_layers": 0, "vgg_frontend": False, "use_feat_batchnorm": True, "output_beam": 10, @@ -105,9 +152,11 @@ def get_params() -> AttributeDict: def compute_alignments( model: torch.nn.Module, dl: torch.utils.data.DataLoader, + labels_writer: FeaturesWriter, + aux_labels_writer: FeaturesWriter, params: AttributeDict, graph_compiler: BpeCtcTrainingGraphCompiler, -) -> List[Tuple[str, List[int]]]: +) -> CutSet: """Compute the framewise alignments of a dataset. Args: @@ -120,9 +169,10 @@ def compute_alignments( graph_compiler: It converts token IDs to decoding graphs. Returns: - Return a list of tuples. Each tuple contains two entries: - - Utterance ID - - Framewise alignments (token IDs) after subsampling + Return a CutSet. Each cut has two custom fields: labels_alignment + and aux_labels_alignment, containing framewise alignments information. + Both are of type `lhotse.array.TemporalArray`. The difference between + the two alignments is that `labels_alignment` contain repeats. """ try: num_batches = len(dl) @@ -131,7 +181,7 @@ def compute_alignments( num_cuts = 0 device = graph_compiler.device - ans = [] + cuts = [] for batch_idx, batch in enumerate(dl): feature = batch["inputs"] @@ -140,11 +190,10 @@ def compute_alignments( feature = feature.to(device) supervisions = batch["supervisions"] + cut_list = supervisions["cut"] - cut_ids = [] - for cut in supervisions["cut"]: - assert len(cut.supervisions) == 1 - cut_ids.append(cut.id) + for cut in cut_list: + assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}" nnet_output, encoder_memory, memory_mask = model(feature, supervisions) # nnet_output is [N, T, C] @@ -156,7 +205,8 @@ def compute_alignments( # In general, new2old is an identity map since lhotse sorts the returned # cuts by duration in descending order new2old = supervision_segments[:, 0].tolist() - cut_ids = [cut_ids[i] for i in new2old] + + cut_list = [cut_list[i] for i in new2old] token_ids = graph_compiler.texts_to_ids(texts) decoding_graph = graph_compiler.compile(token_ids) @@ -178,11 +228,32 @@ def compute_alignments( use_double_scores=params.use_double_scores, ) - ali_ids = get_alignments(best_path) - assert len(ali_ids) == len(cut_ids) - ans += list(zip(cut_ids, ali_ids)) + labels_ali = get_alignments(best_path, kind="labels") + aux_labels_ali = get_alignments(best_path, kind="aux_labels") + assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) + for cut, labels, aux_labels in zip( + cut_list, labels_ali, aux_labels_ali + ): + cut.labels_alignment = labels_writer.store_array( + key=cut.id, + value=np.asarray(labels, dtype=np.int32), + # frame shift is 0.01s, subsampling_factor is 4 + frame_shift=0.04, + temporal_dim=0, + start=0, + ) + cut.aux_labels_alignment = aux_labels_writer.store_array( + key=cut.id, + value=np.asarray(aux_labels, dtype=np.int32), + # frame shift is 0.01s, subsampling_factor is 4 + frame_shift=0.04, + temporal_dim=0, + start=0, + ) - num_cuts += len(ali_ids) + cuts += cut_list + + num_cuts += len(cut_list) if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" @@ -191,7 +262,7 @@ def compute_alignments( f"batch {batch_str}, cuts processed until now is {num_cuts}" ) - return ans + return CutSet.from_cuts(cuts) @torch.no_grad() @@ -200,20 +271,35 @@ def main(): LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() - assert args.return_cuts is True - assert args.concatenate_cuts is False - if args.full_libri is False: - print("Changing --full-libri to True") - args.full_libri = True + args.enable_spec_aug = False + args.enable_musan = False + args.return_cuts = True + args.concatenate_cuts = False params = get_params() params.update(vars(args)) - setup_logger(f"{params.exp_dir}/log/ali") + setup_logger(f"{params.exp_dir}/log-ali") - logging.info("Computing alignment - started") + logging.info(f"Computing alignments for {params.dataset} - started") logging.info(params) + out_dir = Path(params.out_dir) + out_dir.mkdir(exist_ok=True) + + out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" + out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" + out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz" + + for f in ( + out_labels_ali_filename, + out_aux_labels_ali_filename, + out_manifest_filename, + ): + if f.exists(): + logging.info(f"{f} exists - skipping") + return + lexicon = Lexicon(params.lang_dir) max_token_id = max(lexicon.tokens) num_classes = max_token_id + 1 # +1 for the blank @@ -221,6 +307,7 @@ def main(): device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) + logging.info(f"device: {device}") graph_compiler = BpeCtcTrainingGraphCompiler( params.lang_dir, @@ -240,9 +327,12 @@ def main(): vgg_frontend=params.vgg_frontend, use_feat_batchnorm=params.use_feat_batchnorm, ) + model.to(device) if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + load_checkpoint( + f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False + ) else: start = params.epoch - params.avg + 1 filenames = [] @@ -250,61 +340,56 @@ def main(): if start >= 0: filenames.append(f"{params.exp_dir}/epoch-{i}.pt") logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) - model.to(device) model.eval() librispeech = LibriSpeechAsrDataModule(args) + if params.dataset == "test-clean": + test_clean_cuts = librispeech.test_clean_cuts() + dl = librispeech.test_dataloaders(test_clean_cuts) + elif params.dataset == "test-other": + test_other_cuts = librispeech.test_other_cuts() + dl = librispeech.test_dataloaders(test_other_cuts) + elif params.dataset == "train-clean-100": + train_clean_100_cuts = librispeech.train_clean_100_cuts() + dl = librispeech.train_dataloaders(train_clean_100_cuts) + elif params.dataset == "train-clean-360": + train_clean_360_cuts = librispeech.train_clean_360_cuts() + dl = librispeech.train_dataloaders(train_clean_360_cuts) + elif params.dataset == "train-other-500": + train_other_500_cuts = librispeech.train_other_500_cuts() + dl = librispeech.train_dataloaders(train_other_500_cuts) + elif params.dataset == "dev-clean": + dev_clean_cuts = librispeech.dev_clean_cuts() + dl = librispeech.valid_dataloaders(dev_clean_cuts) + else: + assert params.dataset == "dev-other", f"{params.dataset}" + dev_other_cuts = librispeech.dev_other_cuts() + dl = librispeech.valid_dataloaders(dev_other_cuts) - train_dl = librispeech.train_dataloaders() - valid_dl = librispeech.valid_dataloaders() - test_dl = librispeech.test_dataloaders() # a list - - ali_dir = Path(params.ali_dir) - ali_dir.mkdir(exist_ok=True) - - enabled_datasets = { - "test_clean": test_dl[0], - "test_other": test_dl[1], - "train-960": train_dl, - "valid": valid_dl, - } - # For train-960, it takes about 3 hours 40 minutes, i.e., 3.67 hours to - # compute the alignments if you use --max-duration=500 - # - # There are 960 * 3 = 2880 hours data and it takes only - # 3 hours 40 minutes to get the alignment. - # The RTF is roughly: 3.67 / 2880 = 0.0012743 - # - # At the end, you would see - # 2021-09-28 11:32:46,690 INFO [ali.py:188] batch 21000/?, cuts processed until now is 836270 # noqa - # 2021-09-28 11:33:45,084 INFO [ali.py:188] batch 21100/?, cuts processed until now is 840268 # noqa - for name, dl in enabled_datasets.items(): - logging.info(f"Processing {name}") - if name == "train-960": - logging.info( - f"It will take about 3 hours 40 minutes for {name}, " - "which contains 960 * 3 = 2880 hours of data" + logging.info(f"Processing {params.dataset}") + with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer: + with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer: + cut_set = compute_alignments( + model=model, + dl=dl, + labels_writer=labels_writer, + aux_labels_writer=aux_labels_writer, + params=params, + graph_compiler=graph_compiler, ) - alignments = compute_alignments( - model=model, - dl=dl, - params=params, - graph_compiler=graph_compiler, - ) - num_utt = len(alignments) - alignments = dict(alignments) - assert num_utt == len(alignments) - filename = ali_dir / f"{name}.pt" - save_alignments( - alignments=alignments, - subsampling_factor=params.subsampling_factor, - filename=filename, - ) - logging.info( - f"For dataset {name}, its alignments are saved to {filename}" - ) + + cut_set.to_file(out_manifest_filename) + + logging.info( + f"For dataset {params.dataset}, its alignments with repeats are " + f"saved to {out_labels_ali_filename}, the alignments without repeats " + f"are saved to {out_aux_labels_ali_filename}, and the cut manifest " + f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}" + ) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 539590b1b..bcd363df3 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -665,14 +665,17 @@ def main(): logging.info(f"Number of model parameters: {num_param}") librispeech = LibriSpeechAsrDataModule(args) - # CAUTION: `test_sets` is for displaying only. - # If you want to skip test-clean, you have to skip - # it inside the for loop. That is, use - # - # if test_set == 'test-clean': continue - # + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + test_sets = ["test-clean", "test-other"] - for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( dl=test_dl, params=params, diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index ec9b0b7c2..2fbf17a62 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -618,8 +618,16 @@ def run(rank, world_size, args): optimizer.load_state_dict(checkpoints["optimizer"]) librispeech = LibriSpeechAsrDataModule(args) - train_dl = librispeech.train_dataloaders() - valid_dl = librispeech.valid_dataloaders() + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + train_dl = librispeech.train_dataloaders(train_cuts) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) scan_pessimistic_batches_for_oom( model=model, diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 950eba438..e075a2d03 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -19,7 +19,6 @@ import argparse import logging from functools import lru_cache from pathlib import Path -from typing import List, Union from lhotse import CutSet, Fbank, FbankConfig, load_manifest from lhotse.dataset import ( @@ -34,11 +33,10 @@ from lhotse.dataset import ( from lhotse.dataset.input_strategies import OnTheFlyFeatures from torch.utils.data import DataLoader -from icefall.dataset.datamodule import DataModule from icefall.utils import str2bool -class LibriSpeechAsrDataModule(DataModule): +class LibriSpeechAsrDataModule: """ DataModule for k2 ASR experiments. It assumes there is always one train and valid dataloader, @@ -56,9 +54,11 @@ class LibriSpeechAsrDataModule(DataModule): This class should be derived for specific corpora used in ASR tasks. """ + def __init__(self, args: argparse.Namespace): + self.args = args + @classmethod def add_arguments(cls, parser: argparse.ArgumentParser): - super().add_arguments(parser) group = parser.add_argument_group( title="ASR data related options", description="These options are used for the preparation of " @@ -74,7 +74,7 @@ class LibriSpeechAsrDataModule(DataModule): "Otherwise, use 100h subset.", ) group.add_argument( - "--feature-dir", + "--manifest-dir", type=Path, default=Path("data/fbank"), help="Path to directory with train/valid/test cuts.", @@ -154,17 +154,48 @@ class LibriSpeechAsrDataModule(DataModule): "collect the batches.", ) - def train_dataloaders(self) -> DataLoader: - logging.info("About to get train cuts") - cuts_train = self.train_cuts() + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "cuts_musan.json.gz" + ) + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + transforms.append( + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) + ) + else: + logging.info("Disable MUSAN") - logging.info("About to create train dataset") - transforms = [ - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) - ] if self.args.concatenate_cuts: logging.info( f"Using cut concatenation with duration factor " @@ -179,15 +210,25 @@ class LibriSpeechAsrDataModule(DataModule): ) ] + transforms - input_transforms = [ - SpecAugment( - num_frame_masks=2, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" ) - ] + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=2, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + logging.info("About to create train dataset") train = K2SpeechRecognitionDataset( cut_transforms=transforms, input_transforms=input_transforms, @@ -243,10 +284,7 @@ class LibriSpeechAsrDataModule(DataModule): return train_dl - def valid_dataloaders(self) -> DataLoader: - logging.info("About to get dev cuts") - cuts_valid = self.valid_cuts() - + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: transforms = [] if self.args.concatenate_cuts: transforms = [ @@ -285,75 +323,63 @@ class LibriSpeechAsrDataModule(DataModule): return valid_dl - def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]: - cuts = self.test_cuts() - is_list = isinstance(cuts, list) - test_loaders = [] - if not is_list: - cuts = [cuts] - - for cuts_test in cuts: - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = BucketingSampler( - cuts_test, max_duration=self.args.max_duration, shuffle=False - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - test_loaders.append(test_dl) - - if is_list: - return test_loaders - else: - return test_loaders[0] - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest( - self.args.feature_dir / "cuts_train-clean-100.json.gz" + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, ) - if self.args.full_libri: - cuts_train = ( - cuts_train - + load_manifest( - self.args.feature_dir / "cuts_train-clean-360.json.gz" - ) - + load_manifest( - self.args.feature_dir / "cuts_train-other-500.json.gz" - ) - ) - return cuts_train + sampler = BucketingSampler( + cuts, max_duration=self.args.max_duration, shuffle=False + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - cuts_valid = load_manifest( - self.args.feature_dir / "cuts_dev-clean.json.gz" - ) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz") - return cuts_valid + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-100.json.gz" + ) @lru_cache() - def test_cuts(self) -> List[CutSet]: - test_sets = ["test-clean", "test-other"] - cuts = [] - for test_set in test_sets: - logging.debug("About to get test cuts") - cuts.append( - load_manifest( - self.args.feature_dir / f"cuts_{test_set}.json.gz" - ) - ) - return cuts + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-360.json.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest( + self.args.manifest_dir / "cuts_train-other-500.json.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz") + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz") + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz") + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz") diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 636cb9388..9c964c2aa 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -474,14 +474,17 @@ def main(): model.eval() librispeech = LibriSpeechAsrDataModule(args) - # CAUTION: `test_sets` is for displaying only. - # If you want to skip test-clean, you have to skip - # it inside the for loop. That is, use - # - # if test_set == 'test-clean': continue - # + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + test_sets = ["test-clean", "test-other"] - for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( dl=test_dl, params=params, diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 99fe170d2..7439e157a 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -532,8 +532,16 @@ def run(rank, world_size, args): scheduler.load_state_dict(checkpoints["scheduler"]) librispeech = LibriSpeechAsrDataModule(args) - train_dl = librispeech.train_dataloaders() - valid_dl = librispeech.valid_dataloaders() + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + train_dl = librispeech.train_dataloaders(train_cuts) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index d6ddbf515..dbe3c1315 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -85,6 +85,7 @@ def load_checkpoint( optimizer: Optional[Optimizer] = None, scheduler: Optional[_LRScheduler] = None, scaler: Optional[GradScaler] = None, + strict: bool = False, ) -> Dict[str, Any]: """ TODO: document it @@ -101,9 +102,9 @@ def load_checkpoint( src_key = "{}.{}".format("module", key) dst_state_dict[key] = src_state_dict.pop(src_key) assert len(src_state_dict) == 0 - model.load_state_dict(dst_state_dict, strict=False) + model.load_state_dict(dst_state_dict, strict=strict) else: - model.load_state_dict(checkpoint["model"], strict=False) + model.load_state_dict(checkpoint["model"], strict=strict) checkpoint.pop("model") diff --git a/icefall/utils.py b/icefall/utils.py index ba9436fa4..1b2f12184 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -224,8 +224,8 @@ def get_texts( return aux_labels.tolist() -def get_alignments(best_paths: k2.Fsa) -> List[List[int]]: - """Extract the token IDs (from best_paths.labels) from the best-path FSAs. +def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]: + """Extract labels or aux_labels from the best-path FSAs. Args: best_paths: @@ -233,17 +233,34 @@ def get_alignments(best_paths: k2.Fsa) -> List[List[int]]: containing multiple FSAs, which is expected to be the result of k2.shortest_path (otherwise the returned values won't be meaningful). + kind: + Possible values are: "labels" and "aux_labels". Caution: When it is + "labels", the resulting alignments contain repeats. Returns: Returns a list of lists of int, containing the token sequences we decoded. For `ans[i]`, its length equals to the number of frames after subsampling of the i-th utterance in the batch. + + Example: + When `kind` is `labels`, one possible alignment example is (with + repeats):: + + c c c blk a a blk blk t t t blk blk + + If `kind` is `aux_labels`, the above example changes to:: + + c blk blk blk a blk blk blk t blk blk blk blk + """ + assert kind in ("labels", "aux_labels") # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here - label_shape = best_paths.arcs.shape().remove_axis(1) - # label_shape has axes [fsa][arc] - labels = k2.RaggedTensor(label_shape, best_paths.labels.contiguous()) - labels = labels.remove_values_eq(-1) - return labels.tolist() + token_shape = best_paths.arcs.shape().remove_axis(1) + # token_shape has axes [fsa][arc] + tokens = k2.RaggedTensor( + token_shape, getattr(best_paths, kind).contiguous() + ) + tokens = tokens.remove_values_eq(-1) + return tokens.tolist() def save_alignments( diff --git a/test/test_ali.py b/test/test_ali.py index d8ada33e8..b107a6d80 100755 --- a/test/test_ali.py +++ b/test/test_ali.py @@ -25,199 +25,65 @@ from pathlib import Path -import k2 -import torch -from lhotse import load_manifest +from lhotse import CutSet, load_manifest from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler -from torch.nn.utils.rnn import pad_sequence +from lhotse.dataset.collation import collate_custom_field from torch.utils.data import DataLoader -from icefall.ali import ( - convert_alignments_to_tensor, - load_alignments, - lookup_alignments, -) -from icefall.decode import get_lattice, one_best_decoding -from icefall.lexicon import Lexicon -from icefall.utils import get_texts - ICEFALL_DIR = Path(__file__).resolve().parent.parent egs_dir = ICEFALL_DIR / "egs/librispeech/ASR" lang_dir = egs_dir / "data/lang_bpe_500" -# cut_json = egs_dir / "data/fbank/cuts_train-clean-100.json.gz" -# cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz" -# cut_json = egs_dir / "data/fbank/cuts_train-other-500.json.gz" -# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt" - -cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz" -ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt" +cuts_json = egs_dir / "data/ali/cuts_dev-clean.json.gz" def data_exists(): - return ali_filename.exists() and cut_json.exists() and lang_dir.exists() + return cuts_json.exists() and lang_dir.exists() def get_dataloader(): - cuts_train = load_manifest(cut_json) - cuts_train = cuts_train.with_features_path_prefix(egs_dir) - train_sampler = SingleCutSampler( - cuts_train, - max_duration=40, + cuts = load_manifest(cuts_json) + print(cuts[0]) + cuts = cuts.with_features_path_prefix(egs_dir) + sampler = SingleCutSampler( + cuts, + max_duration=10, shuffle=False, ) - train = K2SpeechRecognitionDataset(return_cuts=True) + dataset = K2SpeechRecognitionDataset(return_cuts=True) - train_dl = DataLoader( - train, - sampler=train_sampler, + dl = DataLoader( + dataset, + sampler=sampler, batch_size=None, num_workers=1, persistent_workers=False, ) - return train_dl - - -def test_one_hot(): - a = [1, 3, 2] - b = [1, 0, 4, 2] - c = [torch.tensor(a), torch.tensor(b)] - d = pad_sequence(c, batch_first=True, padding_value=0) - f = torch.nn.functional.one_hot(d, num_classes=5) - e = (1 - f) * -10.0 - expected = torch.tensor( - [ - [ - [-10, 0, -10, -10, -10], - [-10, -10, -10, 0, -10], - [-10, -10, 0, -10, -10], - [0, -10, -10, -10, -10], - ], - [ - [-10, 0, -10, -10, -10], - [0, -10, -10, -10, -10], - [-10, -10, -10, -10, 0], - [-10, -10, 0, -10, -10], - ], - ] - ).to(e.dtype) - assert torch.all(torch.eq(e, expected)) + return dl def test(): - """ - The purpose of this test is to show that we can use pre-computed - alignments to construct a mask, adding it to a randomly generated - nnet_output, to decode the correct transcript from the resulting - nnet_output. - """ if not data_exists(): return - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) dl = get_dataloader() - - subsampling_factor, ali = load_alignments(ali_filename) - ali = convert_alignments_to_tensor(ali, device=device) - - lexicon = Lexicon(lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - word_table = lexicon.word_table - - HLG = k2.Fsa.from_dict( - torch.load(f"{lang_dir}/HLG.pt", map_location=device) - ) - for batch in dl: - features = batch["inputs"] supervisions = batch["supervisions"] - N = features.shape[0] - T = features.shape[1] // subsampling_factor - nnet_output = ( - torch.rand(N, T, num_classes, dtype=torch.float32, device=device) - .softmax(dim=-1) - .log() - ) - cut_ids = [cut.id for cut in supervisions["cut"]] - mask = lookup_alignments( - cut_ids=cut_ids, alignments=ali, num_classes=num_classes - ) - min_len = min(nnet_output.shape[1], mask.shape[1]) - ali_model_scale = 0.8 - - nnet_output[:, :min_len, :] += ali_model_scale * mask[:, :min_len, :] - - supervisions = batch["supervisions"] - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"] // subsampling_factor, - supervisions["num_frames"] // subsampling_factor, - ), - 1, - ).to(torch.int32) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=20, - output_beam=8, - min_active_states=30, - max_active_states=10000, - subsampling_factor=subsampling_factor, + cuts = supervisions["cut"] + labels_alignment, labels_alignment_length = collate_custom_field( + CutSet.from_cuts(cuts), "labels_alignment" ) - best_path = one_best_decoding(lattice=lattice, use_double_scores=True) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - hyps = [" ".join(s) for s in hyps] - print(hyps) - print(supervisions["text"]) + ( + aux_labels_alignment, + aux_labels_alignment_length, + ) = collate_custom_field(CutSet.from_cuts(cuts), "aux_labels_alignment") + + print(labels_alignment) + print(aux_labels_alignment) + print(labels_alignment_length) + print(aux_labels_alignment_length) break -def show_cut_ids(): - # The purpose of this function is to check that - # for each utterance in the training set, there is - # a corresponding alignment. - # - # After generating a1.txt and b1.txt - # You can use - # wc -l a1.txt b1.txt - # which should show the same number of lines. - # - # cat a1.txt | sort | uniq > a11.txt - # cat b1.txt | sort | uniq > b11.txt - # - # md5sum a11.txt b11.txt - # which should show the identical hash - # - # diff a11.txt b11.txt - # should print nothing - - subsampling_factor, ali = load_alignments(ali_filename) - with open("a1.txt", "w") as f: - for key in ali: - f.write(f"{key}\n") - - # dl = get_dataloader() - cuts_train = ( - load_manifest(egs_dir / "data/fbank/cuts_train-clean-100.json.gz") - + load_manifest(egs_dir / "data/fbank/cuts_train-clean-360.json.gz") - + load_manifest(egs_dir / "data/fbank/cuts_train-other-500.json.gz") - ) - - ans = [] - for cut in cuts_train: - ans.append(cut.id) - with open("b1.txt", "w") as f: - for line in ans: - f.write(f"{line}\n") - - if __name__ == "__main__": test()