diff --git a/egs/librispeech/ASR/conformer_mmi/asr_datamodule.py b/egs/librispeech/ASR/conformer_mmi/asr_datamodule.py index 8290e71d1..d3eab87a9 100644 --- a/egs/librispeech/ASR/conformer_mmi/asr_datamodule.py +++ b/egs/librispeech/ASR/conformer_mmi/asr_datamodule.py @@ -162,7 +162,9 @@ class LibriSpeechAsrDataModule(DataModule): cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz") logging.info("About to create train dataset") - transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] + transforms = [ + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ] if self.args.concatenate_cuts: logging.info( f"Using cut concatenation with duration factor " diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 6decbc189..900d109a8 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -21,7 +21,7 @@ import argparse import logging from pathlib import Path from shutil import copyfile -from typing import Optional +from typing import Dict, Optional import k2 import torch @@ -36,6 +36,11 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam +from icefall.ali import ( + convert_alignments_to_tensor, + load_alignments, + lookup_alignments, +) from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist @@ -93,6 +98,17 @@ def get_parser(): """, ) + parser.add_argument( + "--ali-dir", + type=str, + default="data/ali_500", + help="""This folder is expected to contain + two files, train-960.pt and valid.pt, which + contain framewise alignment information for + the training set and validation set. + """, + ) + return parser @@ -284,6 +300,7 @@ def compute_loss( batch: dict, graph_compiler: MmiTrainingGraphCompiler, is_training: bool, + ali: Optional[Dict[str, torch.Tensor]], ): """ Compute LF-MMI loss given the model and its inputs. @@ -304,6 +321,8 @@ def compute_loss( True for training. False for validation. When it is True, this function enables autograd during computation; when it is False, it disables autograd. + ali: + Precomputed alignments. """ device = graph_compiler.device feature = batch["inputs"] @@ -323,6 +342,30 @@ def compute_loss( supervisions, subsampling_factor=params.subsampling_factor ) + if ali is not None and params.batch_idx_train < 4000: + cut_ids = [cut.id for cut in supervisions["cut"]] + + # As encode_supervisions reorders cuts, we need + # also to reorder cut IDs here + new2old = supervision_segments[:, 0].tolist() + cut_ids = [cut_ids[i] for i in new2old] + + # Check that new2old is just a permutation, + # i.e., each cut contains only one utterance + new2old.sort() + assert new2old == torch.arange(len(new2old)).tolist() + mask = lookup_alignments( + cut_ids=cut_ids, + alignments=ali, + num_classes=nnet_output.shape[2], + ).to(nnet_output) + + min_len = min(nnet_output.shape[1], mask.shape[1]) + ali_scale = 500.0 / (params.batch_idx_train + 500) + + nnet_output = nnet_output.clone() + nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] + loss_fn = LFMMILoss( graph_compiler=graph_compiler, use_pruned_intersect=params.use_pruned_intersect, @@ -377,6 +420,7 @@ def compute_validation_loss( graph_compiler: MmiTrainingGraphCompiler, valid_dl: torch.utils.data.DataLoader, world_size: int = 1, + ali: Optional[Dict[str, torch.Tensor]] = None, ) -> None: """Run the validation process. The validation loss is saved in `params.valid_loss`. @@ -394,6 +438,7 @@ def compute_validation_loss( batch=batch, graph_compiler=graph_compiler, is_training=False, + ali=ali, ) assert loss.requires_grad is False assert mmi_loss.requires_grad is False @@ -435,6 +480,8 @@ def train_one_epoch( graph_compiler: MmiTrainingGraphCompiler, train_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader, + train_ali: Optional[Dict[str, torch.Tensor]], + valid_ali: Optional[Dict[str, torch.Tensor]], tb_writer: Optional[SummaryWriter] = None, world_size: int = 1, ) -> None: @@ -457,6 +504,10 @@ def train_one_epoch( Dataloader for the training dataset. valid_dl: Dataloader for the validation dataset. + train_ali: + Precomputed alignments for the training set. + valid_ali: + Precomputed alignments for the validation set. tb_writer: Writer to write log messages to tensorboard. world_size: @@ -481,6 +532,7 @@ def train_one_epoch( batch=batch, graph_compiler=graph_compiler, is_training=True, + ali=train_ali, ) # NOTE: We use reduction==sum and loss is computed over utterances @@ -565,6 +617,7 @@ def train_one_epoch( graph_compiler=graph_compiler, valid_dl=valid_dl, world_size=world_size, + ali=valid_ali, ) model.train() logging.info( @@ -673,12 +726,34 @@ def run(rank, world_size, args): if checkpoints: optimizer.load_state_dict(checkpoints["optimizer"]) + train_960_ali_filename = Path(params.ali_dir) / "train-960.pt" + if params.batch_idx_train < 4000 and train_960_ali_filename.is_file(): + logging.info("Use pre-computed alignments") + subsampling_factor, train_ali = load_alignments(train_960_ali_filename) + assert subsampling_factor == params.subsampling_factor + assert len(train_ali) == 843723, f"{len(train_ali)} vs 843723" + + valid_ali_filename = Path(params.ali_dir) / "valid.pt" + subsampling_factor, valid_ali = load_alignments(valid_ali_filename) + assert subsampling_factor == params.subsampling_factor + + train_ali = convert_alignments_to_tensor(train_ali, device=device) + valid_ali = convert_alignments_to_tensor(valid_ali, device=device) + else: + logging.info("Not using alignments") + train_ali = None + valid_ali = None + librispeech = LibriSpeechAsrDataModule(args) train_dl = librispeech.train_dataloaders() valid_dl = librispeech.valid_dataloaders() for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) + if params.batch_idx_train > 4000 and train_ali is not None: + # Delete the alignments to save memory + train_ali = None + valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: @@ -699,6 +774,8 @@ def run(rank, world_size, args): graph_compiler=graph_compiler, train_dl=train_dl, valid_dl=valid_dl, + train_ali=train_ali, + valid_ali=valid_ali, tb_writer=tb_writer, world_size=world_size, ) diff --git a/icefall/ali.py b/icefall/ali.py new file mode 100644 index 000000000..c3e4b2662 --- /dev/null +++ b/icefall/ali.py @@ -0,0 +1,142 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Tuple + +import torch +from torch.nn.utils.rnn import pad_sequence + + +def save_alignments( + alignments: Dict[str, List[int]], + subsampling_factor: int, + filename: str, +) -> None: + """Save alignments to a file. + + Args: + alignments: + A dict containing alignments. Keys of the dict are utterances and + values are the corresponding framewise alignments after subsampling. + subsampling_factor: + The subsampling factor of the model. + filename: + Path to save the alignments. + Returns: + Return None. + """ + ali_dict = { + "subsampling_factor": subsampling_factor, + "alignments": alignments, + } + torch.save(ali_dict, filename) + + +def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]: + """Load alignments from a file. + + Args: + filename: + Path to the file containing alignment information. + The file should be saved by :func:`save_alignments`. + Returns: + Return a tuple containing: + - subsampling_factor: The subsampling_factor used to compute + the alignments. + - alignments: A dict containing utterances and their corresponding + framewise alignment, after subsampling. + """ + ali_dict = torch.load(filename) + subsampling_factor = ali_dict["subsampling_factor"] + alignments = ali_dict["alignments"] + return subsampling_factor, alignments + + +def convert_alignments_to_tensor( + alignments: Dict[str, List[int]], device: torch.device +) -> Dict[str, torch.Tensor]: + """Convert alignments from list of int to a 1-D torch.Tensor. + + Args: + alignments: + A dict containing alignments. Keys are utterance IDs and + values are their corresponding frame-wise alignments. + device: + The device to move the alignments to. + Returns: + Return a dict using 1-D torch.Tensor to store the alignments. + The dtype of the tensor are `torch.int64`. We choose `torch.int64` + because `torch.nn.functional.one_hot` requires that. + """ + ans = {} + for utt_id, ali in alignments.items(): + ali = torch.tensor(ali, dtype=torch.int64, device=device) + ans[utt_id] = ali + return ans + + +def lookup_alignments( + cut_ids: List[str], + alignments: Dict[str, torch.Tensor], + num_classes: int, + log_score: float = -10, +) -> torch.Tensor: + """Return a mask constructed from alignments by a list of cut IDs. + + The returned mask is a 3-D tensor of shape (N, T, C). For each frame, + i.e., each row, of the returned mask, positions not corresponding to + the alignments are filled with `log_score`, while the position + specified by the alignment is filled with 0. For instance, if the alignments + of two utterances are: + + [ [1, 3, 2], [1, 0, 4, 2] ] + num_classes is 5 and log_score is -10, then the returned mask is + + [ + [[-10, 0, -10, -10, -10], + [-10, -10, -10, 0, -10], + [-10, -10, 0, -10, -10], + [0, -10, -10, -10, -10]], + [[-10, 0, -10, -10, -10], + [0, -10, -10, -10, -10], + [-10, -10, -10, -10, 0], + [-10, -10, 0, -10, -10]] + ] + Note: We pad the alignment of the first utterance with 0. + + Args: + cut_ids: + A list of utterance IDs. + alignments: + A dict containing alignments. The keys are utterance IDs and the values + are framewise alignments. + num_classes: + The max token ID + 1 that appears in the alignments. + log_score: + Positions in the returned tensor not corresponding to the alignments + are filled with this value. + Returns: + Return a 3-D torch.float32 tensor of shape (N, T, C). + """ + # We assume all utterances have their alignments. + ali = [alignments[cut_id] for cut_id in cut_ids] + padded_ali = pad_sequence(ali, batch_first=True, padding_value=0) + padded_one_hot = torch.nn.functional.one_hot( + padded_ali, + num_classes=num_classes, + ) + mask = (1 - padded_one_hot) * float(log_score) + return mask diff --git a/test/test_ali.py b/test/test_ali.py new file mode 100755 index 000000000..e8516e6dc --- /dev/null +++ b/test/test_ali.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Runt his file using one of the following two ways: +# (1) python3 ./test/test_ali.py +# (2) pytest ./test/test_ali.py + +# The purpose of this file is to show that if we build a mask +# from alignments and add it to a randomly generated nnet_output, +# we can decode the correct transcript. + +from pathlib import Path + +import k2 +import torch +from lhotse import load_manifest +from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader + +from icefall.ali import ( + convert_alignments_to_tensor, + load_alignments, + lookup_alignments, +) +from icefall.decode import get_lattice, one_best_decoding +from icefall.lexicon import Lexicon +from icefall.utils import get_texts + +ICEFALL_DIR = Path(__file__).resolve().parent.parent +egs_dir = ICEFALL_DIR / "egs/librispeech/ASR" +lang_dir = egs_dir / "data/lang_bpe_500" +# cut_json = egs_dir / "data/fbank/cuts_train-clean-100.json.gz" +cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz" +# cut_json = egs_dir / "data/fbank/cuts_train-other-500.json.gz" +ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt" + +# cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz" +# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt" + + +def data_exists(): + return ali_filename.exists() and cut_json.exists() and lang_dir.exists() + + +def get_dataloader(): + cuts_train = load_manifest(cut_json) + cuts_train = cuts_train.with_features_path_prefix(egs_dir) + train_sampler = SingleCutSampler( + cuts_train, + max_duration=200, + shuffle=False, + ) + + train = K2SpeechRecognitionDataset(return_cuts=True) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=1, + persistent_workers=False, + ) + return train_dl + + +def test_one_hot(): + a = [1, 3, 2] + b = [1, 0, 4, 2] + c = [torch.tensor(a), torch.tensor(b)] + d = pad_sequence(c, batch_first=True, padding_value=0) + f = torch.nn.functional.one_hot(d, num_classes=5) + e = (1 - f) * -10.0 + expected = torch.tensor( + [ + [ + [-10, 0, -10, -10, -10], + [-10, -10, -10, 0, -10], + [-10, -10, 0, -10, -10], + [0, -10, -10, -10, -10], + ], + [ + [-10, 0, -10, -10, -10], + [0, -10, -10, -10, -10], + [-10, -10, -10, -10, 0], + [-10, -10, 0, -10, -10], + ], + ] + ).to(e.dtype) + assert torch.all(torch.eq(e, expected)) + + +def test(): + """ + The purpose of this test is to show that we can use pre-computed + alignments to construct a mask, adding it to a randomly generated + nnet_output, to decode the correct transcript from the resulting + nnet_output. + """ + if not data_exists(): + return + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + dl = get_dataloader() + + subsampling_factor, ali = load_alignments(ali_filename) + ali = convert_alignments_to_tensor(ali, device=device) + + lexicon = Lexicon(lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + word_table = lexicon.word_table + + HLG = k2.Fsa.from_dict( + torch.load(f"{lang_dir}/HLG.pt", map_location=device) + ) + + for batch in dl: + features = batch["inputs"] + supervisions = batch["supervisions"] + N = features.shape[0] + T = features.shape[1] // subsampling_factor + nnet_output = ( + torch.rand(N, T, num_classes, dtype=torch.float32, device=device) + .softmax(dim=-1) + .log() + ) + cut_ids = [cut.id for cut in supervisions["cut"]] + mask = lookup_alignments( + cut_ids=cut_ids, alignments=ali, num_classes=num_classes + ) + min_len = min(nnet_output.shape[1], mask.shape[1]) + ali_model_scale = 0.8 + + nnet_output[:, :min_len, :] += ali_model_scale * mask[:, :min_len, :] + + supervisions = batch["supervisions"] + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // subsampling_factor, + supervisions["num_frames"] // subsampling_factor, + ), + 1, + ).to(torch.int32) + + lattice = get_lattice( + nnet_output=nnet_output, + HLG=HLG, + supervision_segments=supervision_segments, + search_beam=20, + output_beam=8, + min_active_states=30, + max_active_states=10000, + subsampling_factor=subsampling_factor, + ) + + best_path = one_best_decoding(lattice=lattice, use_double_scores=True) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + hyps = [" ".join(s) for s in hyps] + print(hyps) + print(supervisions["text"]) + break + + +def show_cut_ids(): + # The purpose of this function is to check that + # for each utterance in the training set, there is + # a corresponding alignment. + # + # After generating a1.txt and b1.txt + # You can use + # wc -l a1.txt b1.txt + # which should show the same number of lines. + # + # cat a1.txt | sort | uniq > a11.txt + # cat b1.txt | sort | uniq > b11.txt + # + # md5sum a11.txt b11.txt + # which should show the identical hash + # + # diff a11.txt b11.txt + # should print nothing + + subsampling_factor, ali = load_alignments(ali_filename) + with open("a1.txt", "w") as f: + for key in ali: + f.write(f"{key}\n") + + # dl = get_dataloader() + cuts_train = ( + load_manifest(egs_dir / "data/fbank/cuts_train-clean-100.json.gz") + + load_manifest(egs_dir / "data/fbank/cuts_train-clean-360.json.gz") + + load_manifest(egs_dir / "data/fbank/cuts_train-other-500.json.gz") + ) + + ans = [] + for cut in cuts_train: + ans.append(cut.id) + with open("b1.txt", "w") as f: + for line in ans: + f.write(f"{line}\n") + + +if __name__ == "__main__": + test()