updated the lstm_transducer_stateless recipes

also revoked previous changes in conformer_ctc3/jit_pretrained.py
This commit is contained in:
jinzr 2023-07-23 00:51:51 +08:00
parent 96f8904ce7
commit 696024abab
9 changed files with 165 additions and 142 deletions

View File

@ -24,7 +24,7 @@ Usage (for non-streaming mode):
(1) ctc-decoding (1) ctc-decoding
./conformer_ctc3/pretrained.py \ ./conformer_ctc3/pretrained.py \
--nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \ --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
--tokens data/lang_bpe_500/tokens.txt \ --bpe-model data/lang_bpe_500/bpe.model \
--method ctc-decoding \ --method ctc-decoding \
--sample-rate 16000 \ --sample-rate 16000 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -71,6 +71,7 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from decode import get_decoding_params from decode import get_decoding_params
@ -115,9 +116,11 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--tokens", "--bpe-model",
type=str, type=str,
help="Path to the tokens.txt.", help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
) )
parser.add_argument( parser.add_argument(
@ -126,9 +129,10 @@ def get_parser():
default="1best", default="1best",
help="""Decoding method. help="""Decoding method.
Possible values are: Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file (0) ctc-decoding - Use CTC decoding. It uses a sentence
to convert tokens to actual words or characters. It needs piece model, i.e., lang_dir/bpe.model, to convert
neither a lexicon nor an n-gram LM. word pieces to words. It needs neither a lexicon
nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only (1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding. the transformer encoder output is used for decoding.
We call it HLG decoding. We call it HLG decoding.
@ -277,7 +281,6 @@ def main():
waves = [w.to(device) for w in waves] waves = [w.to(device) for w in waves]
logging.info("Decoding started") logging.info("Decoding started")
hyps = []
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
@ -297,17 +300,10 @@ def main():
if params.method == "ctc-decoding": if params.method == "ctc-decoding":
logging.info("Use CTC decoding") logging.info("Use CTC decoding")
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.bpe_model)
max_token_id = params.num_classes - 1 max_token_id = params.num_classes - 1
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
H = k2.ctc_topo( H = k2.ctc_topo(
max_token=max_token_id, max_token=max_token_id,
modified=False, modified=False,
@ -328,9 +324,9 @@ def main():
best_path = one_best_decoding( best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores
) )
hyp_tokens = get_texts(best_path) token_ids = get_texts(best_path)
for hyp in hyp_tokens: hyps = bpe_model.decode(token_ids)
hyps.append(token_ids_to_words(hyp)) hyps = [s.split() for s in hyps]
elif params.method in [ elif params.method in [
"1best", "1best",
"nbest-rescoring", "nbest-rescoring",
@ -395,16 +391,16 @@ def main():
) )
best_path = next(iter(best_path_dict.values())) best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file) word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyp_tokens = get_texts(best_path) hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
for hyp in hyp_tokens:
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
else: else:
raise ValueError(f"Unsupported decoding method: {params.method}") raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
s += f"{filename}:\n{hyp}\n\n" words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -26,7 +26,7 @@ Usage:
./lstm_transducer_stateless/export.py \ ./lstm_transducer_stateless/export.py \
--exp-dir ./lstm_transducer_stateless/exp \ --exp-dir ./lstm_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 35 \ --epoch 35 \
--avg 10 \ --avg 10 \
--jit-trace 1 --jit-trace 1
@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`,
./lstm_transducer_stateless/export.py \ ./lstm_transducer_stateless/export.py \
--exp-dir ./lstm_transducer_stateless/exp \ --exp-dir ./lstm_transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 35 \ --epoch 35 \
--avg 10 --avg 10
@ -79,7 +79,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
@ -91,7 +91,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -148,10 +148,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -266,12 +266,13 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py # Load id of the <blk> token and the vocab size, <blk> is
params.blank_id = sp.piece_to_id("<blk>") # defined in local/train_bpe_model.py
params.vocab_size = sp.get_piece_size() params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search (1) greedy search
./lstm_transducer_stateless/pretrained.py \ ./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search (2) beam search
./lstm_transducer_stateless/pretrained.py \ ./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search (3) modified beam search
./lstm_transducer_stateless/pretrained.py \ ./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -46,7 +46,7 @@ Usage:
(4) fast beam search (4) fast beam search
./lstm_transducer_stateless/pretrained.py \ ./lstm_transducer_stateless/pretrained.py \
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -66,7 +66,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -79,6 +78,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -95,9 +96,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -214,13 +215,14 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = token_table["<unk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}") logging.info(f"{params}")
@ -275,6 +277,12 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -286,8 +294,8 @@ def main():
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -296,16 +304,16 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1: elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
else: else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
@ -326,12 +334,11 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append(token_ids_to_words(hyp))
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -29,7 +29,7 @@ popd
./lstm_transducer_stateless2/export-for-ncnn.py \ ./lstm_transducer_stateless2/export-for-ncnn.py \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--use-averaged-model 0 \ --use-averaged-model 0 \
@ -49,7 +49,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
@ -60,7 +60,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import setup_logger, str2bool from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser(): def get_parser():
@ -106,10 +106,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -221,12 +221,13 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py # Load id of the <blk> token and the vocab size, <blk> is
params.blank_id = sp.piece_to_id("<blk>") # defined in local/train_bpe_model.py
params.vocab_size = sp.get_piece_size() params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -28,7 +28,7 @@ popd
2. Export the model to ONNX 2. Export the model to ONNX
./lstm_transducer_stateless2/export-onnx.py \ ./lstm_transducer_stateless2/export-onnx.py \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--use-averaged-model 0 \ --use-averaged-model 0 \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
@ -52,8 +52,8 @@ import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
import k2
import onnx import onnx
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from decoder import Decoder from decoder import Decoder
@ -68,7 +68,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import setup_logger, str2bool from icefall.utils import num_tokens, setup_logger, str2bool
def get_parser(): def get_parser():
@ -125,10 +125,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -437,12 +437,13 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py # Load id of the <blk> token and the vocab size, <blk> is
params.blank_id = sp.piece_to_id("<blk>") # defined in local/train_bpe_model.py
params.vocab_size = sp.get_piece_size() params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -27,7 +27,7 @@ Usage:
./lstm_transducer_stateless2/export.py \ ./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \ --exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--epoch 35 \ --epoch 35 \
--avg 10 \ --avg 10 \
--jit-trace 1 --jit-trace 1
@ -39,7 +39,7 @@ It will generate 3 files: `encoder_jit_trace.pt`,
./lstm_transducer_stateless2/export.py \ ./lstm_transducer_stateless2/export.py \
--exp-dir ./lstm_transducer_stateless2/exp \ --exp-dir ./lstm_transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--epoch 35 \ --epoch 35 \
--avg 10 --avg 10
@ -80,7 +80,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
@ -92,7 +92,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -149,10 +149,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -267,12 +267,13 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py # Load id of the <blk> token and the vocab size, <blk> is
params.blank_id = sp.piece_to_id("<blk>") # defined in local/train_bpe_model.py
params.vocab_size = sp.get_piece_size() params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search (1) greedy search
./lstm_transducer_stateless2/pretrained.py \ ./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search (2) beam search
./lstm_transducer_stateless2/pretrained.py \ ./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search (3) modified beam search
./lstm_transducer_stateless2/pretrained.py \ ./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -46,7 +46,7 @@ Usage:
(4) fast beam search (4) fast beam search
./lstm_transducer_stateless2/pretrained.py \ ./lstm_transducer_stateless2/pretrained.py \
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \ --checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens ./data/lang_bpe_500/tokens.txt \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -69,7 +69,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -82,6 +81,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -98,9 +99,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to tokens.txt.""",
) )
parser.add_argument( parser.add_argument(
@ -217,13 +218,14 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = token_table["<unk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}") logging.info(f"{params}")
@ -278,6 +280,12 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -289,8 +297,8 @@ def main():
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -299,16 +307,16 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1: elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
else: else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
@ -329,12 +337,11 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append(token_ids_to_words(hyp))
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -26,7 +26,7 @@ Usage:
./lstm_transducer_stateless3/export.py \ ./lstm_transducer_stateless3/export.py \
--exp-dir ./lstm_transducer_stateless3/exp \ --exp-dir ./lstm_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 40 \ --epoch 40 \
--avg 20 \ --avg 20 \
--jit-trace 1 --jit-trace 1
@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`,
./lstm_transducer_stateless3/export.py \ ./lstm_transducer_stateless3/export.py \
--exp-dir ./lstm_transducer_stateless3/exp \ --exp-dir ./lstm_transducer_stateless3/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 40 \ --epoch 40 \
--avg 20 --avg 20
@ -79,7 +79,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
@ -91,7 +91,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -148,10 +148,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -266,12 +266,13 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# <blk> is defined in local/train_bpe_model.py # Load id of the <blk> token and the vocab size, <blk> is
params.blank_id = sp.piece_to_id("<blk>") # defined in local/train_bpe_model.py
params.vocab_size = sp.get_piece_size() params.blank_id = token_table["<blk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search (1) greedy search
./lstm_transducer_stateless3/pretrained.py \ ./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
@ -28,7 +28,7 @@ Usage:
(2) beam search (2) beam search
./lstm_transducer_stateless3/pretrained.py \ ./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -37,7 +37,7 @@ Usage:
(3) modified beam search (3) modified beam search
./lstm_transducer_stateless3/pretrained.py \ ./lstm_transducer_stateless3/pretrained.py \
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \ --checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -79,6 +79,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -95,9 +97,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to tokens.txt.""",
) )
parser.add_argument( parser.add_argument(
@ -214,13 +216,14 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = token_table["<unk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}") logging.info(f"{params}")
@ -275,6 +278,12 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_tokens = fast_beam_search_one_best( hyp_tokens = fast_beam_search_one_best(
@ -286,8 +295,8 @@ def main():
max_contexts=params.max_contexts, max_contexts=params.max_contexts,
max_states=params.max_states, max_states=params.max_states,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "modified_beam_search": elif params.method == "modified_beam_search":
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
@ -296,16 +305,16 @@ def main():
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
elif params.method == "greedy_search" and params.max_sym_per_frame == 1: elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in hyp_tokens:
hyps.append(hyp.split()) hyps.append(token_ids_to_words(hyp))
else: else:
for i in range(num_waves): for i in range(num_waves):
# fmt: off # fmt: off
@ -326,12 +335,11 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append(token_ids_to_words(hyp))
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")