mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
* add the zipformer codes, copied from branch from_dan_scaled_adam_exp1119 * support model export with torch.jit.script * update RESULTS.md * support exporting streaming model with torch.jit.script * add results of streaming models, with some minor changes * update README.md * add CI test * update k2 version in requirements-ci.txt * update pyproject.toml
383 lines
10 KiB
Python
Executable File
383 lines
10 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
|
#
|
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
This script loads a checkpoint and uses it to decode waves.
|
|
You can generate the checkpoint with the following command:
|
|
|
|
- For non-streaming model:
|
|
|
|
./zipformer/export.py \
|
|
--exp-dir ./zipformer/exp \
|
|
--bpe-model data/lang_bpe_500/bpe.model \
|
|
--epoch 30 \
|
|
--avg 9
|
|
|
|
- For streaming model:
|
|
|
|
./zipformer/export.py \
|
|
--exp-dir ./zipformer/exp \
|
|
--causal 1 \
|
|
--bpe-model data/lang_bpe_500/bpe.model \
|
|
--epoch 30 \
|
|
--avg 9
|
|
|
|
Usage of this script:
|
|
|
|
- For non-streaming model:
|
|
|
|
(1) greedy search
|
|
./zipformer/pretrained.py \
|
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
|
--method greedy_search \
|
|
/path/to/foo.wav \
|
|
/path/to/bar.wav
|
|
|
|
(2) modified beam search
|
|
./zipformer/pretrained.py \
|
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
|
--method modified_beam_search \
|
|
/path/to/foo.wav \
|
|
/path/to/bar.wav
|
|
|
|
(3) fast beam search
|
|
./zipformer/pretrained.py \
|
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
|
--method fast_beam_search \
|
|
/path/to/foo.wav \
|
|
/path/to/bar.wav
|
|
|
|
- For streaming model:
|
|
|
|
(1) greedy search
|
|
./zipformer/pretrained.py \
|
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
|
--causal 1 \
|
|
--chunk-size 16 \
|
|
--left-context-frames 128 \
|
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
|
--method greedy_search \
|
|
/path/to/foo.wav \
|
|
/path/to/bar.wav
|
|
|
|
(2) modified beam search
|
|
./zipformer/pretrained.py \
|
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
|
--causal 1 \
|
|
--chunk-size 16 \
|
|
--left-context-frames 128 \
|
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
|
--method modified_beam_search \
|
|
/path/to/foo.wav \
|
|
/path/to/bar.wav
|
|
|
|
(3) fast beam search
|
|
./zipformer/pretrained.py \
|
|
--checkpoint ./zipformer/exp/pretrained.pt \
|
|
--causal 1 \
|
|
--chunk-size 16 \
|
|
--left-context-frames 128 \
|
|
--bpe-model ./data/lang_bpe_500/bpe.model \
|
|
--method fast_beam_search \
|
|
/path/to/foo.wav \
|
|
/path/to/bar.wav
|
|
|
|
|
|
You can also use `./zipformer/exp/epoch-xx.pt`.
|
|
|
|
Note: ./zipformer/exp/pretrained.pt is generated by ./zipformer/export.py
|
|
"""
|
|
|
|
|
|
import argparse
|
|
import logging
|
|
import math
|
|
from typing import List
|
|
|
|
import k2
|
|
import kaldifeat
|
|
import sentencepiece as spm
|
|
import torch
|
|
import torchaudio
|
|
from beam_search import (
|
|
fast_beam_search_one_best,
|
|
greedy_search_batch,
|
|
modified_beam_search,
|
|
)
|
|
from icefall.utils import make_pad_mask
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
from train import add_model_arguments, get_params, get_transducer_model
|
|
|
|
|
|
def get_parser():
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--checkpoint",
|
|
type=str,
|
|
required=True,
|
|
help="Path to the checkpoint. "
|
|
"The checkpoint is assumed to be saved by "
|
|
"icefall.checkpoint.save_checkpoint().",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--bpe-model",
|
|
type=str,
|
|
help="""Path to bpe.model.""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--method",
|
|
type=str,
|
|
default="greedy_search",
|
|
help="""Possible values are:
|
|
- greedy_search
|
|
- modified_beam_search
|
|
- fast_beam_search
|
|
""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"sound_files",
|
|
type=str,
|
|
nargs="+",
|
|
help="The input sound file(s) to transcribe. "
|
|
"Supported formats are those supported by torchaudio.load(). "
|
|
"For example, wav and flac are supported. "
|
|
"The sample rate has to be 16kHz.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--sample-rate",
|
|
type=int,
|
|
default=16000,
|
|
help="The sample rate of the input sound file",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--beam-size",
|
|
type=int,
|
|
default=4,
|
|
help="""An integer indicating how many candidates we will keep for each
|
|
frame. Used only when --method is beam_search or
|
|
modified_beam_search.""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--beam",
|
|
type=float,
|
|
default=4,
|
|
help="""A floating point value to calculate the cutoff score during beam
|
|
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
|
`beam` in Kaldi.
|
|
Used only when --method is fast_beam_search""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--max-contexts",
|
|
type=int,
|
|
default=4,
|
|
help="""Used only when --method is fast_beam_search""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--max-states",
|
|
type=int,
|
|
default=8,
|
|
help="""Used only when --method is fast_beam_search""",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--context-size",
|
|
type=int,
|
|
default=2,
|
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--max-sym-per-frame",
|
|
type=int,
|
|
default=1,
|
|
help="""Maximum number of symbols per frame. Used only when
|
|
--method is greedy_search.
|
|
""",
|
|
)
|
|
|
|
add_model_arguments(parser)
|
|
|
|
return parser
|
|
|
|
|
|
def read_sound_files(
|
|
filenames: List[str], expected_sample_rate: float
|
|
) -> List[torch.Tensor]:
|
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
|
Args:
|
|
filenames:
|
|
A list of sound filenames.
|
|
expected_sample_rate:
|
|
The expected sample rate of the sound files.
|
|
Returns:
|
|
Return a list of 1-D float32 torch tensors.
|
|
"""
|
|
ans = []
|
|
for f in filenames:
|
|
wave, sample_rate = torchaudio.load(f)
|
|
assert (
|
|
sample_rate == expected_sample_rate
|
|
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
|
# We use only the first channel
|
|
ans.append(wave[0])
|
|
return ans
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
parser = get_parser()
|
|
args = parser.parse_args()
|
|
|
|
params = get_params()
|
|
|
|
params.update(vars(args))
|
|
|
|
sp = spm.SentencePieceProcessor()
|
|
sp.load(params.bpe_model)
|
|
|
|
# <blk> is defined in local/train_bpe_model.py
|
|
params.blank_id = sp.piece_to_id("<blk>")
|
|
params.unk_id = sp.piece_to_id("<unk>")
|
|
params.vocab_size = sp.get_piece_size()
|
|
|
|
logging.info(f"{params}")
|
|
|
|
device = torch.device("cpu")
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda", 0)
|
|
|
|
logging.info(f"device: {device}")
|
|
|
|
if params.causal:
|
|
assert (
|
|
"," not in params.chunk_size
|
|
), "chunk_size should be one value in decoding."
|
|
assert (
|
|
"," not in params.left_context_frames
|
|
), "left_context_frames should be one value in decoding."
|
|
|
|
logging.info("Creating model")
|
|
model = get_transducer_model(params)
|
|
|
|
num_param = sum([p.numel() for p in model.parameters()])
|
|
logging.info(f"Number of model parameters: {num_param}")
|
|
|
|
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
|
model.load_state_dict(checkpoint["model"], strict=False)
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
logging.info("Constructing Fbank computer")
|
|
opts = kaldifeat.FbankOptions()
|
|
opts.device = device
|
|
opts.frame_opts.dither = 0
|
|
opts.frame_opts.snip_edges = False
|
|
opts.frame_opts.samp_freq = params.sample_rate
|
|
opts.mel_opts.num_bins = params.feature_dim
|
|
|
|
fbank = kaldifeat.Fbank(opts)
|
|
|
|
logging.info(f"Reading sound files: {params.sound_files}")
|
|
waves = read_sound_files(
|
|
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
|
)
|
|
waves = [w.to(device) for w in waves]
|
|
|
|
logging.info("Decoding started")
|
|
features = fbank(waves)
|
|
feature_lengths = [f.size(0) for f in features]
|
|
|
|
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
|
|
|
# model forward
|
|
x, x_lens = model.encoder_embed(features, feature_lengths)
|
|
|
|
src_key_padding_mask = make_pad_mask(x_lens)
|
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
|
|
|
encoder_out, encoder_out_lens = model.encoder(
|
|
x, x_lens, src_key_padding_mask
|
|
)
|
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
|
|
|
hyps = []
|
|
msg = f"Using {params.method}"
|
|
logging.info(msg)
|
|
|
|
if params.method == "fast_beam_search":
|
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
|
hyp_tokens = fast_beam_search_one_best(
|
|
model=model,
|
|
decoding_graph=decoding_graph,
|
|
encoder_out=encoder_out,
|
|
encoder_out_lens=encoder_out_lens,
|
|
beam=params.beam,
|
|
max_contexts=params.max_contexts,
|
|
max_states=params.max_states,
|
|
)
|
|
for hyp in sp.decode(hyp_tokens):
|
|
hyps.append(hyp.split())
|
|
elif params.method == "modified_beam_search":
|
|
hyp_tokens = modified_beam_search(
|
|
model=model,
|
|
encoder_out=encoder_out,
|
|
encoder_out_lens=encoder_out_lens,
|
|
beam=params.beam_size,
|
|
)
|
|
|
|
for hyp in sp.decode(hyp_tokens):
|
|
hyps.append(hyp.split())
|
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
|
hyp_tokens = greedy_search_batch(
|
|
model=model,
|
|
encoder_out=encoder_out,
|
|
encoder_out_lens=encoder_out_lens,
|
|
)
|
|
for hyp in sp.decode(hyp_tokens):
|
|
hyps.append(hyp.split())
|
|
else:
|
|
raise ValueError(f"Unsupported method: {params.method}")
|
|
|
|
s = "\n"
|
|
for filename, hyp in zip(params.sound_files, hyps):
|
|
words = " ".join(hyp)
|
|
s += f"{filename}:\n{words}\n\n"
|
|
logging.info(s)
|
|
|
|
logging.info("Decoding Done")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
|
|
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
|
main()
|