mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
updated the lstm_transducer_stateless
recipes
also revoked previous changes in conformer_ctc3/jit_pretrained.py
This commit is contained in:
parent
96f8904ce7
commit
696024abab
@ -24,7 +24,7 @@ Usage (for non-streaming mode):
|
|||||||
(1) ctc-decoding
|
(1) ctc-decoding
|
||||||
./conformer_ctc3/pretrained.py \
|
./conformer_ctc3/pretrained.py \
|
||||||
--nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
|
--nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \
|
||||||
--tokens data/lang_bpe_500/tokens.txt \
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
--method ctc-decoding \
|
--method ctc-decoding \
|
||||||
--sample-rate 16000 \
|
--sample-rate 16000 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -71,6 +71,7 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from decode import get_decoding_params
|
from decode import get_decoding_params
|
||||||
@ -115,9 +116,11 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokens",
|
"--bpe-model",
|
||||||
type=str,
|
type=str,
|
||||||
help="Path to the tokens.txt.",
|
help="""Path to bpe.model.
|
||||||
|
Used only when method is ctc-decoding.
|
||||||
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -126,9 +129,10 @@ def get_parser():
|
|||||||
default="1best",
|
default="1best",
|
||||||
help="""Decoding method.
|
help="""Decoding method.
|
||||||
Possible values are:
|
Possible values are:
|
||||||
(0) ctc-decoding - Use CTC decoding. It uses a tokens.txt file
|
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
||||||
to convert tokens to actual words or characters. It needs
|
piece model, i.e., lang_dir/bpe.model, to convert
|
||||||
neither a lexicon nor an n-gram LM.
|
word pieces to words. It needs neither a lexicon
|
||||||
|
nor an n-gram LM.
|
||||||
(1) 1best - Use the best path as decoding output. Only
|
(1) 1best - Use the best path as decoding output. Only
|
||||||
the transformer encoder output is used for decoding.
|
the transformer encoder output is used for decoding.
|
||||||
We call it HLG decoding.
|
We call it HLG decoding.
|
||||||
@ -277,7 +281,6 @@ def main():
|
|||||||
waves = [w.to(device) for w in waves]
|
waves = [w.to(device) for w in waves]
|
||||||
|
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
hyps = []
|
|
||||||
features = fbank(waves)
|
features = fbank(waves)
|
||||||
feature_lengths = [f.size(0) for f in features]
|
feature_lengths = [f.size(0) for f in features]
|
||||||
|
|
||||||
@ -297,17 +300,10 @@ def main():
|
|||||||
|
|
||||||
if params.method == "ctc-decoding":
|
if params.method == "ctc-decoding":
|
||||||
logging.info("Use CTC decoding")
|
logging.info("Use CTC decoding")
|
||||||
|
bpe_model = spm.SentencePieceProcessor()
|
||||||
|
bpe_model.load(params.bpe_model)
|
||||||
max_token_id = params.num_classes - 1
|
max_token_id = params.num_classes - 1
|
||||||
|
|
||||||
# Load tokens.txt here
|
|
||||||
token_table = k2.SymbolTable.from_file(params.tokens)
|
|
||||||
|
|
||||||
def token_ids_to_words(token_ids: List[int]) -> str:
|
|
||||||
text = ""
|
|
||||||
for i in token_ids:
|
|
||||||
text += token_table[i]
|
|
||||||
return text.replace("▁", " ").strip()
|
|
||||||
|
|
||||||
H = k2.ctc_topo(
|
H = k2.ctc_topo(
|
||||||
max_token=max_token_id,
|
max_token=max_token_id,
|
||||||
modified=False,
|
modified=False,
|
||||||
@ -328,9 +324,9 @@ def main():
|
|||||||
best_path = one_best_decoding(
|
best_path = one_best_decoding(
|
||||||
lattice=lattice, use_double_scores=params.use_double_scores
|
lattice=lattice, use_double_scores=params.use_double_scores
|
||||||
)
|
)
|
||||||
hyp_tokens = get_texts(best_path)
|
token_ids = get_texts(best_path)
|
||||||
for hyp in hyp_tokens:
|
hyps = bpe_model.decode(token_ids)
|
||||||
hyps.append(token_ids_to_words(hyp))
|
hyps = [s.split() for s in hyps]
|
||||||
elif params.method in [
|
elif params.method in [
|
||||||
"1best",
|
"1best",
|
||||||
"nbest-rescoring",
|
"nbest-rescoring",
|
||||||
@ -395,16 +391,16 @@ def main():
|
|||||||
)
|
)
|
||||||
best_path = next(iter(best_path_dict.values()))
|
best_path = next(iter(best_path_dict.values()))
|
||||||
|
|
||||||
|
hyps = get_texts(best_path)
|
||||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
hyp_tokens = get_texts(best_path)
|
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||||
for hyp in hyp_tokens:
|
|
||||||
hyps.append(" ".join([word_sym_table[i] for i in hyp]))
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
s += f"{filename}:\n{hyp}\n\n"
|
words = " ".join(hyp)
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
@ -26,7 +26,7 @@ Usage:
|
|||||||
|
|
||||||
./lstm_transducer_stateless/export.py \
|
./lstm_transducer_stateless/export.py \
|
||||||
--exp-dir ./lstm_transducer_stateless/exp \
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 35 \
|
--epoch 35 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--jit-trace 1
|
--jit-trace 1
|
||||||
@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`,
|
|||||||
|
|
||||||
./lstm_transducer_stateless/export.py \
|
./lstm_transducer_stateless/export.py \
|
||||||
--exp-dir ./lstm_transducer_stateless/exp \
|
--exp-dir ./lstm_transducer_stateless/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 35 \
|
--epoch 35 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -79,7 +79,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
@ -91,7 +91,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -148,10 +148,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -266,12 +266,13 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size, <blk> is
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
# defined in local/train_bpe_model.py
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./lstm_transducer_stateless/pretrained.py \
|
./lstm_transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -28,7 +28,7 @@ Usage:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./lstm_transducer_stateless/pretrained.py \
|
./lstm_transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -37,7 +37,7 @@ Usage:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./lstm_transducer_stateless/pretrained.py \
|
./lstm_transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -46,7 +46,7 @@ Usage:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./lstm_transducer_stateless/pretrained.py \
|
./lstm_transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -66,7 +66,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -79,6 +78,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -95,9 +96,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -214,13 +215,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -275,6 +277,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -286,8 +294,8 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -296,16 +304,16 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -326,12 +334,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
@ -29,7 +29,7 @@ popd
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export-for-ncnn.py \
|
./lstm_transducer_stateless2/export-for-ncnn.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
@ -49,7 +49,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
@ -60,7 +60,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import setup_logger, str2bool
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -106,10 +106,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -221,12 +221,13 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size, <blk> is
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
# defined in local/train_bpe_model.py
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ popd
|
|||||||
2. Export the model to ONNX
|
2. Export the model to ONNX
|
||||||
|
|
||||||
./lstm_transducer_stateless2/export-onnx.py \
|
./lstm_transducer_stateless2/export-onnx.py \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--use-averaged-model 0 \
|
--use-averaged-model 0 \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
@ -52,8 +52,8 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
import onnx
|
import onnx
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
@ -68,7 +68,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import setup_logger, str2bool
|
from icefall.utils import num_tokens, setup_logger, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -125,10 +125,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -437,12 +437,13 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size, <blk> is
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
# defined in local/train_bpe_model.py
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ Usage:
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
./lstm_transducer_stateless2/export.py \
|
||||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 35 \
|
--epoch 35 \
|
||||||
--avg 10 \
|
--avg 10 \
|
||||||
--jit-trace 1
|
--jit-trace 1
|
||||||
@ -39,7 +39,7 @@ It will generate 3 files: `encoder_jit_trace.pt`,
|
|||||||
|
|
||||||
./lstm_transducer_stateless2/export.py \
|
./lstm_transducer_stateless2/export.py \
|
||||||
--exp-dir ./lstm_transducer_stateless2/exp \
|
--exp-dir ./lstm_transducer_stateless2/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 35 \
|
--epoch 35 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -80,7 +80,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
@ -92,7 +92,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -149,10 +149,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -267,12 +267,13 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size, <blk> is
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
# defined in local/train_bpe_model.py
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./lstm_transducer_stateless2/pretrained.py \
|
./lstm_transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -28,7 +28,7 @@ Usage:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./lstm_transducer_stateless2/pretrained.py \
|
./lstm_transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -37,7 +37,7 @@ Usage:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./lstm_transducer_stateless2/pretrained.py \
|
./lstm_transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -46,7 +46,7 @@ Usage:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./lstm_transducer_stateless2/pretrained.py \
|
./lstm_transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -69,7 +69,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -82,6 +81,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -98,9 +99,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -217,13 +218,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -278,6 +280,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -289,8 +297,8 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -299,16 +307,16 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -329,12 +337,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
@ -26,7 +26,7 @@ Usage:
|
|||||||
|
|
||||||
./lstm_transducer_stateless3/export.py \
|
./lstm_transducer_stateless3/export.py \
|
||||||
--exp-dir ./lstm_transducer_stateless3/exp \
|
--exp-dir ./lstm_transducer_stateless3/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 40 \
|
--epoch 40 \
|
||||||
--avg 20 \
|
--avg 20 \
|
||||||
--jit-trace 1
|
--jit-trace 1
|
||||||
@ -38,7 +38,7 @@ It will generate 3 files: `encoder_jit_trace.pt`,
|
|||||||
|
|
||||||
./lstm_transducer_stateless3/export.py \
|
./lstm_transducer_stateless3/export.py \
|
||||||
--exp-dir ./lstm_transducer_stateless3/exp \
|
--exp-dir ./lstm_transducer_stateless3/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 40 \
|
--epoch 40 \
|
||||||
--avg 20
|
--avg 20
|
||||||
|
|
||||||
@ -79,7 +79,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
@ -91,7 +91,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -148,10 +148,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -266,12 +266,13 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# Load id of the <blk> token and the vocab size, <blk> is
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
# defined in local/train_bpe_model.py
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.blank_id = token_table["<blk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./lstm_transducer_stateless3/pretrained.py \
|
./lstm_transducer_stateless3/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
@ -28,7 +28,7 @@ Usage:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./lstm_transducer_stateless3/pretrained.py \
|
./lstm_transducer_stateless3/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -37,7 +37,7 @@ Usage:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./lstm_transducer_stateless3/pretrained.py \
|
./lstm_transducer_stateless3/pretrained.py \
|
||||||
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
|
--checkpoint ./lstm_transducer_stateless3/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -79,6 +79,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import add_model_arguments, get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -95,9 +97,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -214,13 +216,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -275,6 +278,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_tokens = fast_beam_search_one_best(
|
hyp_tokens = fast_beam_search_one_best(
|
||||||
@ -286,8 +295,8 @@ def main():
|
|||||||
max_contexts=params.max_contexts,
|
max_contexts=params.max_contexts,
|
||||||
max_states=params.max_states,
|
max_states=params.max_states,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "modified_beam_search":
|
elif params.method == "modified_beam_search":
|
||||||
hyp_tokens = modified_beam_search(
|
hyp_tokens = modified_beam_search(
|
||||||
model=model,
|
model=model,
|
||||||
@ -296,16 +305,16 @@ def main():
|
|||||||
beam=params.beam_size,
|
beam=params.beam_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in hyp_tokens:
|
||||||
hyps.append(hyp.split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
else:
|
else:
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -326,12 +335,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user