From 0816be86ae4c3ac71d0d0ce3030091ea75281d4e Mon Sep 17 00:00:00 2001 From: jinzr <60612200+JinZr@users.noreply.github.com> Date: Mon, 24 Jul 2023 23:48:36 +0800 Subject: [PATCH] updates for the `zipformer_mmi` and `transducer_stateless` recipes --- ...speech-transducer-stateless2-2022-04-19.sh | 4 +- ...un-librispeech-zipformer-mmi-2022-12-08.sh | 4 +- ...d-transducer-stateless-librispeech-100h.sh | 4 +- ...d-transducer-stateless-librispeech-960h.sh | 4 +- .../run-pre-trained-transducer-stateless.sh | 4 +- .github/scripts/run-pre-trained-transducer.sh | 2 +- egs/librispeech/ASR/transducer/export.py | 22 +++++---- egs/librispeech/ASR/transducer/pretrained.py | 33 +++++++------ .../ASR/transducer_stateless/export.py | 22 +++++---- .../ASR/transducer_stateless/pretrained.py | 36 ++++++++------ .../ASR/transducer_stateless2/export.py | 22 +++++---- .../ASR/transducer_stateless2/pretrained.py | 36 ++++++++------ .../export.py | 22 +++++---- .../pretrained.py | 36 ++++++++------ egs/librispeech/ASR/zipformer_mmi/export.py | 24 +++++----- .../ASR/zipformer_mmi/pretrained.py | 47 ++++++++++--------- 16 files changed, 181 insertions(+), 141 deletions(-) diff --git a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh index b4aca1b6b..ff77855a2 100755 --- a/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh +++ b/.github/scripts/run-librispeech-transducer-stateless2-2022-04-19.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh index a58b8ec56..c59921055 100755 --- a/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh +++ b/.github/scripts/run-librispeech-zipformer-mmi-2022-12-08.sh @@ -37,7 +37,7 @@ log "Export to torchscript model" ./zipformer_mmi/export.py \ --exp-dir $repo/exp \ --use-averaged-model false \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ --epoch 99 \ --avg 1 \ --jit 1 @@ -61,7 +61,7 @@ for method in 1best nbest nbest-rescoring-LG nbest-rescoring-3-gram nbest-rescor --method $method \ --checkpoint $repo/exp/pretrained.pt \ --lang-dir $repo/data/lang_bpe_500 \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh index 89115e88d..7b686328d 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-100h.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh index 85e2c89e6..a8eeeb514 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless-librispeech-960h.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer-stateless.sh b/.github/scripts/run-pre-trained-transducer-stateless.sh index 41456f11b..2e2360435 100755 --- a/.github/scripts/run-pre-trained-transducer-stateless.sh +++ b/.github/scripts/run-pre-trained-transducer-stateless.sh @@ -28,7 +28,7 @@ for sym in 1 2 3; do --method greedy_search \ --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav @@ -41,7 +41,7 @@ for method in fast_beam_search modified_beam_search beam_search; do --method $method \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/scripts/run-pre-trained-transducer.sh b/.github/scripts/run-pre-trained-transducer.sh index 1331c966c..b865f8d13 100755 --- a/.github/scripts/run-pre-trained-transducer.sh +++ b/.github/scripts/run-pre-trained-transducer.sh @@ -27,7 +27,7 @@ log "Beam search decoding" --method beam_search \ --beam-size 4 \ --checkpoint $repo/exp/pretrained.pt \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py index 6db0272f0..3b9e4a5dc 100755 --- a/egs/librispeech/ASR/transducer/export.py +++ b/egs/librispeech/ASR/transducer/export.py @@ -22,7 +22,7 @@ Usage: ./transducer/export.py \ --exp-dir ./transducer/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 34 \ --avg 11 @@ -46,7 +46,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from conformer import Conformer from decoder import Decoder @@ -55,7 +55,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -90,10 +90,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -191,12 +191,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index 511610245..c2413f5de 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -19,7 +19,7 @@ Usage: ./transducer/pretrained.py \ --checkpoint ./transducer/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav \ @@ -36,8 +36,8 @@ import logging import math from typing import List +import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import beam_search, greedy_search @@ -48,7 +48,7 @@ from model import Transducer from torch.nn.utils.rnn import pad_sequence from icefall.env import get_env_info -from icefall.utils import AttributeDict +from icefall.utils import AttributeDict, num_tokens def get_parser(): @@ -66,11 +66,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, + help="Path to tokens.txt.", ) parser.add_argument( @@ -204,12 +202,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -257,6 +257,12 @@ def main(): x=features, x_lens=feature_lengths ) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + num_waves = encoder_out.size(0) hyps = [] for i in range(num_waves): @@ -272,12 +278,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py index 89359f1a4..c397eb171 100755 --- a/egs/librispeech/ASR/transducer_stateless/export.py +++ b/egs/librispeech/ASR/transducer_stateless/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless/export.py \ --exp-dir ./transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -46,7 +46,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -56,7 +56,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -91,10 +91,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -191,12 +191,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index 915a6069d..5898dd0f5 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -47,7 +47,7 @@ Usage: (4) fast beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -80,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -96,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -213,12 +214,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -273,6 +276,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_list = fast_beam_search_one_best( @@ -318,12 +327,11 @@ def main(): raise ValueError(f"Unsupported method: {params.method}") hyp_list.append(hyp) - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps = [token_ids_to_words(hyp) for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py index d33d02642..f4b6f5554 100755 --- a/egs/librispeech/ASR/transducer_stateless2/export.py +++ b/egs/librispeech/ASR/transducer_stateless2/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless2/export.py \ --exp-dir ./transducer_stateless2/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -46,12 +46,12 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from train import get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -86,10 +86,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", ) parser.add_argument( @@ -123,12 +123,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py index 0738f30c0..b69b347ef 100755 --- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -47,7 +47,7 @@ Usage: (4) fast beam search ./transducer_stateless2/pretrained.py \ --checkpoint ./transducer_stateless2/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -80,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -96,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -213,12 +214,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -273,6 +276,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_list = fast_beam_search_one_best( @@ -318,12 +327,11 @@ def main(): raise ValueError(f"Unsupported method: {params.method}") hyp_list.append(hyp) - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps = [token_ids_to_words(hyp) for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py index 3735ef452..6d31dfe34 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless_multi_datasets/export.py \ --exp-dir ./transducer_stateless_multi_datasets/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -47,7 +47,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch import torch.nn as nn from conformer import Conformer @@ -57,7 +57,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info -from icefall.utils import AttributeDict, str2bool +from icefall.utils import AttributeDict, num_tokens, str2bool def get_parser(): @@ -92,10 +92,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -192,12 +192,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py index 8c7726367..4f29d6f1f 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method greedy_search \ --max-sym-per-frame 1 \ /path/to/foo.wav \ @@ -29,7 +29,7 @@ Usage: (2) beam search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -38,7 +38,7 @@ Usage: (3) modified beam search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -47,7 +47,7 @@ Usage: (4) fast beam search ./transducer_stateless_multi_datasets/pretrained.py \ --checkpoint ./transducer_stateless_multi_datasets/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -67,7 +67,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -80,6 +79,8 @@ from beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import get_params, get_transducer_model +from icefall.utils import num_tokens + def get_parser(): parser = argparse.ArgumentParser( @@ -96,9 +97,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -213,12 +214,14 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -273,6 +276,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_list = fast_beam_search_one_best( @@ -318,12 +327,11 @@ def main(): raise ValueError(f"Unsupported method: {params.method}") hyp_list.append(hyp) - hyps = [sp.decode(hyp).split() for hyp in hyp_list] + hyps = [token_ids_to_words(hyp) for hyp in hyp_list] s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/zipformer_mmi/export.py b/egs/librispeech/ASR/zipformer_mmi/export.py index 0af7bd367..1aec56420 100755 --- a/egs/librispeech/ASR/zipformer_mmi/export.py +++ b/egs/librispeech/ASR/zipformer_mmi/export.py @@ -26,7 +26,7 @@ Usage: ./zipformer_mmi/export.py \ --exp-dir ./zipformer_mmi/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +45,7 @@ for how to use the exported models outside of icefall. ./zipformer_mmi/export.py \ --exp-dir ./zipformer_mmi/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -86,7 +86,7 @@ import argparse import logging from pathlib import Path -import sentencepiece as spm +import k2 import torch from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_ctc_model, get_params @@ -97,7 +97,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.utils import str2bool +from icefall.utils import num_tokens, str2bool def get_parser(): @@ -154,10 +154,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -190,12 +190,14 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) diff --git a/egs/librispeech/ASR/zipformer_mmi/pretrained.py b/egs/librispeech/ASR/zipformer_mmi/pretrained.py index 0e7fd0daf..3ba4da5dd 100755 --- a/egs/librispeech/ASR/zipformer_mmi/pretrained.py +++ b/egs/librispeech/ASR/zipformer_mmi/pretrained.py @@ -21,7 +21,7 @@ You can generate the checkpoint with the following command: ./zipformer_mmi/export.py \ --exp-dir ./zipformer_mmi/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -30,14 +30,14 @@ Usage of this script: (1) 1best ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --method 1best \ /path/to/foo.wav \ /path/to/bar.wav (2) nbest ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest \ /path/to/foo.wav \ @@ -45,7 +45,7 @@ Usage of this script: (3) nbest-rescoring-LG ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest-rescoring-LG \ /path/to/foo.wav \ @@ -53,7 +53,7 @@ Usage of this script: (4) nbest-rescoring-3-gram ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest-rescoring-3-gram \ /path/to/foo.wav \ @@ -61,7 +61,7 @@ Usage of this script: (5) nbest-rescoring-4-gram ./zipformer_mmi/pretrained.py \ --checkpoint ./zipformer_mmi/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --nbest-scale 1.2 \ --method nbest-rescoring-4-gram \ /path/to/foo.wav \ @@ -83,7 +83,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from decode import get_decoding_params @@ -97,7 +96,7 @@ from icefall.decode import ( one_best_decoding, ) from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import get_texts +from icefall.utils import get_texts, num_tokens def get_parser(): @@ -115,9 +114,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="""Path to tokens.txt.""", ) parser.add_argument( @@ -247,13 +246,14 @@ def main(): params.update(get_decoding_params()) params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) + # Load id of the token and the vocab size # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -298,8 +298,6 @@ def main(): features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) feature_lengths = torch.tensor(feature_lengths, device=device) - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) mmi_graph_compiler = MmiTrainingGraphCompiler( params.lang_dir, uniq_filename="lexicon.txt", @@ -313,6 +311,12 @@ def main(): if not hasattr(HP, "lm_scores"): HP.lm_scores = HP.scores.clone() + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + method = params.method assert method in ( "1best", @@ -390,14 +394,11 @@ def main(): # # token_ids is a lit-of-list of IDs token_ids = get_texts(best_path) - # hyps is a list of str, e.g., ['xxx yyy zzz', ...] - hyps = bpe_model.decode(token_ids) - # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] - hyps = [s.split() for s in hyps] + hyps = [token_ids_to_words(ids) for ids in token_ids] + s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done")