update for the pruned_transducer_stateless7 for aishell and librispeech

This commit is contained in:
jinzr 2023-07-04 08:55:44 +08:00
parent aede8a8ed1
commit 67acaf9431
6 changed files with 69 additions and 1072 deletions

View File

@ -1,321 +0,0 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.script()
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--lang-dir data/lang_char \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("cpu_jit.pt")`.
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
Check
https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--lang-dir data/lang_char \
--epoch 20 \
--avg 10
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `pruned_transducer_stateless7/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./pruned_transducer_stateless7/decode.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--lang-dir data/lang_char
Check ./pretrained.py for its usage.
Note: If you don't want to train a model from scratch, we have
provided one for you. You can get it at
https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21
with the following commands:
sudo apt-get install git-lfs
git lfs install
git clone https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21
# You will find the pre-trained model in icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21exp
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train2 import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_char",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named cpu_jit.pt
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/export.py

View File

@ -46,7 +46,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless7/export.py \ ./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless7/exp \
--tokens data/lang_bpe_500/tokens.txt \ --tokens data/lang_char/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -66,7 +66,7 @@ you can do:
--avg 1 \ --avg 1 \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search \ --decoding-method greedy_search \
--tokens data/lang_bpe_500/tokens.txt \ --tokens data/lang_char/tokens.txt
Check ./pretrained.py for its usage. Check ./pretrained.py for its usage.
@ -89,11 +89,10 @@ from pathlib import Path
import re import re
import k2 import k2
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model from train2 import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,

View File

@ -1,348 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--lang-dir data/lang_char \
--epoch 20 \
--avg 10
Usage of this script:
(1) greedy search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by
./pruned_transducer_stateless7/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.lexicon import Lexicon
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=1,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
lexicon = Lexicon(params.lang_dir)
params.blank_id = 0
params.vocab_size = max(lexicon.tokens) + 1
token_table = lexicon.token_table
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp_tokens = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp_tokens = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
# Zengrui Jin)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -26,7 +27,7 @@ Usage:
./pruned_transducer_stateless7/export.py \ ./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 9 \ --avg 9 \
--jit 1 --jit 1
@ -45,7 +46,7 @@ for how to use the exported models outside of icefall.
./pruned_transducer_stateless7/export.py \ ./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \ --exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -65,7 +66,7 @@ you can do:
--avg 1 \ --avg 1 \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search \ --decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model --tokens data/lang_bpe_500/tokens.txt \
Check ./pretrained.py for its usage. Check ./pretrained.py for its usage.
@ -85,8 +86,9 @@ with the following commands:
import argparse import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import re
import sentencepiece as spm import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
@ -101,6 +103,26 @@ from icefall.checkpoint import (
from icefall.utils import str2bool from icefall.utils import str2bool
def num_tokens(
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
) -> int:
"""Return the number of tokens excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
symbols = token_table.symbols
ans = []
for s in symbols:
if not disambig_pattern.match(s):
ans.append(token_table[s])
num_tokens = len(ans)
if 0 in ans:
num_tokens -= 1
return num_tokens
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -155,10 +177,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", help="Path to the tokens.txt.",
help="Path to the BPE model",
) )
parser.add_argument( parser.add_argument(
@ -198,12 +219,12 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py # Load id of the <blk> token and the vocab size
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)
@ -292,7 +313,7 @@ def main():
model.to("cpu") model.to("cpu")
model.eval() model.eval()
if params.jit is True: if params.jit:
convert_scaled_to_non_scaled(model, inplace=True) convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore # We won't use the forward() method of the model in C++, so just ignore
# it here. # it here.

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Zengrui Jin)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -29,7 +30,7 @@ Usage of this script:
(1) greedy search (1) greedy search
./pruned_transducer_stateless7/pretrained.py \ ./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
@ -37,7 +38,7 @@ Usage of this script:
(2) beam search (2) beam search
./pruned_transducer_stateless7/pretrained.py \ ./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -46,7 +47,7 @@ Usage of this script:
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless7/pretrained.py \ ./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -55,7 +56,7 @@ Usage of this script:
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless7/pretrained.py \ ./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -75,7 +76,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -87,6 +87,7 @@ from beam_search import (
) )
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from export import num_tokens
from icefall.utils import str2bool from icefall.utils import str2bool
@ -106,9 +107,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -225,13 +226,13 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py # Load id of the <blk> token and the vocab size
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = token_table["<unk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}") logging.info(f"{params}")
@ -286,6 +287,12 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -297,8 +304,8 @@ def main():
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -307,16 +314,16 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1: elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
else: else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
@ -337,12 +344,11 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append(token_ids_to_words(hyp))
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -1,362 +0,0 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Zengrui Jin)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10
Usage of this script:
(1) greedy search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
(2) beam search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) modified beam search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(4) fast beam search
./pruned_transducer_stateless7/pretrained.py \
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by
./pruned_transducer_stateless7/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from export_unified import num_tokens
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --method is fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert (
sample_rate == expected_sample_rate
), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()