updates for the zipformer_mmi and transducer_stateless recipes

This commit is contained in:
jinzr 2023-07-24 23:48:36 +08:00
parent e0e8db3c91
commit 0816be86ae
16 changed files with 181 additions and 141 deletions

View File

@ -28,7 +28,7 @@ for sym in 1 2 3; do
--method greedy_search \ --method greedy_search \
--max-sym-per-frame $sym \ --max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -37,7 +37,7 @@ log "Export to torchscript model"
./zipformer_mmi/export.py \ ./zipformer_mmi/export.py \
--exp-dir $repo/exp \ --exp-dir $repo/exp \
--use-averaged-model false \ --use-averaged-model false \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
--epoch 99 \ --epoch 99 \
--avg 1 \ --avg 1 \
--jit 1 --jit 1
@ -61,7 +61,7 @@ for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescor
--method $method \ --method $method \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--lang-dir $repo/data/lang_bpe_500 \ --lang-dir $repo/data/lang_bpe_500 \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -28,7 +28,7 @@ for sym in 1 2 3; do
--method greedy_search \ --method greedy_search \
--max-sym-per-frame $sym \ --max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -28,7 +28,7 @@ for sym in 1 2 3; do
--method greedy_search \ --method greedy_search \
--max-sym-per-frame $sym \ --max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -28,7 +28,7 @@ for sym in 1 2 3; do
--method greedy_search \ --method greedy_search \
--max-sym-per-frame $sym \ --max-sym-per-frame $sym \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
--method $method \ --method $method \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -27,7 +27,7 @@ log "Beam search decoding"
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
--checkpoint $repo/exp/pretrained.pt \ --checkpoint $repo/exp/pretrained.pt \
--bpe-model $repo/data/lang_bpe_500/bpe.model \ --tokens $repo/data/lang_bpe_500/tokens.txt \
$repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1089-134686-0001.wav \
$repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav $repo/test_wavs/1221-135766-0002.wav

View File

@ -22,7 +22,7 @@
Usage: Usage:
./transducer/export.py \ ./transducer/export.py \
--exp-dir ./transducer/exp \ --exp-dir ./transducer/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 34 \ --epoch 34 \
--avg 11 --avg 11
@ -46,7 +46,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
@ -55,7 +55,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, str2bool from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser(): def get_parser():
@ -90,10 +90,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -191,12 +191,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -19,7 +19,7 @@ Usage:
./transducer/pretrained.py \ ./transducer/pretrained.py \
--checkpoint ./transducer/exp/pretrained.pt \ --checkpoint ./transducer/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav \ /path/to/bar.wav \
@ -36,8 +36,8 @@ import logging
import math import math
from typing import List from typing import List
import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search
@ -48,7 +48,7 @@ from model import Transducer
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict from icefall.utils import AttributeDict, num_tokens
def get_parser(): def get_parser():
@ -66,11 +66,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model. help="Path to tokens.txt.",
Used only when method is ctc-decoding.
""",
) )
parser.add_argument( parser.add_argument(
@ -204,12 +202,14 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}") logging.info(f"{params}")
@ -257,6 +257,12 @@ def main():
x=features, x_lens=feature_lengths x=features, x_lens=feature_lengths
) )
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
num_waves = encoder_out.size(0) num_waves = encoder_out.size(0)
hyps = [] hyps = []
for i in range(num_waves): for i in range(num_waves):
@ -272,12 +278,11 @@ def main():
else: else:
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split()) hyps.append(token_ids_to_words(hyp))
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage: Usage:
./transducer_stateless/export.py \ ./transducer_stateless/export.py \
--exp-dir ./transducer_stateless/exp \ --exp-dir ./transducer_stateless/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -46,7 +46,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from conformer import Conformer from conformer import Conformer
@ -56,7 +56,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, str2bool from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser(): def get_parser():
@ -91,10 +91,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -191,12 +191,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search (1) greedy search
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \ --checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
--max-sym-per-frame 1 \ --max-sym-per-frame 1 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -29,7 +29,7 @@ Usage:
(2) beam search (2) beam search
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \ --checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -38,7 +38,7 @@ Usage:
(3) modified beam search (3) modified beam search
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \ --checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -47,7 +47,7 @@ Usage:
(4) fast beam search (4) fast beam search
./transducer_stateless/pretrained.py \ ./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \ --checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -67,7 +67,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -80,6 +79,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -96,9 +97,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to tokens.txt.""",
) )
parser.add_argument( parser.add_argument(
@ -213,12 +214,14 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}") logging.info(f"{params}")
@ -273,6 +276,12 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best( hyp_list = fast_beam_search_one_best(
@ -318,12 +327,11 @@ def main():
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp) hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list] hyps = [token_ids_to_words(hyp) for hyp in hyp_list]
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage: Usage:
./transducer_stateless2/export.py \ ./transducer_stateless2/export.py \
--exp-dir ./transducer_stateless2/exp \ --exp-dir ./transducer_stateless2/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -46,12 +46,12 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -86,10 +86,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt",
) )
parser.add_argument( parser.add_argument(
@ -123,12 +123,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search (1) greedy search
./transducer_stateless2/pretrained.py \ ./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
--max-sym-per-frame 1 \ --max-sym-per-frame 1 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -29,7 +29,7 @@ Usage:
(2) beam search (2) beam search
./transducer_stateless2/pretrained.py \ ./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -38,7 +38,7 @@ Usage:
(3) modified beam search (3) modified beam search
./transducer_stateless2/pretrained.py \ ./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -47,7 +47,7 @@ Usage:
(4) fast beam search (4) fast beam search
./transducer_stateless2/pretrained.py \ ./transducer_stateless2/pretrained.py \
--checkpoint ./transducer_stateless2/exp/pretrained.pt \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -67,7 +67,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -80,6 +79,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -96,9 +97,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to tokens.txt.""",
) )
parser.add_argument( parser.add_argument(
@ -213,12 +214,14 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}") logging.info(f"{params}")
@ -273,6 +276,12 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best( hyp_list = fast_beam_search_one_best(
@ -318,12 +327,11 @@ def main():
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp) hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list] hyps = [token_ids_to_words(hyp) for hyp in hyp_list]
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -22,7 +22,7 @@
Usage: Usage:
./transducer_stateless_multi_datasets/export.py \ ./transducer_stateless_multi_datasets/export.py \
--exp-dir ./transducer_stateless_multi_datasets/exp \ --exp-dir ./transducer_stateless_multi_datasets/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -47,7 +47,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from conformer import Conformer from conformer import Conformer
@ -57,7 +57,7 @@ from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, str2bool from icefall.utils import AttributeDict, num_tokens, str2bool
def get_parser(): def get_parser():
@ -92,10 +92,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -192,12 +192,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -20,7 +20,7 @@ Usage:
(1) greedy search (1) greedy search
./transducer_stateless_multi_datasets/pretrained.py \ ./transducer_stateless_multi_datasets/pretrained.py \
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method greedy_search \ --method greedy_search \
--max-sym-per-frame 1 \ --max-sym-per-frame 1 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -29,7 +29,7 @@ Usage:
(2) beam search (2) beam search
./transducer_stateless_multi_datasets/pretrained.py \ ./transducer_stateless_multi_datasets/pretrained.py \
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method beam_search \ --method beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -38,7 +38,7 @@ Usage:
(3) modified beam search (3) modified beam search
./transducer_stateless_multi_datasets/pretrained.py \ ./transducer_stateless_multi_datasets/pretrained.py \
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method modified_beam_search \ --method modified_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -47,7 +47,7 @@ Usage:
(4) fast beam search (4) fast beam search
./transducer_stateless_multi_datasets/pretrained.py \ ./transducer_stateless_multi_datasets/pretrained.py \
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method fast_beam_search \ --method fast_beam_search \
--beam-size 4 \ --beam-size 4 \
/path/to/foo.wav \ /path/to/foo.wav \
@ -67,7 +67,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from beam_search import ( from beam_search import (
@ -80,6 +79,8 @@ from beam_search import (
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.utils import num_tokens
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -96,9 +97,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to tokens.txt.""",
) )
parser.add_argument( parser.add_argument(
@ -213,12 +214,14 @@ def main():
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}") logging.info(f"{params}")
@ -273,6 +276,12 @@ def main():
msg += f" with beam size {params.beam_size}" msg += f" with beam size {params.beam_size}"
logging.info(msg) logging.info(msg)
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
if params.method == "fast_beam_search": if params.method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
hyp_list = fast_beam_search_one_best( hyp_list = fast_beam_search_one_best(
@ -318,12 +327,11 @@ def main():
raise ValueError(f"Unsupported method: {params.method}") raise ValueError(f"Unsupported method: {params.method}")
hyp_list.append(hyp) hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list] hyps = [token_ids_to_words(hyp) for hyp in hyp_list]
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")

View File

@ -26,7 +26,7 @@ Usage:
./zipformer_mmi/export.py \ ./zipformer_mmi/export.py \
--exp-dir ./zipformer_mmi/exp \ --exp-dir ./zipformer_mmi/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \ --epoch 30 \
--avg 9 \ --avg 9 \
--jit 1 --jit 1
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
./zipformer_mmi/export.py \ ./zipformer_mmi/export.py \
--exp-dir ./zipformer_mmi/exp \ --exp-dir ./zipformer_mmi/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -86,7 +86,7 @@ import argparse
import logging import logging
from pathlib import Path from pathlib import Path
import sentencepiece as spm import k2
import torch import torch
from scaling_converter import convert_scaled_to_non_scaled from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_ctc_model, get_params from train import add_model_arguments, get_ctc_model, get_params
@ -97,7 +97,7 @@ from icefall.checkpoint import (
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
from icefall.utils import str2bool from icefall.utils import num_tokens, str2bool
def get_parser(): def get_parser():
@ -154,10 +154,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
default="data/lang_bpe_500/bpe.model", default="data/lang_bpe_500/tokens.txt",
help="Path to the BPE model", help="Path to the tokens.txt.",
) )
parser.add_argument( parser.add_argument(
@ -190,12 +190,14 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.vocab_size = sp.get_piece_size() params.unk_id = token_table["<unk>"]
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(params) logging.info(params)

View File

@ -21,7 +21,7 @@ You can generate the checkpoint with the following command:
./zipformer_mmi/export.py \ ./zipformer_mmi/export.py \
--exp-dir ./zipformer_mmi/exp \ --exp-dir ./zipformer_mmi/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
@ -30,14 +30,14 @@ Usage of this script:
(1) 1best (1) 1best
./zipformer_mmi/pretrained.py \ ./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--method 1best \ --method 1best \
/path/to/foo.wav \ /path/to/foo.wav \
/path/to/bar.wav /path/to/bar.wav
(2) nbest (2) nbest
./zipformer_mmi/pretrained.py \ ./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--nbest-scale 1.2 \ --nbest-scale 1.2 \
--method nbest \ --method nbest \
/path/to/foo.wav \ /path/to/foo.wav \
@ -45,7 +45,7 @@ Usage of this script:
(3) nbest-rescoring-LG (3) nbest-rescoring-LG
./zipformer_mmi/pretrained.py \ ./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--nbest-scale 1.2 \ --nbest-scale 1.2 \
--method nbest-rescoring-LG \ --method nbest-rescoring-LG \
/path/to/foo.wav \ /path/to/foo.wav \
@ -53,7 +53,7 @@ Usage of this script:
(4) nbest-rescoring-3-gram (4) nbest-rescoring-3-gram
./zipformer_mmi/pretrained.py \ ./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--nbest-scale 1.2 \ --nbest-scale 1.2 \
--method nbest-rescoring-3-gram \ --method nbest-rescoring-3-gram \
/path/to/foo.wav \ /path/to/foo.wav \
@ -61,7 +61,7 @@ Usage of this script:
(5) nbest-rescoring-4-gram (5) nbest-rescoring-4-gram
./zipformer_mmi/pretrained.py \ ./zipformer_mmi/pretrained.py \
--checkpoint ./zipformer_mmi/exp/pretrained.pt \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \ --tokens data/lang_bpe_500/tokens.txt \
--nbest-scale 1.2 \ --nbest-scale 1.2 \
--method nbest-rescoring-4-gram \ --method nbest-rescoring-4-gram \
/path/to/foo.wav \ /path/to/foo.wav \
@ -83,7 +83,6 @@ from typing import List
import k2 import k2
import kaldifeat import kaldifeat
import sentencepiece as spm
import torch import torch
import torchaudio import torchaudio
from decode import get_decoding_params from decode import get_decoding_params
@ -97,7 +96,7 @@ from icefall.decode import (
one_best_decoding, one_best_decoding,
) )
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
from icefall.utils import get_texts from icefall.utils import get_texts, num_tokens
def get_parser(): def get_parser():
@ -115,9 +114,9 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--bpe-model", "--tokens",
type=str, type=str,
help="""Path to bpe.model.""", help="""Path to tokens.txt.""",
) )
parser.add_argument( parser.add_argument(
@ -247,13 +246,14 @@ def main():
params.update(get_decoding_params()) params.update(get_decoding_params())
params.update(vars(args)) params.update(vars(args))
sp = spm.SentencePieceProcessor() # Load tokens.txt here
sp.load(params.bpe_model) token_table = k2.SymbolTable.from_file(params.tokens)
# Load id of the <blk> token and the vocab size
# <blk> is defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = token_table["<blk>"]
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = token_table["<unk>"]
params.vocab_size = sp.get_piece_size() params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
logging.info(f"{params}") logging.info(f"{params}")
@ -298,8 +298,6 @@ def main():
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
feature_lengths = torch.tensor(feature_lengths, device=device) feature_lengths = torch.tensor(feature_lengths, device=device)
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
mmi_graph_compiler = MmiTrainingGraphCompiler( mmi_graph_compiler = MmiTrainingGraphCompiler(
params.lang_dir, params.lang_dir,
uniq_filename="lexicon.txt", uniq_filename="lexicon.txt",
@ -313,6 +311,12 @@ def main():
if not hasattr(HP, "lm_scores"): if not hasattr(HP, "lm_scores"):
HP.lm_scores = HP.scores.clone() HP.lm_scores = HP.scores.clone()
def token_ids_to_words(token_ids: List[int]) -> str:
text = ""
for i in token_ids:
text += token_table[i]
return text.replace("", " ").strip()
method = params.method method = params.method
assert method in ( assert method in (
"1best", "1best",
@ -390,14 +394,11 @@ def main():
# #
# token_ids is a lit-of-list of IDs # token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path) token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...] hyps = [token_ids_to_words(ids) for ids in token_ids]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
s = "\n" s = "\n"
for filename, hyp in zip(params.sound_files, hyps): for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp) s += f"{filename}:\n{hyp}\n\n"
s += f"{filename}:\n{words}\n\n"
logging.info(s) logging.info(s)
logging.info("Decoding Done") logging.info("Decoding Done")