mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Apply single_prob mask, so sometimes we just get one layer as output.
This commit is contained in:
parent
d8f7310118
commit
056b9a4f9a
@ -1041,6 +1041,7 @@ class AttentionCombine(nn.Module):
|
||||
num_channels: int,
|
||||
num_inputs: int,
|
||||
random_prob: float = 0.333,
|
||||
single_prob: float = 0.5,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -1053,16 +1054,20 @@ class AttentionCombine(nn.Module):
|
||||
random_prob:
|
||||
the probability with which we apply a nontrivial mask, in training
|
||||
mode.
|
||||
|
||||
single_prob:
|
||||
the probability with which we mask to allow just a single
|
||||
module's output (in training)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.random_prob = random_prob
|
||||
self.single_prob = single_prob
|
||||
self.weight = torch.nn.Parameter(torch.zeros(num_channels,
|
||||
num_inputs))
|
||||
self.bias = torch.nn.Parameter(torch.zeros(num_inputs))
|
||||
|
||||
assert 0 <= random_prob <= 1, random_prob
|
||||
assert 0 <= single_prob <= 1, single_prob
|
||||
|
||||
|
||||
|
||||
@ -1097,10 +1102,20 @@ class AttentionCombine(nn.Module):
|
||||
if self.training:
|
||||
# random masking..
|
||||
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),
|
||||
size=(num_frames,), device=weights.device)
|
||||
size=(num_frames,), device=weights.device).unsqueeze(1)
|
||||
# mask will have rows like: [ False, False, False, True, True, .. ]
|
||||
mask = torch.arange(num_inputs, device=weights.device).unsqueeze(0).expand(
|
||||
num_frames, num_inputs) >= mask_start.unsqueeze(1)
|
||||
arange = torch.arange(num_inputs, device=weights.device).unsqueeze(0).expand(
|
||||
num_frames, num_inputs)
|
||||
mask = arange >= mask_start
|
||||
|
||||
apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1),
|
||||
device=weights.device) < self.single_prob,
|
||||
mask_start < num_inputs)
|
||||
single_prob_mask = torch.logical_and(apply_single_prob,
|
||||
arange < mask_start - 1)
|
||||
|
||||
mask = torch.logical_or(mask,
|
||||
single_prob_mask)
|
||||
|
||||
weights = weights.masked_fill(mask, float('-inf'))
|
||||
weights = weights.softmax(dim=1)
|
||||
@ -1124,7 +1139,8 @@ def _test_random_combine():
|
||||
m = AttentionCombine(
|
||||
num_channels=num_channels,
|
||||
num_inputs=num_inputs,
|
||||
random_prob=0.5)
|
||||
random_prob=0.5,
|
||||
single_prob=0.0)
|
||||
|
||||
|
||||
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user