diff --git a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py new file mode 100755 index 000000000..26a333e67 --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py @@ -0,0 +1,207 @@ +#!/usr/bin/env python3 +# Author: Haoyu Tang +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +from pathlib import Path + +from lhotse.features.io import LilcomChunkyWriter +from lhotse.features.base import store_feature_array +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from conformer import Conformer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=49, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 4, + "feature_dim": 80, + "nhead": 4, + "attention_dim": 512, + "num_encoder_layers": 12, + "num_decoder_layers": 6, + "vgg_frontend": False, + "use_feat_batchnorm": True, + } + ) + return params + + +def generate_ctc_label_batch( + params: AttributeDict, + model: nn.Module, + batch: dict, + device: torch.device, +): + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + return nnet_output + + +def generate_ctc_label_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + device: torch.device, + output_path: str, +): + with LilcomChunkyWriter(output_path) as writer: + for batch_idx, batch in enumerate(dl): + nnet_output = generate_ctc_label_batch( + params=params, + model=model, + batch=batch, + device=device, + ) + store_feature_array( + nnet_output.cpu().detach().numpy(), + writer, + ) + + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.enable_spec_aug = False + args.enable_musan = False + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-ctc-label/log-decode") + logging.info("CTC label generation started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_encoder_layers=params.num_encoder_layers, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + aishell = AishellAsrDataModule(args) + train_cuts = aishell.train_cuts() + train_dl = aishell.train_dataloaders(train_cuts) + + train_sets = ["train"] + train_dls = [train_dl] + + for train_set, train_dl in zip(train_sets, train_dls): + generate_ctc_label_dataset( + dl=train_dl, + params=params, + model=model, + device=device, + output_path=os.path.join(args.exp_dir, f"ctc-label-{train_set}.lca"), + ) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main()