2021-08-18 19:31:06 +08:00

158 lines
4.1 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import logging
import k2
import torch
import torchaudio
from conformer import Conformer
from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint."
"The checkpoint is assume to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--words-file",
type=str,
required=True,
help="Path to words.txt",
)
parser.add_argument(
"--hlg", type=str, required=True, help="Path to HLG.pt."
)
parser.add_argument(
"--sound-file",
type=str,
required=True,
help="The input sound file to transcribe. "
"Supported formats are those that supported by torchaudio.load(). "
"For example, wav, flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"feature_dim": 80,
"nhead": 8,
"num_classes": 5000,
"sample_freq": 16000,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
"vgg_frontend": False,
"is_espnet_structure": True,
"mmi_loss": False,
"use_feat_batchnorm": True,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=params.num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
is_espnet_structure=params.is_espnet_structure,
mmi_loss=params.mmi_loss,
use_feat_batchnorm=params.use_feat_batchnorm,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
HLG = k2.Fsa.from_dict(torch.load(params.hlg))
HLG = HLG.to(device)
model.to(device)
wave, sample_freq = torchaudio.load(params.sound_file)
assert sample_freq == params.sample_freq
wave = wave.to(device)
features = torchaudio.compliance.kaldi.fbank(
waveform=wave,
num_mel_bins=params.feature_dim,
snip_edges=False,
sample_frequency=params.sample_freq,
)
features = features.unsqueeze(0)
nnet_output, _, _ = model(features)
supervision_segments = torch.tensor(
[[0, 0, nnet_output.shape[1]]], dtype=torch.int32
)
lattice = get_lattice(
nnet_output=nnet_output,
HLG=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
logging.info(hyps)
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()