mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
support CTC model export, add jit_pretrained_ctc.py, pretrained_ctc.py
This commit is contained in:
parent
046b6cb6d5
commit
40d2bda318
@ -22,7 +22,6 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from model import Transducer
|
||||
|
||||
from icefall import NgramLm, NgramLmStateCost
|
||||
from icefall.decode import Nbest, one_best_decoding
|
||||
@ -36,10 +35,11 @@ from icefall.utils import (
|
||||
get_texts,
|
||||
get_texts_with_timestamp,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
|
||||
def fast_beam_search_one_best(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
decoding_graph: k2.Fsa,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
@ -103,7 +103,7 @@ def fast_beam_search_one_best(
|
||||
|
||||
|
||||
def fast_beam_search_nbest_LG(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
decoding_graph: k2.Fsa,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
@ -229,7 +229,7 @@ def fast_beam_search_nbest_LG(
|
||||
|
||||
|
||||
def fast_beam_search_nbest(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
decoding_graph: k2.Fsa,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
@ -319,7 +319,7 @@ def fast_beam_search_nbest(
|
||||
|
||||
|
||||
def fast_beam_search_nbest_oracle(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
decoding_graph: k2.Fsa,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
@ -424,7 +424,7 @@ def fast_beam_search_nbest_oracle(
|
||||
|
||||
|
||||
def fast_beam_search(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
decoding_graph: k2.Fsa,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
@ -523,7 +523,7 @@ def fast_beam_search(
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
max_sym_per_frame: int,
|
||||
return_timestamps: bool = False,
|
||||
@ -623,7 +623,7 @@ def greedy_search(
|
||||
|
||||
|
||||
def greedy_search_batch(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
return_timestamps: bool = False,
|
||||
@ -914,7 +914,7 @@ def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||
|
||||
|
||||
def modified_beam_search(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
beam: int = 4,
|
||||
@ -1083,7 +1083,7 @@ def modified_beam_search(
|
||||
|
||||
|
||||
def modified_beam_search_lm_rescore(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
LM: LmScorer,
|
||||
@ -1281,7 +1281,7 @@ def modified_beam_search_lm_rescore(
|
||||
|
||||
|
||||
def modified_beam_search_lm_rescore_LODR(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
LM: LmScorer,
|
||||
@ -1497,7 +1497,7 @@ def modified_beam_search_lm_rescore_LODR(
|
||||
|
||||
|
||||
def _deprecated_modified_beam_search(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
return_timestamps: bool = False,
|
||||
@ -1622,7 +1622,7 @@ def _deprecated_modified_beam_search(
|
||||
|
||||
|
||||
def beam_search(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
beam: int = 4,
|
||||
temperature: float = 1.0,
|
||||
@ -1782,7 +1782,7 @@ def beam_search(
|
||||
|
||||
|
||||
def fast_beam_search_with_nbest_rescoring(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
decoding_graph: k2.Fsa,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
@ -1942,7 +1942,7 @@ def fast_beam_search_with_nbest_rescoring(
|
||||
|
||||
|
||||
def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
decoding_graph: k2.Fsa,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
@ -2133,7 +2133,7 @@ def fast_beam_search_with_nbest_rnn_rescoring(
|
||||
|
||||
|
||||
def modified_beam_search_ngram_rescoring(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ngram_lm: NgramLm,
|
||||
@ -2297,7 +2297,7 @@ def modified_beam_search_ngram_rescoring(
|
||||
|
||||
|
||||
def modified_beam_search_LODR(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
LODR_lm: NgramLm,
|
||||
@ -2568,7 +2568,7 @@ def modified_beam_search_LODR(
|
||||
|
||||
|
||||
def modified_beam_search_lm_shallow_fusion(
|
||||
model: Transducer,
|
||||
model: nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
LM: LmScorer,
|
||||
|
@ -20,49 +20,57 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(1) ctc-decoding
|
||||
./pruned_transducer_stateless7_ctc/ctc_decode.py \
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method ctc-decoding
|
||||
|
||||
(2) 1best
|
||||
./pruned_transducer_stateless7_ctc/ctc_decode.py \
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.8 \
|
||||
--hlg-scale 0.6 \
|
||||
--decoding-method 1best
|
||||
|
||||
(3) nbest
|
||||
./pruned_transducer_stateless7_ctc/ctc_decode.py \
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.8 \
|
||||
--hlg-scale 0.6 \
|
||||
--decoding-method nbest
|
||||
|
||||
(4) nbest-rescoring
|
||||
./pruned_transducer_stateless7_ctc/ctc_decode.py \
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.8 \
|
||||
--hlg-scale 0.6 \
|
||||
--nbest-scale 1.0 \
|
||||
--lm-dir data/lm \
|
||||
--decoding-method nbest-rescoring
|
||||
|
||||
(5) whole-lattice-rescoring
|
||||
./pruned_transducer_stateless7_ctc/ctc_decode.py \
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 30 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless7_ctc/exp \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--max-duration 600 \
|
||||
--hlg-scale 0.8 \
|
||||
--hlg-scale 0.6 \
|
||||
--nbest-scale 1.0 \
|
||||
--lm-dir data/lm \
|
||||
--decoding-method whole-lattice-rescoring
|
||||
"""
|
||||
@ -156,7 +164,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless7_ctc/exp",
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
@ -220,7 +228,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
default=1.0,
|
||||
help="""The scale to be applied to `lattice.scores`.
|
||||
It's needed if you use any kinds of n-best based rescoring.
|
||||
Used only when "method" is one of the following values:
|
||||
@ -232,7 +240,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--hlg-scale",
|
||||
type=float,
|
||||
default=0.8,
|
||||
default=0.6,
|
||||
help="""The scale to be applied to `hlg.scores`.
|
||||
""",
|
||||
)
|
||||
|
434
egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py
Executable file
434
egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py
Executable file
@ -0,0 +1,434 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--causal 1 \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) ctc-decoding
|
||||
./zipformer/jit_pretrained_ctc.py \
|
||||
--model-filename ./zipformer/exp/jit_script.pt \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--method ctc-decoding \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) 1best
|
||||
./zipformer/jit_pretrained_ctc.py \
|
||||
--model-filename ./zipformer/exp/jit_script.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--method 1best \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) nbest-rescoring
|
||||
./zipformer/jit_pretrained_ctc.py \
|
||||
--model-filename ./zipformer/exp/jit_script.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--G data/lm/G_4_gram.pt \
|
||||
--method nbest-rescoring \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(4) whole-lattice-rescoring
|
||||
./zipformer/jit_pretrained_ctc.py \
|
||||
--model-filename ./zipformer/exp/jit_script.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--G data/lm/G_4_gram.pt \
|
||||
--method whole-lattice-rescoring \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from ctc_decode import get_decoding_params
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import get_params
|
||||
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the torchscript model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--words-file",
|
||||
type=str,
|
||||
help="""Path to words.txt.
|
||||
Used only when method is not ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--HLG",
|
||||
type=str,
|
||||
help="""Path to HLG.pt.
|
||||
Used only when method is not ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.
|
||||
Used only when method is ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="1best",
|
||||
help="""Decoding method.
|
||||
Possible values are:
|
||||
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
||||
piece model, i.e., lang_dir/bpe.model, to convert
|
||||
word pieces to words. It needs neither a lexicon
|
||||
nor an n-gram LM.
|
||||
(1) 1best - Use the best path as decoding output. Only
|
||||
the transformer encoder output is used for decoding.
|
||||
We call it HLG decoding.
|
||||
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an LM, the path with
|
||||
the highest score is the decoding result.
|
||||
We call it HLG decoding + n-gram LM rescoring.
|
||||
(3) whole-lattice-rescoring - Use an LM to rescore the
|
||||
decoding lattice and then use 1best to decode the
|
||||
rescored lattice.
|
||||
We call it HLG decoding + n-gram LM rescoring.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--G",
|
||||
type=str,
|
||||
help="""An LM for rescoring.
|
||||
Used only when method is
|
||||
whole-lattice-rescoring or nbest-rescoring.
|
||||
It's usually a 4-gram LM.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies the size of n-best list.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=1.3,
|
||||
help="""
|
||||
Used only when method is whole-lattice-rescoring and nbest-rescoring.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
(Note: You need to tune it on a dataset.)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="""
|
||||
Used only when method is nbest-rescoring.
|
||||
It specifies the scale for lattice.scores when
|
||||
extracting n-best lists. A smaller value results in
|
||||
more unique number of paths with the risk of missing
|
||||
the best path.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-classes",
|
||||
type=int,
|
||||
default=500,
|
||||
help="""
|
||||
Vocab size in the BPE model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
# add decoding params
|
||||
params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = torch.jit.load(args.model_filename)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(features, feature_lengths)
|
||||
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
batch_size = ctc_output.shape[0]
|
||||
supervision_segments = torch.tensor(
|
||||
[
|
||||
[i, 0, feature_lengths[i].item() // params.subsampling_factor]
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
logging.info("Use CTC decoding")
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(params.bpe_model)
|
||||
max_token_id = params.num_classes - 1
|
||||
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=ctc_output,
|
||||
decoding_graph=H,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
token_ids = get_texts(best_path)
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
hyps = [s.split() for s in hyps]
|
||||
elif params.method in [
|
||||
"1best",
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
G = G.to(device)
|
||||
if params.method == "whole-lattice-rescoring":
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
# G.lm_scores is used to replace HLG.lm_scores during
|
||||
# LM rescoring.
|
||||
G.lm_scores = G.scores.clone()
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=ctc_output,
|
||||
decoding_graph=HLG,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.method == "1best":
|
||||
logging.info("Use HLG decoding")
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
if params.method == "nbest-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -27,191 +27,6 @@ from icefall.utils import add_sos, make_pad_mask
|
||||
from scaling import ScaledLinear
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_embed: nn.Module,
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int, # TODO: remove it
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder_embed:
|
||||
It is a Convolutional 2D subsampling module. It converts
|
||||
an input of shape (N, T, idim) to an output of of shape
|
||||
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
assert hasattr(decoder, "blank_id")
|
||||
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_scale=0.25)
|
||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size, initial_scale=0.25)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer loss.
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
assert torch.all(x_lens > 0)
|
||||
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(encoder_out.size(0), 4),
|
||||
dtype=torch.int64,
|
||||
device=encoder_out.device,
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
# if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -120,7 +120,6 @@ from beam_search import (
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from icefall.utils import make_pad_mask
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_model
|
||||
|
||||
@ -318,15 +317,7 @@ def main():
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
# model forward
|
||||
x, x_lens = model.encoder_embed(features, feature_lengths)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x, x_lens, src_key_padding_mask
|
||||
)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
|
||||
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
|
453
egs/librispeech/ASR/zipformer/pretrained_ctc.py
Executable file
453
egs/librispeech/ASR/zipformer/pretrained_ctc.py
Executable file
@ -0,0 +1,453 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
- For non-streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
- For streaming model:
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--use-ctc 1 \
|
||||
--causal 1 \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) ctc-decoding
|
||||
./zipformer/pretrained_ctc.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--method ctc-decoding \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) 1best
|
||||
./zipformer/pretrained_ctc.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--method 1best \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) nbest-rescoring
|
||||
./zipformer/pretrained_ctc.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--G data/lm/G_4_gram.pt \
|
||||
--method nbest-rescoring \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
|
||||
(4) whole-lattice-rescoring
|
||||
./zipformer/pretrained_ctc.py \
|
||||
--checkpoint ./zipformer/exp/pretrained.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--G data/lm/G_4_gram.pt \
|
||||
--method whole-lattice-rescoring \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from ctc_decode import get_decoding_params
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_model
|
||||
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--words-file",
|
||||
type=str,
|
||||
help="""Path to words.txt.
|
||||
Used only when method is not ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--HLG",
|
||||
type=str,
|
||||
help="""Path to HLG.pt.
|
||||
Used only when method is not ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.
|
||||
Used only when method is ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="1best",
|
||||
help="""Decoding method.
|
||||
Possible values are:
|
||||
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
||||
piece model, i.e., lang_dir/bpe.model, to convert
|
||||
word pieces to words. It needs neither a lexicon
|
||||
nor an n-gram LM.
|
||||
(1) 1best - Use the best path as decoding output. Only
|
||||
the transformer encoder output is used for decoding.
|
||||
We call it HLG decoding.
|
||||
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an LM, the path with
|
||||
the highest score is the decoding result.
|
||||
We call it HLG decoding + n-gram LM rescoring.
|
||||
(3) whole-lattice-rescoring - Use an LM to rescore the
|
||||
decoding lattice and then use 1best to decode the
|
||||
rescored lattice.
|
||||
We call it HLG decoding + n-gram LM rescoring.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--G",
|
||||
type=str,
|
||||
help="""An LM for rescoring.
|
||||
Used only when method is
|
||||
whole-lattice-rescoring or nbest-rescoring.
|
||||
It's usually a 4-gram LM.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies the size of n-best list.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=1.3,
|
||||
help="""
|
||||
Used only when method is whole-lattice-rescoring and nbest-rescoring.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
(Note: You need to tune it on a dataset.)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="""
|
||||
Used only when method is nbest-rescoring.
|
||||
It specifies the scale for lattice.scores when
|
||||
extracting n-best lists. A smaller value results in
|
||||
more unique number of paths with the risk of missing
|
||||
the best path.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-classes",
|
||||
type=int,
|
||||
default=500,
|
||||
help="""
|
||||
Vocab size in the BPE model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
# add decoding params
|
||||
params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
params.vocab_size = params.num_classes
|
||||
params.blank_id = 0
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.forward_encoder(features, feature_lengths)
|
||||
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
batch_size = ctc_output.shape[0]
|
||||
supervision_segments = torch.tensor(
|
||||
[
|
||||
[i, 0, feature_lengths[i].item() // params.subsampling_factor]
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
logging.info("Use CTC decoding")
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(params.bpe_model)
|
||||
max_token_id = params.num_classes - 1
|
||||
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=ctc_output,
|
||||
decoding_graph=H,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
token_ids = get_texts(best_path)
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
hyps = [s.split() for s in hyps]
|
||||
elif params.method in [
|
||||
"1best",
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
G = G.to(device)
|
||||
if params.method == "whole-lattice-rescoring":
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
# G.lm_scores is used to replace HLG.lm_scores during
|
||||
# LM rescoring.
|
||||
G.lm_scores = G.scores.clone()
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=ctc_output,
|
||||
decoding_graph=HLG,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.method == "1best":
|
||||
logging.info("Use HLG decoding")
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
if params.method == "nbest-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -44,6 +44,10 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--full-libri 1 \
|
||||
--max-duration 1000
|
||||
|
||||
It supports training with:
|
||||
- transducer loss (default), with `--use-transducer True --use-ctc False`
|
||||
- ctc loss (not recommended), with `--use-transducer False --use-ctc True`
|
||||
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
|
||||
"""
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user