mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
130 lines
3.5 KiB
Python
Executable File
130 lines
3.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
|
|
|
|
# (still working in progress)
|
|
|
|
import argparse
|
|
import logging
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from conformer import Conformer
|
|
|
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
|
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
|
from icefall.utils import (
|
|
AttributeDict,
|
|
get_texts,
|
|
setup_logger,
|
|
store_transcripts,
|
|
write_error_stats,
|
|
)
|
|
|
|
|
|
def get_params() -> AttributeDict:
|
|
params = AttributeDict(
|
|
{
|
|
"exp_dir": Path("conformer_ctc/exp"),
|
|
"lang_dir": Path("data/lang/bpe"),
|
|
"lm_dir": Path("data/lm"),
|
|
"feature_dim": 80,
|
|
"nhead": 8,
|
|
"attention_dim": 512,
|
|
"num_classes": 5000,
|
|
"subsampling_factor": 4,
|
|
"num_decoder_layers": 6,
|
|
"vgg_frontend": False,
|
|
"is_espnet_structure": True,
|
|
"mmi_loss": False,
|
|
"use_feat_batchnorm": True,
|
|
"search_beam": 20,
|
|
"output_beam": 5,
|
|
"min_active_states": 30,
|
|
"max_active_states": 10000,
|
|
"use_double_scores": True,
|
|
# Possible values for method:
|
|
# - 1best
|
|
# - nbest
|
|
# - nbest-rescoring
|
|
# - whole-lattice-rescoring
|
|
"method": "whole-lattice-rescoring",
|
|
# num_paths is used when method is "nbest" and "nbest-rescoring"
|
|
"num_paths": 30,
|
|
}
|
|
)
|
|
return params
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--epoch",
|
|
type=int,
|
|
default=9,
|
|
help="It specifies the checkpoint to use for decoding."
|
|
"Note: Epoch counts from 0.",
|
|
)
|
|
parser.add_argument(
|
|
"--avg",
|
|
type=int,
|
|
default=1,
|
|
help="Number of checkpoints to average. Automatically select "
|
|
"consecutive checkpoints before the checkpoint specified by "
|
|
"'--epoch'. ",
|
|
)
|
|
return parser
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
parser = get_parser()
|
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
|
args = parser.parse_args()
|
|
|
|
params = get_params()
|
|
params.update(vars(args))
|
|
|
|
setup_logger(f"{params.exp_dir}/log/log-decode")
|
|
logging.info("Decoding started")
|
|
logging.info(params)
|
|
|
|
device = torch.device("cpu")
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda", 0)
|
|
|
|
model = Conformer(
|
|
num_features=params.feature_dim,
|
|
nhead=params.nhead,
|
|
d_model=params.attention_dim,
|
|
num_classes=params.num_classes,
|
|
subsampling_factor=params.subsampling_factor,
|
|
num_decoder_layers=params.num_decoder_layers,
|
|
vgg_frontend=params.vgg_frontend,
|
|
is_espnet_structure=params.is_espnet_structure,
|
|
mmi_loss=params.mmi_loss,
|
|
use_feat_batchnorm=params.use_feat_batchnorm,
|
|
)
|
|
|
|
if params.avg == 1:
|
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
|
else:
|
|
start = params.epoch - params.avg + 1
|
|
filenames = []
|
|
for i in range(start, params.epoch + 1):
|
|
if start >= 0:
|
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
|
logging.info(f"averaging {filenames}")
|
|
model.load_state_dict(average_checkpoints(filenames))
|
|
|
|
model.to(device)
|
|
model.eval()
|
|
token_ids_with_blank = list(range(params.num_classes))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|