Merge f4bf9e4505d047decbc1aec2a46c78c9b4aaf608 into abd9437e6d5419a497707748eb935e50976c3b7b

This commit is contained in:
Charlie_Tang 2025-06-27 11:32:06 +00:00 committed by GitHub
commit 556eebaeae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -0,0 +1,207 @@
#!/usr/bin/env python3
# Author: Haoyu Tang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from pathlib import Path
from lhotse.features.io import LilcomChunkyWriter
from lhotse.features.base import store_feature_array
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from conformer import Conformer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=49,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="The lang dir",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"subsampling_factor": 4,
"feature_dim": 80,
"nhead": 4,
"attention_dim": 512,
"num_encoder_layers": 12,
"num_decoder_layers": 6,
"vgg_frontend": False,
"use_feat_batchnorm": True,
}
)
return params
def generate_ctc_label_batch(
params: AttributeDict,
model: nn.Module,
batch: dict,
device: torch.device,
):
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
return nnet_output
def generate_ctc_label_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
device: torch.device,
output_path: str,
):
with LilcomChunkyWriter(output_path) as writer:
for batch_idx, batch in enumerate(dl):
nnet_output = generate_ctc_label_batch(
params=params,
model=model,
batch=batch,
device=device,
)
store_feature_array(
nnet_output.cpu().detach().numpy(),
writer,
)
@torch.no_grad()
def main():
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.enable_spec_aug = False
args.enable_musan = False
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log-ctc-label/log-decode")
logging.info("CTC label generation started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_encoder_layers=params.num_encoder_layers,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
aishell = AishellAsrDataModule(args)
train_cuts = aishell.train_cuts()
train_dl = aishell.train_dataloaders(train_cuts)
train_sets = ["train"]
train_dls = [train_dl]
for train_set, train_dl in zip(train_sets, train_dls):
generate_ctc_label_dataset(
dl=train_dl,
params=params,
model=model,
device=device,
output_path=os.path.join(args.exp_dir, f"ctc-label-{train_set}.lca"),
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()