mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement AttentionCombine as replacement for RandomCombine
This commit is contained in:
parent
e5a0d8929b
commit
461ad3655a
@ -308,11 +308,10 @@ class ConformerEncoder(nn.Module):
|
||||
self.aux_layers = set(aux_layers + [num_layers - 1])
|
||||
|
||||
num_channels = encoder_layer.norm_final.num_channels
|
||||
self.combiner = RandomCombine(
|
||||
self.combiner = AttentionCombine(
|
||||
num_channels=encoder_layer.d_model,
|
||||
num_inputs=len(self.aux_layers),
|
||||
final_weight=0.5,
|
||||
pure_prob=0.333,
|
||||
stddev=2.0,
|
||||
random_prob=0.5,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -1019,7 +1018,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
return x
|
||||
|
||||
class RandomCombine(nn.Module):
|
||||
class AttentionCombine(nn.Module):
|
||||
"""
|
||||
This module combines a list of Tensors, all with the same shape, to
|
||||
produce a single output of that same shape which, in training time,
|
||||
@ -1038,59 +1037,32 @@ class RandomCombine(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels: int,
|
||||
num_inputs: int,
|
||||
final_weight: float = 0.5,
|
||||
pure_prob: float = 0.5,
|
||||
stddev: float = 2.0,
|
||||
random_prob: float = 0.5,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
num_channels:
|
||||
the number of channels
|
||||
num_inputs:
|
||||
The number of tensor inputs, which equals the number of layers'
|
||||
outputs that are fed into this module. E.g. in an 18-layer neural
|
||||
net if we output layers 16, 12, 18, num_inputs would be 3.
|
||||
final_weight:
|
||||
The amount of weight or probability we assign to the
|
||||
final layer when randomly choosing layers or when choosing
|
||||
continuous layer weights.
|
||||
pure_prob:
|
||||
The probability, on each frame, with which we choose
|
||||
only a single layer to output (rather than an interpolation)
|
||||
stddev:
|
||||
A standard deviation that we add to log-probs for computing
|
||||
randomized weights.
|
||||
random_prob:
|
||||
the probability with which we apply a nontrivial mask, in training
|
||||
mode.
|
||||
|
||||
The method of choosing which layers, or combinations of layers, to use,
|
||||
is conceptually as follows::
|
||||
|
||||
With probability `pure_prob`::
|
||||
With probability `final_weight`: choose final layer,
|
||||
Else: choose random non-final layer.
|
||||
Else::
|
||||
Choose initial log-weights that correspond to assigning
|
||||
weight `final_weight` to the final layer and equal
|
||||
weights to other layers; then add Gaussian noise
|
||||
with variance `stddev` to these log-weights, and normalize
|
||||
to weights (note: the average weight assigned to the
|
||||
final layer here will not be `final_weight` if stddev>0).
|
||||
"""
|
||||
super().__init__()
|
||||
assert 0 <= pure_prob <= 1, pure_prob
|
||||
assert 0 < final_weight < 1, final_weight
|
||||
assert num_inputs >= 1
|
||||
|
||||
self.num_inputs = num_inputs
|
||||
self.final_weight = final_weight
|
||||
self.pure_prob = pure_prob
|
||||
self.stddev = stddev
|
||||
self.random_prob = random_prob
|
||||
self.weight = torch.nn.Parameter(torch.zeros(num_channels,
|
||||
num_inputs))
|
||||
self.bias = torch.nn.Parameter(torch.zeros(num_inputs))
|
||||
|
||||
assert 0 <= random_prob <= 1, random_prob
|
||||
|
||||
self.final_log_weight = (
|
||||
torch.tensor(
|
||||
(final_weight / (1 - final_weight)) * (self.num_inputs - 1)
|
||||
)
|
||||
.log()
|
||||
.item()
|
||||
)
|
||||
|
||||
|
||||
def forward(self, inputs: List[Tensor]) -> Tensor:
|
||||
@ -1103,10 +1075,8 @@ class RandomCombine(nn.Module):
|
||||
A Tensor of shape (*, num_channels). In test mode
|
||||
this is just the final input.
|
||||
"""
|
||||
num_inputs = self.num_inputs
|
||||
num_inputs = self.weight.shape[1]
|
||||
assert len(inputs) == num_inputs
|
||||
if not self.training:
|
||||
return inputs[-1]
|
||||
|
||||
# Shape of weights: (*, num_inputs)
|
||||
num_channels = inputs[0].shape[-1]
|
||||
@ -1118,14 +1088,21 @@ class RandomCombine(nn.Module):
|
||||
(num_frames, num_channels, num_inputs)
|
||||
)
|
||||
|
||||
# weights: (num_frames, num_inputs)
|
||||
weights = self._get_random_weights(
|
||||
inputs[0].dtype, inputs[0].device, num_frames
|
||||
)
|
||||
weights = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias
|
||||
|
||||
weights = weights.reshape(num_frames, num_inputs, 1)
|
||||
# ans: (num_frames, num_channels, 1)
|
||||
ans = torch.matmul(stacked_inputs, weights)
|
||||
if self.training:
|
||||
# random masking..
|
||||
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),
|
||||
size=(num_frames,), device=weights.device)
|
||||
# mask will have rows like: [ False, False, False, True, True, .. ]
|
||||
mask = torch.arange(num_inputs, device=weights.device).unsqueeze(0).expand(
|
||||
num_frames, num_inputs) >= mask_start.unsqueeze(1)
|
||||
|
||||
weights = weights.masked_fill(mask, float('-inf'))
|
||||
weights = weights.softmax(dim=1)
|
||||
|
||||
# (num_frames, num_channels, num_inputs) * (num_frames, num_inputs, 1) -> (num_frames, num_channels, 1),
|
||||
ans = torch.matmul(stacked_inputs, weights.unsqueeze(2))
|
||||
# ans: (*, num_channels)
|
||||
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
|
||||
|
||||
@ -1134,104 +1111,17 @@ class RandomCombine(nn.Module):
|
||||
print("Weights = ", weights.reshape(num_frames, num_inputs))
|
||||
return ans
|
||||
|
||||
def _get_random_weights(
|
||||
self, dtype: torch.dtype, device: torch.device, num_frames: int
|
||||
) -> Tensor:
|
||||
"""Return a tensor of random weights, of shape
|
||||
`(num_frames, self.num_inputs)`,
|
||||
Args:
|
||||
dtype:
|
||||
The data-type desired for the answer, e.g. float, double.
|
||||
device:
|
||||
The device needed for the answer.
|
||||
num_frames:
|
||||
The number of sets of weights desired
|
||||
Returns:
|
||||
A tensor of shape (num_frames, self.num_inputs), such that
|
||||
`ans.sum(dim=1)` is all ones.
|
||||
"""
|
||||
pure_prob = self.pure_prob
|
||||
if pure_prob == 0.0:
|
||||
return self._get_random_mixed_weights(dtype, device, num_frames)
|
||||
elif pure_prob == 1.0:
|
||||
return self._get_random_pure_weights(dtype, device, num_frames)
|
||||
else:
|
||||
p = self._get_random_pure_weights(dtype, device, num_frames)
|
||||
m = self._get_random_mixed_weights(dtype, device, num_frames)
|
||||
return torch.where(
|
||||
torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
|
||||
)
|
||||
|
||||
def _get_random_pure_weights(
|
||||
self, dtype: torch.dtype, device: torch.device, num_frames: int
|
||||
):
|
||||
"""Return a tensor of random one-hot weights, of shape
|
||||
`(num_frames, self.num_inputs)`,
|
||||
Args:
|
||||
dtype:
|
||||
The data-type desired for the answer, e.g. float, double.
|
||||
device:
|
||||
The device needed for the answer.
|
||||
num_frames:
|
||||
The number of sets of weights desired.
|
||||
Returns:
|
||||
A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
|
||||
exactly one weight equal to 1.0 on each frame.
|
||||
"""
|
||||
final_prob = self.final_weight
|
||||
|
||||
# final contains self.num_inputs - 1 in all elements
|
||||
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
|
||||
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
|
||||
nonfinal = torch.randint(
|
||||
self.num_inputs - 1, (num_frames,), device=device
|
||||
)
|
||||
|
||||
indexes = torch.where(
|
||||
torch.rand(num_frames, device=device) < final_prob, final, nonfinal
|
||||
)
|
||||
ans = torch.nn.functional.one_hot(
|
||||
indexes, num_classes=self.num_inputs
|
||||
).to(dtype=dtype)
|
||||
return ans
|
||||
|
||||
def _get_random_mixed_weights(
|
||||
self, dtype: torch.dtype, device: torch.device, num_frames: int
|
||||
):
|
||||
"""Return a tensor of random one-hot weights, of shape
|
||||
`(num_frames, self.num_inputs)`,
|
||||
Args:
|
||||
dtype:
|
||||
The data-type desired for the answer, e.g. float, double.
|
||||
device:
|
||||
The device needed for the answer.
|
||||
num_frames:
|
||||
The number of sets of weights desired.
|
||||
Returns:
|
||||
A tensor of shape (num_frames, self.num_inputs), which elements
|
||||
in [0..1] that sum to one over the second axis, i.e.
|
||||
`ans.sum(dim=1)` is all ones.
|
||||
"""
|
||||
logprobs = (
|
||||
torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
|
||||
* self.stddev
|
||||
)
|
||||
logprobs[:, -1] += self.final_log_weight
|
||||
return logprobs.softmax(dim=1)
|
||||
|
||||
|
||||
def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
|
||||
print(
|
||||
f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}"
|
||||
)
|
||||
def _test_random_combine():
|
||||
print("_test_random_combine()")
|
||||
num_inputs = 3
|
||||
num_channels = 50
|
||||
m = RandomCombine(
|
||||
m = AttentionCombine(
|
||||
num_channels=num_channels,
|
||||
num_inputs=num_inputs,
|
||||
final_weight=final_weight,
|
||||
pure_prob=pure_prob,
|
||||
stddev=stddev,
|
||||
)
|
||||
random_prob=0.5)
|
||||
|
||||
|
||||
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
|
||||
|
||||
@ -1240,15 +1130,6 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
|
||||
assert torch.allclose(y, x[0]) # .. since actually all ones.
|
||||
|
||||
|
||||
def _test_random_combine_main():
|
||||
_test_random_combine(0.999, 0, 0.0)
|
||||
_test_random_combine(0.5, 0, 0.0)
|
||||
_test_random_combine(0.999, 0, 0.0)
|
||||
_test_random_combine(0.5, 0, 0.3)
|
||||
_test_random_combine(0.5, 1, 0.3)
|
||||
_test_random_combine(0.5, 0.5, 0.3)
|
||||
|
||||
|
||||
def _test_conformer_main():
|
||||
feature_dim = 50
|
||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||
@ -1277,5 +1158,5 @@ if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
_test_random_combine()
|
||||
_test_conformer_main()
|
||||
_test_random_combine_main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user