updated all conformer_ctc* recipes to use tokens.txt in export.py and pretrained.py

This commit is contained in:
jinzr 2023-07-13 15:32:43 +08:00
parent 54c023034e
commit 13bcfda1e4
6 changed files with 81 additions and 71 deletions

View File

@ -66,6 +66,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--tokens", "--tokens",
type=str, type=str,
required=True,
help="Path to the tokens.txt.", help="Path to the tokens.txt.",
) )

View File

@ -80,10 +80,9 @@ def get_parser():
default="1best", default="1best",
help="""Decoding method. help="""Decoding method.
Possible values are: Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a sentence (0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file
piece model, i.e., lang_dir/bpe.model, to convert to convert tokens to actual words or characters. It needs
word pieces to words. It needs neither a lexicon neither a lexicon nor an n-gram LM.
nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only (1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding. the transformer encoder output is used for decoding.
We call it HLG decoding. We call it HLG decoding.
@ -254,9 +253,6 @@ def main():
params.update(vars(args)) params.update(vars(args))
logging.info(f"{params}") logging.info(f"{params}")
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", 0) device = torch.device("cuda", 0)
@ -312,16 +308,19 @@ def main():
dtype=torch.int32, dtype=torch.int32,
) )
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "ctc-decoding": if params.method == "ctc-decoding":
logging.info("Use CTC decoding") logging.info("Use CTC decoding")
max_token_id = params.num_classes - 1 max_token_id = params.num_classes - 1
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
H = k2.ctc_topo( H = k2.ctc_topo(
max_token=max_token_id, max_token=max_token_id,
modified=params.num_classes > 500, modified=params.num_classes > 500,

View File

@ -23,6 +23,7 @@
Usage: Usage:
./conformer_ctc2/export.py \ ./conformer_ctc2/export.py \
--exp-dir ./conformer_ctc2/exp \ --exp-dir ./conformer_ctc2/exp \
--tokens ./data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -46,6 +47,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import k2
import torch import torch
from conformer import Conformer from conformer import Conformer
from decode import get_params from decode import get_params
@ -56,8 +58,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon from icefall.utils import num_tokens, str2bool
from icefall.utils import str2bool
def get_parser(): def get_parser():
@ -123,10 +124,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--tokens",
type=str, type=str,
default="data/lang_bpe_500", required=True,
help="The lang dir", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -143,14 +144,14 @@ def get_parser():
def main(): def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
lexicon = Lexicon(params.lang_dir) # Load tokens.txt here
max_token_id = max(lexicon.tokens) token_table = k2.SymbolTable.from_file(params.tokens)
num_classes = max_token_id + 1 # +1 for the blank
num_classes = num_tokens(token_table) + 1 # +1 for the blank
device = torch.device("cpu") device = torch.device("cpu")
if torch.cuda.is_available(): if torch.cuda.is_available():

View File

@ -25,7 +25,7 @@ Usage:
./conformer_ctc3/export.py \ ./conformer_ctc3/export.py \
--exp-dir ./conformer_ctc3/exp \ --exp-dir ./conformer_ctc3/exp \
--lang-dir data/lang_bpe_500 \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 \ --avg 10 \
--jit-trace 1 --jit-trace 1
@ -36,7 +36,7 @@ It will generates the file: `jit_trace.pt`.
./conformer_ctc3/export.py \ ./conformer_ctc3/export.py \
--exp-dir ./conformer_ctc3/exp \ --exp-dir ./conformer_ctc3/exp \
--lang-dir data/lang_bpe_500 \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -62,6 +62,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import k2
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_ctc_model, get_params from train import add_model_arguments, get_ctc_model, get_params
@ -72,8 +73,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.lexicon import Lexicon from icefall.utils import num_tokens, str2bool
from icefall.utils import str2bool
def get_parser(): def get_parser():
@ -130,10 +130,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lang-dir", "--tokens",
type=Path, type=str,
default="data/lang_bpe_500", required=True,
help="The lang dir containing word table and LG graph", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -171,9 +171,10 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
lexicon = Lexicon(params.lang_dir) # Load tokens.txt here
max_token_id = max(lexicon.tokens) token_table = k2.SymbolTable.from_file(params.tokens)
num_classes = max_token_id + 1 # +1 for the blank
num_classes = num_tokens(token_table) + 1 # +1 for the blank
params.vocab_size = num_classes params.vocab_size = num_classes
if params.streaming_model: if params.streaming_model:

View File

@ -24,7 +24,7 @@ Usage (for non-streaming mode):
(1) ctc-decoding (1) ctc-decoding
./conformer_ctc3/pretrained.py \ ./conformer_ctc3/pretrained.py \
--nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \ --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method ctc-decoding \ --method ctc-decoding \
--sample-rate 16000 \ --sample-rate 16000 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -71,7 +71,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from decode import get_decoding_params from decode import get_decoding_params
@ -116,11 +115,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model. help="Path to the tokens.txt.",
Used only when method is ctc-decoding.
""",
) )
parser.add_argument( parser.add_argument(
@ -129,10 +126,9 @@ def get_parser():
default="1best", default="1best",
help="""Decoding method. help="""Decoding method.
Possible values are: Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a sentence (0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file
piece model, i.e., lang_dir/bpe.model, to convert to convert tokens to actual words or characters. It needs
word pieces to words. It needs neither a lexicon neither a lexicon nor an n-gram LM.
nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only (1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding. the transformer encoder output is used for decoding.
We call it HLG decoding. We call it HLG decoding.
@ -281,6 +277,7 @@ def main():
waves = [w.to(device) for w in waves] waves = [w.to(device) for w in waves]
logging.info("Decoding started") logging.info("Decoding started")
hyps = []
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
@ -300,10 +297,17 @@ def main():
if params.method == "ctc-decoding": if params.method == "ctc-decoding":
logging.info("Use CTC decoding") logging.info("Use CTC decoding")
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.bpe_model)
max_token_id = params.num_classes - 1 max_token_id = params.num_classes - 1
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
H = k2.ctc_topo( H = k2.ctc_topo(
max_token=max_token_id, max_token=max_token_id,
modified=False, modified=False,
@ -324,9 +328,9 @@ def main():
best_path = one_best_decoding( best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores
) )
token_ids = get_texts(best_path) hyp_tokens = get_texts(best_path)
hyps = bpe_model.decode(token_ids) for hyp in hyp_tokens:
hyps = [s.split() for s in hyps] hyps.append(token_ids_to_words(hyp))
elif params.method in [ elif params.method in [
"1best", "1best",
"nbest-rescoring", "nbest-rescoring",
@ -391,16 +395,16 @@ def main():
) )
best_path = next(iter(best_path_dict.values())) best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file) word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps] hyp_tokens = get_texts(best_path)
for hyp in hyp_tokens:
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
else: else:
raise ValueError(f"Unsupported decoding method: {params.method}") raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -24,7 +24,7 @@ Usage (for non-streaming mode):
(1) ctc-decoding (1) ctc-decoding
./conformer_ctc3/pretrained.py \ ./conformer_ctc3/pretrained.py \
--checkpoint conformer_ctc3/exp/pretrained.pt \ --checkpoint conformer_ctc3/exp/pretrained.pt \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method ctc-decoding \ --method ctc-decoding \
--sample-rate 16000 \ --sample-rate 16000 \
test_wavs/1089-134686-0001.wav test_wavs/1089-134686-0001.wav
@ -67,7 +67,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from decode import get_decoding_params from decode import get_decoding_params
@ -114,11 +113,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model. help="Path to the tokens.txt.",
Used only when method is ctc-decoding.
""",
) )
parser.add_argument( parser.add_argument(
@ -127,10 +124,9 @@ def get_parser():
default="1best", default="1best",
help="""Decoding method. help="""Decoding method.
Possible values are: Possible values are:
(0) ctc-decoding - Use CTC decoding. It uses a sentence (0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file
piece model, i.e., lang_dir/bpe.model, to convert to convert tokens to actual words or characters. It needs
word pieces to words. It needs neither a lexicon neither a lexicon nor an n-gram LM.
nor an n-gram LM.
(1) 1best - Use the best path as decoding output. Only (1) 1best - Use the best path as decoding output. Only
the transformer encoder output is used for decoding. the transformer encoder output is used for decoding.
We call it HLG decoding. We call it HLG decoding.
@ -316,6 +312,7 @@ def main():
waves = [w.to(device) for w in waves] waves = [w.to(device) for w in waves]
logging.info("Decoding started") logging.info("Decoding started")
hyps = []
features = fbank(waves) features = fbank(waves)
feature_lengths = [f.size(0) for f in features] feature_lengths = [f.size(0) for f in features]
@ -348,10 +345,17 @@ def main():
if params.method == "ctc-decoding": if params.method == "ctc-decoding":
logging.info("Use CTC decoding") logging.info("Use CTC decoding")
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(params.bpe_model)
max_token_id = params.num_classes - 1 max_token_id = params.num_classes - 1
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
H = k2.ctc_topo( H = k2.ctc_topo(
max_token=max_token_id, max_token=max_token_id,
modified=False, modified=False,
@ -372,9 +376,9 @@ def main():
best_path = one_best_decoding( best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores lattice=lattice, use_double_scores=params.use_double_scores
) )
token_ids = get_texts(best_path) hyp_tokens = get_texts(best_path)
hyps = bpe_model.decode(token_ids) for hyp in hyp_tokens:
hyps = [s.split() for s in hyps] hyps.append(token_ids_to_words(hyp))
elif params.method in [ elif params.method in [
"1best", "1best",
"nbest-rescoring", "nbest-rescoring",
@ -439,16 +443,16 @@ def main():
) )
best_path = next(iter(best_path_dict.values())) best_path = next(iter(best_path_dict.values()))
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file) word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps] hyp_tokens = get_texts(best_path)
for hyp in hyp_tokens:
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
else: else:
raise ValueError(f"Unsupported decoding method: {params.method}") raise ValueError(f"Unsupported decoding method: {params.method}")
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")