updated the lstm_transducer_stateless recipes

also revoked previous changes in conformer_ctc3/jit_pretrained.py
This commit is contained in:
jinzr 2023-07-23 00:51:51 +08:00
parent 96f8904ce7
commit 696024abab
9 changed files with 165 additions and 142 deletions

View File

@ -24,7 +24,7 @@ Usage (for non-streaming mode):
(1) ctc-decoding
./conformer_ctc3/pretrained.py \
--nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
--tokens data/lang_bpe_500/tokens.txt \
--bpe-model data/lang_bpe_500/bpe.model \
--method ctc-decoding \
--sample-rate 16000 \
/path/to/foo.wav \
@ -71,6 +71,7 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from decode import get_decoding_params
@ -115,9 +116,11 @@ def get_parser():
)
parser.add_argument(
"--tokens",
"--bpe-model",
type=str,
help="Path to the tokens.txt.",
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
)
parser.add_argument(
@ -126,9 +129,10 @@ def get_parser():
default="1best",
help="""Decoding method.
Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file
to convert tokens to actual words or characters. It needs
neither a lexicon nor an n-gram LM.
(0) ctc-decoding - Use CTC decoding. It uses a sentence
piece model, i.e., lang_dir/bpe.model, to convert
word pieces to words. It needs neither a lexicon
nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding.
We call it HLG decoding.
@ -277,7 +281,6 @@ def main():
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
hyps = []
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
@ -297,17 +300,10 @@ def main():
if params.method == "ctc-decoding":
logging.info("Use CTC decoding")
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.bpe_model)
max_token_id = params.num_classes - 1
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
H = k2.ctc_topo(
max_token=max_token_id,
modified=False,
@ -328,9 +324,9 @@ def main():
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
hyp_tokens = get_texts(best_path)
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
token_ids = get_texts(best_path)
hyps = bpe_model.decode(token_ids)
hyps = [s.split() for s in hyps]
elif params.method in [
"1best",
"nbest-rescoring",
@ -395,16 +391,16 @@ def main():
)
best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyp_tokens = get_texts(best_path)
for hyp in hyp_tokens:
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
else:
raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
s += f"{filename}:\n{hyp}\n\n"
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -26,7 +26,7 @@ Usage:
./lstm_transducer_stateless/export.py \
--exp-dir ./lstm_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 35 \
--avg 10 \
--jit-trace 1
@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`,
./lstm_transducer_stateless/export.py \
--exp-dir ./lstm_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 35 \
--avg 10
@ -79,7 +79,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
@ -91,7 +91,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -148,10 +148,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -266,12 +266,13 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size, <blk> is
# defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage:
(4) fast beam search
./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -66,7 +66,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -79,6 +78,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -95,9 +96,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -214,13 +215,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -275,6 +277,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -286,8 +294,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -296,16 +304,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -326,12 +334,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -29,7 +29,7 @@ popd
./lstm_transducer_stateless2/export-for-ncnn.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
@ -49,7 +49,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model
@ -60,7 +60,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -106,10 +106,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -221,12 +221,13 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size, <blk> is
# defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX
./lstm_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \
--epoch 99 \
--avg 1 \
@ -52,8 +52,8 @@ import logging
from pathlib import Path
from typing import Dict, Optional, Tuple
import k2
import onnx
import sentencepiece as spm
import torch
import torch.nn as nn
from decoder import Decoder
@ -68,7 +68,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import setup_logger, str2bool
from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser():
@ -125,10 +125,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -437,12 +437,13 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size, <blk> is
# defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -27,7 +27,7 @@ Usage:
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 35 \
--avg 10 \
--jit-trace 1
@ -39,7 +39,7 @@ It will generate 3 files: `encoder_jit_trace.pt`,
./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 35 \
--avg 10
@ -80,7 +80,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
@ -92,7 +92,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -149,10 +149,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -267,12 +267,13 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size, <blk> is
# defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search
./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search
./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -46,7 +46,7 @@ Usage:
(4) fast beam search
./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -69,7 +69,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -82,6 +81,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -98,9 +99,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -217,13 +218,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -278,6 +280,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -289,8 +297,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -299,16 +307,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -329,12 +337,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -26,7 +26,7 @@ Usage:
./lstm_transducer_stateless3/export.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 40 \
--avg 20 \
--jit-trace 1
@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`,
./lstm_transducer_stateless3/export.py \
--exp-dir ./lstm_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 40 \
--avg 20
@ -79,7 +79,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
@ -91,7 +91,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -148,10 +148,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to tokens.txt.",
)
parser.add_argument(
@ -266,12 +266,13 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
# Load id of the <blk> token and the vocab size, <blk> is
# defined in local/train_bpe_model.py
params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search
./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -79,6 +79,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -95,9 +97,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -214,13 +216,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -275,6 +278,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best(
@ -286,8 +295,8 @@ def main():
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
@ -296,16 +305,16 @@ def main():
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
for hyp in hyp_tokens:
hyps.append(token_ids_to_words(hyp))
else:
for i in range(num_waves):
# fmt: off
@ -326,12 +335,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")