2022-11-17 09:42:17 -05:00

375 lines
11 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from conformer import Conformer
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_attention_decoder,
)
from icefall.utils import AttributeDict, get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--tokens-file",
type=str,
help="Path to tokens.txt" "Used only when method is ctc-decoding",
)
parser.add_argument(
"--words-file",
type=str,
help="Path to words.txt" "Used when method is NOT ctc-decoding",
)
parser.add_argument(
"--HLG",
type=str,
help="Path to HLG.pt." "Used when method is NOT ctc-decoding",
)
parser.add_argument(
"--method",
type=str,
default="1best",
help="""Decoding method.
Possible values are:
(0) ctc-decoding - Use ctc decoding. It maps the tokens ids to tokens
using the token symbol table directly.
(1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding.
We call it HLG decoding.
(2) attention-decoder - Extract n paths from the rescored
lattice and use the transformer attention decoder for
rescoring.
We call it HLG decoding + n-gram LM rescoring + attention
decoder rescoring.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""
Used only when method is attention-decoder.
It specifies the size of n-best list.""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=0.3,
help="""
Used only when method is attention-decoder.
It specifies the scale for n-gram LM scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"--attention-decoder-scale",
type=float,
default=0.9,
help="""
Used only when method is attention-decoder.
It specifies the scale for attention decoder scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=0.5,
help="""
Used only when method is attention-decoder.
It specifies the scale for lattice.scores when
extracting n-best lists. A smaller value results in
more unique number of paths with the risk of missing
the best path.
""",
)
parser.add_argument(
"--sos-id",
type=int,
default=1,
help="""
Used only when method is attention-decoder.
It specifies ID for the SOS token.
""",
)
parser.add_argument(
"--eos-id",
type=int,
default=1,
help="""
Used only when method is attention-decoder.
It specifies ID for the EOS token.
""",
)
parser.add_argument(
"--num_classes",
type=int,
default=4336,
help="The Vocab size.",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"subsampling_factor": 4,
"feature_dim": 80,
"nhead": 4,
"attention_dim": 512,
"num_decoder_layers": 6,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for deocding
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
logging.info(f"{params}")
if args.method != "attention-decoder":
# to save memory as the attention decoder
# will not be used
params.num_decoder_layers = 0
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=params.num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
# Note: We don't use key padding mask for attention during decoding
with torch.no_grad():
nnet_output, memory, memory_key_padding_mask = model(features)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
token_sym_table = k2.SymbolTable.from_file(params.tokens_file)
max_token_id = params.num_classes - 1
H = k2.ctc_topo(
max_token=max_token_id,
modified=True,
device=device,
)
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=H,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
token_ids = get_texts(best_path)
hyps = [[token_sym_table[i] for i in ids] for ids in token_ids]
elif params.method in ["1best", "attention-decoder"]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
lattice = get_lattice(
nnet_output=nnet_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
elif params.method == "attention-decoder":
logging.info("Use HLG + attention decoder rescoring")
best_path_dict = rescore_with_attention_decoder(
lattice=lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=params.sos_id,
eos_id=params.eos_id,
nbest_scale=params.nbest_scale,
ngram_lm_scale=params.ngram_lm_scale,
attention_scale=params.attention_decoder_scale,
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()