mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
updated all conformer_ctc*
recipes to use tokens.txt
in export.py
and pretrained.py
This commit is contained in:
parent
54c023034e
commit
13bcfda1e4
@ -66,6 +66,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokens",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
|
required=True,
|
||||||
help="Path to the tokens.txt.",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -80,10 +80,9 @@ def get_parser():
|
|||||||
default="1best",
|
default="1best",
|
||||||
help="""Decoding method.
|
help="""Decoding method.
|
||||||
Possible values are:
|
Possible values are:
|
||||||
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
(0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file
|
||||||
piece model, i.e., lang_dir/bpe.model, to convert
|
to convert tokens to actual words or characters. It needs
|
||||||
word pieces to words. It needs neither a lexicon
|
neither a lexicon nor an n-gram LM.
|
||||||
nor an n-gram LM.
|
|
||||||
(1) 1best - Use the best path as decoding output. Only
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
the transformer encoder output is used for decoding.
|
the transformer encoder output is used for decoding.
|
||||||
We call it HLG decoding.
|
We call it HLG decoding.
|
||||||
@ -254,9 +253,6 @@ def main():
|
|||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
# Load tokens.txt here
|
|
||||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", 0)
|
device = torch.device("cuda", 0)
|
||||||
@ -312,16 +308,19 @@ def main():
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
|
|
||||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
|
||||||
text = ""
|
|
||||||
for i in token_ids:
|
|
||||||
text += token_table[i]
|
|
||||||
return text.replace("▁", " ").strip()
|
|
||||||
|
|
||||||
if params.method == "ctc-decoding":
|
if params.method == "ctc-decoding":
|
||||||
logging.info("Use CTC decoding")
|
logging.info("Use CTC decoding")
|
||||||
max_token_id = params.num_classes - 1
|
max_token_id = params.num_classes - 1
|
||||||
|
|
||||||
|
# Load tokens.txt here
|
||||||
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
H = k2.ctc_topo(
|
H = k2.ctc_topo(
|
||||||
max_token=max_token_id,
|
max_token=max_token_id,
|
||||||
modified=params.num_classes > 500,
|
modified=params.num_classes > 500,
|
||||||
|
@ -23,6 +23,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./conformer_ctc2/export.py \
|
./conformer_ctc2/export.py \
|
||||||
--exp-dir ./conformer_ctc2/exp \
|
--exp-dir ./conformer_ctc2/exp \
|
||||||
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -46,6 +47,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decode import get_params
|
from decode import get_params
|
||||||
@ -56,8 +58,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.utils import num_tokens, str2bool
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -123,10 +124,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500",
|
required=True,
|
||||||
help="The lang dir",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -143,14 +144,14 @@ def get_parser():
|
|||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.lang_dir = Path(args.lang_dir)
|
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
# Load tokens.txt here
|
||||||
max_token_id = max(lexicon.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
num_classes = max_token_id + 1 # +1 for the blank
|
|
||||||
|
num_classes = num_tokens(token_table) + 1 # +1 for the blank
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -25,7 +25,7 @@ Usage:
|
|||||||
|
|
||||||
./conformer_ctc3/export.py \
|
./conformer_ctc3/export.py \
|
||||||
--exp-dir ./conformer_ctc3/exp \
|
--exp-dir ./conformer_ctc3/exp \
|
||||||
--lang-dir data/lang_bpe_500 \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--jit-trace 1
|
--jit-trace 1
|
||||||
@ -36,7 +36,7 @@ It will generates the file: `jit_trace.pt`.
|
|||||||
|
|
||||||
./conformer_ctc3/export.py \
|
./conformer_ctc3/export.py \
|
||||||
--exp-dir ./conformer_ctc3/exp \
|
--exp-dir ./conformer_ctc3/exp \
|
||||||
--lang-dir data/lang_bpe_500 \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -62,6 +62,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_ctc_model, get_params
|
from train import add_model_arguments, get_ctc_model, get_params
|
||||||
@ -72,8 +73,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.utils import num_tokens, str2bool
|
||||||
from icefall.utils import str2bool
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -130,10 +130,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lang-dir",
|
"--tokens",
|
||||||
type=Path,
|
type=str,
|
||||||
default="data/lang_bpe_500",
|
required=True,
|
||||||
help="The lang dir containing word table and LG graph",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -171,9 +171,10 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
# Load tokens.txt here
|
||||||
max_token_id = max(lexicon.tokens)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
num_classes = max_token_id + 1 # +1 for the blank
|
|
||||||
|
num_classes = num_tokens(token_table) + 1 # +1 for the blank
|
||||||
params.vocab_size = num_classes
|
params.vocab_size = num_classes
|
||||||
|
|
||||||
if params.streaming_model:
|
if params.streaming_model:
|
||||||
|
@ -24,7 +24,7 @@ Usage (for non-streaming mode):
|
|||||||
(1) ctc-decoding
|
(1) ctc-decoding
|
||||||
./conformer_ctc3/pretrained.py \
|
./conformer_ctc3/pretrained.py \
|
||||||
--nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
|
--nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method ctc-decoding \
|
--method ctc-decoding \
|
||||||
--sample-rate 16000 \
|
--sample-rate 16000 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -71,7 +71,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from decode import get_decoding_params
|
from decode import get_decoding_params
|
||||||
@ -116,11 +115,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.
|
help="Path to the tokens.txt.",
|
||||||
Used only when method is ctc-decoding.
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -129,10 +126,9 @@ def get_parser():
|
|||||||
default="1best",
|
default="1best",
|
||||||
help="""Decoding method.
|
help="""Decoding method.
|
||||||
Possible values are:
|
Possible values are:
|
||||||
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
(0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file
|
||||||
piece model, i.e., lang_dir/bpe.model, to convert
|
to convert tokens to actual words or characters. It needs
|
||||||
word pieces to words. It needs neither a lexicon
|
neither a lexicon nor an n-gram LM.
|
||||||
nor an n-gram LM.
|
|
||||||
(1) 1best - Use the best path as decoding output. Only
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
the transformer encoder output is used for decoding.
|
the transformer encoder output is used for decoding.
|
||||||
We call it HLG decoding.
|
We call it HLG decoding.
|
||||||
@ -281,6 +277,7 @@ def main():
|
|||||||
waves = [w.to(device) for w in waves]
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
hyps = []
|
||||||
features = fbank(waves)
|
features = fbank(waves)
|
||||||
feature_lengths = [f.size(0) for f in features]
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
@ -300,10 +297,17 @@ def main():
|
|||||||
|
|
||||||
if params.method == "ctc-decoding":
|
if params.method == "ctc-decoding":
|
||||||
logging.info("Use CTC decoding")
|
logging.info("Use CTC decoding")
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
|
||||||
bpe_model.load(params.bpe_model)
|
|
||||||
max_token_id = params.num_classes - 1
|
max_token_id = params.num_classes - 1
|
||||||
|
|
||||||
|
# Load tokens.txt here
|
||||||
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
H = k2.ctc_topo(
|
H = k2.ctc_topo(
|
||||||
max_token=max_token_id,
|
max_token=max_token_id,
|
||||||
modified=False,
|
modified=False,
|
||||||
@ -324,9 +328,9 @@ def main():
|
|||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
token_ids = get_texts(best_path)
|
hyp_tokens = get_texts(best_path)
|
||||||
hyps = bpe_model.decode(token_ids)
|
for hyp in hyp_tokens:
|
||||||
hyps = [s.split() for s in hyps]
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method in [
|
elif params.method in [
|
||||||
"1best",
|
"1best",
|
||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
@ -391,16 +395,16 @@ def main():
|
|||||||
)
|
)
|
||||||
best_path = next(iter(best_path_dict.values()))
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
|
||||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
hyp_tokens = get_texts(best_path)
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
@ -24,7 +24,7 @@ Usage (for non-streaming mode):
|
|||||||
(1) ctc-decoding
|
(1) ctc-decoding
|
||||||
./conformer_ctc3/pretrained.py \
|
./conformer_ctc3/pretrained.py \
|
||||||
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
--checkpoint conformer_ctc3/exp/pretrained.pt \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method ctc-decoding \
|
--method ctc-decoding \
|
||||||
--sample-rate 16000 \
|
--sample-rate 16000 \
|
||||||
test_wavs/1089-134686-0001.wav
|
test_wavs/1089-134686-0001.wav
|
||||||
@ -67,7 +67,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from decode import get_decoding_params
|
from decode import get_decoding_params
|
||||||
@ -114,11 +113,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.
|
help="Path to the tokens.txt.",
|
||||||
Used only when method is ctc-decoding.
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -127,10 +124,9 @@ def get_parser():
|
|||||||
default="1best",
|
default="1best",
|
||||||
help="""Decoding method.
|
help="""Decoding method.
|
||||||
Possible values are:
|
Possible values are:
|
||||||
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
(0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file
|
||||||
piece model, i.e., lang_dir/bpe.model, to convert
|
to convert tokens to actual words or characters. It needs
|
||||||
word pieces to words. It needs neither a lexicon
|
neither a lexicon nor an n-gram LM.
|
||||||
nor an n-gram LM.
|
|
||||||
(1) 1best - Use the best path as decoding output. Only
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
the transformer encoder output is used for decoding.
|
the transformer encoder output is used for decoding.
|
||||||
We call it HLG decoding.
|
We call it HLG decoding.
|
||||||
@ -316,6 +312,7 @@ def main():
|
|||||||
waves = [w.to(device) for w in waves]
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
|
hyps = []
|
||||||
features = fbank(waves)
|
features = fbank(waves)
|
||||||
feature_lengths = [f.size(0) for f in features]
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
@ -348,10 +345,17 @@ def main():
|
|||||||
|
|
||||||
if params.method == "ctc-decoding":
|
if params.method == "ctc-decoding":
|
||||||
logging.info("Use CTC decoding")
|
logging.info("Use CTC decoding")
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
|
||||||
bpe_model.load(params.bpe_model)
|
|
||||||
max_token_id = params.num_classes - 1
|
max_token_id = params.num_classes - 1
|
||||||
|
|
||||||
|
# Load tokens.txt here
|
||||||
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
H = k2.ctc_topo(
|
H = k2.ctc_topo(
|
||||||
max_token=max_token_id,
|
max_token=max_token_id,
|
||||||
modified=False,
|
modified=False,
|
||||||
@ -372,9 +376,9 @@ def main():
|
|||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
token_ids = get_texts(best_path)
|
hyp_tokens = get_texts(best_path)
|
||||||
hyps = bpe_model.decode(token_ids)
|
for hyp in hyp_tokens:
|
||||||
hyps = [s.split() for s in hyps]
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method in [
|
elif params.method in [
|
||||||
"1best",
|
"1best",
|
||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
@ -439,16 +443,16 @@ def main():
|
|||||||
)
|
)
|
||||||
best_path = next(iter(best_path_dict.values()))
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
hyps = get_texts(best_path)
|
|
||||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
hyp_tokens = get_texts(best_path)
|
||||||
|
for hyp in hyp_tokens:
|
||||||
|
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user