Apply single_prob mask, so sometimes we just get one layer as output.

This commit is contained in:
Daniel Povey 2022-09-29 15:26:27 +08:00
parent d8f7310118
commit 056b9a4f9a

View File

@ -1041,6 +1041,7 @@ class AttentionCombine(nn.Module):
num_channels: int,
num_inputs: int,
random_prob: float = 0.333,
single_prob: float = 0.5,
) -> None:
"""
Args:
@ -1053,16 +1054,20 @@ class AttentionCombine(nn.Module):
random_prob:
the probability with which we apply a nontrivial mask, in training
mode.
single_prob:
the probability with which we mask to allow just a single
module's output (in training)
"""
super().__init__()
self.random_prob = random_prob
self.single_prob = single_prob
self.weight = torch.nn.Parameter(torch.zeros(num_channels,
num_inputs))
self.bias = torch.nn.Parameter(torch.zeros(num_inputs))
assert 0 <= random_prob <= 1, random_prob
assert 0 <= single_prob <= 1, single_prob
@ -1097,10 +1102,20 @@ class AttentionCombine(nn.Module):
if self.training:
# random masking..
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),
size=(num_frames,), device=weights.device)
size=(num_frames,), device=weights.device).unsqueeze(1)
# mask will have rows like: [ False, False, False, True, True, .. ]
mask = torch.arange(num_inputs, device=weights.device).unsqueeze(0).expand(
num_frames, num_inputs) >= mask_start.unsqueeze(1)
arange = torch.arange(num_inputs, device=weights.device).unsqueeze(0).expand(
num_frames, num_inputs)
mask = arange >= mask_start
apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1),
device=weights.device) < self.single_prob,
mask_start < num_inputs)
single_prob_mask = torch.logical_and(apply_single_prob,
arange < mask_start - 1)
mask = torch.logical_or(mask,
single_prob_mask)
weights = weights.masked_fill(mask, float('-inf'))
weights = weights.softmax(dim=1)
@ -1124,7 +1139,8 @@ def _test_random_combine():
m = AttentionCombine(
num_channels=num_channels,
num_inputs=num_inputs,
random_prob=0.5)
random_prob=0.5,
single_prob=0.0)
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]