2021-07-24 17:13:20 +08:00

211 lines
5.6 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import logging
from pathlib import Path
from typing import List, Tuple
import k2
import torch
import torch.nn as nn
from model import TdnnLstm
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
write_error_stats,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=9,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=5,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("tdnn_lstm_ctc/exp3/"),
"lang_dir": Path("data/lang"),
"feature_dim": 80,
"subsampling_factor": 3,
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
}
)
return params
@torch.no_grad()
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
HLG: k2.Fsa,
batch: dict,
lexicon: Lexicon,
) -> List[Tuple[List[str], List[str]]]:
"""Decode one batch and return a list of tuples containing
`(ref_words, hyp_words)`.
Args:
params:
It is the return value of :func:`get_params`.
"""
device = HLG.device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is [N, T, C]
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
nnet_output = model(feature)
# nnet_output is [N, T, C]
supervisions = batch["supervisions"]
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"] // params.subsampling_factor,
supervisions["num_frames"] // params.subsampling_factor,
),
1,
).to(torch.int32)
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
lattices = k2.intersect_dense_pruned(
HLG,
dense_fsa_vec,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
)
best_paths = k2.shortest_path(lattices, use_double_scores=True)
hyps = get_texts(best_paths)
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
texts = supervisions["text"]
results = []
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
results.append((ref_words, hyp_words))
return results
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log/log-decode")
logging.info("Decoding started")
logging.info(params)
lexicon = Lexicon(params.lang_dir)
max_phone_id = max(lexicon.tokens)
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
HLG = k2.Fsa.from_dict(torch.load("data/lm/HLG.pt"))
HLG = HLG.to(device)
assert HLG.requires_grad is False
model = TdnnLstm(
num_features=params.feature_dim,
num_classes=max_phone_id + 1, # +1 for the blank symbol
subsampling_factor=params.subsampling_factor,
)
if params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
model.to(device)
model.eval()
librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
tot_num_cuts = len(test_dl.dataset.cuts)
num_cuts = 0
results = []
for batch_idx, batch in enumerate(test_dl):
this_batch = decode_one_batch(
params=params,
model=model,
HLG=HLG,
batch=batch,
lexicon=lexicon,
)
results.extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
logging.info(
f"batch {batch_idx}, cuts processed until now is "
f"{num_cuts}/{tot_num_cuts} "
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
)
errs_filename = params.exp_dir / f"errs-{test_set}.txt"
with open(errs_filename, "w") as f:
write_error_stats(f, test_set, results)
logging.info("Done!")
if __name__ == "__main__":
main()