From 2767b9ff11e2410c2307b2041442ea87a0441022 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 27 Apr 2023 14:36:36 +0800 Subject: [PATCH] Support exporting RNNLM to ONNX. (#1014) * Support exporting RNNLM to ONNX. * add int8 models * fix style issues * Fix EOS padding * support exporting for streaming ASR --- icefall/rnn_lm/.gitignore | 1 + icefall/rnn_lm/check-onnx-streaming.py | 132 +++++++++ icefall/rnn_lm/check-onnx.py | 119 ++++++++ icefall/rnn_lm/export-onnx.py | 395 +++++++++++++++++++++++++ icefall/rnn_lm/export-onnx.sh | 26 ++ icefall/rnn_lm/export.py | 7 +- icefall/rnn_lm/export.sh | 27 ++ icefall/rnn_lm/model.py | 46 +++ 8 files changed, 752 insertions(+), 1 deletion(-) create mode 100644 icefall/rnn_lm/.gitignore create mode 100755 icefall/rnn_lm/check-onnx-streaming.py create mode 100755 icefall/rnn_lm/check-onnx.py create mode 100755 icefall/rnn_lm/export-onnx.py create mode 100755 icefall/rnn_lm/export-onnx.sh create mode 100755 icefall/rnn_lm/export.sh diff --git a/icefall/rnn_lm/.gitignore b/icefall/rnn_lm/.gitignore new file mode 100644 index 000000000..877fb1e18 --- /dev/null +++ b/icefall/rnn_lm/.gitignore @@ -0,0 +1 @@ +icefall-librispeech-rnn-lm diff --git a/icefall/rnn_lm/check-onnx-streaming.py b/icefall/rnn_lm/check-onnx-streaming.py new file mode 100755 index 000000000..8850c1c71 --- /dev/null +++ b/icefall/rnn_lm/check-onnx-streaming.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation + +""" +Usage: + +./check-onnx-streaming.py \ + --jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \ + --onnx ./icefall-librispeech-rnn-lm/exp/with-state-epoch-99-avg-1.onnx + +Note: You can download pre-trained models from +https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + +""" + +import argparse +import logging +from typing import Tuple + +import onnxruntime as ort +import torch + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx", + required=True, + type=str, + help="Path to the onnx model", + ) + + return parser + + +class OnnxModel: + def __init__(self, filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.model = ort.InferenceSession( + filename, + sess_options=session_opts, + ) + + meta_data = self.model.get_modelmeta().custom_metadata_map + self.sos_id = int(meta_data["sos_id"]) + self.eos_id = int(meta_data["eos_id"]) + self.vocab_size = int(meta_data["vocab_size"]) + self.num_layers = int(meta_data["num_layers"]) + self.hidden_size = int(meta_data["hidden_size"]) + print(meta_data) + + def __call__( + self, x: torch.Tensor, y: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + out = self.model.run( + [ + self.model.get_outputs()[0].name, + self.model.get_outputs()[1].name, + self.model.get_outputs()[2].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: y.numpy(), + self.model.get_inputs()[2].name: h0.numpy(), + self.model.get_inputs()[3].name: c0.numpy(), + }, + ) + return ( + torch.from_numpy(out[0]), + torch.from_numpy(out[1]), + torch.from_numpy(out[2]), + ) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_model = torch.jit.load(args.jit).cpu() + onnx_model = OnnxModel(args.onnx) + N = torch.arange(1, 5).tolist() + + num_layers = onnx_model.num_layers + hidden_size = onnx_model.hidden_size + + for n in N: + L = torch.randint(low=1, high=100, size=(1,)).item() + x = torch.randint( + low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64 + ) + y = torch.randint( + low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64 + ) + h0 = torch.rand(num_layers, n, hidden_size) + c0 = torch.rand(num_layers, n, hidden_size) + + torch_nll, torch_h0, torch_c0 = torch_model.streaming_forward(x, y, h0, c0) + onnx_nll, onnx_h0, onnx_c0 = onnx_model(x, y, h0, c0) + + for torch_v, onnx_v in zip( + (torch_nll, torch_h0, torch_c0), (onnx_nll, onnx_h0, onnx_c0) + ): + + assert torch.allclose(torch_v, onnx_v, atol=1e-5), ( + torch_v.shape, + onnx_v.shape, + (torch_v - onnx_v).abs().max(), + ) + print(n, L, torch_v.sum(), onnx_v.sum()) + + +if __name__ == "__main__": + torch.manual_seed(20230423) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/rnn_lm/check-onnx.py b/icefall/rnn_lm/check-onnx.py new file mode 100755 index 000000000..24c5395f8 --- /dev/null +++ b/icefall/rnn_lm/check-onnx.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation + +""" +Usage: + +./check-onnx.py \ + --jit ./icefall-librispeech-rnn-lm/exp/cpu_jit.pt \ + --onnx ./icefall-librispeech-rnn-lm/exp/no-state-epoch-99-avg-1.onnx + +Note: You can download pre-trained models from +https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + +""" + +import argparse +import logging + +import onnxruntime as ort +import torch + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--jit", + required=True, + type=str, + help="Path to the torchscript model", + ) + + parser.add_argument( + "--onnx", + required=True, + type=str, + help="Path to the onnx model", + ) + + return parser + + +class OnnxModel: + def __init__(self, filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 1 + + self.model = ort.InferenceSession( + filename, + sess_options=session_opts, + ) + + meta_data = self.model.get_modelmeta().custom_metadata_map + self.sos_id = int(meta_data["sos_id"]) + self.eos_id = int(meta_data["eos_id"]) + self.vocab_size = int(meta_data["vocab_size"]) + print(meta_data) + + def __call__(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: x.numpy(), + self.model.get_inputs()[1].name: x_lens.numpy(), + }, + ) + return torch.from_numpy(out[0]) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + logging.info(vars(args)) + + torch_model = torch.jit.load(args.jit).cpu() + onnx_model = OnnxModel(args.onnx) + N = torch.arange(1, 5).tolist() + + for n in N: + L = torch.randint(low=1, high=100, size=(1,)).item() + x = torch.randint( + low=1, high=onnx_model.vocab_size, size=(n, L), dtype=torch.int64 + ) + x_lens = torch.full((n,), fill_value=L, dtype=torch.int64) + if n > 1: + x_lens[0] = L // 2 + 1 + + sos = torch.full((1,), fill_value=onnx_model.sos_id).expand(n, 1) + sos_x = torch.cat([sos, x], dim=1) + + pad_col = torch.zeros((1,), dtype=x.dtype).expand(n, 1) + x_eos = torch.cat([x, pad_col], dim=1) + + row_index = torch.arange(0, n, dtype=x.dtype) + x_eos[row_index, x_lens] = onnx_model.eos_id + + torch_nll = torch_model(sos_x, x_eos, x_lens + 1).sum(dim=-1) + onnx_nll = onnx_model(x, x_lens) + # Note: For int8 models, the differences may be quite large, + # e.g., within 0.9 + assert torch.allclose(torch_nll, onnx_nll), ( + torch_nll, + onnx_nll, + ) + print(n, L, torch_nll, onnx_nll) + + +if __name__ == "__main__": + torch.manual_seed(20230420) + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/rnn_lm/export-onnx.py b/icefall/rnn_lm/export-onnx.py new file mode 100755 index 000000000..6855f9bea --- /dev/null +++ b/icefall/rnn_lm/export-onnx.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation + +import argparse +import logging +from pathlib import Path + +import onnx +import torch +from model import RnnLmModel +from onnxruntime.quantization import QuantType, quantize_dynamic + +from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.utils import AttributeDict, str2bool +from typing import Dict +from train import get_params + + +def add_meta_data(filename: str, meta_data: Dict[str, str]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = value + + onnx.save(model, filename) + + +# A wrapper for RnnLm model to simpily the C++ calling code +# when exporting the model to ONNX. +# +# TODO(fangjun): The current wrapper works only for non-streaming ASR +# since we don't expose the LM state and it is used to score +# a complete sentence at once. +class RnnLmModelWrapper(torch.nn.Module): + def __init__(self, model: RnnLmModel, sos_id: int, eos_id: int): + super().__init__() + self.model = model + self.sos_id = sos_id + self.eos_id = eos_id + + def forward(self, x: torch.Tensor, x_lens: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + A 2-D tensor of shape (N, L) with dtype torch.int64. + It does not contain SOS or EOS. We will add SOS and EOS inside + this function. + x_lens: + A 1-D tensor of shape (N,) with dtype torch.int64. It contains + number of valid tokens in ``x`` before padding. + Returns: + Return a 1-D tensor of shape (N,) containing negative loglikelihood. + Its dtype is torch.float32 + """ + N = x.size(0) + + sos_tensor = torch.full((1,), fill_value=self.sos_id, dtype=x.dtype).expand( + N, 1 + ) + sos_x = torch.cat([sos_tensor, x], dim=1) + + pad_col = torch.zeros((1,), dtype=x.dtype).expand(N, 1) + x_eos = torch.cat([x, pad_col], dim=1) + + row_index = torch.arange(0, N, dtype=x.dtype) + x_eos[row_index, x_lens] = self.eos_id + + # use x_lens + 1 here since we prepended x with sos + return ( + self.model(x=sos_x, y=x_eos, lengths=x_lens + 1) + .to(torch.float32) + .sum(dim=1) + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=29, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=5, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=3, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--tie-weights", + type=str2bool, + default=True, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="rnn_lm/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + return parser + + +def export_without_state( + model: RnnLmModel, + filename: str, + params: AttributeDict, + opset_version: int, +): + model_wrapper = RnnLmModelWrapper( + model, + sos_id=params.sos_id, + eos_id=params.eos_id, + ) + + N = 1 + L = 20 + x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) + x_lens = torch.full((N,), fill_value=L, dtype=torch.int64) + + # Note(fangjun): The following warnings can be ignored. + # We can use ./check-onnx.py to validate the exported model with batch_size > 1 + """ + torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX + with a batch_size other than 1, with a variable length with LSTM can cause + an error when running the ONNX model with a different batch size. Make sure + to save the model with a batch size of 1, or define the initial states + (h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX + with a batch_size other than 1, " + + """ + + torch.onnx.export( + model_wrapper, + (x, x_lens), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens"], + output_names=["nll"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, + "x_lens": {0: "N"}, + "nll": {0: "N"}, + }, + ) + + meta_data = { + "model_type": "rnnlm", + "version": "1", + "model_author": "k2-fsa", + "comment": "rnnlm without state", + "sos_id": str(params.sos_id), + "eos_id": str(params.eos_id), + "vocab_size": str(params.vocab_size), + "url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=filename, meta_data=meta_data) + + +def export_with_state( + model: RnnLmModel, + filename: str, + params: AttributeDict, + opset_version: int, +): + N = 1 + L = 20 + num_layers = model.rnn.num_layers + hidden_size = model.rnn.hidden_size + + x = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) + y = torch.randint(low=1, high=params.vocab_size, size=(N, L), dtype=torch.int64) + h0 = torch.zeros(num_layers, N, hidden_size) + c0 = torch.zeros(num_layers, N, hidden_size) + + # Note(fangjun): The following warnings can be ignored. + # We can use ./check-onnx.py to validate the exported model with batch_size > 1 + """ + torch/onnx/symbolic_opset9.py:2119: UserWarning: Exporting a model to ONNX + with a batch_size other than 1, with a variable length with LSTM can cause + an error when running the ONNX model with a different batch size. Make sure + to save the model with a batch size of 1, or define the initial states + (h0/c0) as inputs of the model. warnings.warn("Exporting a model to ONNX + with a batch_size other than 1, " + + """ + + torch.onnx.export( + model, + (x, y, h0, c0), + filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "y", "h0", "c0"], + output_names=["nll", "next_h0", "next_c0"], + dynamic_axes={ + "x": {0: "N", 1: "L"}, + "y": {0: "N", 1: "L"}, + "h0": {1: "N"}, + "c0": {1: "N"}, + "nll": {0: "N"}, + "next_h0": {1: "N"}, + "next_c0": {1: "N"}, + }, + ) + + meta_data = { + "model_type": "rnnlm", + "version": "1", + "model_author": "k2-fsa", + "comment": "rnnlm state", + "sos_id": str(params.sos_id), + "eos_id": str(params.eos_id), + "vocab_size": str(params.vocab_size), + "num_layers": str(num_layers), + "hidden_size": str(hidden_size), + "url": "https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm", + } + logging.info(f"meta_data: {meta_data}") + + add_meta_data(filename=filename, meta_data=meta_data) + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + device = torch.device("cpu") + logging.info(f"device: {device}") + + model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + hidden_dim=params.hidden_dim, + num_layers=params.num_layers, + tie_weights=params.tie_weights, + ) + + model.to(device) + + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.to("cpu") + model.eval() + + if params.iter > 0: + suffix = f"iter-{params.iter}" + else: + suffix = f"epoch-{params.epoch}" + + suffix += f"-avg-{params.avg}" + + opset_version = 13 + + logging.info("Exporting model without state") + filename = params.exp_dir / f"no-state-{suffix}.onnx" + export_without_state( + model=model, + filename=filename, + params=params, + opset_version=opset_version, + ) + + filename_int8 = params.exp_dir / f"no-state-{suffix}.int8.onnx" + quantize_dynamic( + model_input=filename, + model_output=filename_int8, + weight_type=QuantType.QInt8, + ) + + # now for streaming export + saved_forward = model.__class__.forward + model.__class__.forward = model.__class__.streaming_forward + streaming_filename = params.exp_dir / f"with-state-{suffix}.onnx" + export_with_state( + model=model, + filename=streaming_filename, + params=params, + opset_version=opset_version, + ) + model.__class__.forward = saved_forward + + streaming_filename_int8 = params.exp_dir / f"with-state-{suffix}.int8.onnx" + quantize_dynamic( + model_input=streaming_filename, + model_output=streaming_filename_int8, + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/rnn_lm/export-onnx.sh b/icefall/rnn_lm/export-onnx.sh new file mode 100755 index 000000000..6e3262b5e --- /dev/null +++ b/icefall/rnn_lm/export-onnx.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash + +# We use the model from +# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main +# as an example + +export CUDA_VISIBLE_DEVICES= + +if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + pushd icefall-librispeech-rnn-lm/exp + git lfs pull --include "pretrained.pt" + ln -s pretrained.pt epoch-99.pt + popd +fi + +python3 ./export-onnx.py \ + --exp-dir ./icefall-librispeech-rnn-lm/exp \ + --epoch 99 \ + --avg 1 \ + --vocab-size 500 \ + --embedding-dim 2048 \ + --hidden-dim 2048 \ + --num-layers 3 \ + --tie-weights 1 + diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py index a8598a1ce..be4e7f8c5 100644 --- a/icefall/rnn_lm/export.py +++ b/icefall/rnn_lm/export.py @@ -26,7 +26,7 @@ import torch from model import RnnLmModel from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import AttributeDict, load_averaged_model, str2bool +from icefall.utils import AttributeDict, str2bool def get_parser(): @@ -118,6 +118,7 @@ def get_parser(): return parser +@torch.no_grad() def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) @@ -180,6 +181,10 @@ def main(): if params.jit: logging.info("Using torch.jit.script") + + model.__class__.streaming_forward = torch.jit.export( + model.__class__.streaming_forward + ) model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" model.save(str(filename)) diff --git a/icefall/rnn_lm/export.sh b/icefall/rnn_lm/export.sh new file mode 100755 index 000000000..678bc294e --- /dev/null +++ b/icefall/rnn_lm/export.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +# We use the model from +# https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm/tree/main +# as an example + +export CUDA_VISIBLE_DEVICES= + +if [ ! -f ./icefall-librispeech-rnn-lm/exp/pretrained.pt ]; then + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + pushd icefall-librispeech-rnn-lm/exp + git lfs pull --include "pretrained.pt" + ln -s pretrained.pt epoch-99.pt + popd +fi + +python3 ./export.py \ + --exp-dir ./icefall-librispeech-rnn-lm/exp \ + --epoch 99 \ + --avg 1 \ + --vocab-size 500 \ + --embedding-dim 2048 \ + --hidden-dim 2048 \ + --num-layers 3 \ + --tie-weights 1 \ + --jit 1 + diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index ebb3128e3..8d3e16432 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import Tuple import torch import torch.nn.functional as F @@ -47,6 +48,11 @@ class RnnLmModel(torch.nn.Module): and https://arxiv.org/abs/1611.01462 """ super().__init__() + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.tie_weights = tie_weights self.input_embedding = torch.nn.Embedding( num_embeddings=vocab_size, @@ -74,6 +80,46 @@ class RnnLmModel(torch.nn.Module): self.cache = {} + def streaming_forward( + self, + x: torch.Tensor, + y: torch.Tensor, + h0: torch.Tensor, + c0: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ''' + Args: + x: + A 2-D tensor of shape (N, L). We won't prepend it with SOS. + y: + A 2-D tensor of shape (N, L). We won't append it with EOS. + h0: + A 3-D tensor of shape (num_layers, N, hidden_size). + (If proj_size > 0, then it is (num_layers, N, proj_size)) + c0: + A 3-D tensor of shape (num_layers, N, hidden_size). + Returns: + Return a tuple containing 3 tensors: + - negative loglike (nll), a 1-D tensor of shape (N,) + - next_h0, a 3-D tensor with the same shape as h0 + - next_c0, a 3-D tensor with the same shape as c0 + ''' + assert x.ndim == y.ndim == 2, (x.ndim, y.ndim) + assert x.shape == y.shape, (x.shape, y.shape) + + # embedding is of shape (N, L, embedding_dim) + embedding = self.input_embedding(x) + # Note: We use batch_first==True + rnn_out, (next_h0, next_c0) = self.rnn(embedding, (h0, c0)) + logits = self.output_linear(rnn_out) + nll_loss = F.cross_entropy( + logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none" + ) + + batch_size = x.size(0) + nll_loss = nll_loss.reshape(batch_size, -1).sum(dim=1) + return nll_loss, next_h0, next_c0 + def forward( self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor ) -> torch.Tensor: