diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export.py b/egs/aishell/ASR/pruned_transducer_stateless7/export.py deleted file mode 100755 index 1b0e8d3b9..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/export.py +++ /dev/null @@ -1,321 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Check -https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --epoch 20 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `pruned_transducer_stateless7/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless7/decode.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --lang-dir data/lang_char - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/marcoyang/icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21 - # You will find the pre-trained model in icefall-asr-aishell-zipformer-pruned-transducer-stateless7-2023-03-21exp -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -import torch.nn as nn -from scaling_converter import convert_scaled_to_non_scaled -from train2 import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt - - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/export.py b/egs/aishell/ASR/pruned_transducer_stateless7/export.py new file mode 120000 index 000000000..2713792e6 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/export.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export_unified.py b/egs/aishell/ASR/pruned_transducer_stateless7/export2.py old mode 100755 new mode 100644 similarity index 98% rename from egs/librispeech/ASR/pruned_transducer_stateless7/export_unified.py rename to egs/aishell/ASR/pruned_transducer_stateless7/export2.py index e491c3762..824b619e7 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export_unified.py +++ b/egs/aishell/ASR/pruned_transducer_stateless7/export2.py @@ -46,7 +46,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --tokens data/lang_bpe_500/tokens.txt \ + --tokens data/lang_char/tokens.txt \ --epoch 20 \ --avg 10 @@ -66,7 +66,7 @@ you can do: --avg 1 \ --max-duration 600 \ --decoding-method greedy_search \ - --tokens data/lang_bpe_500/tokens.txt \ + --tokens data/lang_char/tokens.txt Check ./pretrained.py for its usage. @@ -89,11 +89,10 @@ from pathlib import Path import re import k2 -import sentencepiece as spm import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model +from train2 import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( average_checkpoints, diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py deleted file mode 100644 index cc54027d6..000000000 --- a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py +++ /dev/null @@ -1,348 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --lang-dir data/lang_char \ - --epoch 20 \ - --avg 10 - -Usage of this script: - -(1) greedy search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by -./pruned_transducer_stateless7/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.lexicon import Lexicon -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=str, - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - token_table = lexicon.token_table - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp_tokens = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp_tokens = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py new file mode 120000 index 000000000..068f0f57f --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless7/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/pretrained.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py old mode 100755 new mode 100644 index 3e3160e7e..b77737a11 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang +# Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -26,7 +27,7 @@ Usage: ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -45,7 +46,7 @@ for how to use the exported models outside of icefall. ./pruned_transducer_stateless7/export.py \ --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --tokens data/lang_bpe_500/tokens.txt \ --epoch 20 \ --avg 10 @@ -65,7 +66,7 @@ you can do: --avg 1 \ --max-duration 600 \ --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model + --tokens data/lang_bpe_500/tokens.txt \ Check ./pretrained.py for its usage. @@ -85,8 +86,9 @@ with the following commands: import argparse import logging from pathlib import Path +import re -import sentencepiece as spm +import k2 import torch import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled @@ -101,6 +103,26 @@ from icefall.checkpoint import ( from icefall.utils import str2bool +def num_tokens( + token_table: k2.SymbolTable, disambig_pattern: str = re.compile(r"^#\d+$") +) -> int: + """Return the number of tokens excluding those from + disambiguation symbols. + + Caution: + 0 is not a token ID so it is excluded from the return value. + """ + symbols = token_table.symbols + ans = [] + for s in symbols: + if not disambig_pattern.match(s): + ans.append(token_table[s]) + num_tokens = len(ans) + if 0 in ans: + num_tokens -= 1 + return num_tokens + + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -155,10 +177,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -198,12 +219,12 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(params) @@ -292,7 +313,7 @@ def main(): model.to("cpu") model.eval() - if params.jit is True: + if params.jit: convert_scaled_to_non_scaled(model, inplace=True) # We won't use the forward() method of the model in C++, so just ignore # it here. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py old mode 100755 new mode 100644 index d05bafcfb..d97e1c138 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Zengrui Jin) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -29,7 +30,7 @@ Usage of this script: (1) greedy search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav @@ -37,7 +38,7 @@ Usage of this script: (2) beam search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -46,7 +47,7 @@ Usage of this script: (3) modified beam search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method modified_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -55,7 +56,7 @@ Usage of this script: (4) fast beam search ./pruned_transducer_stateless7/pretrained.py \ --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --tokens ./data/lang_bpe_500/tokens.txt \ --method fast_beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -75,7 +76,6 @@ from typing import List import k2 import kaldifeat -import sentencepiece as spm import torch import torchaudio from beam_search import ( @@ -87,6 +87,7 @@ from beam_search import ( ) from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model +from export import num_tokens from icefall.utils import str2bool @@ -106,9 +107,9 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--tokens", type=str, - help="""Path to bpe.model.""", + help="Path to the tokens.txt.", ) parser.add_argument( @@ -225,13 +226,13 @@ def main(): params.update(vars(args)) - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) + # Load tokens.txt here + token_table = k2.SymbolTable.from_file(params.tokens) - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + # Load id of the token and the vocab size + params.blank_id = token_table[""] + params.unk_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 # +1 for logging.info(f"{params}") @@ -286,6 +287,12 @@ def main(): msg += f" with beam size {params.beam_size}" logging.info(msg) + def token_ids_to_words(token_ids: List[int]) -> str: + text = "" + for i in token_ids: + text += token_table[i] + return text.replace("▁", " ").strip() + if params.method == "fast_beam_search": decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) hyp_tokens = fast_beam_search_one_best( @@ -297,8 +304,8 @@ def main(): max_contexts=params.max_contexts, max_states=params.max_states, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -307,16 +314,16 @@ def main(): beam=params.beam_size, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) elif params.method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in hyp_tokens: + hyps.append(token_ids_to_words(hyp)) else: for i in range(num_waves): # fmt: off @@ -337,12 +344,11 @@ def main(): else: raise ValueError(f"Unsupported method: {params.method}") - hyps.append(sp.decode(hyp).split()) + hyps.append(token_ids_to_words(hyp)) s = "\n" for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" + s += f"{filename}:\n{hyp}\n\n" logging.info(s) logging.info("Decoding Done") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained_unified.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained_unified.py deleted file mode 100755 index 9285479b6..000000000 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained_unified.py +++ /dev/null @@ -1,362 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Zengrui Jin) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -./pruned_transducer_stateless7/export.py \ - --exp-dir ./pruned_transducer_stateless7/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 - -Usage of this script: - -(1) greedy search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless7/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ - --tokens ./data/lang_bpe_500/tokens.txt \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by -./pruned_transducer_stateless7/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model -from export_unified import num_tokens - -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt.", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - # Load tokens.txt here - token_table = k2.SymbolTable.from_file(params.tokens) - - # Load id of the token and the vocab size - params.blank_id = token_table[""] - params.unk_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 # +1 for - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(token_ids_to_words(hyp)) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - s += f"{filename}:\n{hyp}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main()