This commit is contained in:
yifanyang 2023-09-01 17:19:10 +08:00
parent d44a434fed
commit a887c0156a
12 changed files with 1530 additions and 56 deletions

View File

@ -97,6 +97,7 @@ Usage:
import argparse
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
@ -122,7 +123,7 @@ from beam_search import (
)
from train import add_model_arguments, get_model, get_params
from icefall import LmScorer, NgramLm
from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
@ -215,6 +216,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- modified_beam_search_LODR
- fast_beam_search
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
@ -251,7 +253,7 @@ def get_parser():
type=float,
default=0.01,
help="""
Used only when --decoding_method is fast_beam_search_nbest_LG.
Used only when --decoding-method is fast_beam_search_nbest_LG.
It specifies the scale for n-gram LM scores.
""",
)
@ -285,7 +287,7 @@ def get_parser():
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
Used only when --decoding-method is greedy_search""",
)
parser.add_argument(
@ -347,6 +349,27 @@ def get_parser():
help="ID of the backoff symbol in the ngram LM",
)
parser.add_argument(
"--context-score",
type=float,
default=2,
help="""
The bonus score of each token for the context biasing words/phrases.
Used only when --decoding-method is modified_beam_search and
modified_beam_search_LODR.
""",
)
parser.add_argument(
"--context-file",
type=str,
default="",
help="""
The path of the context biasing lists, one word/phrase each line
Used only when --decoding-method is modified_beam_search and
modified_beam_search_LODR.
""",
)
add_model_arguments(parser)
return parser
@ -359,6 +382,7 @@ def decode_one_batch(
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
@ -388,7 +412,7 @@ def decode_one_batch(
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
LM:
A neural network language model.
@ -493,6 +517,7 @@ def decode_one_batch(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -515,6 +540,7 @@ def decode_one_batch(
LODR_lm=ngram_lm,
LODR_lm_scale=ngram_lm_scale,
LM=LM,
context_graph=context_graph,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
@ -578,16 +604,22 @@ def decode_one_batch(
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
return {key: hyps}
elif params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
):
ans = dict()
assert ans_dict is not None
for key, hyps in ans_dict.items():
hyps = [sp.decode(hyp).split() for hyp in hyps]
ans[f"beam_size_{params.beam_size}_{key}"] = hyps
return ans
elif "modified_beam_search" in params.decoding_method:
prefix = f"beam_size_{params.beam_size}"
if params.decoding_method in (
"modified_beam_search_lm_rescore",
"modified_beam_search_lm_rescore_LODR",
):
ans = dict()
assert ans_dict is not None
for key, hyps in ans_dict.items():
hyps = [sp.decode(hyp).split() for hyp in hyps]
ans[f"{prefix}_{key}"] = hyps
return ans
else:
if params.has_contexts:
prefix += f"-context-score-{params.context_score}"
return {prefix: hyps}
else:
return {f"beam_size_{params.beam_size}": hyps}
@ -599,6 +631,7 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
context_graph: Optional[ContextGraph] = None,
LM: Optional[LmScorer] = None,
ngram_lm=None,
ngram_lm_scale: float = 0.0,
@ -618,7 +651,7 @@ def decode_dataset(
The word symbol table.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
only when --decoding-method is fast_beam_search, fast_beam_search_nbest,
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
@ -649,6 +682,7 @@ def decode_dataset(
model=model,
sp=sp,
decoding_graph=decoding_graph,
context_graph=context_graph,
word_table=word_table,
batch=batch,
LM=LM,
@ -744,6 +778,11 @@ def main():
)
params.res_dir = params.exp_dir / params.decoding_method
if os.path.exists(params.context_file):
params.has_contexts = True
else:
params.has_contexts = False
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
@ -770,6 +809,12 @@ def main():
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
elif "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
if params.decoding_method in (
"modified_beam_search",
"modified_beam_search_LODR",
):
if params.has_contexts:
params.suffix += f"-context-score-{params.context_score}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -952,6 +997,18 @@ def main():
decoding_graph = None
word_table = None
if "modified_beam_search" in params.decoding_method:
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
context_graph = ContextGraph(params.context_score)
context_graph.build(sp.encode(contexts))
else:
context_graph = None
else:
context_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -976,6 +1033,7 @@ def main():
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
context_graph=context_graph,
LM=LM,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,

View File

@ -79,12 +79,12 @@ class DecodeStream(object):
self.pad_length = 7 + 2 * 3
if params.decoding_method == "greedy_search":
self.hyp = [params.blank_id] * params.context_size
self.hyp = [-1] * (params.context_size - 1) + [params.blank_id]
elif params.decoding_method == "modified_beam_search":
self.hyps = HypothesisList()
self.hyps.add(
Hypothesis(
ys=[params.blank_id] * params.context_size,
ys=[-1] * (params.context_size - 1) + [params.blank_id],
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)

View File

@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"
cd exp
@ -74,7 +73,6 @@ import onnx
import torch
import torch.nn as nn
from decoder import Decoder
from export import num_tokens
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
@ -86,7 +84,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():

View File

@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"
cd exp
@ -71,7 +70,6 @@ import onnx
import torch
import torch.nn as nn
from decoder import Decoder
from export import num_tokens
from onnxruntime.quantization import QuantType, quantize_dynamic
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_model, get_params
@ -83,7 +81,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, str2bool
from icefall.utils import make_pad_mask, num_tokens, str2bool
def get_parser():

View File

@ -160,7 +160,6 @@ with the following commands:
import argparse
import logging
import re
from pathlib import Path
from typing import List, Tuple
@ -176,27 +175,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, str2bool
def num_tokens(
token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$")
) -> int:
"""Return the number of tokens excluding those from
disambiguation symbols.
Caution:
0 is not a token ID so it is excluded from the return value.
"""
symbols = token_table.symbols
ans = []
for s in symbols:
if not disambig_pattern.match(s):
ans.append(token_table[s])
num_tokens = len(ans)
if 0 in ans:
num_tokens -= 1
return num_tokens
from icefall.utils import make_pad_mask, num_tokens, str2bool
def get_parser():
@ -487,6 +466,8 @@ def main():
device=device,
)
)
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg

View File

@ -410,10 +410,20 @@ def main():
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
if params.method == "ctc-decoding":
for filename, hyp in zip(params.sound_files, hyps):
words = "".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
elif params.method in [
"1best",
"nbest-rescoring",
"whole-lattice-rescoring",
]:
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -33,7 +33,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"
cd exp

View File

@ -19,7 +19,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained.pt"
cd exp
@ -29,7 +28,7 @@ popd
2. Export the model to ONNX
./zipformer/export-onnx-streaming.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \

View File

@ -31,7 +31,6 @@ GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)
pushd $repo
git lfs pull --include "data/lang_bpe_500/tokens.txt"
git lfs pull --include "exp/pretrained.pt"
cd exp

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/pretrained.py

View File

@ -0,0 +1,455 @@
#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads a checkpoint and uses it to decode waves.
You can generate the checkpoint with the following command:
- For non-streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
- For streaming model:
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--use-ctc 1 \
--causal 1 \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
Usage of this script:
(1) ctc-decoding
./zipformer/pretrained_ctc.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--tokens data/lang_bpe_500/tokens.txt \
--method ctc-decoding \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
(2) 1best
./zipformer/pretrained_ctc.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
--method 1best \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
(3) nbest-rescoring
./zipformer/pretrained_ctc.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
--G data/lm/G_4_gram.pt \
--method nbest-rescoring \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
(4) whole-lattice-rescoring
./zipformer/pretrained_ctc.py \
--checkpoint ./zipformer/exp/pretrained.pt \
--HLG data/lang_bpe_500/HLG.pt \
--words-file data/lang_bpe_500/words.txt \
--G data/lm/G_4_gram.pt \
--method whole-lattice-rescoring \
--sample-rate 16000 \
/path/to/foo.wav \
/path/to/bar.wav
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from ctc_decode import get_decoding_params
from export import num_tokens
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
from icefall.decode import (
get_lattice,
one_best_decoding,
rescore_with_n_best_list,
rescore_with_whole_lattice,
)
from icefall.utils import get_texts
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
)
parser.add_argument(
"--words-file",
type=str,
help="""Path to words.txt.
Used only when method is not ctc-decoding.
""",
)
parser.add_argument(
"--HLG",
type=str,
help="""Path to HLG.pt.
Used only when method is not ctc-decoding.
""",
)
parser.add_argument(
"--tokens",
type=str,
help="""Path to tokens.txt.
Used only when method is ctc-decoding.
""",
)
parser.add_argument(
"--method",
type=str,
default="1best",
help="""Decoding method.
Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a token table,
i.e., lang_dir/tokens.txt, to convert
word pieces to words. It needs neither a lexicon
nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding.
We call it HLG decoding.
(2) nbest-rescoring. Extract n paths from the decoding lattice,
rescore them with an LM, the path with
the highest score is the decoding result.
We call it HLG decoding + nbest n-gram LM rescoring.
(3) whole-lattice-rescoring - Use an LM to rescore the
decoding lattice and then use 1best to decode the
rescored lattice.
We call it HLG decoding + whole-lattice n-gram LM rescoring.
""",
)
parser.add_argument(
"--G",
type=str,
help="""An LM for rescoring.
Used only when method is
whole-lattice-rescoring or nbest-rescoring.
It's usually a 4-gram LM.
""",
)
parser.add_argument(
"--num-paths",
type=int,
default=100,
help="""
Used only when method is attention-decoder.
It specifies the size of n-best list.""",
)
parser.add_argument(
"--ngram-lm-scale",
type=float,
default=1.3,
help="""
Used only when method is whole-lattice-rescoring and nbest-rescoring.
It specifies the scale for n-gram LM scores.
(Note: You need to tune it on a dataset.)
""",
)
parser.add_argument(
"--nbest-scale",
type=float,
default=1.0,
help="""
Used only when method is nbest-rescoring.
It specifies the scale for lattice.scores when
extracting n-best lists. A smaller value results in
more unique number of paths with the risk of missing
the best path.
""",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
add_model_arguments(parser)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float = 16000
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0].contiguous())
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
# add decoding params
params.update(get_decoding_params())
params.update(vars(args))
token_table = k2.SymbolTable.from_file(params.tokens)
params.vocab_size = num_tokens(token_table) + 1 # +1 for blank
params.blank_id = token_table["<blk>"]
assert params.blank_id == 0
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_model(params)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
batch_size = ctc_output.shape[0]
supervision_segments = torch.tensor(
[
[i, 0, feature_lengths[i].item() // params.subsampling_factor]
for i in range(batch_size)
],
dtype=torch.int32,
)
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
max_token_id = params.vocab_size - 1
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
device=device,
)
lattice = get_lattice(
nnet_output=ctc_output,
decoding_graph=H,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
token_ids = get_texts(best_path)
hyps = [[token_table[i] for i in ids] for ids in token_ids]
elif params.method in [
"1best",
"nbest-rescoring",
"whole-lattice-rescoring",
]:
logging.info(f"Loading HLG from {params.HLG}")
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
HLG = HLG.to(device)
if not hasattr(HLG, "lm_scores"):
# For whole-lattice-rescoring and attention-decoder
HLG.lm_scores = HLG.scores.clone()
if params.method in [
"nbest-rescoring",
"whole-lattice-rescoring",
]:
logging.info(f"Loading G from {params.G}")
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
G = G.to(device)
if params.method == "whole-lattice-rescoring":
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
# G.lm_scores is used to replace HLG.lm_scores during
# LM rescoring.
G.lm_scores = G.scores.clone()
lattice = get_lattice(
nnet_output=ctc_output,
decoding_graph=HLG,
supervision_segments=supervision_segments,
search_beam=params.search_beam,
output_beam=params.output_beam,
min_active_states=params.min_active_states,
max_active_states=params.max_active_states,
subsampling_factor=params.subsampling_factor,
)
if params.method == "1best":
logging.info("Use HLG decoding")
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
if params.method == "nbest-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_n_best_list(
lattice=lattice,
G=G,
num_paths=params.num_paths,
lm_scale_list=[params.ngram_lm_scale],
nbest_scale=params.nbest_scale,
)
best_path = next(iter(best_path_dict.values()))
elif params.method == "whole-lattice-rescoring":
logging.info("Use HLG decoding + LM rescoring")
best_path_dict = rescore_with_whole_lattice(
lattice=lattice,
G_with_epsilon_loops=G,
lm_scale_list=[params.ngram_lm_scale],
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
if params.method == "ctc-decoding":
for filename, hyp in zip(params.sound_files, hyps):
words = "".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
elif params.method in [
"1best",
"nbest-rescoring",
"whole-lattice-rescoring",
]:
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
words = words.replace("", " ").strip()
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1,104 @@
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file replaces various modules in a model.
Specifically, ActivationBalancer is replaced with an identity operator;
Whiten is also replaced with an identity operator;
BasicNorm is replaced by a module with `exp` removed.
"""
import copy
from typing import List, Tuple
import torch
import torch.nn as nn
from scaling import (
Balancer,
Dropout3,
ScaleGrad,
SwooshL,
SwooshLOnnx,
SwooshR,
SwooshROnnx,
Whiten,
)
from zipformer import CompactRelPositionalEncoding
# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa
# get_submodule was added to nn.Module at v1.9.0
def get_submodule(model, target):
if target == "":
return model
atoms: List[str] = target.split(".")
mod: torch.nn.Module = model
for item in atoms:
if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no " "attribute `" + item + "`"
)
mod = getattr(mod, item)
if not isinstance(mod, torch.nn.Module):
raise AttributeError("`" + item + "` is not " "an nn.Module")
return mod
def convert_scaled_to_non_scaled(
model: nn.Module,
inplace: bool = False,
is_pnnx: bool = False,
is_onnx: bool = False,
):
"""
Args:
model:
The model to be converted.
inplace:
If True, the input model is modified inplace.
If False, the input model is copied and we modify the copied version.
is_pnnx:
True if we are going to export the model for PNNX.
is_onnx:
True if we are going to export the model for ONNX.
Return:
Return a model without scaled layers.
"""
if not inplace:
model = copy.deepcopy(model)
d = {}
for name, m in model.named_modules():
if isinstance(m, (Balancer, Dropout3, ScaleGrad, Whiten)):
d[name] = nn.Identity()
elif is_onnx and isinstance(m, SwooshR):
d[name] = SwooshROnnx()
elif is_onnx and isinstance(m, SwooshL):
d[name] = SwooshLOnnx()
elif is_onnx and isinstance(m, CompactRelPositionalEncoding):
# We want to recreate the positional encoding vector when
# the input changes, so we have to use torch.jit.script()
# to replace torch.jit.trace()
d[name] = torch.jit.script(m)
for k, v in d.items():
if "." in k:
parent, child = k.rsplit(".", maxsplit=1)
setattr(get_submodule(model, parent), child, v)
else:
setattr(model, k, v)
return model

View File

@ -1 +0,0 @@
../../../librispeech/ASR/zipformer/streaming_beam_search.py

View File

@ -0,0 +1,876 @@
#!/usr/bin/env python3
# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang,
# Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./zipformer/streaming_decode.py \
--epoch 28 \
--avg 15 \
--causal 1 \
--chunk-size 32 \
--left-context-frames 256 \
--exp-dir ./zipformer/exp \
--decoding-method greedy_search \
--num-decode-streams 2000
"""
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import sentencepiece as spm
import torch
from asr_datamodule import LibriSpeechAsrDataModule
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""",
)
parser.add_argument(
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
add_model_arguments(parser)
return parser
def get_init_states(
model: nn.Module,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> List[torch.Tensor]:
"""
Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6]
is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2).
states[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
"""
states = model.encoder.get_init_states(batch_size, device)
embed_states = model.encoder_embed.get_init_states(batch_size, device)
states.append(embed_states)
processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device)
states.append(processed_lens)
return states
def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]:
"""Stack list of zipformer states that correspond to separate utterances
into a single emformer state, so that it can be used as an input for
zipformer when those utterances are formed into a batch.
Args:
state_list:
Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance. For element-n,
state_list[n] is a list of cached tensors of all encoder layers. For layer-i,
state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1,
cached_val2, cached_conv1, cached_conv2).
state_list[n][-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
state_list[n][-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Note:
It is the inverse of :func:`unstack_states`.
"""
batch_size = len(state_list)
assert (len(state_list[0]) - 2) % 6 == 0, len(state_list[0])
tot_num_layers = (len(state_list[0]) - 2) // 6
batch_states = []
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key = torch.cat(
[state_list[i][layer_offset] for i in range(batch_size)], dim=1
)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn = torch.cat(
[state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1 = torch.cat(
[state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2 = torch.cat(
[state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1 = torch.cat(
[state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2 = torch.cat(
[state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0
)
batch_states += [
cached_key,
cached_nonlin_attn,
cached_val1,
cached_val2,
cached_conv1,
cached_conv2,
]
cached_embed_left_pad = torch.cat(
[state_list[i][-2] for i in range(batch_size)], dim=0
)
batch_states.append(cached_embed_left_pad)
processed_lens = torch.cat(
[state_list[i][-1] for i in range(batch_size)], dim=0
)
batch_states.append(processed_lens)
return batch_states
def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]:
"""Unstack the zipformer state corresponding to a batch of utterances
into a list of states, where the i-th entry is the state from the i-th
utterance in the batch.
Note:
It is the inverse of :func:`stack_states`.
Args:
batch_states: A list of cached tensors of all encoder layers. For layer-i,
states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2,
cached_conv1, cached_conv2).
state_list[-2] is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
states[-1] is processed_lens of shape (batch,), which records the number
of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch.
Returns:
state_list: A list of list. Each element in state_list corresponding to the internal state
of the zipformer model for a single utterance.
"""
assert (len(batch_states) - 2) % 6 == 0, len(batch_states)
tot_num_layers = (len(batch_states) - 2) // 6
processed_lens = batch_states[-1]
batch_size = processed_lens.shape[0]
state_list = [[] for _ in range(batch_size)]
for layer in range(tot_num_layers):
layer_offset = layer * 6
# cached_key: (left_context_len, batch_size, key_dim)
cached_key_list = batch_states[layer_offset].chunk(
chunks=batch_size, dim=1
)
# cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim)
cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk(
chunks=batch_size, dim=1
)
# cached_val1: (left_context_len, batch_size, value_dim)
cached_val1_list = batch_states[layer_offset + 2].chunk(
chunks=batch_size, dim=1
)
# cached_val2: (left_context_len, batch_size, value_dim)
cached_val2_list = batch_states[layer_offset + 3].chunk(
chunks=batch_size, dim=1
)
# cached_conv1: (#batch, channels, left_pad)
cached_conv1_list = batch_states[layer_offset + 4].chunk(
chunks=batch_size, dim=0
)
# cached_conv2: (#batch, channels, left_pad)
cached_conv2_list = batch_states[layer_offset + 5].chunk(
chunks=batch_size, dim=0
)
for i in range(batch_size):
state_list[i] += [
cached_key_list[i],
cached_nonlin_attn_list[i],
cached_val1_list[i],
cached_val2_list[i],
cached_conv1_list[i],
cached_conv2_list[i],
]
cached_embed_left_pad_list = batch_states[-2].chunk(
chunks=batch_size, dim=0
)
for i in range(batch_size):
state_list[i].append(cached_embed_left_pad_list[i])
processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0)
for i in range(batch_size):
state_list[i].append(processed_lens_list[i])
return state_list
def streaming_forward(
features: Tensor,
feature_lens: Tensor,
model: nn.Module,
states: List[Tensor],
chunk_size: int,
left_context_len: int,
) -> Tuple[Tensor, Tensor, List[Tensor]]:
"""
Returns encoder outputs, output lengths, and updated states.
"""
cached_embed_left_pad = states[-2]
(
x,
x_lens,
new_cached_embed_left_pad,
) = model.encoder_embed.streaming_forward(
x=features,
x_lens=feature_lens,
cached_left_pad=cached_embed_left_pad,
)
assert x.size(1) == chunk_size, (x.size(1), chunk_size)
src_key_padding_mask = make_pad_mask(x_lens)
# processed_mask is used to mask out initial states
processed_mask = torch.arange(left_context_len, device=x.device).expand(
x.size(0), left_context_len
)
processed_lens = states[-1] # (batch,)
# (batch, left_context_size)
processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1)
# Update processed lengths
new_processed_lens = processed_lens + x_lens
# (batch, left_context_size + chunk_size)
src_key_padding_mask = torch.cat(
[processed_mask, src_key_padding_mask], dim=1
)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_states = states[:-2]
(
encoder_out,
encoder_out_lens,
new_encoder_states,
) = model.encoder.streaming_forward(
x=x,
x_lens=x_lens,
states=encoder_states,
src_key_padding_mask=src_key_padding_mask,
)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
new_states = new_encoder_states + [
new_cached_embed_left_pad,
new_processed_lens,
]
return encoder_out, encoder_out_lens, new_states
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
chunk_size = int(params.chunk_size)
left_context_len = int(params.left_context_frames)
features = []
feature_lens = []
states = []
processed_lens = [] # Used in fast-beam-search
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# Make sure the length after encoder_embed is at least 1.
# The encoder_embed subsample features (T - 7) // 2
# The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling
tail_length = chunk_size * 2 + 7 + 2 * 3
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPS,
)
states = stack_states(states)
encoder_out, encoder_out_lens, new_states = streaming_forward(
features=features,
feature_lens=feature_lens,
model=model,
states=states,
chunk_size=chunk_size,
left_context_len=left_context_len,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
greedy_search(
model=model, encoder_out=encoder_out, streams=decode_streams
)
elif params.decoding_method == "fast_beam_search":
processed_lens = torch.tensor(processed_lens, device=device)
processed_lens = processed_lens + encoder_out_lens
fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
states = unstack_states(new_states)
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 100
decode_results = []
# Contain decode streams currently running.
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = get_init_states(
model=model, batch_size=1, device=device
)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank(samples.to(device))
decode_stream.set_features(feature, tail_pad_len=30)
decode_stream.ground_truth = cut.supervisions[0].text
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
del decode_streams[i]
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
return {key: decode_results}
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=True
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir
/ f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
assert params.causal, params.causal
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
test_clean_cuts = librispeech.test_clean_cuts()
test_other_cuts = librispeech.test_other_cuts()
test_sets = ["test-clean", "test-other"]
test_cuts = [test_clean_cuts, test_other_cuts]
for test_set, test_cut in zip(test_sets, test_cuts):
results_dict = decode_dataset(
cuts=test_cut,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()