updates for the zipformer_mmi and transducer_stateless recipes

This commit is contained in:
jinzr 2023-07-24 23:48:36 +08:00
parent f13f0b990e
commit 71351291a8
16 changed files with 181 additions and 141 deletions

View File

@ -28,7 +28,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -37,7 +37,7 @@ log "Export to torchscript model"
./zipformer_mmi/export.py \
--exp-dir $repo/exp \
--use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \
--avg 1 \
--jit 1
@ -61,7 +61,7 @@ for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescor
--method $method \
--checkpoint $repo/exp/pretrained.pt \
--lang-dir $repo/data/lang_bpe_500 \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -28,7 +28,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -28,7 +28,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -28,7 +28,7 @@ for sym in 1 2 3; do
--method greedy_search \
--max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
--method $method \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -27,7 +27,7 @@ log "Beam search decoding"
--method beam_search \
--beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

View File

@ -22,7 +22,7 @@
Usage:
./transducer/export.py \
--exp-dir ./transducer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 34 \
--avg 11
@ -46,7 +46,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from conformer import Conformer
from decoder import Decoder
@ -55,7 +55,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import AttributeDict, str2bool
from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser():
@ -90,10 +90,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -191,12 +191,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -19,7 +19,7 @@ Usage:
./transducer/pretrained.py \
--checkpoint ./transducer/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
/path/to/foo.wav \
/path/to/bar.wav \
@ -36,8 +36,8 @@ import logging
import math
from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import beam_search, greedy_search
@ -48,7 +48,7 @@ from model import Transducer
from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.utils import AttributeDict
from icefall.utils import AttributeDict, num_tokens
def get_parser():
@ -66,11 +66,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
help="Path to tokens.txt.",
)
parser.add_argument(
@ -204,12 +202,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -257,6 +257,12 @@ def main():
x=features, x_lens=feature_lengths
)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
num_waves = encoder_out.size(0)
hyps = []
for i in range(num_waves):
@ -272,12 +278,11 @@ def main():
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
hyps.append(token_ids_to_words(hyp))
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage:
./transducer_stateless/export.py \
--exp-dir ./transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -46,7 +46,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from conformer import Conformer
@ -56,7 +56,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import AttributeDict, str2bool
from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser():
@ -91,10 +91,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -191,12 +191,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
@ -29,7 +29,7 @@ Usage:
(2) beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -38,7 +38,7 @@ Usage:
(3) modified beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -47,7 +47,7 @@ Usage:
(4) fast beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -67,7 +67,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -80,6 +79,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -96,9 +97,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -213,12 +214,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -273,6 +276,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
@ -318,12 +327,11 @@ def main():
raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
hyps = [token_ids_to_words(hyp) for hyp in hyp_list]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage:
./transducer_stateless2/export.py \
--exp-dir ./transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -46,12 +46,12 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -86,10 +86,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
@ -123,12 +123,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
@ -29,7 +29,7 @@ Usage:
(2) beam search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -38,7 +38,7 @@ Usage:
(3) modified beam search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -47,7 +47,7 @@ Usage:
(4) fast beam search
./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -67,7 +67,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -80,6 +79,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -96,9 +97,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -213,12 +214,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -273,6 +276,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
@ -318,12 +327,11 @@ def main():
raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
hyps = [token_ids_to_words(hyp) for hyp in hyp_list]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage:
./transducer_stateless_multi_datasets/export.py \
--exp-dir ./transducer_stateless_multi_datasets/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -47,7 +47,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
import torch.nn as nn
from conformer import Conformer
@ -57,7 +57,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import AttributeDict, str2bool
from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser():
@ -92,10 +92,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -192,12 +192,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search
./transducer_stateless_multi_datasets/pretrained.py \
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
@ -29,7 +29,7 @@ Usage:
(2) beam search
./transducer_stateless_multi_datasets/pretrained.py \
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -38,7 +38,7 @@ Usage:
(3) modified beam search
./transducer_stateless_multi_datasets/pretrained.py \
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -47,7 +47,7 @@ Usage:
(4) fast beam search
./transducer_stateless_multi_datasets/pretrained.py \
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \
--beam-size 4 \
/path/to/foo.wav \
@ -67,7 +67,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from beam_search import (
@ -80,6 +79,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser():
parser = argparse.ArgumentParser(
@ -96,9 +97,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -213,12 +214,14 @@ def main():
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -273,6 +276,12 @@ def main():
msg += f" with beam size {params.beam_size}"
logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best(
@ -318,12 +327,11 @@ def main():
raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
hyps = [token_ids_to_words(hyp) for hyp in hyp_list]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")

View File

@ -26,7 +26,7 @@ Usage:
./zipformer_mmi/export.py \
--exp-dir ./zipformer_mmi/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./zipformer_mmi/export.py \
--exp-dir ./zipformer_mmi/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -86,7 +86,7 @@ import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import k2
import torch
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_ctc_model, get_params
@ -97,7 +97,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
from icefall.utils import num_tokens, str2bool
def get_parser():
@ -154,10 +154,10 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt.",
)
parser.add_argument(
@ -190,12 +190,14 @@ def main():
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params)

View File

@ -21,7 +21,7 @@ You can generate the checkpoint with the following command:
./zipformer_mmi/export.py \
--exp-dir ./zipformer_mmi/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \
--avg 10
@ -30,14 +30,14 @@ Usage of this script:
(1) 1best
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--method 1best \
/path/to/foo.wav \
/path/to/bar.wav
(2) nbest
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--nbest-scale 1.2 \
--method nbest \
/path/to/foo.wav \
@ -45,7 +45,7 @@ Usage of this script:
(3) nbest-rescoring-LG
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--nbest-scale 1.2 \
--method nbest-rescoring-LG \
/path/to/foo.wav \
@ -53,7 +53,7 @@ Usage of this script:
(4) nbest-rescoring-3-gram
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--nbest-scale 1.2 \
--method nbest-rescoring-3-gram \
/path/to/foo.wav \
@ -61,7 +61,7 @@ Usage of this script:
(5) nbest-rescoring-4-gram
./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--tokens data/lang_bpe_500/tokens.txt \
--nbest-scale 1.2 \
--method nbest-rescoring-4-gram \
/path/to/foo.wav \
@ -83,7 +83,6 @@ from typing import List
import k2
import kaldifeat
import sentencepiece as spm
import torch
import torchaudio
from decode import get_decoding_params
@ -97,7 +96,7 @@ from icefall.decode import (
one_best_decoding,
)
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import get_texts
from icefall.utils import get_texts, num_tokens
def get_parser():
@ -115,9 +114,9 @@ def get_parser():
)
parser.add_argument(
"--bpe-model",
"--tokens",
type=str,
help="""Path to bpe.model.""",
help="""Path to tokens.txt.""",
)
parser.add_argument(
@ -247,13 +246,14 @@ def main():
params.update(get_decoding_params())
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# Load tokens.txt here
token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
params.blank_id = token_table["<blk>"]
params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}")
@ -298,8 +298,6 @@ def main():
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
mmi_graph_compiler = MmiTrainingGraphCompiler(
params.lang_dir,
uniq_filename="lexicon.txt",
@ -313,6 +311,12 @@ def main():
if not hasattr(HP, "lm_scores"):
HP.lm_scores = HP.scores.clone()
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
method = params.method
assert method in (
"1best",
@ -390,14 +394,11 @@ def main():
#
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
hyps = [token_ids_to_words(ids) for ids in token_ids]
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
s += f"{filename}:\n{hyp}\n\n"
logging.info(s)
logging.info("Decoding Done")