diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 6465d5a55..2d9350ea3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -308,11 +308,10 @@ class ConformerEncoder(nn.Module): self.aux_layers = set(aux_layers + [num_layers - 1]) num_channels = encoder_layer.norm_final.num_channels - self.combiner = RandomCombine( + self.combiner = AttentionCombine( + num_channels=encoder_layer.d_model, num_inputs=len(self.aux_layers), - final_weight=0.5, - pure_prob=0.333, - stddev=2.0, + random_prob=0.5, ) def forward( @@ -1019,7 +1018,7 @@ class Conv2dSubsampling(nn.Module): # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) return x -class RandomCombine(nn.Module): +class AttentionCombine(nn.Module): """ This module combines a list of Tensors, all with the same shape, to produce a single output of that same shape which, in training time, @@ -1038,59 +1037,32 @@ class RandomCombine(nn.Module): def __init__( self, + num_channels: int, num_inputs: int, - final_weight: float = 0.5, - pure_prob: float = 0.5, - stddev: float = 2.0, + random_prob: float = 0.5, ) -> None: """ Args: + num_channels: + the number of channels num_inputs: The number of tensor inputs, which equals the number of layers' outputs that are fed into this module. E.g. in an 18-layer neural net if we output layers 16, 12, 18, num_inputs would be 3. - final_weight: - The amount of weight or probability we assign to the - final layer when randomly choosing layers or when choosing - continuous layer weights. - pure_prob: - The probability, on each frame, with which we choose - only a single layer to output (rather than an interpolation) - stddev: - A standard deviation that we add to log-probs for computing - randomized weights. + random_prob: + the probability with which we apply a nontrivial mask, in training + mode. - The method of choosing which layers, or combinations of layers, to use, - is conceptually as follows:: - - With probability `pure_prob`:: - With probability `final_weight`: choose final layer, - Else: choose random non-final layer. - Else:: - Choose initial log-weights that correspond to assigning - weight `final_weight` to the final layer and equal - weights to other layers; then add Gaussian noise - with variance `stddev` to these log-weights, and normalize - to weights (note: the average weight assigned to the - final layer here will not be `final_weight` if stddev>0). """ super().__init__() - assert 0 <= pure_prob <= 1, pure_prob - assert 0 < final_weight < 1, final_weight - assert num_inputs >= 1 - self.num_inputs = num_inputs - self.final_weight = final_weight - self.pure_prob = pure_prob - self.stddev = stddev + self.random_prob = random_prob + self.weight = torch.nn.Parameter(torch.zeros(num_channels, + num_inputs)) + self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) + + assert 0 <= random_prob <= 1, random_prob - self.final_log_weight = ( - torch.tensor( - (final_weight / (1 - final_weight)) * (self.num_inputs - 1) - ) - .log() - .item() - ) def forward(self, inputs: List[Tensor]) -> Tensor: @@ -1103,10 +1075,8 @@ class RandomCombine(nn.Module): A Tensor of shape (*, num_channels). In test mode this is just the final input. """ - num_inputs = self.num_inputs + num_inputs = self.weight.shape[1] assert len(inputs) == num_inputs - if not self.training: - return inputs[-1] # Shape of weights: (*, num_inputs) num_channels = inputs[0].shape[-1] @@ -1118,14 +1088,21 @@ class RandomCombine(nn.Module): (num_frames, num_channels, num_inputs) ) - # weights: (num_frames, num_inputs) - weights = self._get_random_weights( - inputs[0].dtype, inputs[0].device, num_frames - ) + weights = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias - weights = weights.reshape(num_frames, num_inputs, 1) - # ans: (num_frames, num_channels, 1) - ans = torch.matmul(stacked_inputs, weights) + if self.training: + # random masking.. + mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), + size=(num_frames,), device=weights.device) + # mask will have rows like: [ False, False, False, True, True, .. ] + mask = torch.arange(num_inputs, device=weights.device).unsqueeze(0).expand( + num_frames, num_inputs) >= mask_start.unsqueeze(1) + + weights = weights.masked_fill(mask, float('-inf')) + weights = weights.softmax(dim=1) + + # (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1), + ans = torch.matmul(stacked_inputs, weights.unsqueeze(2)) # ans: (*, num_channels) ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) @@ -1134,104 +1111,17 @@ class RandomCombine(nn.Module): print("Weights = ", weights.reshape(num_frames, num_inputs)) return ans - def _get_random_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ) -> Tensor: - """Return a tensor of random weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired - Returns: - A tensor of shape (num_frames, self.num_inputs), such that - `ans.sum(dim=1)` is all ones. - """ - pure_prob = self.pure_prob - if pure_prob == 0.0: - return self._get_random_mixed_weights(dtype, device, num_frames) - elif pure_prob == 1.0: - return self._get_random_pure_weights(dtype, device, num_frames) - else: - p = self._get_random_pure_weights(dtype, device, num_frames) - m = self._get_random_mixed_weights(dtype, device, num_frames) - return torch.where( - torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m - ) - - def _get_random_pure_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ): - """Return a tensor of random one-hot weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired. - Returns: - A one-hot tensor of shape `(num_frames, self.num_inputs)`, with - exactly one weight equal to 1.0 on each frame. - """ - final_prob = self.final_weight - - # final contains self.num_inputs - 1 in all elements - final = torch.full((num_frames,), self.num_inputs - 1, device=device) - # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint( - self.num_inputs - 1, (num_frames,), device=device - ) - - indexes = torch.where( - torch.rand(num_frames, device=device) < final_prob, final, nonfinal - ) - ans = torch.nn.functional.one_hot( - indexes, num_classes=self.num_inputs - ).to(dtype=dtype) - return ans - - def _get_random_mixed_weights( - self, dtype: torch.dtype, device: torch.device, num_frames: int - ): - """Return a tensor of random one-hot weights, of shape - `(num_frames, self.num_inputs)`, - Args: - dtype: - The data-type desired for the answer, e.g. float, double. - device: - The device needed for the answer. - num_frames: - The number of sets of weights desired. - Returns: - A tensor of shape (num_frames, self.num_inputs), which elements - in [0..1] that sum to one over the second axis, i.e. - `ans.sum(dim=1)` is all ones. - """ - logprobs = ( - torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) - * self.stddev - ) - logprobs[:, -1] += self.final_log_weight - return logprobs.softmax(dim=1) -def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): - print( - f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" - ) +def _test_random_combine(): + print("_test_random_combine()") num_inputs = 3 num_channels = 50 - m = RandomCombine( + m = AttentionCombine( + num_channels=num_channels, num_inputs=num_inputs, - final_weight=final_weight, - pure_prob=pure_prob, - stddev=stddev, - ) + random_prob=0.5) + x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] @@ -1240,15 +1130,6 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): assert torch.allclose(y, x[0]) # .. since actually all ones. -def _test_random_combine_main(): - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.0) - _test_random_combine(0.999, 0, 0.0) - _test_random_combine(0.5, 0, 0.3) - _test_random_combine(0.5, 1, 0.3) - _test_random_combine(0.5, 0.5, 0.3) - - def _test_conformer_main(): feature_dim = 50 c = Conformer(num_features=feature_dim, d_model=128, nhead=4) @@ -1277,5 +1158,5 @@ if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) torch.set_num_interop_threads(1) + _test_random_combine() _test_conformer_main() - _test_random_combine_main()