mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
updates for the zipformer_mmi
and transducer_stateless
recipes
This commit is contained in:
parent
e0e8db3c91
commit
0816be86ae
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -37,7 +37,7 @@ log "Export to torchscript model"
|
|||||||
./zipformer_mmi/export.py \
|
./zipformer_mmi/export.py \
|
||||||
--exp-dir $repo/exp \
|
--exp-dir $repo/exp \
|
||||||
--use-averaged-model false \
|
--use-averaged-model false \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 99 \
|
--epoch 99 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -61,7 +61,7 @@ for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescor
|
|||||||
--method $method \
|
--method $method \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--lang-dir $repo/data/lang_bpe_500 \
|
--lang-dir $repo/data/lang_bpe_500 \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -28,7 +28,7 @@ for sym in 1 2 3; do
|
|||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame $sym \
|
--max-sym-per-frame $sym \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do
|
|||||||
--method $method \
|
--method $method \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -27,7 +27,7 @@ log "Beam search decoding"
|
|||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
--checkpoint $repo/exp/pretrained.pt \
|
--checkpoint $repo/exp/pretrained.pt \
|
||||||
--bpe-model $repo/data/lang_bpe_500/bpe.model \
|
--tokens $repo/data/lang_bpe_500/tokens.txt \
|
||||||
$repo/test_wavs/1089-134686-0001.wav \
|
$repo/test_wavs/1089-134686-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0001.wav \
|
$repo/test_wavs/1221-135766-0001.wav \
|
||||||
$repo/test_wavs/1221-135766-0002.wav
|
$repo/test_wavs/1221-135766-0002.wav
|
||||||
|
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./transducer/export.py \
|
./transducer/export.py \
|
||||||
--exp-dir ./transducer/exp \
|
--exp-dir ./transducer/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 34 \
|
--epoch 34 \
|
||||||
--avg 11
|
--avg 11
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
@ -55,7 +55,7 @@ from model import Transducer
|
|||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, str2bool
|
from icefall.utils import AttributeDict, num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -90,10 +90,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -191,12 +191,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -19,7 +19,7 @@ Usage:
|
|||||||
|
|
||||||
./transducer/pretrained.py \
|
./transducer/pretrained.py \
|
||||||
--checkpoint ./transducer/exp/pretrained.pt \
|
--checkpoint ./transducer/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav \
|
/path/to/bar.wav \
|
||||||
@ -36,8 +36,8 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import beam_search, greedy_search
|
from beam_search import beam_search, greedy_search
|
||||||
@ -48,7 +48,7 @@ from model import Transducer
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict
|
from icefall.utils import AttributeDict, num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -66,11 +66,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.
|
help="Path to tokens.txt.",
|
||||||
Used only when method is ctc-decoding.
|
|
||||||
""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -204,12 +202,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -257,6 +257,12 @@ def main():
|
|||||||
x=features, x_lens=feature_lengths
|
x=features, x_lens=feature_lengths
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
num_waves = encoder_out.size(0)
|
num_waves = encoder_out.size(0)
|
||||||
hyps = []
|
hyps = []
|
||||||
for i in range(num_waves):
|
for i in range(num_waves):
|
||||||
@ -272,12 +278,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
|
|
||||||
hyps.append(sp.decode(hyp).split())
|
hyps.append(token_ids_to_words(hyp))
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./transducer_stateless/export.py \
|
./transducer_stateless/export.py \
|
||||||
--exp-dir ./transducer_stateless/exp \
|
--exp-dir ./transducer_stateless/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
@ -56,7 +56,7 @@ from model import Transducer
|
|||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, str2bool
|
from icefall.utils import AttributeDict, num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -91,10 +91,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -191,12 +191,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame 1 \
|
--max-sym-per-frame 1 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -29,7 +29,7 @@ Usage:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -38,7 +38,7 @@ Usage:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -47,7 +47,7 @@ Usage:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./transducer_stateless/pretrained.py \
|
./transducer_stateless/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -67,7 +67,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -80,6 +79,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -96,9 +97,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -213,12 +214,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -273,6 +276,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_list = fast_beam_search_one_best(
|
hyp_list = fast_beam_search_one_best(
|
||||||
@ -318,12 +327,11 @@ def main():
|
|||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
hyp_list.append(hyp)
|
hyp_list.append(hyp)
|
||||||
|
|
||||||
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
|
hyps = [token_ids_to_words(hyp) for hyp in hyp_list]
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./transducer_stateless2/export.py \
|
./transducer_stateless2/export.py \
|
||||||
--exp-dir ./transducer_stateless2/exp \
|
--exp-dir ./transducer_stateless2/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -46,12 +46,12 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -86,10 +86,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -123,12 +123,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./transducer_stateless2/pretrained.py \
|
./transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame 1 \
|
--max-sym-per-frame 1 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -29,7 +29,7 @@ Usage:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./transducer_stateless2/pretrained.py \
|
./transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -38,7 +38,7 @@ Usage:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./transducer_stateless2/pretrained.py \
|
./transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -47,7 +47,7 @@ Usage:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./transducer_stateless2/pretrained.py \
|
./transducer_stateless2/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless2/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -67,7 +67,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -80,6 +79,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -96,9 +97,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -213,12 +214,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -273,6 +276,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_list = fast_beam_search_one_best(
|
hyp_list = fast_beam_search_one_best(
|
||||||
@ -318,12 +327,11 @@ def main():
|
|||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
hyp_list.append(hyp)
|
hyp_list.append(hyp)
|
||||||
|
|
||||||
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
|
hyps = [token_ids_to_words(hyp) for hyp in hyp_list]
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
@ -22,7 +22,7 @@
|
|||||||
Usage:
|
Usage:
|
||||||
./transducer_stateless_multi_datasets/export.py \
|
./transducer_stateless_multi_datasets/export.py \
|
||||||
--exp-dir ./transducer_stateless_multi_datasets/exp \
|
--exp-dir ./transducer_stateless_multi_datasets/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
@ -57,7 +57,7 @@ from model import Transducer
|
|||||||
|
|
||||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.utils import AttributeDict, str2bool
|
from icefall.utils import AttributeDict, num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -92,10 +92,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -192,12 +192,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ Usage:
|
|||||||
(1) greedy search
|
(1) greedy search
|
||||||
./transducer_stateless_multi_datasets/pretrained.py \
|
./transducer_stateless_multi_datasets/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method greedy_search \
|
--method greedy_search \
|
||||||
--max-sym-per-frame 1 \
|
--max-sym-per-frame 1 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -29,7 +29,7 @@ Usage:
|
|||||||
(2) beam search
|
(2) beam search
|
||||||
./transducer_stateless_multi_datasets/pretrained.py \
|
./transducer_stateless_multi_datasets/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method beam_search \
|
--method beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -38,7 +38,7 @@ Usage:
|
|||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./transducer_stateless_multi_datasets/pretrained.py \
|
./transducer_stateless_multi_datasets/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method modified_beam_search \
|
--method modified_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -47,7 +47,7 @@ Usage:
|
|||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./transducer_stateless_multi_datasets/pretrained.py \
|
./transducer_stateless_multi_datasets/pretrained.py \
|
||||||
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
--checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method fast_beam_search \
|
--method fast_beam_search \
|
||||||
--beam-size 4 \
|
--beam-size 4 \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -67,7 +67,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
@ -80,6 +79,8 @@ from beam_search import (
|
|||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
from icefall.utils import num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -96,9 +97,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -213,12 +214,14 @@ def main():
|
|||||||
|
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -273,6 +276,12 @@ def main():
|
|||||||
msg += f" with beam size {params.beam_size}"
|
msg += f" with beam size {params.beam_size}"
|
||||||
logging.info(msg)
|
logging.info(msg)
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
if params.method == "fast_beam_search":
|
if params.method == "fast_beam_search":
|
||||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
hyp_list = fast_beam_search_one_best(
|
hyp_list = fast_beam_search_one_best(
|
||||||
@ -318,12 +327,11 @@ def main():
|
|||||||
raise ValueError(f"Unsupported method: {params.method}")
|
raise ValueError(f"Unsupported method: {params.method}")
|
||||||
hyp_list.append(hyp)
|
hyp_list.append(hyp)
|
||||||
|
|
||||||
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
|
hyps = [token_ids_to_words(hyp) for hyp in hyp_list]
|
||||||
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
@ -26,7 +26,7 @@ Usage:
|
|||||||
|
|
||||||
./zipformer_mmi/export.py \
|
./zipformer_mmi/export.py \
|
||||||
--exp-dir ./zipformer_mmi/exp \
|
--exp-dir ./zipformer_mmi/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 9 \
|
--avg 9 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -45,7 +45,7 @@ for how to use the exported models outside of icefall.
|
|||||||
|
|
||||||
./zipformer_mmi/export.py \
|
./zipformer_mmi/export.py \
|
||||||
--exp-dir ./zipformer_mmi/exp \
|
--exp-dir ./zipformer_mmi/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import sentencepiece as spm
|
import k2
|
||||||
import torch
|
import torch
|
||||||
from scaling_converter import convert_scaled_to_non_scaled
|
from scaling_converter import convert_scaled_to_non_scaled
|
||||||
from train import add_model_arguments, get_ctc_model, get_params
|
from train import add_model_arguments, get_ctc_model, get_params
|
||||||
@ -97,7 +97,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import str2bool
|
from icefall.utils import num_tokens, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -154,10 +154,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/lang_bpe_500/bpe.model",
|
default="data/lang_bpe_500/tokens.txt",
|
||||||
help="Path to the BPE model",
|
help="Path to the tokens.txt.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -190,12 +190,14 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.unk_id = token_table["<unk>"]
|
||||||
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ You can generate the checkpoint with the following command:
|
|||||||
|
|
||||||
./zipformer_mmi/export.py \
|
./zipformer_mmi/export.py \
|
||||||
--exp-dir ./zipformer_mmi/exp \
|
--exp-dir ./zipformer_mmi/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
@ -30,14 +30,14 @@ Usage of this script:
|
|||||||
(1) 1best
|
(1) 1best
|
||||||
./zipformer_mmi/pretrained.py \
|
./zipformer_mmi/pretrained.py \
|
||||||
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--method 1best \
|
--method 1best \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
/path/to/bar.wav
|
/path/to/bar.wav
|
||||||
(2) nbest
|
(2) nbest
|
||||||
./zipformer_mmi/pretrained.py \
|
./zipformer_mmi/pretrained.py \
|
||||||
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--nbest-scale 1.2 \
|
--nbest-scale 1.2 \
|
||||||
--method nbest \
|
--method nbest \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -45,7 +45,7 @@ Usage of this script:
|
|||||||
(3) nbest-rescoring-LG
|
(3) nbest-rescoring-LG
|
||||||
./zipformer_mmi/pretrained.py \
|
./zipformer_mmi/pretrained.py \
|
||||||
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--nbest-scale 1.2 \
|
--nbest-scale 1.2 \
|
||||||
--method nbest-rescoring-LG \
|
--method nbest-rescoring-LG \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -53,7 +53,7 @@ Usage of this script:
|
|||||||
(4) nbest-rescoring-3-gram
|
(4) nbest-rescoring-3-gram
|
||||||
./zipformer_mmi/pretrained.py \
|
./zipformer_mmi/pretrained.py \
|
||||||
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--nbest-scale 1.2 \
|
--nbest-scale 1.2 \
|
||||||
--method nbest-rescoring-3-gram \
|
--method nbest-rescoring-3-gram \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -61,7 +61,7 @@ Usage of this script:
|
|||||||
(5) nbest-rescoring-4-gram
|
(5) nbest-rescoring-4-gram
|
||||||
./zipformer_mmi/pretrained.py \
|
./zipformer_mmi/pretrained.py \
|
||||||
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
--checkpoint ./zipformer_mmi/exp/pretrained.pt \
|
||||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--nbest-scale 1.2 \
|
--nbest-scale 1.2 \
|
||||||
--method nbest-rescoring-4-gram \
|
--method nbest-rescoring-4-gram \
|
||||||
/path/to/foo.wav \
|
/path/to/foo.wav \
|
||||||
@ -83,7 +83,6 @@ from typing import List
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from decode import get_decoding_params
|
from decode import get_decoding_params
|
||||||
@ -97,7 +96,7 @@ from icefall.decode import (
|
|||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
)
|
)
|
||||||
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||||
from icefall.utils import get_texts
|
from icefall.utils import get_texts, num_tokens
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -115,9 +114,9 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--bpe-model",
|
"--tokens",
|
||||||
type=str,
|
type=str,
|
||||||
help="""Path to bpe.model.""",
|
help="""Path to tokens.txt.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -247,13 +246,14 @@ def main():
|
|||||||
params.update(get_decoding_params())
|
params.update(get_decoding_params())
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
sp = spm.SentencePieceProcessor()
|
# Load tokens.txt here
|
||||||
sp.load(params.bpe_model)
|
token_table = k2.SymbolTable.from_file(params.tokens)
|
||||||
|
|
||||||
|
# Load id of the <blk> token and the vocab size
|
||||||
# <blk> is defined in local/train_bpe_model.py
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
params.blank_id = sp.piece_to_id("<blk>")
|
params.blank_id = token_table["<blk>"]
|
||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = token_table["<unk>"]
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = num_tokens(token_table) + 1 # +1 for <blk>
|
||||||
|
|
||||||
logging.info(f"{params}")
|
logging.info(f"{params}")
|
||||||
|
|
||||||
@ -298,8 +298,6 @@ def main():
|
|||||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||||
|
|
||||||
bpe_model = spm.SentencePieceProcessor()
|
|
||||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
|
||||||
mmi_graph_compiler = MmiTrainingGraphCompiler(
|
mmi_graph_compiler = MmiTrainingGraphCompiler(
|
||||||
params.lang_dir,
|
params.lang_dir,
|
||||||
uniq_filename="lexicon.txt",
|
uniq_filename="lexicon.txt",
|
||||||
@ -313,6 +311,12 @@ def main():
|
|||||||
if not hasattr(HP, "lm_scores"):
|
if not hasattr(HP, "lm_scores"):
|
||||||
HP.lm_scores = HP.scores.clone()
|
HP.lm_scores = HP.scores.clone()
|
||||||
|
|
||||||
|
def token_ids_to_words(token_ids: List[int]) -> str:
|
||||||
|
text = ""
|
||||||
|
for i in token_ids:
|
||||||
|
text += token_table[i]
|
||||||
|
return text.replace("▁", " ").strip()
|
||||||
|
|
||||||
method = params.method
|
method = params.method
|
||||||
assert method in (
|
assert method in (
|
||||||
"1best",
|
"1best",
|
||||||
@ -390,14 +394,11 @@ def main():
|
|||||||
#
|
#
|
||||||
# token_ids is a lit-of-list of IDs
|
# token_ids is a lit-of-list of IDs
|
||||||
token_ids = get_texts(best_path)
|
token_ids = get_texts(best_path)
|
||||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
hyps = [token_ids_to_words(ids) for ids in token_ids]
|
||||||
hyps = bpe_model.decode(token_ids)
|
|
||||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
|
||||||
hyps = [s.split() for s in hyps]
|
|
||||||
s = "\n"
|
s = "\n"
|
||||||
for filename, hyp in zip(params.sound_files, hyps):
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
words = " ".join(hyp)
|
s += f"{filename}:\n{hyp}\n\n"
|
||||||
s += f"{filename}:\n{words}\n\n"
|
|
||||||
logging.info(s)
|
logging.info(s)
|
||||||
|
|
||||||
logging.info("Decoding Done")
|
logging.info("Decoding Done")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user