Implement AttentionCombine as replacement for RandomCombine

This commit is contained in:
Daniel Povey 2022-09-29 13:44:03 +08:00
parent e5a0d8929b
commit 461ad3655a

View File

@ -308,11 +308,10 @@ class ConformerEncoder(nn.Module):
self.aux_layers = set(aux_layers + [num_layers - 1])
num_channels = encoder_layer.norm_final.num_channels
self.combiner = RandomCombine(
self.combiner = AttentionCombine(
num_channels=encoder_layer.d_model,
num_inputs=len(self.aux_layers),
final_weight=0.5,
pure_prob=0.333,
stddev=2.0,
random_prob=0.5,
)
def forward(
@ -1019,7 +1018,7 @@ class Conv2dSubsampling(nn.Module):
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
return x
class RandomCombine(nn.Module):
class AttentionCombine(nn.Module):
"""
This module combines a list of Tensors, all with the same shape, to
produce a single output of that same shape which, in training time,
@ -1038,59 +1037,32 @@ class RandomCombine(nn.Module):
def __init__(
self,
num_channels: int,
num_inputs: int,
final_weight: float = 0.5,
pure_prob: float = 0.5,
stddev: float = 2.0,
random_prob: float = 0.5,
) -> None:
"""
Args:
num_channels:
the number of channels
num_inputs:
The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3.
final_weight:
The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing
continuous layer weights.
pure_prob:
The probability, on each frame, with which we choose
only a single layer to output (rather than an interpolation)
stddev:
A standard deviation that we add to log-probs for computing
randomized weights.
random_prob:
the probability with which we apply a nontrivial mask, in training
mode.
The method of choosing which layers, or combinations of layers, to use,
is conceptually as follows::
With probability `pure_prob`::
With probability `final_weight`: choose final layer,
Else: choose random non-final layer.
Else::
Choose initial log-weights that correspond to assigning
weight `final_weight` to the final layer and equal
weights to other layers; then add Gaussian noise
with variance `stddev` to these log-weights, and normalize
to weights (note: the average weight assigned to the
final layer here will not be `final_weight` if stddev>0).
"""
super().__init__()
assert 0 <= pure_prob <= 1, pure_prob
assert 0 < final_weight < 1, final_weight
assert num_inputs >= 1
self.num_inputs = num_inputs
self.final_weight = final_weight
self.pure_prob = pure_prob
self.stddev = stddev
self.random_prob = random_prob
self.weight = torch.nn.Parameter(torch.zeros(num_channels,
num_inputs))
self.bias = torch.nn.Parameter(torch.zeros(num_inputs))
assert 0 <= random_prob <= 1, random_prob
self.final_log_weight = (
torch.tensor(
(final_weight / (1 - final_weight)) * (self.num_inputs - 1)
)
.log()
.item()
)
def forward(self, inputs: List[Tensor]) -> Tensor:
@ -1103,10 +1075,8 @@ class RandomCombine(nn.Module):
A Tensor of shape (*, num_channels). In test mode
this is just the final input.
"""
num_inputs = self.num_inputs
num_inputs = self.weight.shape[1]
assert len(inputs) == num_inputs
if not self.training:
return inputs[-1]
# Shape of weights: (*, num_inputs)
num_channels = inputs[0].shape[-1]
@ -1118,14 +1088,21 @@ class RandomCombine(nn.Module):
(num_frames, num_channels, num_inputs)
)
# weights: (num_frames, num_inputs)
weights = self._get_random_weights(
inputs[0].dtype, inputs[0].device, num_frames
)
weights = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias
weights = weights.reshape(num_frames, num_inputs, 1)
# ans: (num_frames, num_channels, 1)
ans = torch.matmul(stacked_inputs, weights)
if self.training:
# random masking..
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),
size=(num_frames,), device=weights.device)
# mask will have rows like: [ False, False, False, True, True, .. ]
mask = torch.arange(num_inputs, device=weights.device).unsqueeze(0).expand(
num_frames, num_inputs) >= mask_start.unsqueeze(1)
weights = weights.masked_fill(mask, float('-inf'))
weights = weights.softmax(dim=1)
# (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1),
ans = torch.matmul(stacked_inputs, weights.unsqueeze(2))
# ans: (*, num_channels)
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
@ -1134,104 +1111,17 @@ class RandomCombine(nn.Module):
print("Weights = ", weights.reshape(num_frames, num_inputs))
return ans
def _get_random_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
) -> Tensor:
"""Return a tensor of random weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired
Returns:
A tensor of shape (num_frames, self.num_inputs), such that
`ans.sum(dim=1)` is all ones.
"""
pure_prob = self.pure_prob
if pure_prob == 0.0:
return self._get_random_mixed_weights(dtype, device, num_frames)
elif pure_prob == 1.0:
return self._get_random_pure_weights(dtype, device, num_frames)
else:
p = self._get_random_pure_weights(dtype, device, num_frames)
m = self._get_random_mixed_weights(dtype, device, num_frames)
return torch.where(
torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
)
def _get_random_pure_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
):
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
exactly one weight equal to 1.0 on each frame.
"""
final_prob = self.final_weight
# final contains self.num_inputs - 1 in all elements
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
nonfinal = torch.randint(
self.num_inputs - 1, (num_frames,), device=device
)
indexes = torch.where(
torch.rand(num_frames, device=device) < final_prob, final, nonfinal
)
ans = torch.nn.functional.one_hot(
indexes, num_classes=self.num_inputs
).to(dtype=dtype)
return ans
def _get_random_mixed_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
):
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A tensor of shape (num_frames, self.num_inputs), which elements
in [0..1] that sum to one over the second axis, i.e.
`ans.sum(dim=1)` is all ones.
"""
logprobs = (
torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
* self.stddev
)
logprobs[:, -1] += self.final_log_weight
return logprobs.softmax(dim=1)
def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
print(
f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}"
)
def _test_random_combine():
print("_test_random_combine()")
num_inputs = 3
num_channels = 50
m = RandomCombine(
m = AttentionCombine(
num_channels=num_channels,
num_inputs=num_inputs,
final_weight=final_weight,
pure_prob=pure_prob,
stddev=stddev,
)
random_prob=0.5)
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
@ -1240,15 +1130,6 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
assert torch.allclose(y, x[0]) # .. since actually all ones.
def _test_random_combine_main():
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.0)
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.3)
_test_random_combine(0.5, 1, 0.3)
_test_random_combine(0.5, 0.5, 0.3)
def _test_conformer_main():
feature_dim = 50
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
@ -1277,5 +1158,5 @@ if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
_test_random_combine()
_test_conformer_main()
_test_random_combine_main()