From 4b96af27774de7ba954b6bede77bfb9b26b3fe45 Mon Sep 17 00:00:00 2001 From: qmpzzpmq <405691733@qq.com> Date: Mon, 15 Aug 2022 14:42:54 +0000 Subject: [PATCH 1/9] ctc label generate --- .../ASR/conformer_ctc/generate_CTC_label.py | 222 ++++++++++++++++++ 1 file changed, 222 insertions(+) create mode 100755 egs/aishell/ASR/conformer_ctc/generate_CTC_label.py diff --git a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py new file mode 100755 index 000000000..e560f846f --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, +# Fangjun Kuang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from collections import defaultdict +import os +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from pdb import set_trace + +import k2 +from lhotse.features.io import LilcomChunkyWriter +from lhotse.features.base import store_feature_array +import torch +import torch.nn as nn +from asr_datamodule import AishellAsrDataModule +from conformer import Conformer + +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, +) +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=49, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 4, + "feature_dim": 80, + "nhead": 4, + "attention_dim": 512, + "num_encoder_layers": 12, + "num_decoder_layers": 6, + "vgg_frontend": False, + "use_feat_batchnorm": True, + } + ) + return params + +def generate_ctc_label_batch( + params: AttributeDict, + model: nn.Module, + batch: dict, + device: torch.device, +): + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + return nnet_output + +def generate_ctc_label_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + device: torch.device, + output_path: str, +): + set_trace() + with LilcomChunkyWriter(output_path) as writer: + for batch_idx, batch in enumerate(dl): + nnet_output = generate_ctc_label_batch( + params=params, + model=model, + batch=batch, + device=device, + ) + store_feature_array( + nnet_output.cpu().detach().numpy(), + writer, + ) + +@torch.no_grad() +def main(): + parser = get_parser() + AishellAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-ctc-label/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_encoder_layers=params.num_encoder_layers, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + aishell = AishellAsrDataModule(args) + test_cuts = aishell.test_cuts() + test_dl = aishell.test_dataloaders(test_cuts) + + test_sets = ["test"] + test_dls = [test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + generate_ctc_label_dataset( + dl=test_dl, + params=params, + model=model, + device=device, + output_path=os.path.join(args.exp_dir, f"ctc-label-{test_set}.lca"), + ) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From de1078072d774cc6673c1c6ccbf86a7b6f86ffc8 Mon Sep 17 00:00:00 2001 From: qmpzzpmq <405691733@qq.com> Date: Mon, 15 Aug 2022 14:45:35 +0000 Subject: [PATCH 2/9] linting --- egs/aishell/ASR/conformer_ctc/generate_CTC_label.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py index e560f846f..49818f1fc 100755 --- a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py +++ b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py @@ -23,7 +23,6 @@ from collections import defaultdict import os from pathlib import Path from typing import Dict, List, Optional, Tuple -from pdb import set_trace import k2 from lhotse.features.io import LilcomChunkyWriter @@ -128,7 +127,6 @@ def generate_ctc_label_dataset( device: torch.device, output_path: str, ): - set_trace() with LilcomChunkyWriter(output_path) as writer: for batch_idx, batch in enumerate(dl): nnet_output = generate_ctc_label_batch( From 483acdb4f477070a6cd57ce63ef1a2d184108f93 Mon Sep 17 00:00:00 2001 From: qmpzzpmq <405691733@qq.com> Date: Mon, 15 Aug 2022 14:54:34 +0000 Subject: [PATCH 3/9] cleaning --- egs/aishell/ASR/conformer_ctc/generate_CTC_label.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py index 49818f1fc..4868be2d2 100755 --- a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py +++ b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py @@ -32,22 +32,11 @@ import torch.nn as nn from asr_datamodule import AishellAsrDataModule from conformer import Conformer -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import ( - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_attention_decoder, -) -from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - get_texts, setup_logger, - store_transcripts, write_error_stats, ) From bf5c07381f7bfc06678209dfcfd95e2b594387da Mon Sep 17 00:00:00 2001 From: qmpzzpmq <405691733@qq.com> Date: Mon, 15 Aug 2022 14:57:06 +0000 Subject: [PATCH 4/9] author --- egs/aishell/ASR/conformer_ctc/generate_CTC_label.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py index 4868be2d2..ef04d7f7b 100755 --- a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py +++ b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py @@ -1,9 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, -# Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors +# Author: Haoyu Tang # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 208ac69519b1fef7cdfa6a090940680c05e99374 Mon Sep 17 00:00:00 2001 From: qmpzzpmq <405691733@qq.com> Date: Mon, 15 Aug 2022 14:59:57 +0000 Subject: [PATCH 5/9] blank line --- egs/aishell/ASR/conformer_ctc/generate_CTC_label.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py index ef04d7f7b..83f74b580 100755 --- a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py +++ b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py @@ -90,6 +90,7 @@ def get_params() -> AttributeDict: ) return params + def generate_ctc_label_batch( params: AttributeDict, model: nn.Module, @@ -105,6 +106,7 @@ def generate_ctc_label_batch( nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) return nnet_output + def generate_ctc_label_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, @@ -125,6 +127,7 @@ def generate_ctc_label_dataset( writer, ) + @torch.no_grad() def main(): parser = get_parser() From 6da35058fab4e7ca76636d01152ffd6a01712dd3 Mon Sep 17 00:00:00 2001 From: qmpzzpmq <405691733@qq.com> Date: Mon, 15 Aug 2022 15:11:00 +0000 Subject: [PATCH 6/9] data config update --- .../ASR/conformer_ctc/generate_CTC_label.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py index 83f74b580..a41921caf 100755 --- a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py +++ b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py @@ -133,6 +133,8 @@ def main(): parser = get_parser() AishellAsrDataModule.add_arguments(parser) args = parser.parse_args() + args.enable_spec_aug = False + args.enable_musan = False args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) @@ -183,19 +185,19 @@ def main(): logging.info(f"Number of model parameters: {num_param}") aishell = AishellAsrDataModule(args) - test_cuts = aishell.test_cuts() - test_dl = aishell.test_dataloaders(test_cuts) + train_cuts = aishell.train_cuts() + train_dl = aishell.train_dataloaders(train_cuts) - test_sets = ["test"] - test_dls = [test_dl] + train_sets = ["train"] + train_dls = [train_dl] - for test_set, test_dl in zip(test_sets, test_dls): + for train_set, train_dl in zip(train_sets, train_dls): generate_ctc_label_dataset( - dl=test_dl, + dl=train_dl, params=params, model=model, device=device, - output_path=os.path.join(args.exp_dir, f"ctc-label-{test_set}.lca"), + output_path=os.path.join(args.exp_dir, f"ctc-label-{train_set}.lca"), ) logging.info("Done!") From 74ee6e613013c4dd6a096bbee97048329b12cb12 Mon Sep 17 00:00:00 2001 From: qmpzzpmq <405691733@qq.com> Date: Mon, 15 Aug 2022 15:12:45 +0000 Subject: [PATCH 7/9] linting --- egs/aishell/ASR/conformer_ctc/generate_CTC_label.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py index a41921caf..7dbb1bb75 100755 --- a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py +++ b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py @@ -105,7 +105,7 @@ def generate_ctc_label_batch( supervisions = batch["supervisions"] nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) return nnet_output - + def generate_ctc_label_dataset( dl: torch.utils.data.DataLoader, From 32cae7bbcf9138c703907345c8fd685d8768cfe9 Mon Sep 17 00:00:00 2001 From: qmpzzpmq <405691733@qq.com> Date: Mon, 15 Aug 2022 15:14:13 +0000 Subject: [PATCH 8/9] linting --- egs/aishell/ASR/conformer_ctc/generate_CTC_label.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py index 7dbb1bb75..e008a658a 100755 --- a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py +++ b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py @@ -15,12 +15,9 @@ import argparse import logging -from collections import defaultdict import os from pathlib import Path -from typing import Dict, List, Optional, Tuple -import k2 from lhotse.features.io import LilcomChunkyWriter from lhotse.features.base import store_feature_array import torch From f4bf9e4505d047decbc1aec2a46c78c9b4aaf608 Mon Sep 17 00:00:00 2001 From: qmpzzpmq <405691733@qq.com> Date: Mon, 15 Aug 2022 15:18:13 +0000 Subject: [PATCH 9/9] wording --- egs/aishell/ASR/conformer_ctc/generate_CTC_label.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py index e008a658a..26a333e67 100755 --- a/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py +++ b/egs/aishell/ASR/conformer_ctc/generate_CTC_label.py @@ -139,7 +139,7 @@ def main(): params.update(vars(args)) setup_logger(f"{params.exp_dir}/log-ctc-label/log-decode") - logging.info("Decoding started") + logging.info("CTC label generation started") logging.info(params) lexicon = Lexicon(params.lang_dir)