diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index 07b80076d..327849485 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -18,7 +18,7 @@ import copy import math import warnings -from typing import Optional, Tuple +from typing import Optional, Tuple, Sequence import torch from torch import Tensor, nn @@ -56,6 +56,7 @@ class Conformer(Transformer): cnn_module_kernel: int = 31, normalize_before: bool = True, vgg_frontend: bool = False, + aux_layer_period: int = 3 ) -> None: super(Conformer, self).__init__( num_features=num_features, @@ -80,10 +81,11 @@ class Conformer(Transformer): cnn_module_kernel, normalize_before, ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, + aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) self.normalize_before = normalize_before if self.normalize_before: - self.after_norm = nn.LayerNorm(d_model) + self.after_norm = nn.LayerNorm(d_model) # TODO: remove. else: # Note: TorchScript detects that self.after_norm could be used inside forward() # and throws an error without this change. @@ -280,12 +282,21 @@ class ConformerEncoder(nn.Module): """ def __init__( - self, encoder_layer: nn.Module, num_layers: int + self, encoder_layer: nn.Module, + num_layers: int, + aux_layers: Sequence[int], ) -> None: super(ConformerEncoder, self).__init__() self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_layers)]) + self.aux_layers = set(aux_layers + [num_layers - 1]) + assert num_layers - 1 not in aux_layers self.num_layers = num_layers - + num_channels = encoder_layer.norm_final.weight.numel() + self.combiner = RandomCombine(num_inputs=len(self.aux_layers), + num_channels=num_channels, + final_weight=0.5, + pure_prob=0.333, + stddev=2.0) def forward( self, @@ -312,14 +323,19 @@ class ConformerEncoder(nn.Module): """ output = src - for mod in self.layers: + outputs = [] + + for i, mod in enumerate(self.layers): output = mod( output, pos_emb, src_mask=mask, src_key_padding_mask=src_key_padding_mask, ) + if i in self.aux_layers: + outputs.append(output) + output = self.combiner(outputs) return output @@ -918,7 +934,203 @@ def identity(x): return x +class RandomCombine(torch.nn.Module): + """ + This module combines a list of Tensors, all with the same shape, to + produce a single output of that same shape which, in training time, + is a random combination of all the inputs; but which in test time + will be just the last input. + + All but the last input will have a linear transform before we + randomly combine them; these linear transforms will be initialzed + to the identity transform. + + The idea is that the list of Tensors will be a list of outputs of multiple + conformer layers. This has a similar effect as iterated loss. (See: + DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER + NETWORKS). + """ + def __init__(self, num_inputs: int, + num_channels: int, + final_weight: float = 0.5, + pure_prob: float = 0.5, + stddev: float = 2.0) -> None: + """ + Args: + num_inputs: The number of tensor inputs, which equals the number of layers' + outputs that are fed into this module. E.g. in an 18-layer neural + net if we output layers 16, 12, 18, num_inputs would be 3. + num_channels: The number of channels on the input, e.g. 512. + final_weight: The amount of weight or probability we assign to the + final layer when randomly choosing layers or when choosing + continuous layer weights. + pure_prob: The probability, on each frame, with which we choose + only a single layer to output (rather than an interpolation) + stddev: A standard deviation that we add to log-probs for computing + randomized weights. + + The method of choosing which layers, + or combinations of layers, to use, is conceptually as follows. + With probability `pure_prob`: + With probability `final_weight`: choose final layer, + Else: choose random non-final layer. + Else: + Choose initial log-weights that correspond to assigning + weight `final_weight` to the final layer and equal + weights to other layers; then add Gaussian noise + with variance `stddev` to these log-weights, and normalize + to weights (note: the average weight assigned to the + final layer here will not be `final_weight` if stddev>0). + """ + super(RandomCombine, self).__init__() + assert pure_prob >= 0 and pure_prob <= 1 + assert final_weight > 0 and final_weight < 1 + assert num_inputs >= 1 + self.linear = nn.ModuleList([nn.Linear(num_channels, num_channels, bias=True) + for _ in range(num_inputs - 1)]) + + self.num_inputs = num_inputs + self.final_weight = final_weight + self.pure_prob = pure_prob + self.stddev= stddev + + self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item() + self._reset_parameters() + + def _reset_parameters(self): + for i in range(len(self.linear)): + nn.init.eye_(self.linear[i].weight) + nn.init.constant_(self.linear[i].bias, 0.0) + + def forward(self, inputs: Sequence[Tensor]) -> Tensor: + """ + Forward function. + Args: + inputs: a list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (*, num_channels) + Returns: + a Tensor of shape (*, num_channels). In test mode + this is just the final input. + """ + num_inputs = self.num_inputs + assert len(inputs) == num_inputs + if not self.training: + return inputs[-1] + + # Shape of weights: (*, num_inputs) + num_channels = inputs[0].shape[-1] + num_frames = inputs[0].numel() // num_channels + + mod_inputs = [] + for i in range(num_inputs - 1): + mod_inputs.append(self.linear[i](inputs[i])) + mod_inputs.append(inputs[num_inputs - 1]) + + + ndim = inputs[0].ndim + # stacked_inputs: (num_frames, num_channels, num_inputs) + stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape((num_frames, + num_channels, + num_inputs)) + + # weights: (num_frames, num_inputs) + weights = self._get_random_weights(inputs[0].dtype, inputs[0].device, + num_frames) + + weights = weights.reshape(num_frames, num_inputs, 1) + # ans: (num_frames, num_channels, 1) + ans = torch.matmul(stacked_inputs, weights) + # ans: (*, num_channels) + ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) + + if __name__ == "__main__": + # for testing only... + print("Weights = ", weights.reshape(num_frames, num_inputs)) + return ans + + + def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor: + """ + Return a tensor of random weights, of shape (num_frames, self.num_inputs), + Args: + dtype: the data-type desired for the answer, e.g. float, double + device: the device needed for the answer + num_frames: the number of sets of weights desired + Returns: a tensor of shape (num_frames, self.num_inputs), such that + ans.sum(dim=1) is all ones. + + """ + pure_prob = self.pure_prob + if pure_prob == 0.0: + return self._get_random_mixed_weights(dtype, device, num_frames) + elif pure_prob == 1.0: + return self._get_random_pure_weights(dtype, device, num_frames) + else: + p = self._get_random_pure_weights(dtype, device, num_frames) + m = self._get_random_mixed_weights(dtype, device, num_frames) + return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m) + + def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): + """ + Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), + Args: + dtype: the data-type desired for the answer, e.g. float, double + device: the device needed for the answer + num_frames: the number of sets of weights desired + Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with + exactly one weight equal to 1.0 on each frame. + """ + + final_prob = self.final_weight + + # final contains self.num_inputs - 1 in all elements + final = torch.full((num_frames,), self.num_inputs - 1, device=device) + # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) + + indexes = torch.where(torch.rand(num_frames, device=device) < final_prob, + final, nonfinal) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype) + return ans + + + def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int): + """ + Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs), + Args: + dtype: the data-type desired for the answer, e.g. float, double + device: the device needed for the answer + num_frames: the number of sets of weights desired + Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that + sum to one over the second axis, i.e. ans.sum(dim=1) is all ones. + """ + logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev + logprobs[:,-1] += self.final_log_weight + return logprobs.softmax(dim=1) + + +def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): + print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}") + num_inputs = 3 + num_channels = 50 + m = RandomCombine(num_inputs=num_inputs, num_channels=num_channels, + final_weight=final_weight, pure_prob=pure_prob, stddev=stddev) + + x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ] + + y = m(x) + assert y.shape == x[0].shape + assert torch.allclose(y, x[0]) # .. since actually all ones. + + if __name__ == '__main__': + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.0) + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.3) + _test_random_combine(0.5, 1, 0.3) + _test_random_combine(0.5, 0.5, 0.3) + feature_dim = 50 c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) batch_size = 5 diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 0e1bbeaff..8877d4e75 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -110,7 +110,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="transducer_stateless/specaugmod_baseline", + default="transducer_stateless/specaugmod_baseline_randcombine1", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved