diff --git a/egs/librispeech/ASR/conformer_ctc/code_indices.py b/egs/librispeech/ASR/conformer_ctc/code_indices.py new file mode 100755 index 000000000..f664ee6e5 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/code_indices.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (author: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple +from quantization import Quantizer + +import torch +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from lhotse.features.io import NumpyHdf5Writer +from lhotse import CutSet + +from icefall.checkpoint import load_checkpoint +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--data-dir", + type=Path, + default="./data/", + help="The experiment dir", + ) + + parser.add_argument( + "--mem-dir", + type=Path, + default="conformer_ctc/exp/mem", + help="The experiment dir", + ) + + parser.add_argument( + "--quantizer-id", + type=str, + default=None, + help="quantizer_id", + ) + + parser.add_argument( + "--bytes-per-frame", + type=int, + default=4, + help="The number of bytes to use to quantize each memory embeddings", + ) + + parser.add_argument( + "--memory-embedding-dim", + type=int, + default=512, + help="dim of memory embeddings to train quantizer", + ) + + parser.add_argument( + "--pretrained-model", + type=Path, + default=None, + help="use a pretrained model, e.g. a modle downloaded from model zoo", + ) + + parser.add_argument( + "--model-id", + type=str, + default=None, + help="a short str to introduce which models the embeddings come from" + "e.g. icefall or wav2vec2", + ) + + parser.add_argument( + "--mem-layer", + type=int, + default=None, + help="which layer to extract memory embedding" + "Set this manully to avoid mistake.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "subsampling_factor": 4, + "num_decoder_layers": 6, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "output_beam": 10, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def compute_codeindices( + model: torch.nn.Module, + dl: torch.utils.data.DataLoader, + quantizer: None, + params: AttributeDict, + writer: None, +) -> List[Tuple[str, List[int]]]: + """Compute the framewise alignments of a dataset. + + Args: + model: + The neural network model. + dl: + Dataloader containing the dataset. + params: + Parameters for computing memory. + Returns: + Return a list of tuples. Each tuple contains two entries: + - Utterance ID + - memory embeddings + """ + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + num_cuts = 0 + + device = params.device + cuts = [] + total_frames = 0 + for batch_idx, batch in enumerate(dl): + feature = batch["inputs"] + + # at entry, feature is [N, T, C] + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + + _, encoder_memory, memory_mask = model(feature, supervisions) + codebook_indices = quantizer.encode(encoder_memory, as_bytes=True) + + # [T, N, C] --> [N, T, C] + codebook_indices = codebook_indices.transpose(0, 1).to("cpu").numpy() + + # for idx, cut in enumerate(cut_ids): + cut_list = supervisions["cut"] + assert len(cut_list) == codebook_indices.shape[0] + num_cuts += len(cut_list) + assert all(supervisions["start_frame"] == 0) + for idx, cut in enumerate(cut_list): + num_frames = ( + ((supervisions["num_frames"][idx] - 3) // 2 + 1) - 3 + ) // 2 + 1 + cut.codebook_indices = writer.store_array( + key=cut.id, + value=codebook_indices[idx][:num_frames], + frame_shift=0.04, + temporal_dim=0, + start=0, + ) + total_frames += num_frames + + cuts += cut_list + logging.info( + f"processed {total_frames} frames and {num_cuts} cuts; {batch_idx} of {num_batches}" # noqa: E501 + ) + return CutSet.from_cuts(cuts) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert args.return_cuts is True + assert args.concatenate_cuts is False + assert args.quantizer_id is not None + assert args.model_id is not None + assert args.mem_layer is not None + assert args.pretrained_model is not None + assert args.subset in ["clean-100", "clean-360", "other-500"] + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log/mem") + + logging.info("Computing memory embedings- started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + quantizer_fn = ( + params.mem_dir + / f"{params.mem_layer}layer-{params.quantizer_id}-bytes_per_frame_{params.bytes_per_frame}-quantizer.pt" # noqa: E501 + ) + + quantizer = Quantizer( + dim=params.memory_embedding_dim, + num_codebooks=args.bytes_per_frame, + codebook_size=256, + ) + quantizer.load_state_dict(torch.load(quantizer_fn)) + quantizer = quantizer.to("cuda") + + load_checkpoint(f"{params.pretrained_model}", model) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + params["device"] = device + + model.to(device) + model.eval() + + librispeech = LibriSpeechAsrDataModule(args) + + train_dl = librispeech.train_dataloaders() + + cdidx_dir = ( + Path(params.data_dir) + / f"{args.model_id}-{args.mem_layer}layer-{args.quantizer_id}-bytes_per_frame-{args.bytes_per_frame}" # noqa: E501 + ) + cdidx_dir.mkdir(exist_ok=True) + + with NumpyHdf5Writer( + cdidx_dir + / f"{args.model_id}-{args.mem_layer}layer-cdidx_train-{args.subset}" + ) as writer: + cut_set = compute_codeindices( + model=model, + dl=train_dl, + quantizer=quantizer, + params=params, + writer=writer, + ) + cut_set.to_json(cdidx_dir / f"cuts_train-{args.subset}.json.gz") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 73c60b2d0..db17f4f0d 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -23,6 +23,9 @@ from typing import Optional, Tuple import torch from torch import Tensor, nn from transformer import Supervisions, Transformer, encoder_padding_mask +from prediction import JointCodebookPredictor +from ckpnt_prediction import JointCodebookLoss +from powerful_prediction import Powerful_JointCodebookLoss class CodeIndicesNet(nn.Module): @@ -51,18 +54,9 @@ class CodeIndicesNet(nn.Module): self.num_codebooks = num_codebooks self.quantizer_dim = quantizer_dim - def forward(self, memory): - """ - Args: - memory: - memory embeddings, with shape[T, N, C] - output: - shape [N, T, num_codebooks*quantizer_dim] - """ - x = self.linear1(memory) - return x - - def loss(self, memory: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + def forward( + self, memory: torch.Tensor, target: torch.Tensor + ) -> torch.Tensor: """ Args: memory: @@ -75,12 +69,14 @@ class CodeIndicesNet(nn.Module): actually it's the sum of num_codebooks CE losses """ - memory = memory.transpose(0, 1) # T, N, C --> N, T, C - x = self.forward(memory) + x = self.linear1(memory) x = x.reshape(-1, self.quantizer_dim) target = target.reshape(-1) + assert ( + x.shape[0] == target.shape[0] + ), f"x.shape: {x.shape} while target.shape: {target.shape}" ret = self.ce(x, target) - return ret + return -ret, None class Conformer(Transformer): @@ -115,6 +111,9 @@ class Conformer(Transformer): normalize_before: bool = True, vgg_frontend: bool = False, use_feat_batchnorm: bool = False, + use_codebook_loss: bool = False, + num_codebooks: int = 4, + predictor: str = "predictor", # "simple_linear", "predictor", "ckpnt_predictor, powerful" ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -150,7 +149,27 @@ class Conformer(Transformer): # and throws an error without this change. self.after_norm = identity - self.cdidxnet = CodeIndicesNet() + if use_codebook_loss: + assert predictor in [ + "powerful", + "predictor", + "ckpnt_predictor", + "simple_linear", + ] + if predictor == "predictor": + self.cdidxnet = JointCodebookPredictor( + predictor_dim=512, num_codebooks=num_codebooks + ) + elif predictor == "ckpnt_predictor": + self.cdidxnet = JointCodebookLoss( + predictor_channels=512, num_codebooks=num_codebooks + ) + elif predictor == "simple_linear": + self.cdidxnet = CodeIndicesNet(num_codebooks=num_codebooks) + elif predictor == "powerful": + self.cdidxnet = Powerful_JointCodebookLoss( + predictor_channels=512, num_codebooks=num_codebooks + ) def run_encoder( self, x: Tensor, supervisions: Optional[Supervisions] = None diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index ed2da7b76..d46c781da 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -499,10 +499,10 @@ def save_results( enable_log = True test_set_wers = dict() for key, results in results_dict.items(): + result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-" recog_path = ( params.exp_dir - / f"epoch-{params.epoch}-avg-{params.avg}- \ - recogs-{test_set_name}-{key}.txt" + / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt" ) store_transcripts(filename=recog_path, texts=results) if enable_log: @@ -512,8 +512,7 @@ def save_results( # ref/hyp pairs. errs_filename = ( params.exp_dir - / f"epoch-{params.epoch}-avg-{params.avg}- \ - errs-{test_set_name}-{key}.txt" + / f"{result_file_prefix}errs-{test_set_name}-{key}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -528,9 +527,7 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.exp_dir - / f"epoch-{params.epoch}-avg-{params.avg}- \ - wer-summary-{test_set_name}.txt" + params.exp_dir / f"{result_file_prefix}wer-summary-{test_set_name}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/conformer_ctc/memory_embedding.py b/egs/librispeech/ASR/conformer_ctc/memory_embedding.py new file mode 100755 index 000000000..51b63569e --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/memory_embedding.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (author: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import torch +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from lhotse.features.io import NumpyHdf5Writer +from lhotse import CutSet + +from icefall.checkpoint import load_checkpoint +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--mem-dir", + type=str, + default="conformer_ctc/exp/mem", + help="The experiment dir", + ) + + parser.add_argument( + "--num-utts", + type=int, + default=1000, + help="number of utts to extract memory embeddings", + ) + + parser.add_argument( + "--mem-layer", + type=int, + default=None, + help="which layer to extract memory embedding", + ) + parser.add_argument( + "--pretrained-model", + type=Path, + default=None, + help="use a pretrained model, e.g. a modle downloaded from model zoo", + ) + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "subsampling_factor": 4, + "num_decoder_layers": 6, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "output_beam": 10, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def compute_memory( + model: torch.nn.Module, + dl: torch.utils.data.DataLoader, + params: AttributeDict, + writer: None, +) -> List[Tuple[str, List[int]]]: + """Compute the framewise alignments of a dataset. + + Args: + model: + The neural network model. + dl: + Dataloader containing the dataset. + params: + Parameters for computing memory. + Returns: + Return a list of tuples. Each tuple contains two entries: + - Utterance ID + - memory embeddings + """ + num_cuts = 0 + + device = params.device + cuts = [] + total_frames = 0 + for batch_idx, batch in enumerate(dl): + feature = batch["inputs"] + + # at entry, feature is [N, T, C] + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + + _, encoder_memory, memory_mask = model(feature, supervisions) + + # [T, N, C] --> [N, T, C] + encoder_memory = encoder_memory.transpose(0, 1).to("cpu").numpy() + + cut_list = supervisions["cut"] + assert len(cut_list) == encoder_memory.shape[0] + assert all(supervisions["start_frame"] == 0) + for idx, cut in enumerate(cut_list): + num_frames = supervisions["num_frames"][idx] + cut.encoder_memory = writer.store_array( + key=cut.id, + value=encoder_memory[idx][:num_frames], + ) + total_frames += num_frames + + cuts += cut_list + num_cuts += len(cut_list) + logging.info(f"processed {total_frames} frames and {num_cuts} cuts.") + if len(cuts) > params.num_utts: + break + return CutSet.from_cuts(cuts) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + assert args.return_cuts is True + assert args.concatenate_cuts is False + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log/mem") + + logging.info("Computing memory embedings- started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + assert params.pretrained_model is not None + load_checkpoint(f"{params.pretrained_model}", model) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + params["device"] = device + + model.to(device) + model.eval() + + librispeech = LibriSpeechAsrDataModule(args) + + test_dl = librispeech.test_dataloaders() # a list + + mem_dir = Path(params.mem_dir) + mem_dir.mkdir(exist_ok=True) + + enabled_datasets = { + "test_clean": test_dl[0], + } + + mem_storage = mem_dir / f"{args.mem_layer}layer-memory_embeddings" + mem_manifest = mem_dir / f"{args.mem_layer}layer-memory_manifest.json" + with NumpyHdf5Writer(mem_storage) as writer: + for name, dl in enabled_datasets.items(): + cut_set = compute_memory( + model=model, + dl=dl, + params=params, + writer=writer, + ) + cut_set.to_json(mem_manifest) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_ctc/quantizer_train.py b/egs/librispeech/ASR/conformer_ctc/quantizer_train.py new file mode 100755 index 000000000..df4b8b0bf --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/quantizer_train.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (author: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +from pathlib import Path + +from lhotse import load_manifest +from lhotse.dataset import ( + BucketingSampler, + K2SpeechRecognitionDataset, +) +from torch.utils.data import DataLoader +from icefall.utils import setup_logger +import torch +import quantization + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--bytes-per-frame", + type=int, + default=4, + help="The number of bytes to use to quantize each memory embeddings" + "Usually, it's equal to number codebooks", + ) + + parser.add_argument( + "--memory-embedding-dim", + type=int, + default=1024, + help="dim of memory embeddings to train quantizer", + ) + + parser.add_argument( + "--mem-dir", + type=Path, + default="conformer_ctc/exp/mem", + help="The experiment dir", + ) + + parser.add_argument( + "--output-layer-index", + type=int, + default=None, + help="which layer to extract memory embedding" + "Specify this manully every time incase of mistakes", + ) + + return parser + + +def initialize_memory_dataloader( + mem_dir: Path = None, output_layer_index: int = None +): + assert mem_dir is not None + assert output_layer_index is not None + mem_manifest_file = ( + mem_dir / f"{output_layer_index}layer-memory_manifest.json" + ) + assert os.path.isfile( + mem_manifest_file + ), f"{mem_manifest_file} does not exist." + cuts = load_manifest(mem_manifest_file) + dataset = K2SpeechRecognitionDataset(return_cuts=True) + max_duration = 1 + sampler = BucketingSampler( + cuts, + max_duration=max_duration, + shuffle=False, + ) + dl = DataLoader(dataset, batch_size=None, sampler=sampler, num_workers=4) + return dl + + +def main(): + parser = get_parser() + args = parser.parse_args() + assert args.output_layer_index is not None + setup_logger(f"{args.mem_dir}/log/quantizer_train") + trainer = quantization.QuantizerTrainer( + dim=args.memory_embedding_dim, + bytes_per_frame=args.bytes_per_frame, + device=torch.device("cuda"), + ) + dl = initialize_memory_dataloader(args.mem_dir, args.output_layer_index) + num_cuts = 0 + done_flag = False + epoch = 0 + while not trainer.done(): + for batch in dl: + cuts = batch["supervisions"]["cut"] + embeddings = torch.cat( + [ + torch.from_numpy(c.load_custom("encoder_memory")) + for c in cuts + ] + ) + embeddings = embeddings.to("cuda") + num_cuts += len(cuts) + trainer.step(embeddings) + if trainer.done(): + done_flag = True + break + if done_flag: + break + else: + epoch += 1 + dl = initialize_memory_dataloader( + args.mem_dir, args.output_layer_index + ) + quantizer = trainer.get_quantizer() + quantizer_fn = ( + f"{args.output_layer_index}layer-" + + quantizer.get_id() + + f"-bytes_per_frame_{args.bytes_per_frame}-quantizer.pt" + ) + quantizer_fn = args.mem_dir / quantizer_fn + torch.save(quantizer.state_dict(), quantizer_fn) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 8a9bcfa8b..2e7b50762 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -30,6 +30,7 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer +from lhotse.cut import MonoCut from lhotse.utils import fix_random_seed from lhotse.dataset.collation import collate_custom_field from torch import Tensor @@ -65,6 +66,13 @@ def get_parser(): help="Number of GPUs for DDP training.", ) + parser.add_argument( + "--bytes-per-frame", + type=int, + default=4, + help="number of code books", + ) + parser.add_argument( "--master-port", type=int, @@ -79,6 +87,13 @@ def get_parser(): help="Should various information be logged in tensorboard.", ) + parser.add_argument( + "--predictor", + type=str, + default=None, + help="simple_linear predictor ckpnt_predictor", + ) + parser.add_argument( "--num-epochs", type=int, @@ -103,6 +118,7 @@ def get_parser(): help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved + Note: no tailing "/". """, ) @@ -128,7 +144,7 @@ def get_parser(): parser.add_argument( "--codebook-weight", type=float, - default=0.1, + default=0.3, help="""The weight of code book loss. Note: Currently rate of ctc_loss + rate of att_loss = 1.0 codebook_weight is independent with previous two. @@ -142,6 +158,14 @@ def get_parser(): help="The lr_factor for Noam optimizer", ) + parser.add_argument( + "--model-id", + type=str, + default=None, + help="a short str to introduce which models the embeddings come from" + "e.g. icefall or wav2vec2", + ) + return parser @@ -406,27 +430,42 @@ def compute_loss( ) loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss - if params.codebook_weight != 0.0: + if params.codebook_weight > 0.0 and is_training: cuts = batch["supervisions"]["cut"] # -100 is identical to ignore_value in CE loss computation. + cuts_pre_mixed = [ + c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts + ] codebook_indices, codebook_indices_lens = collate_custom_field( - cuts, "codebook_indices", pad_value=-100 + cuts_pre_mixed, "codebook_indices", pad_value=-100 ) + # import pdb; pdb.set_trace() assert ( codebook_indices.shape[0] == encoder_memory.shape[1] ) # N: batch_size - assert ( - codebook_indices.shape[1] == encoder_memory.shape[0] - ) # T: num frames + + if "wav2vec" == params.model_id: + # frame rate of wav2vec codebooks_indices is 50 + # while for conformer is 25 + t_expected = encoder_memory.shape[0] * 2 + assert codebook_indices.shape[1] >= t_expected + codebook_indices = codebook_indices[:, 0:t_expected:2, :] + encoder_memory = encoder_memory.transpose(0, 1) # T, N, C --> N, T, C codebook_indices = codebook_indices.to(encoder_memory.device).long() - codebook_loss = mmodel.cdidxnet.loss( - encoder_memory, target=codebook_indices - ) + if ( + params.predictor == "ckpnt_predictor" + or params.predictor == "powerful" + ): + codebook_loss = mmodel.cdidxnet(encoder_memory, codebook_indices) + else: + total_logprob, _ = mmodel.cdidxnet(encoder_memory, codebook_indices) + codebook_loss = -total_logprob loss += params.codebook_weight * codebook_loss - else: + + if params.codebook_weight == 0.0 and params.att_rate == 0.0: loss = ctc_loss att_loss = torch.tensor([0]) @@ -438,7 +477,7 @@ def compute_loss( if params.att_rate != 0.0: info["att_loss"] = att_loss.detach().cpu().item() - if params.codebook_weight != 0.0: + if params.codebook_weight > 0.0 and is_training: info["codebook_loss"] = codebook_loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item() @@ -633,6 +672,9 @@ def run(rank, world_size, args): num_decoder_layers=params.num_decoder_layers, vgg_frontend=False, use_feat_batchnorm=params.use_feat_batchnorm, + use_codebook_loss=True if params.codebook_weight > 0.0 else False, + num_codebooks=params.bytes_per_frame, + predictor=params.predictor, ) checkpoints = load_checkpoint_if_available(params=params, model=model) @@ -747,7 +789,12 @@ def main(): parser = get_parser() LibriSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) + if 0.0 != args.codebook_weight: + assert -1 == args.time_warp_factor + assert not args.exp_dir.endswith("/") + args.exp_dir = Path( + f"{args.exp_dir}-time_warp_factor{args.time_warp_factor}-bytes_per_frame{args.bytes_per_frame}-cdweight{args.codebook_weight}-predictor{args.predictor}-maxduration{args.max_duration}" # noqa: E501 + ) args.lang_dir = Path(args.lang_dir) world_size = args.world_size diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 950eba438..aabb2804c 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -31,7 +31,7 @@ from lhotse.dataset import ( SingleCutSampler, SpecAugment, ) -from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures from torch.utils.data import DataLoader from icefall.dataset.datamodule import DataModule @@ -73,6 +73,21 @@ class LibriSpeechAsrDataModule(DataModule): help="When enabled, use 960h LibriSpeech. " "Otherwise, use 100h subset.", ) + parser.add_argument( + "--subset", + type=Path, + default=None, + help="which subset to extract codebook index" + "clean-100, clean-360, other-500", + ) + + group.add_argument( + "--enable-augmentation", + type=str2bool, + default=True, + help="Set to False to disable all augmentaion." + "Used when extracting codebook_indexes.", + ) group.add_argument( "--feature-dir", type=Path, @@ -100,6 +115,13 @@ class LibriSpeechAsrDataModule(DataModule): help="The number of buckets for the BucketingSampler" "(you might want to increase it for larger datasets).", ) + group.add_argument( + "--time-warp-factor", + type=int, + default=80, + help="Set None or less than 1 to disable" + "details in lhotse.lhotse.dataset.signal_transform", + ) group.add_argument( "--concatenate-cuts", type=str2bool, @@ -154,7 +176,16 @@ class LibriSpeechAsrDataModule(DataModule): "collect the batches.", ) + group.add_argument( + "--input-strategy", + type=str, + default=PrecomputedFeatures, + help="The number of training dataloader workers that " + "collect the batches.", + ) + def train_dataloaders(self) -> DataLoader: + logging.info(f"enable-augmentation: {self.args.enable_augmentation}") logging.info("About to get train cuts") cuts_train = self.train_cuts() @@ -181,6 +212,7 @@ class LibriSpeechAsrDataModule(DataModule): input_transforms = [ SpecAugment( + time_warp_factor=self.args.time_warp_factor, num_frame_masks=2, features_mask_size=27, num_feature_masks=2, @@ -189,12 +221,21 @@ class LibriSpeechAsrDataModule(DataModule): ] train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, + input_strategy=AudioSamples() + if self.args.input_strategy == "AudioSamples" + else PrecomputedFeatures(), + cut_transforms=transforms + if self.args.enable_augmentation + else None, + input_transforms=input_transforms + if self.args.enable_augmentation + else None, return_cuts=self.args.return_cuts, ) if self.args.on_the_fly_feats: + assert self.args.enable_aug_mentation + # self.args.enable_aug_mentation==False is only tested with precomputed features. # noqa # NOTE: the PerturbSpeed transform should be added only if we # remove it from data prep stage. # Add on-the-fly speed perturbation; since originally it would @@ -222,7 +263,7 @@ class LibriSpeechAsrDataModule(DataModule): shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, bucket_method="equal_duration", - drop_last=True, + drop_last=True if self.args.enable_augmentation else False, ) else: logging.info("Using SingleCutSampler.") @@ -294,14 +335,20 @@ class LibriSpeechAsrDataModule(DataModule): for cuts_test in cuts: logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) + if self.args.input_strategy == "AudioSamples": + test = K2SpeechRecognitionDataset( + input_strategy=AudioSamples(), + return_cuts=self.args.return_cuts, + ) + else: + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, ) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) sampler = BucketingSampler( cuts_test, max_duration=self.args.max_duration, shuffle=False ) @@ -322,19 +369,26 @@ class LibriSpeechAsrDataModule(DataModule): @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get train cuts") - cuts_train = load_manifest( - self.args.feature_dir / "cuts_train-clean-100.json.gz" - ) if self.args.full_libri: + assert self.args.subset is None + cuts_train = load_manifest( + self.args.feature_dir / "cuts_train-clean-100.json" + ) cuts_train = ( cuts_train + load_manifest( - self.args.feature_dir / "cuts_train-clean-360.json.gz" + self.args.feature_dir / "cuts_train-clean-360.json" ) + load_manifest( - self.args.feature_dir / "cuts_train-other-500.json.gz" + self.args.feature_dir / "cuts_train-other-500.json" ) ) + if self.args.subset is not None: + assert not self.args.full_libri + assert self.args.subset in ["clean-100", "clean-360", "other-500"] + cuts_train = load_manifest( + self.args.feature_dir / f"cuts_train-{self.args.subset}.json.gz" + ) return cuts_train @lru_cache()