Apply single_prob mask, so sometimes we just get one layer as output.
This commit is contained in:
parent
d8f7310118
commit
056b9a4f9a
@ -1041,6 +1041,7 @@ class AttentionCombine(nn.Module):
|
|||||||
num_channels: int,
|
num_channels: int,
|
||||||
num_inputs: int,
|
num_inputs: int,
|
||||||
random_prob: float = 0.333,
|
random_prob: float = 0.333,
|
||||||
|
single_prob: float = 0.5,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -1053,16 +1054,20 @@ class AttentionCombine(nn.Module):
|
|||||||
random_prob:
|
random_prob:
|
||||||
the probability with which we apply a nontrivial mask, in training
|
the probability with which we apply a nontrivial mask, in training
|
||||||
mode.
|
mode.
|
||||||
|
single_prob:
|
||||||
|
the probability with which we mask to allow just a single
|
||||||
|
module's output (in training)
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.random_prob = random_prob
|
self.random_prob = random_prob
|
||||||
|
self.single_prob = single_prob
|
||||||
self.weight = torch.nn.Parameter(torch.zeros(num_channels,
|
self.weight = torch.nn.Parameter(torch.zeros(num_channels,
|
||||||
num_inputs))
|
num_inputs))
|
||||||
self.bias = torch.nn.Parameter(torch.zeros(num_inputs))
|
self.bias = torch.nn.Parameter(torch.zeros(num_inputs))
|
||||||
|
|
||||||
assert 0 <= random_prob <= 1, random_prob
|
assert 0 <= random_prob <= 1, random_prob
|
||||||
|
assert 0 <= single_prob <= 1, single_prob
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1097,10 +1102,20 @@ class AttentionCombine(nn.Module):
|
|||||||
if self.training:
|
if self.training:
|
||||||
# random masking..
|
# random masking..
|
||||||
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),
|
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),
|
||||||
size=(num_frames,), device=weights.device)
|
size=(num_frames,), device=weights.device).unsqueeze(1)
|
||||||
# mask will have rows like: [ False, False, False, True, True, .. ]
|
# mask will have rows like: [ False, False, False, True, True, .. ]
|
||||||
mask = torch.arange(num_inputs, device=weights.device).unsqueeze(0).expand(
|
arange = torch.arange(num_inputs, device=weights.device).unsqueeze(0).expand(
|
||||||
num_frames, num_inputs) >= mask_start.unsqueeze(1)
|
num_frames, num_inputs)
|
||||||
|
mask = arange >= mask_start
|
||||||
|
|
||||||
|
apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1),
|
||||||
|
device=weights.device) < self.single_prob,
|
||||||
|
mask_start < num_inputs)
|
||||||
|
single_prob_mask = torch.logical_and(apply_single_prob,
|
||||||
|
arange < mask_start - 1)
|
||||||
|
|
||||||
|
mask = torch.logical_or(mask,
|
||||||
|
single_prob_mask)
|
||||||
|
|
||||||
weights = weights.masked_fill(mask, float('-inf'))
|
weights = weights.masked_fill(mask, float('-inf'))
|
||||||
weights = weights.softmax(dim=1)
|
weights = weights.softmax(dim=1)
|
||||||
@ -1124,7 +1139,8 @@ def _test_random_combine():
|
|||||||
m = AttentionCombine(
|
m = AttentionCombine(
|
||||||
num_channels=num_channels,
|
num_channels=num_channels,
|
||||||
num_inputs=num_inputs,
|
num_inputs=num_inputs,
|
||||||
random_prob=0.5)
|
random_prob=0.5,
|
||||||
|
single_prob=0.0)
|
||||||
|
|
||||||
|
|
||||||
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
|
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user