diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py index 3291bc351..8b9cd9982 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py @@ -1041,6 +1041,7 @@ class AttentionCombine(nn.Module): num_channels: int, num_inputs: int, random_prob: float = 0.333, + single_prob: float = 0.5, ) -> None: """ Args: @@ -1053,16 +1054,20 @@ class AttentionCombine(nn.Module): random_prob: the probability with which we apply a nontrivial mask, in training mode. - + single_prob: + the probability with which we mask to allow just a single + module's output (in training) """ super().__init__() self.random_prob = random_prob + self.single_prob = single_prob self.weight = torch.nn.Parameter(torch.zeros(num_channels, num_inputs)) self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) assert 0 <= random_prob <= 1, random_prob + assert 0 <= single_prob <= 1, single_prob @@ -1097,10 +1102,20 @@ class AttentionCombine(nn.Module): if self.training: # random masking.. mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), - size=(num_frames,), device=weights.device) + size=(num_frames,), device=weights.device).unsqueeze(1) # mask will have rows like: [ False, False, False, True, True, .. ] - mask = torch.arange(num_inputs, device=weights.device).unsqueeze(0).expand( - num_frames, num_inputs) >= mask_start.unsqueeze(1) + arange = torch.arange(num_inputs, device=weights.device).unsqueeze(0).expand( + num_frames, num_inputs) + mask = arange >= mask_start + + apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1), + device=weights.device) < self.single_prob, + mask_start < num_inputs) + single_prob_mask = torch.logical_and(apply_single_prob, + arange < mask_start - 1) + + mask = torch.logical_or(mask, + single_prob_mask) weights = weights.masked_fill(mask, float('-inf')) weights = weights.softmax(dim=1) @@ -1124,7 +1139,8 @@ def _test_random_combine(): m = AttentionCombine( num_channels=num_channels, num_inputs=num_inputs, - random_prob=0.5) + random_prob=0.5, + single_prob=0.0) x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]