# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from typing import Tuple import torch import torch.nn.functional as F from icefall.utils import add_eos, add_sos, make_pad_mask class RnnLmModel(torch.nn.Module): def __init__( self, vocab_size: int, embedding_dim: int, hidden_dim: int, num_layers: int, tie_weights: bool = False, ): """ Args: vocab_size: Vocabulary size of BPE model. embedding_dim: Input embedding dimension. hidden_dim: Hidden dimension of RNN layers. num_layers: Number of RNN layers. tie_weights: True to share the weights between the input embedding layer and the last output linear layer. See https://arxiv.org/abs/1608.05859 and https://arxiv.org/abs/1611.01462 """ super().__init__() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.hidden_dim = hidden_dim self.num_layers = num_layers self.tie_weights = tie_weights self.input_embedding = torch.nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim, ) self.rnn = torch.nn.LSTM( input_size=embedding_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True, ) self.output_linear = torch.nn.Linear( in_features=hidden_dim, out_features=vocab_size ) self.vocab_size = vocab_size if tie_weights: logging.info("Tying weights") assert embedding_dim == hidden_dim, (embedding_dim, hidden_dim) self.output_linear.weight = self.input_embedding.weight else: logging.info("Not tying weights") self.cache = {} def streaming_forward( self, x: torch.Tensor, y: torch.Tensor, h0: torch.Tensor, c0: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: A 2-D tensor of shape (N, L). We won't prepend it with SOS. y: A 2-D tensor of shape (N, L). We won't append it with EOS. h0: A 3-D tensor of shape (num_layers, N, hidden_size). (If proj_size > 0, then it is (num_layers, N, proj_size)) c0: A 3-D tensor of shape (num_layers, N, hidden_size). Returns: Return a tuple containing 3 tensors: - negative loglike (nll), a 1-D tensor of shape (N,) - next_h0, a 3-D tensor with the same shape as h0 - next_c0, a 3-D tensor with the same shape as c0 """ assert x.ndim == y.ndim == 2, (x.ndim, y.ndim) assert x.shape == y.shape, (x.shape, y.shape) # embedding is of shape (N, L, embedding_dim) embedding = self.input_embedding(x) # Note: We use batch_first==True rnn_out, (next_h0, next_c0) = self.rnn(embedding, (h0, c0)) logits = self.output_linear(rnn_out) nll_loss = F.cross_entropy( logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none" ) batch_size = x.size(0) nll_loss = nll_loss.reshape(batch_size, -1).sum(dim=1) return nll_loss, next_h0, next_c0 def forward( self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor ) -> torch.Tensor: """ Args: x: A 2-D tensor with shape (N, L). Each row contains token IDs for a sentence and starts with the SOS token. y: A shifted version of `x` and with EOS appended. lengths: A 1-D tensor of shape (N,). It contains the sentence lengths before padding. Returns: Return a 2-D tensor of shape (N, L) containing negative log-likelihood loss values. Note: Loss values for padding positions are set to 0. """ assert x.ndim == y.ndim == 2, (x.ndim, y.ndim) assert lengths.ndim == 1, lengths.ndim assert x.shape == y.shape, (x.shape, y.shape) batch_size = x.size(0) assert lengths.size(0) == batch_size, (lengths.size(0), batch_size) # embedding is of shape (N, L, embedding_dim) embedding = self.input_embedding(x) # Note: We use batch_first==True rnn_out, _ = self.rnn(embedding) logits = self.output_linear(rnn_out) # Note: No need to use `log_softmax()` here # since F.cross_entropy() expects unnormalized probabilities # nll_loss is of shape (N*L,) # nll -> negative log-likelihood nll_loss = F.cross_entropy( logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none" ) # Set loss values for padding positions to 0 mask = make_pad_mask(lengths).reshape(-1) nll_loss.masked_fill_(mask, 0) nll_loss = nll_loss.reshape(batch_size, -1) return nll_loss def predict_batch(self, tokens, token_lens, sos_id, eos_id, blank_id): device = next(self.parameters()).device batch_size = len(token_lens) sos_tokens = add_sos(tokens, sos_id) tokens_eos = add_eos(tokens, eos_id) sos_tokens_row_splits = sos_tokens.shape.row_splits(1) sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) x_tokens = x_tokens.to(torch.int64).to(device) y_tokens = y_tokens.to(torch.int64).to(device) sentence_lengths = sentence_lengths.to(torch.int64).to(device) embedding = self.input_embedding(x_tokens) # Note: We use batch_first==True rnn_out, states = self.rnn(embedding) logits = self.output_linear(rnn_out) mask = torch.zeros(logits.shape).bool().to(device) for i in range(batch_size): mask[i, token_lens[i], :] = True logits = logits[mask].reshape(batch_size, -1) return logits[:, :].log_softmax(-1), states def clean_cache(self): self.cache = {} def score_token(self, x: torch.Tensor, x_lens: torch.Tensor, state=None): """Score a batch of tokens, i.e each sample in the batch should be a single token. For example, x = torch.tensor([[5],[10],[20]]) Args: x (torch.Tensor): A batch of tokens x_lens (torch.Tensor): The length of tokens in the batch before padding state (optional): Either None or a tuple of two torch.Tensor. Each tensor has the shape of (num_layers, bs, hidden_dim) Returns: _type_: _description_ """ device = next(self.parameters()).device batch_size = x.size(0) if state: h, c = state else: h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to( device ) c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to( device ) embedding = self.input_embedding(x) rnn_out, states = self.rnn(embedding, (h, c)) logits = self.output_linear(rnn_out) return logits[:, 0].log_softmax(-1), states def score_token_onnx( self, x: torch.Tensor, state_h: torch.Tensor, state_c: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Score a batch of tokens, i.e each sample in the batch should be a single token. For example, x = torch.tensor([[5],[10],[20]]) Args: x (torch.Tensor): A batch of tokens state_h: state h of RNN has the shape of (num_layers, bs, hidden_dim) state_c: state c of RNN has the shape of (num_layers, bs, hidden_dim) Returns: _type_: _description_ """ embedding = self.input_embedding(x) rnn_out, (next_h0, next_c0) = self.rnn(embedding, (state_h, state_c)) logits = self.output_linear(rnn_out) return logits[:, 0].log_softmax(-1), next_h0, next_c0 def forward_with_state( self, tokens, token_lens, sos_id, eos_id, blank_id, state=None ): batch_size = len(token_lens) if state: h, c = state else: h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size) c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size) device = next(self.parameters()).device sos_tokens = add_sos(tokens, sos_id) tokens_eos = add_eos(tokens, eos_id) sos_tokens_row_splits = sos_tokens.shape.row_splits(1) sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) x_tokens = x_tokens.to(torch.int64).to(device) y_tokens = y_tokens.to(torch.int64).to(device) sentence_lengths = sentence_lengths.to(torch.int64).to(device) embedding = self.input_embedding(x_tokens) # Note: We use batch_first==True rnn_out, states = self.rnn(embedding, (h, c)) logits = self.output_linear(rnn_out) return logits, states