# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang # Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional import torch import torch.nn as nn from torch import Tensor class TdnnLiGRU(nn.Module): def __init__( self, num_features: int, num_classes: int, subsampling_factor: int = 3 ) -> None: """ Args: num_features: The input dimension of the model. num_classes: The output dimension of the model. subsampling_factor: It reduces the number of output frames by this factor. """ super().__init__() self.num_features = num_features self.num_classes = num_classes self.subsampling_factor = subsampling_factor self.tdnn = nn.Sequential( nn.Conv1d( in_channels=num_features, out_channels=512, kernel_size=3, stride=1, padding=1, ), nn.ReLU(inplace=True), nn.BatchNorm1d(num_features=512, affine=False), nn.Conv1d( in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, ), nn.ReLU(inplace=True), nn.BatchNorm1d(num_features=512, affine=False), nn.Conv1d( in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, ), nn.ReLU(inplace=True), nn.BatchNorm1d(num_features=512, affine=False), nn.Conv1d( in_channels=512, out_channels=512, kernel_size=3, stride=self.subsampling_factor, # stride: subsampling_factor! padding=1, ), nn.ReLU(inplace=True), nn.BatchNorm1d(num_features=512, affine=False), ) self.ligrus = nn.ModuleList( [ LiGRU( input_shape=[None, None, 512], hidden_size=512, num_layers=1, bidirectional=True, ) for _ in range(4) ] ) self.linears = nn.ModuleList( [nn.Linear(in_features=1024, out_features=512) for _ in range(4)] ) self.bnorms = nn.ModuleList( [nn.BatchNorm1d(num_features=512, affine=False) for _ in range(4)] ) self.dropout = nn.Dropout(0.2) self.linear = nn.Linear(in_features=512, out_features=self.num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: Its shape is [N, C, T] Returns: The output tensor has shape [N, T, C] """ x = self.tdnn(x) x = x.permute(0, 2, 1) for ligru, linear, bnorm in zip(self.ligrus, self.linears, self.bnorms): x_new, _ = ligru(x) x_new = linear(x_new) x_new = bnorm(x_new.permute(0, 2, 1)).permute(0, 2, 1) # (N, T, C) -> (N, C, T) -> (N, T, C) x_new = self.dropout(x_new) x = x_new + x # skip connections x = self.linear(x) x = nn.functional.log_softmax(x, dim=-1) return x class LiGRU(torch.nn.Module): """This function implements a Light GRU (liGRU). This LiGRU model is from speechbrain, please see https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/RNN.py LiGRU is single-gate GRU model based on batch-norm + relu activations + recurrent dropout. For more info see: "M. Ravanelli, P. Brakel, M. Omologo, Y. Bengio, Light Gated Recurrent Units for Speech Recognition, in IEEE Transactions on Emerging Topics in Computational Intelligence, 2018" (https://arxiv.org/abs/1803.10225) This is a custm RNN and to speed it up it must be compiled with the torch just-in-time compiler (jit) right before using it. You can compile it with: compiled_model = torch.jit.script(model) It accepts in input tensors formatted as (batch, time, fea). In the case of 4d inputs like (batch, time, fea, channel) the tensor is flattened as (batch, time, fea*channel). Arguments --------- hidden_size : int Number of output neurons (i.e, the dimensionality of the output). values (i.e, time and frequency kernel sizes respectively). input_shape : tuple The shape of an example input. nonlinearity : str Type of nonlinearity (tanh, relu). normalization : str Type of normalization for the ligru model (batchnorm, layernorm). Every string different from batchnorm and layernorm will result in no normalization. num_layers : int Number of layers to employ in the RNN architecture. bias : bool If True, the additive bias b is adopted. dropout : float It is the dropout factor (must be between 0 and 1). bidirectional : bool If True, a bidirectional model that scans the sequence both right-to-left and left-to-right is used. Example ------- >>> inp_tensor = torch.rand([4, 10, 20]) >>> net = LiGRU(input_shape=inp_tensor.shape, hidden_size=5) >>> out_tensor, _ = net(inp_tensor) >>> torch.Size([4, 10, 5]) """ def __init__( self, hidden_size, input_shape, nonlinearity="relu", normalization="batchnorm", num_layers=1, bias=True, dropout=0.0, bidirectional=False, ): super().__init__() self.hidden_size = hidden_size self.nonlinearity = nonlinearity self.num_layers = num_layers self.normalization = normalization self.bias = bias self.dropout = dropout self.bidirectional = bidirectional self.reshape = False # Computing the feature dimensionality if len(input_shape) > 3: self.reshape = True self.fea_dim = float(torch.prod(torch.tensor(input_shape[2:]))) self.batch_size = input_shape[0] self.rnn = self._init_layers() def _init_layers(self): """Initializes the layers of the liGRU.""" rnn = torch.nn.ModuleList([]) current_dim = self.fea_dim for i in range(self.num_layers): rnn_lay = LiGRU_Layer( current_dim, self.hidden_size, self.num_layers, self.batch_size, dropout=self.dropout, nonlinearity=self.nonlinearity, normalization=self.normalization, bidirectional=self.bidirectional, ) rnn.append(rnn_lay) if self.bidirectional: current_dim = self.hidden_size * 2 else: current_dim = self.hidden_size return rnn def forward(self, x, hx: Optional[Tensor] = None): """Returns the output of the liGRU. Arguments --------- x : torch.Tensor The input tensor. hx : torch.Tensor Starting hidden state. """ # Reshaping input tensors for 4d inputs if self.reshape: if x.ndim == 4: x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]) # run ligru output, hh = self._forward_ligru(x, hx=hx) return output, hh def _forward_ligru(self, x, hx: Optional[Tensor]): """Returns the output of the vanilla liGRU. Arguments --------- x : torch.Tensor Input tensor. hx : torch.Tensor """ h = [] if hx is not None: if self.bidirectional: hx = hx.reshape(self.num_layers, self.batch_size * 2, self.hidden_size) # Processing the different layers for i, ligru_lay in enumerate(self.rnn): if hx is not None: x = ligru_lay(x, hx=hx[i]) else: x = ligru_lay(x, hx=None) h.append(x[:, -1, :]) h = torch.stack(h, dim=1) if self.bidirectional: h = h.reshape(h.shape[1] * 2, h.shape[0], self.hidden_size) else: h = h.transpose(0, 1) return x, h class LiGRU_Layer(torch.nn.Module): """This function implements Light-Gated Recurrent Units (ligru) layer. Arguments --------- input_size : int Feature dimensionality of the input tensors. batch_size : int Batch size of the input tensors. hidden_size : int Number of output neurons. num_layers : int Number of layers to employ in the RNN architecture. nonlinearity : str Type of nonlinearity (tanh, relu). normalization : str Type of normalization (batchnorm, layernorm). Every string different from batchnorm and layernorm will result in no normalization. dropout : float It is the dropout factor (must be between 0 and 1). bidirectional : bool if True, a bidirectional model that scans the sequence both right-to-left and left-to-right is used. """ def __init__( self, input_size, hidden_size, num_layers, batch_size, dropout=0.0, nonlinearity="relu", normalization="batchnorm", bidirectional=False, ): super(LiGRU_Layer, self).__init__() self.hidden_size = int(hidden_size) self.input_size = int(input_size) self.batch_size = batch_size self.bidirectional = bidirectional self.dropout = dropout self.drop = torch.nn.Dropout(p=self.dropout, inplace=False) self.N_drop_masks = 16000 self.drop_mask_cnt = 0 self.drop_mask_te = torch.tensor([1.0]).float() self.w = nn.Linear(self.input_size, 2 * self.hidden_size, bias=False) self.u = nn.Linear(self.hidden_size, 2 * self.hidden_size, bias=False) # Initializing batch norm self.normalize = False if normalization == "batchnorm": self.norm = nn.BatchNorm1d(2 * self.hidden_size, momentum=0.05) self.normalize = True elif normalization == "layernorm": self.norm = torch.nn.LayerNorm(2 * self.hidden_size) self.normalize = True else: # Normalization is disabled here. self.norm is only formally # initialized to avoid jit issues. self.norm = torch.nn.LayerNorm(2 * self.hidden_size) self.normalize = True # Initial state self.register_buffer("h_init", torch.zeros(1, self.hidden_size)) # Setting the activation function if nonlinearity == "tanh": self.act = torch.nn.Tanh() elif nonlinearity == "sin": self.act = torch.sin elif nonlinearity == "leaky_relu": self.act = torch.nn.LeakyReLU() else: self.act = torch.nn.ReLU() def forward(self, x, hx: Optional[Tensor] = None): # type: (Tensor, Optional[Tensor]) -> Tensor # noqa F821 """Returns the output of the liGRU layer. Arguments --------- x : torch.Tensor Input tensor. """ if self.bidirectional: x_flip = x.flip(1) x = torch.cat([x, x_flip], dim=0) # Change batch size if needed self._change_batch_size(x) # Feed-forward affine transformations (all steps in parallel) w = self.w(x) # Apply batch normalization if self.normalize: w_bn = self.norm(w.reshape(w.shape[0] * w.shape[1], w.shape[2])) w = w_bn.reshape(w.shape[0], w.shape[1], w.shape[2]) # Processing time steps if hx is not None: h = self._ligru_cell(w, hx) else: h = self._ligru_cell(w, self.h_init) if self.bidirectional: h_f, h_b = h.chunk(2, dim=0) h_b = h_b.flip(1) h = torch.cat([h_f, h_b], dim=2) return h def _ligru_cell(self, w, ht): """Returns the hidden states for each time step. Arguments --------- wx : torch.Tensor Linearly transformed input. """ hiddens = [] # Sampling dropout mask drop_mask = self._sample_drop_mask(w) # Loop over time axis for k in range(w.shape[1]): gates = w[:, k] + self.u(ht) at, zt = gates.chunk(2, 1) zt = torch.sigmoid(zt) hcand = self.act(at) * drop_mask ht = zt * ht + (1 - zt) * hcand hiddens.append(ht) # Stacking hidden states h = torch.stack(hiddens, dim=1) return h def _init_drop(self, batch_size): """Initializes the recurrent dropout operation. To speed it up, the dropout masks are sampled in advance. """ self.N_drop_masks = 16000 self.drop_mask_cnt = 0 self.register_buffer( "drop_masks", self.drop(torch.ones(self.N_drop_masks, self.hidden_size)).data, ) self.register_buffer("drop_mask_te", torch.tensor([1.0]).float()) def _sample_drop_mask(self, w): """Selects one of the pre-defined dropout masks""" if self.training: # Sample new masks when needed if self.drop_mask_cnt + self.batch_size > self.N_drop_masks: self.drop_mask_cnt = 0 self.drop_masks = self.drop( torch.ones(self.N_drop_masks, self.hidden_size, device=w.device) ).data # Sampling the mask left_boundary = self.drop_mask_cnt right_boundary = self.drop_mask_cnt + self.batch_size drop_mask = self.drop_masks[left_boundary:right_boundary] self.drop_mask_cnt = self.drop_mask_cnt + self.batch_size else: self.drop_mask_te = self.drop_mask_te.to(w.device) drop_mask = self.drop_mask_te return drop_mask def _change_batch_size(self, x): """This function changes the batch size when it is different from the one detected in the initialization method. This might happen in the case of multi-gpu or when we have different batch sizes in train and test. We also update the h_int and drop masks. """ if self.batch_size != x.shape[0]: self.batch_size = x.shape[0] if self.training: self.drop_masks = self.drop( torch.ones( self.N_drop_masks, self.hidden_size, device=x.device, ) ).data