diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py new file mode 100755 index 000000000..6a9b52caa --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -0,0 +1,581 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.trace() + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit-trace 1 + +It will generates 3 files: `encoder_jit_trace.pt`, +`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`. + + +(3) Export to ONNX format + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +It will generate the following three files in the given `exp_dir`. +Check `onnx_check.py` for how to use them. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + + +(4) Export `model.state_dict()` + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `lstm_transducer_stateless/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./lstm_transducer_stateless/decode.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 + # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/exp +""" + +import argparse +import logging +from pathlib import Path + +import onnx +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit-trace", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.trace. + It will generate 3 files: + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + + Check ./jit_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--onnx", + type=str2bool, + default=False, + help="""If True, --jit is ignored and it exports the model + to onnx format. Three files will be generated: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + + Check ./onnx_check.py and ./onnx_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + states = encoder_model.get_init_states() + states = (states[0].unsqueeze(1), states[1].unsqueeze(1)) + + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has two inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + + and it has two outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + states = encoder_model.get_init_states() + hidden_states = states[0].unsqueeze(1) + cell_states = states[1].unsqueeze(1) + # encoder_model = torch.jit.script(encoder_model) + # It throws the following error for the above statement + # + # RuntimeError: Exporting the operator __is_ to ONNX opset version + # 11 is not supported. Please feel free to request support or + # submit a pull request on PyTorch GitHub. + # + # I cannot find which statement causes the above error. + # torch.onnx.export() will use torch.jit.trace() internally, which + # works well for the current reworked model + warmup = 1.0 + torch.onnx.export( + encoder_model, + (x, x_lens, (hidden_states, cell_states), warmup), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens", "hidden_states", "cell_states", "warmup"], + output_names=[ + "encoder_out", + "encoder_out_lens", + "new_hidden_states", + "new_cell_states", + ], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "hidden_states": {1: "N"}, + "cell_states": {1: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + "new_hidden_states": {1: "N"}, + "new_cell_states": {1: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported model has two inputs: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and has one output: + + - joiner_out: a tensor of shape (N, vocab_size) + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + project_input = True + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (encoder_out, decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out", "decoder_out", "project_input"], + output_names=["logit"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + +def export_all_in_one_onnx( + encoder_filename: str, + decoder_filename: str, + joiner_filename: str, + all_in_one_filename: str, +): + encoder_onnx = onnx.load(encoder_filename) + decoder_onnx = onnx.load(decoder_filename) + joiner_onnx = onnx.load(joiner_filename) + + encoder_onnx = onnx.compose.add_prefix(encoder_onnx, prefix="encoder/") + decoder_onnx = onnx.compose.add_prefix(decoder_onnx, prefix="decoder/") + joiner_onnx = onnx.compose.add_prefix(joiner_onnx, prefix="joiner/") + + combined_model = onnx.compose.merge_models( + encoder_onnx, decoder_onnx, io_map={} + ) + combined_model = onnx.compose.merge_models( + combined_model, joiner_onnx, io_map={} + ) + onnx.save(combined_model, all_in_one_filename) + logging.info(f"Saved to {all_in_one_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.to("cpu") + model.eval() + convert_scaled_to_non_scaled(model, inplace=True) + + if params.onnx is True: + opset_version = 11 + logging.info("Exporting to onnx format") + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + + all_in_one_filename = params.exp_dir / "all_in_one.onnx" + export_all_in_one_onnx( + encoder_filename, + decoder_filename, + joiner_filename, + all_in_one_filename, + ) + elif params.jit_trace is True: + logging.info("Using torch.jit.trace()") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + else: + logging.info("Not using torchscript") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py new file mode 100755 index 000000000..eeb7dadad --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py @@ -0,0 +1,319 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, either exported by `torch.jit.trace()` +or by `torch.jit.script()`, and uses them to decode waves. +You can use the following command to get the exported models: + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit-trace 1 + +Usage of this script: + +./lstm_transducer_stateless/jit_pretrained.py \ + --encoder-model-filename ./lstm_transducer_stateless/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./lstm_transducer_stateless/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./lstm_transducer_stateless/exp/joiner_jit_trace.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens, _ = encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index 1d6fda0b4..7eef94fea 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -51,7 +51,8 @@ class RNN(EncoderInterface): Dropout value for model-level warmup (default=0.075). aux_layer_period (int): Peroid of auxiliary layers used for randomly combined during training. - If not larger than 0, will not use the random combiner. + If set to 0, will not use the random combiner (Default). + You can set a positive integer to use the random combiner, e.g., 3. """ def __init__( @@ -64,7 +65,7 @@ class RNN(EncoderInterface): num_encoder_layers: int = 12, dropout: float = 0.1, layer_dropout: float = 0.075, - aux_layer_period: int = 3, + aux_layer_period: int = 0, ) -> None: super(RNN, self).__init__() @@ -106,62 +107,11 @@ class RNN(EncoderInterface): ) def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - The input tensor. Its shape is (N, T, C), where N is the batch size, - T is the sequence length, C is the feature dimension. - x_lens: - A tensor of shape (N,), containing the number of frames in `x` - before padding. - warmup: - A floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. - - Returns: - A tuple of 2 tensors: - - embeddings: its shape is (N, T', d_model), where T' is the output - sequence lengths. - - lengths: a tensor of shape (batch_size,) containing the number of - frames in `embeddings` before padding. - """ - x = self.encoder_embed(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning - # - # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - lengths = (((x_lens - 1) >> 1) - 1) >> 1 - assert x.size(0) == lengths.max().item() - - x = self.encoder(x, warmup) - - x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - return x, lengths - - @torch.jit.export - def get_init_states( - self, device: torch.device - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Get model initial states.""" - # for rnn hidden states - hidden_states = torch.zeros( - (self.num_encoder_layers, self.d_model), device=device - ) - cell_states = torch.zeros( - (self.num_encoder_layers, self.rnn_hidden_size), device=device - ) - return (hidden_states, cell_states) - - @torch.jit.export - def infer( self, x: torch.Tensor, x_lens: torch.Tensor, - states: Tuple[torch.Tensor, torch.Tensor], + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Args: @@ -172,11 +122,15 @@ class RNN(EncoderInterface): A tensor of shape (N,), containing the number of frames in `x` before padding. states: - It is a list of 2 tensors. + A tuple of 2 tensors (optional). It is for streaming inference. states[0] is the hidden states of all layers, with shape of (num_layers, N, d_model); states[1] is the cell states of all layers, with shape of (num_layers, N, rnn_hidden_size). + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. Returns: A tuple of 3 tensors: @@ -186,36 +140,57 @@ class RNN(EncoderInterface): frames in `embeddings` before padding. - updated states, whose shape is same as the input states. """ - assert not self.training - assert len(states) == 2 - # for hidden state - assert states[0].shape == ( - self.num_encoder_layers, - x.size(0), - self.d_model, - ) - # for cell state - assert states[1].shape == ( - self.num_encoder_layers, - x.size(0), - self.rnn_hidden_size, - ) + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning # # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 lengths = (((x_lens - 1) >> 1) - 1) >> 1 - # we will cut off 1 frame on each side of encoder_embed output - lengths -= 2 + if not torch.jit.is_tracing(): + assert x.size(0) == lengths.max().item() - embed = self.encoder_embed(x) - embed = embed[:, 1:-1, :] - embed = embed.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + if states is None: + x = self.encoder(x, warmup=warmup)[0] + # torch.jit.trace requires returned types be the same as annotated + new_states = (torch.empty(0), torch.empty(0)) + else: + # we cut off 1 frame on each side of encoder_embed output + lengths -= 2 + x = x[1:-1, :, :] - x, states = self.encoder.infer(embed, states) + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == ( + self.num_encoder_layers, + x.size(1), + self.d_model, + ) + # for cell state + assert states[1].shape == ( + self.num_encoder_layers, + x.size(1), + self.rnn_hidden_size, + ) + x, new_states = self.encoder(x, states) x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - return x, lengths, states + return x, lengths, new_states + + def get_init_states( + self, device: torch.device = torch.device("cpu") + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Get model initial states.""" + # for rnn hidden states + hidden_states = torch.zeros( + (self.num_encoder_layers, self.d_model), device=device + ) + cell_states = torch.zeros( + (self.num_encoder_layers, self.rnn_hidden_size), device=device + ) + return (hidden_states, cell_states) class RNNEncoderLayer(nn.Module): @@ -271,7 +246,12 @@ class RNNEncoderLayer(nn.Module): ) self.dropout = nn.Dropout(dropout) - def forward(self, src: torch.Tensor, warmup: float = 1.0) -> torch.Tensor: + def forward( + self, + src: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Pass the input through the encoder layer. @@ -280,6 +260,12 @@ class RNNEncoderLayer(nn.Module): The sequence to the encoder layer (required). Its shape is (S, N, E), where S is the sequence length, N is the batch size, and E is the feature number. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (1, N, d_model); + states[1] is the cell states of all layers, + with shape of (1, N, rnn_hidden_size). warmup: It controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. @@ -299,7 +285,19 @@ class RNNEncoderLayer(nn.Module): alpha = 1.0 # lstm module - src_lstm = self.lstm(src)[0] + if states is None: + src_lstm = self.lstm(src)[0] + # torch.jit.trace requires returned types be the same as annotated + new_states = (torch.empty(0), torch.empty(0)) + else: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == (1, src.size(1), self.d_model) + # for cell state + assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) + src_lstm, new_states = self.lstm(src, states) src = src + self.dropout(src_lstm) # feed forward module @@ -310,41 +308,6 @@ class RNNEncoderLayer(nn.Module): if alpha != 1.0: src = alpha * src + (1 - alpha) * src_orig - return src - - @torch.jit.export - def infer( - self, src: torch.Tensor, states: Tuple[torch.Tensor, torch.Tensor] - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Pass the input through the encoder layer. - - Args: - src: - The sequence to the encoder layer (required). - Its shape is (S, N, d_model), where S is the sequence length, - N is the batch size. - states: - It is a tuple of 2 tensors. - states[0] is the hidden state, with shape of (1, N, d_model); - states[1] is the cell state, with shape of (1, N, rnn_hidden_size). - """ - assert not self.training - assert len(states) == 2 - # for hidden state - assert states[0].shape == (1, src.size(1), self.d_model) - # for cell state - assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) - - # lstm module - src_lstm, new_states = self.lstm(src, states) - src = src + self.dropout(src_lstm) - - # feed forward module - src = src + self.dropout(self.feed_forward(src)) - - src = self.norm_final(self.balancer(src)) - return src, new_states @@ -373,11 +336,11 @@ class RNNEncoder(nn.Module): self.d_model = encoder_layer.d_model self.rnn_hidden_size = encoder_layer.rnn_hidden_size - self.use_random_combiner = False + self.aux_layers: List[int] = [] + self.combiner: Optional[nn.Module] = None if aux_layers is not None: assert len(set(aux_layers)) == len(aux_layers) assert num_layers - 1 not in aux_layers - self.use_random_combiner = True self.aux_layers = aux_layers + [num_layers - 1] self.combiner = RandomCombine( num_inputs=len(self.aux_layers), @@ -386,7 +349,12 @@ class RNNEncoder(nn.Module): stddev=2.0, ) - def forward(self, src: torch.Tensor, warmup: float = 1.0) -> torch.Tensor: + def forward( + self, + src: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Pass the input through the encoder layer in turn. @@ -395,75 +363,66 @@ class RNNEncoder(nn.Module): The sequence to the encoder layer (required). Its shape is (S, N, E), where S is the sequence length, N is the batch size, and E is the feature number. - warmup: - It controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - """ - output = src - - outputs = [] - - for i, mod in enumerate(self.layers): - output = mod(output, warmup=warmup) - if self.use_random_combiner: - if i in self.aux_layers: - outputs.append(output) - - if self.use_random_combiner: - output = self.combiner(outputs) - - return output - - @torch.jit.export - def infer( - self, src: torch.Tensor, states: Tuple[torch.Tensor, torch.Tensor] - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Pass the input through the encoder layer. - - Args: - src: - The sequence to the encoder layer (required). - Its shape is (S, N, d_model), where S is the sequence length, - N is the batch size. states: - It is a list of 2 tensors. + A tuple of 2 tensors (optional). It is for streaming inference. states[0] is the hidden states of all layers, with shape of (num_layers, N, d_model); states[1] is the cell states of all layers, with shape of (num_layers, N, rnn_hidden_size). + warmup: + It controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. """ - assert not self.training - assert len(states) == 2 - # for hidden state - assert states[0].shape == (self.num_layers, src.size(1), self.d_model) - # for cell state - assert states[1].shape == ( - self.num_layers, - src.size(1), - self.rnn_hidden_size, - ) + if states is not None: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == ( + self.num_layers, + src.size(1), + self.d_model, + ) + # for cell state + assert states[1].shape == ( + self.num_layers, + src.size(1), + self.rnn_hidden_size, + ) output = src + + outputs = [] + new_hidden_states = [] new_cell_states = [] - for layer_index, mod in enumerate(self.layers): - layer_states = ( - states[0][ - layer_index : layer_index + 1, :, : - ], # h: (1, N, d_model) - states[1][ - layer_index : layer_index + 1, :, : - ], # c: (1, N, rnn_hidden_size) - ) - output, (h, c) = mod.infer(output, layer_states) - new_hidden_states.append(h) - new_cell_states.append(c) - new_states = ( - torch.cat(new_hidden_states, dim=0), - torch.cat(new_cell_states, dim=0), - ) + for i, mod in enumerate(self.layers): + if states is None: + output = mod(output, warmup=warmup)[0] + else: + layer_state = ( + states[0][i : i + 1, :, :], # h: (1, N, d_model) + states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size) + ) + output, (h, c) = mod(output, layer_state) + new_hidden_states.append(h) + new_cell_states.append(c) + + if self.combiner is not None and i in self.aux_layers: + outputs.append(output) + + if self.combiner is not None: + output = self.combiner(outputs) + + if states is None: + new_states = (torch.empty(0), torch.empty(0)) + else: + new_states = ( + torch.cat(new_hidden_states, dim=0), + torch.cat(new_cell_states, dim=0), + ) + return output, new_states @@ -804,9 +763,9 @@ if __name__ == "__main__": m = RNN( num_features=feature_dim, d_model=512, - rnn_hidden_size=1024, + rnn_hidden_size=1536, dim_feedforward=2048, - num_encoder_layers=12, + num_encoder_layers=10, ) batch_size = 5 seq_len = 20 diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/test_scaling_converter.py b/egs/librispeech/ASR/lstm_transducer_stateless/test_scaling_converter.py index b4bc8d91f..0a7e63b60 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/test_scaling_converter.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/test_scaling_converter.py @@ -19,7 +19,7 @@ To run this file, do: cd icefall/egs/librispeech/ASR - python ./pruned_transducer_stateless3/test_scaling_converter.py + python ./lstm_transducer_stateless/test_scaling_converter.py """ import copy diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index 354815f92..b6351b1d3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -389,8 +389,9 @@ class ScaledLSTM(nn.LSTM): initial_speed: float = 1.0, **kwargs ): - # Hardcode bidirectional=False - super(ScaledLSTM, self).__init__(*args, bidirectional=False, **kwargs) + if "bidirectional" in kwargs: + assert kwargs["bidirectional"] is False + super(ScaledLSTM, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self._scales_names = [] self._scales = [] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index e82c52e6f..cc7889d41 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -170,7 +170,7 @@ def scaled_embedding_to_embedding( return embedding -def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM): +def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: """Convert an instance of ScaledLSTM to nn.LSTM. Args: