add inference script with a pretrained model

This commit is contained in:
marcoyang 2024-03-20 18:41:36 +08:00
parent 1921692d52
commit 9c4db1b3fb

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) # Copyright 2024 Xiaomi Corp. (authors: Xiaoyu Yang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -21,7 +21,6 @@ You can generate the checkpoint with the following command:
Note: This is a example for librispeech dataset, if you are using different Note: This is a example for librispeech dataset, if you are using different
dataset, you should change the argument values according to your dataset. dataset, you should change the argument values according to your dataset.
- For non-streaming model:
./zipformer/export.py \ ./zipformer/export.py \
--exp-dir ./zipformer/exp \ --exp-dir ./zipformer/exp \
@ -29,75 +28,10 @@ dataset, you should change the argument values according to your dataset.
--epoch 30 \ --epoch 30 \
--avg 9 --avg 9
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--causal 1 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
Usage of this script: Usage of this script:
- For non-streaming model:
(1) greedy search
./zipformer/pretrained.py \ ./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \ --checkpoint ./zipformer/exp/pretrained.pt \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
- For streaming model:
(1) greedy search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./zipformer/pretrained.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--causal 1 \
--chunk-size 16 \
--left-context-frames 128 \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
@ -109,6 +43,7 @@ Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py
import argparse import argparse
import csv
import logging import logging
import math import math
from typing import List from typing import List
@ -117,11 +52,6 @@ import k2
import kaldifeat import kaldifeat
import torch import torch
import torchaudio import torchaudio
from beam_search import (
fast_beam_search_one_best,
greedy_search_batch,
modified_beam_search,
)
from export import num_tokens from export import num_tokens
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
@ -144,20 +74,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--tokens", "--label-dict",
type=str, type=str,
help="""Path to tokens.txt.""", help="""class_labels_indices.csv.""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
) )
parser.add_argument( parser.add_argument(
@ -177,55 +96,6 @@ def get_parser():
help="The sample rate of the input sound file", help="The sample rate of the input sound file",
) )
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser) add_model_arguments(parser)
return parser return parser
@ -263,12 +133,6 @@ def main():
params.update(vars(args)) params.update(vars(args))
token_table = k2.SymbolTable.from_file(params.tokens)
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1
logging.info(f"{params}") logging.info(f"{params}")
device = torch.device("cpu") device = torch.device("cpu")
@ -277,14 +141,6 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
logging.info("Creating model") logging.info("Creating model")
model = get_model(params) model = get_model(params)
@ -296,6 +152,15 @@ def main():
model.to(device) model.to(device)
model.eval() model.eval()
# get the label dictionary
label_dict = {}
with open(params.label_dict, "r") as f:
reader = csv.reader(f, delimiter=",")
for i, row in enumerate(reader):
if i == 0:
continue
label_dict[int(row[0])] = row[2]
logging.info("Constructing Fbank computer") logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions() opts = kaldifeat.FbankOptions()
opts.device = device opts.device = device
@ -320,57 +185,15 @@ def main():
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
# model forward # model forward and predict the audio events
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths) encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
logits = model.forward_audio_tagging(encoder_out, encoder_out_lens)
hyps = [] results = []
msg = f"Using {params.method}" for i, logit in enumerate(logits):
logging.info(msg) topk_prob, topk_index = logit.sigmoid().topk(5)
topk_labels = [label_dict[index.item()] for index in topk_index]
def token_ids_to_words(token_ids: List[int]) -> str: print(f"Top 5 predicted labels of the {i} th audio are {topk_labels} with probability of {topk_prob.tolist()}")
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
raise ValueError(f"Unsupported method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")