# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao) # # See ../../../../LICENSE for clarification regarding multiple authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import math from typing import List, Optional, Tuple import torch from encoder_interface import EncoderInterface from scaling import ( ActivationBalancer, BasicNorm, DoubleSwish, ScaledConv2d, ScaledLinear, ScaledLSTM, ) from torch import nn LOG_EPSILON = math.log(1e-10) def unstack_states( states: Tuple[torch.Tensor, torch.Tensor] ) -> List[Tuple[torch.Tensor, torch.Tensor]]: """ Unstack the lstm states corresponding to a batch of utterances into a list of states, where the i-th entry is the state from the i-th utterance. Args: states: A tuple of 2 elements. ``states[0]`` is the lstm hidden states, of a batch of utterance. ``states[1]`` is the lstm cell states, of a batch of utterances. Returns: A list of states. ``states[i]`` is a tuple of 2 elememts of i-th utterance. ``states[i][0]`` is the lstm hidden states of i-th utterance. ``states[i][1]`` is the lstm cell states of i-th utterance. """ hidden_states, cell_states = states list_hidden_states = hidden_states.unbind(dim=1) list_cell_states = cell_states.unbind(dim=1) ans = [ (h.unsqueeze(1), c.unsqueeze(1)) for (h, c) in zip(list_hidden_states, list_cell_states) ] return ans def stack_states( states_list: List[Tuple[torch.Tensor, torch.Tensor]] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Stack list of lstm states corresponding to separate utterances into a single lstm state so that it can be used as an input for lstm when those utterances are formed into a batch. Args: state_list: Each element in state_list corresponds to the lstm state for a single utterance. ``states[i]`` is a tuple of 2 elememts of i-th utterance. ``states[i][0]`` is the lstm hidden states of i-th utterance. ``states[i][1]`` is the lstm cell states of i-th utterance. Returns: A new state corresponding to a batch of utterances. It is a tuple of 2 elements. ``states[0]`` is the lstm hidden states, of a batch of utterance. ``states[1]`` is the lstm cell states, of a batch of utterances. """ hidden_states = torch.cat([s[0] for s in states_list], dim=1) cell_states = torch.cat([s[1] for s in states_list], dim=1) ans = (hidden_states, cell_states) return ans class RNN(EncoderInterface): """ Args: num_features (int): Number of input features. subsampling_factor (int): Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa d_model (int): Output dimension (default=512). dim_feedforward (int): Feedforward dimension (default=2048). rnn_hidden_size (int): Hidden dimension for lstm layers (default=1024). grad_norm_threshold: For each sequence element in batch, its gradient will be filtered out if the gradient norm is larger than `grad_norm_threshold * median`, where `median` is the median value of gradient norms of all elememts in batch. num_encoder_layers (int): Number of encoder layers (default=12). dropout (float): Dropout rate (default=0.1). layer_dropout (float): Dropout value for model-level warmup (default=0.075). aux_layer_period (int): Period of auxiliary layers used for random combiner during training. If set to 0, will not use the random combiner (Default). You can set a positive integer to use the random combiner, e.g., 3. """ def __init__( self, num_features: int, subsampling_factor: int = 4, d_model: int = 512, dim_feedforward: int = 2048, rnn_hidden_size: int = 1024, grad_norm_threshold: float = 10.0, num_encoder_layers: int = 12, dropout: float = 0.1, layer_dropout: float = 0.075, aux_layer_period: int = 0, ) -> None: super(RNN, self).__init__() self.num_features = num_features self.subsampling_factor = subsampling_factor if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). # That is, it does two things simultaneously: # (1) subsampling: T -> T//subsampling_factor # (2) embedding: num_features -> d_model self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.num_encoder_layers = num_encoder_layers self.d_model = d_model self.rnn_hidden_size = rnn_hidden_size encoder_layer = RNNEncoderLayer( d_model=d_model, dim_feedforward=dim_feedforward, rnn_hidden_size=rnn_hidden_size, grad_norm_threshold=grad_norm_threshold, dropout=dropout, layer_dropout=layer_dropout, ) self.encoder = RNNEncoder( encoder_layer, num_encoder_layers, aux_layers=list( range( num_encoder_layers // 3, num_encoder_layers - 1, aux_layer_period, ) ) if aux_layer_period > 0 else None, ) def forward( self, x: torch.Tensor, x_lens: torch.Tensor, states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, warmup: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Args: x: The input tensor. Its shape is (N, T, C), where N is the batch size, T is the sequence length, C is the feature dimension. x_lens: A tensor of shape (N,), containing the number of frames in `x` before padding. states: A tuple of 2 tensors (optional). It is for streaming inference. states[0] is the hidden states of all layers, with shape of (num_layers, N, d_model); states[1] is the cell states of all layers, with shape of (num_layers, N, rnn_hidden_size). warmup: A floating point value that gradually increases from 0 throughout training; when it is >= 1.0 we are "fully warmed up". It is used to turn modules on sequentially. Returns: A tuple of 3 tensors: - embeddings: its shape is (N, T', d_model), where T' is the output sequence lengths. - lengths: a tensor of shape (batch_size,) containing the number of frames in `embeddings` before padding. - updated states, whose shape is the same as the input states. """ x = self.encoder_embed(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning # # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 lengths = (((x_lens - 3) >> 1) - 1) >> 1 if not torch.jit.is_tracing(): assert x.size(0) == lengths.max().item() if states is None: x = self.encoder(x, warmup=warmup)[0] # torch.jit.trace requires returned types to be the same as annotated # noqa new_states = (torch.empty(0), torch.empty(0)) else: assert not self.training assert len(states) == 2 if not torch.jit.is_tracing(): # for hidden state assert states[0].shape == ( self.num_encoder_layers, x.size(1), self.d_model, ) # for cell state assert states[1].shape == ( self.num_encoder_layers, x.size(1), self.rnn_hidden_size, ) x, new_states = self.encoder(x, states) x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C) return x, lengths, new_states @torch.jit.export def get_init_states( self, batch_size: int = 1, device: torch.device = torch.device("cpu") ) -> Tuple[torch.Tensor, torch.Tensor]: """Get model initial states.""" # for rnn hidden states hidden_states = torch.zeros( (self.num_encoder_layers, batch_size, self.d_model), device=device ) cell_states = torch.zeros( (self.num_encoder_layers, batch_size, self.rnn_hidden_size), device=device, ) return (hidden_states, cell_states) class RNNEncoderLayer(nn.Module): """ RNNEncoderLayer is made up of lstm and feedforward networks. For stable training, in each lstm module, gradient filter is applied to filter out extremely large elements in batch gradients and also the module parameters with soft masks. Args: d_model: The number of expected features in the input (required). dim_feedforward: The dimension of feedforward network model (default=2048). rnn_hidden_size: The hidden dimension of rnn layer. grad_norm_threshold: For each sequence element in batch, its gradient will be filtered out if the gradient norm is larger than `grad_norm_threshold * median`, where `median` is the median value of gradient norms of all elememts in batch. dropout: The dropout value (default=0.1). layer_dropout: The dropout value for model-level warmup (default=0.075). """ def __init__( self, d_model: int, dim_feedforward: int, rnn_hidden_size: int, grad_norm_threshold: float = 10.0, dropout: float = 0.1, layer_dropout: float = 0.075, ) -> None: super(RNNEncoderLayer, self).__init__() self.layer_dropout = layer_dropout self.d_model = d_model self.rnn_hidden_size = rnn_hidden_size assert rnn_hidden_size >= d_model, (rnn_hidden_size, d_model) self.lstm = ScaledLSTM( input_size=d_model, hidden_size=rnn_hidden_size, proj_size=d_model if rnn_hidden_size > d_model else 0, num_layers=1, dropout=0.0, grad_norm_threshold=grad_norm_threshold, ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), ActivationBalancer(channel_dim=-1), DoubleSwish(), nn.Dropout(dropout), ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) self.norm_final = BasicNorm(d_model) # try to ensure the output is close to zero-mean (or at least, zero-median). # noqa self.balancer = ActivationBalancer( channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 ) self.dropout = nn.Dropout(dropout) def forward( self, src: torch.Tensor, states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, warmup: float = 1.0, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Pass the input through the encoder layer. Args: src: The sequence to the encoder layer (required). Its shape is (S, N, E), where S is the sequence length, N is the batch size, and E is the feature number. states: A tuple of 2 tensors (optional). It is for streaming inference. states[0] is the hidden states of all layers, with shape of (1, N, d_model); states[1] is the cell states of all layers, with shape of (1, N, rnn_hidden_size). warmup: It controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. """ src_orig = src warmup_scale = min(0.1 + warmup, 1.0) # alpha = 1.0 means fully use this encoder layer, 0.0 would mean # completely bypass it. if self.training: alpha = ( warmup_scale if torch.rand(()).item() <= (1.0 - self.layer_dropout) else 0.1 ) else: alpha = 1.0 # lstm module if states is None: src_lstm = self.lstm(src)[0] # torch.jit.trace requires returned types be the same as annotated new_states = (torch.empty(0), torch.empty(0)) else: assert not self.training assert len(states) == 2 if not torch.jit.is_tracing(): # for hidden state assert states[0].shape == (1, src.size(1), self.d_model) # for cell state assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) src_lstm, new_states = self.lstm(src, states) src = src + self.dropout(src_lstm) # feed forward module src = src + self.dropout(self.feed_forward(src)) src = self.norm_final(self.balancer(src)) if alpha != 1.0: src = alpha * src + (1 - alpha) * src_orig return src, new_states class RNNEncoder(nn.Module): """ RNNEncoder is a stack of N encoder layers. Args: encoder_layer: An instance of the RNNEncoderLayer() class (required). num_layers: The number of sub-encoder-layers in the encoder (required). """ def __init__( self, encoder_layer: nn.Module, num_layers: int, aux_layers: Optional[List[int]] = None, ) -> None: super(RNNEncoder, self).__init__() self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] ) self.num_layers = num_layers self.d_model = encoder_layer.d_model self.rnn_hidden_size = encoder_layer.rnn_hidden_size self.aux_layers: List[int] = [] self.combiner: Optional[nn.Module] = None if aux_layers is not None: assert len(set(aux_layers)) == len(aux_layers) assert num_layers - 1 not in aux_layers self.aux_layers = aux_layers + [num_layers - 1] self.combiner = RandomCombine( num_inputs=len(self.aux_layers), final_weight=0.5, pure_prob=0.333, stddev=2.0, ) def forward( self, src: torch.Tensor, states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, warmup: float = 1.0, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Pass the input through the encoder layer in turn. Args: src: The sequence to the encoder layer (required). Its shape is (S, N, E), where S is the sequence length, N is the batch size, and E is the feature number. states: A tuple of 2 tensors (optional). It is for streaming inference. states[0] is the hidden states of all layers, with shape of (num_layers, N, d_model); states[1] is the cell states of all layers, with shape of (num_layers, N, rnn_hidden_size). warmup: It controls selective bypass of of layers; if < 1.0, we will bypass layers more frequently. """ if states is not None: assert not self.training assert len(states) == 2 if not torch.jit.is_tracing(): # for hidden state assert states[0].shape == ( self.num_layers, src.size(1), self.d_model, ) # for cell state assert states[1].shape == ( self.num_layers, src.size(1), self.rnn_hidden_size, ) output = src outputs = [] new_hidden_states = [] new_cell_states = [] for i, mod in enumerate(self.layers): if states is None: output = mod(output, warmup=warmup)[0] else: layer_state = ( states[0][i : i + 1, :, :], # h: (1, N, d_model) states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size) ) output, (h, c) = mod(output, layer_state) new_hidden_states.append(h) new_cell_states.append(c) if self.combiner is not None and i in self.aux_layers: outputs.append(output) if self.combiner is not None: output = self.combiner(outputs) if states is None: new_states = (torch.empty(0), torch.empty(0)) else: new_states = ( torch.cat(new_hidden_states, dim=0), torch.cat(new_cell_states, dim=0), ) return output, new_states class Conv2dSubsampling(nn.Module): """Convolutional 2D subsampling (to 1/4 length). Convert an input of shape (N, T, idim) to an output with shape (N, T', odim), where T' = ((T-3)//2-1)//2, which approximates T' == T//4 It is based on https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa """ def __init__( self, in_channels: int, out_channels: int, layer1_channels: int = 8, layer2_channels: int = 32, layer3_channels: int = 128, ) -> None: """ Args: in_channels: Number of channels in. The input shape is (N, T, in_channels). Caution: It requires: T >= 9, in_channels >= 9. out_channels Output dim. The output shape is (N, ((T-3)//2-1)//2, out_channels) layer1_channels: Number of channels in layer1 layer1_channels: Number of channels in layer2 """ assert in_channels >= 9 super().__init__() self.conv = nn.Sequential( ScaledConv2d( in_channels=1, out_channels=layer1_channels, kernel_size=3, padding=0, ), ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=layer1_channels, out_channels=layer2_channels, kernel_size=3, stride=2, ), ActivationBalancer(channel_dim=1), DoubleSwish(), ScaledConv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, stride=2, ), ActivationBalancer(channel_dim=1), DoubleSwish(), ) self.out = ScaledLinear( layer3_channels * (((in_channels - 3) // 2 - 1) // 2), out_channels ) # set learn_eps=False because out_norm is preceded by `out`, and `out` # itself has learned scale, so the extra degree of freedom is not # needed. self.out_norm = BasicNorm(out_channels, learn_eps=False) # constrain median of output to be close to zero. self.out_balancer = ActivationBalancer( channel_dim=-1, min_positive=0.45, max_positive=0.55 ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. Args: x: Its shape is (N, T, idim). Returns: Return a tensor of shape (N, ((T-3)//2-1)//2, odim) """ # On entry, x is (N, T, idim) x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) x = self.conv(x) # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) b, c, t, f = x.size() x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) # Now x is of shape (N, ((T-3)//2-1))//2, odim) x = self.out_norm(x) x = self.out_balancer(x) return x class RandomCombine(nn.Module): """ This module combines a list of Tensors, all with the same shape, to produce a single output of that same shape which, in training time, is a random combination of all the inputs; but which in test time will be just the last input. The idea is that the list of Tensors will be a list of outputs of multiple conformer layers. This has a similar effect as iterated loss. (See: DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER NETWORKS). """ def __init__( self, num_inputs: int, final_weight: float = 0.5, pure_prob: float = 0.5, stddev: float = 2.0, ) -> None: """ Args: num_inputs: The number of tensor inputs, which equals the number of layers' outputs that are fed into this module. E.g. in an 18-layer neural net if we output layers 16, 12, 18, num_inputs would be 3. final_weight: The amount of weight or probability we assign to the final layer when randomly choosing layers or when choosing continuous layer weights. pure_prob: The probability, on each frame, with which we choose only a single layer to output (rather than an interpolation) stddev: A standard deviation that we add to log-probs for computing randomized weights. The method of choosing which layers, or combinations of layers, to use, is conceptually as follows:: With probability `pure_prob`:: With probability `final_weight`: choose final layer, Else: choose random non-final layer. Else:: Choose initial log-weights that correspond to assigning weight `final_weight` to the final layer and equal weights to other layers; then add Gaussian noise with variance `stddev` to these log-weights, and normalize to weights (note: the average weight assigned to the final layer here will not be `final_weight` if stddev>0). """ super().__init__() assert 0 <= pure_prob <= 1, pure_prob assert 0 < final_weight < 1, final_weight assert num_inputs >= 1 self.num_inputs = num_inputs self.final_weight = final_weight self.pure_prob = pure_prob self.stddev = stddev self.final_log_weight = ( torch.tensor( (final_weight / (1 - final_weight)) * (self.num_inputs - 1) ) .log() .item() ) def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: """Forward function. Args: inputs: A list of Tensor, e.g. from various layers of a transformer. All must be the same shape, of (*, num_channels) Returns: A Tensor of shape (*, num_channels). In test mode this is just the final input. """ num_inputs = self.num_inputs assert len(inputs) == num_inputs if not self.training or torch.jit.is_scripting(): return inputs[-1] # Shape of weights: (*, num_inputs) num_channels = inputs[0].shape[-1] num_frames = inputs[0].numel() // num_channels ndim = inputs[0].ndim # stacked_inputs: (num_frames, num_channels, num_inputs) stacked_inputs = torch.stack(inputs, dim=ndim).reshape( (num_frames, num_channels, num_inputs) ) # weights: (num_frames, num_inputs) weights = self._get_random_weights( inputs[0].dtype, inputs[0].device, num_frames ) weights = weights.reshape(num_frames, num_inputs, 1) # ans: (num_frames, num_channels, 1) ans = torch.matmul(stacked_inputs, weights) # ans: (*, num_channels) ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) # The following if causes errors for torch script in torch 1.6.0 # if __name__ == "__main__": # # for testing only... # print("Weights = ", weights.reshape(num_frames, num_inputs)) return ans def _get_random_weights( self, dtype: torch.dtype, device: torch.device, num_frames: int ) -> torch.Tensor: """Return a tensor of random weights, of shape `(num_frames, self.num_inputs)`, Args: dtype: The data-type desired for the answer, e.g. float, double. device: The device needed for the answer. num_frames: The number of sets of weights desired Returns: A tensor of shape (num_frames, self.num_inputs), such that `ans.sum(dim=1)` is all ones. """ pure_prob = self.pure_prob if pure_prob == 0.0: return self._get_random_mixed_weights(dtype, device, num_frames) elif pure_prob == 1.0: return self._get_random_pure_weights(dtype, device, num_frames) else: p = self._get_random_pure_weights(dtype, device, num_frames) m = self._get_random_mixed_weights(dtype, device, num_frames) return torch.where( torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m ) def _get_random_pure_weights( self, dtype: torch.dtype, device: torch.device, num_frames: int ): """Return a tensor of random one-hot weights, of shape `(num_frames, self.num_inputs)`, Args: dtype: The data-type desired for the answer, e.g. float, double. device: The device needed for the answer. num_frames: The number of sets of weights desired. Returns: A one-hot tensor of shape `(num_frames, self.num_inputs)`, with exactly one weight equal to 1.0 on each frame. """ final_prob = self.final_weight # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa nonfinal = torch.randint( self.num_inputs - 1, (num_frames,), device=device ) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) ans = torch.nn.functional.one_hot( indexes, num_classes=self.num_inputs ).to(dtype=dtype) return ans def _get_random_mixed_weights( self, dtype: torch.dtype, device: torch.device, num_frames: int ): """Return a tensor of random one-hot weights, of shape `(num_frames, self.num_inputs)`, Args: dtype: The data-type desired for the answer, e.g. float, double. device: The device needed for the answer. num_frames: The number of sets of weights desired. Returns: A tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that sum to one over the second axis, i.e. `ans.sum(dim=1)` is all ones. """ logprobs = ( torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev ) logprobs[:, -1] += self.final_log_weight return logprobs.softmax(dim=1) def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): print( f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa ) num_inputs = 3 num_channels = 50 m = RandomCombine( num_inputs=num_inputs, final_weight=final_weight, pure_prob=pure_prob, stddev=stddev, ) x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] y = m(x) assert y.shape == x[0].shape assert torch.allclose(y, x[0]) # .. since actually all ones. def _test_random_combine_main(): _test_random_combine(0.999, 0, 0.0) _test_random_combine(0.5, 0, 0.0) _test_random_combine(0.999, 0, 0.0) _test_random_combine(0.5, 0, 0.3) _test_random_combine(0.5, 1, 0.3) _test_random_combine(0.5, 0.5, 0.3) feature_dim = 50 c = RNN(num_features=feature_dim, d_model=128) batch_size = 5 seq_len = 20 # Just make sure the forward pass runs. f = c( torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), ) f # to remove flake8 warnings if __name__ == "__main__": feature_dim = 80 m = RNN( num_features=feature_dim, d_model=512, rnn_hidden_size=1024, dim_feedforward=2048, num_encoder_layers=12, ) batch_size = 5 seq_len = 20 # Just make sure the forward pass runs. f = m( torch.randn(batch_size, seq_len, feature_dim), torch.full((batch_size,), seq_len, dtype=torch.int64), warmup=0.5, ) num_param = sum([p.numel() for p in m.parameters()]) print(f"Number of model parameters: {num_param}") _test_random_combine_main()