First version of rand-combine iterated-training-like idea.

This commit is contained in:
Daniel Povey 2022-02-27 17:34:58 +08:00
parent 63d8d935d4
commit c1063def95
2 changed files with 219 additions and 7 deletions

View File

@ -18,7 +18,7 @@
import copy import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple from typing import Optional, Tuple, Sequence
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
@ -56,6 +56,7 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
normalize_before: bool = True, normalize_before: bool = True,
vgg_frontend: bool = False, vgg_frontend: bool = False,
aux_layer_period: int = 3
) -> None: ) -> None:
super(Conformer, self).__init__( super(Conformer, self).__init__(
num_features=num_features, num_features=num_features,
@ -80,10 +81,11 @@ class Conformer(Transformer):
cnn_module_kernel, cnn_module_kernel,
normalize_before, normalize_before,
) )
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
self.normalize_before = normalize_before self.normalize_before = normalize_before
if self.normalize_before: if self.normalize_before:
self.after_norm = nn.LayerNorm(d_model) self.after_norm = nn.LayerNorm(d_model) # TODO: remove.
else: else:
# Note: TorchScript detects that self.after_norm could be used inside forward() # Note: TorchScript detects that self.after_norm could be used inside forward()
# and throws an error without this change. # and throws an error without this change.
@ -280,12 +282,21 @@ class ConformerEncoder(nn.Module):
""" """
def __init__( def __init__(
self, encoder_layer: nn.Module, num_layers: int self, encoder_layer: nn.Module,
num_layers: int,
aux_layers: Sequence[int],
) -> None: ) -> None:
super(ConformerEncoder, self).__init__() super(ConformerEncoder, self).__init__()
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_layers)]) self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_layers)])
self.aux_layers = set(aux_layers + [num_layers - 1])
assert num_layers - 1 not in aux_layers
self.num_layers = num_layers self.num_layers = num_layers
num_channels = encoder_layer.norm_final.weight.numel()
self.combiner = RandomCombine(num_inputs=len(self.aux_layers),
num_channels=num_channels,
final_weight=0.5,
pure_prob=0.333,
stddev=2.0)
def forward( def forward(
self, self,
@ -312,14 +323,19 @@ class ConformerEncoder(nn.Module):
""" """
output = src output = src
for mod in self.layers: outputs = []
for i, mod in enumerate(self.layers):
output = mod( output = mod(
output, output,
pos_emb, pos_emb,
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
) )
if i in self.aux_layers:
outputs.append(output)
output = self.combiner(outputs)
return output return output
@ -918,7 +934,203 @@ def identity(x):
return x return x
class RandomCombine(torch.nn.Module):
"""
This module combines a list of Tensors, all with the same shape, to
produce a single output of that same shape which, in training time,
is a random combination of all the inputs; but which in test time
will be just the last input.
All but the last input will have a linear transform before we
randomly combine them; these linear transforms will be initialzed
to the identity transform.
The idea is that the list of Tensors will be a list of outputs of multiple
conformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
NETWORKS).
"""
def __init__(self, num_inputs: int,
num_channels: int,
final_weight: float = 0.5,
pure_prob: float = 0.5,
stddev: float = 2.0) -> None:
"""
Args:
num_inputs: The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3.
num_channels: The number of channels on the input, e.g. 512.
final_weight: The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing
continuous layer weights.
pure_prob: The probability, on each frame, with which we choose
only a single layer to output (rather than an interpolation)
stddev: A standard deviation that we add to log-probs for computing
randomized weights.
The method of choosing which layers,
or combinations of layers, to use, is conceptually as follows.
With probability `pure_prob`:
With probability `final_weight`: choose final layer,
Else: choose random non-final layer.
Else:
Choose initial log-weights that correspond to assigning
weight `final_weight` to the final layer and equal
weights to other layers; then add Gaussian noise
with variance `stddev` to these log-weights, and normalize
to weights (note: the average weight assigned to the
final layer here will not be `final_weight` if stddev>0).
"""
super(RandomCombine, self).__init__()
assert pure_prob >= 0 and pure_prob <= 1
assert final_weight > 0 and final_weight < 1
assert num_inputs >= 1
self.linear = nn.ModuleList([nn.Linear(num_channels, num_channels, bias=True)
for _ in range(num_inputs - 1)])
self.num_inputs = num_inputs
self.final_weight = final_weight
self.pure_prob = pure_prob
self.stddev= stddev
self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item()
self._reset_parameters()
def _reset_parameters(self):
for i in range(len(self.linear)):
nn.init.eye_(self.linear[i].weight)
nn.init.constant_(self.linear[i].bias, 0.0)
def forward(self, inputs: Sequence[Tensor]) -> Tensor:
"""
Forward function.
Args:
inputs: a list of Tensor, e.g. from various layers of a transformer.
All must be the same shape, of (*, num_channels)
Returns:
a Tensor of shape (*, num_channels). In test mode
this is just the final input.
"""
num_inputs = self.num_inputs
assert len(inputs) == num_inputs
if not self.training:
return inputs[-1]
# Shape of weights: (*, num_inputs)
num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels
mod_inputs = []
for i in range(num_inputs - 1):
mod_inputs.append(self.linear[i](inputs[i]))
mod_inputs.append(inputs[num_inputs - 1])
ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape((num_frames,
num_channels,
num_inputs))
# weights: (num_frames, num_inputs)
weights = self._get_random_weights(inputs[0].dtype, inputs[0].device,
num_frames)
weights = weights.reshape(num_frames, num_inputs, 1)
# ans: (num_frames, num_channels, 1)
ans = torch.matmul(stacked_inputs, weights)
# ans: (*, num_channels)
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
if __name__ == "__main__":
# for testing only...
print("Weights = ", weights.reshape(num_frames, num_inputs))
return ans
def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor:
"""
Return a tensor of random weights, of shape (num_frames, self.num_inputs),
Args:
dtype: the data-type desired for the answer, e.g. float, double
device: the device needed for the answer
num_frames: the number of sets of weights desired
Returns: a tensor of shape (num_frames, self.num_inputs), such that
ans.sum(dim=1) is all ones.
"""
pure_prob = self.pure_prob
if pure_prob == 0.0:
return self._get_random_mixed_weights(dtype, device, num_frames)
elif pure_prob == 1.0:
return self._get_random_pure_weights(dtype, device, num_frames)
else:
p = self._get_random_pure_weights(dtype, device, num_frames)
m = self._get_random_mixed_weights(dtype, device, num_frames)
return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m)
def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int):
"""
Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs),
Args:
dtype: the data-type desired for the answer, e.g. float, double
device: the device needed for the answer
num_frames: the number of sets of weights desired
Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with
exactly one weight equal to 1.0 on each frame.
"""
final_prob = self.final_weight
# final contains self.num_inputs - 1 in all elements
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device)
indexes = torch.where(torch.rand(num_frames, device=device) < final_prob,
final, nonfinal)
ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype)
return ans
def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int):
"""
Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs),
Args:
dtype: the data-type desired for the answer, e.g. float, double
device: the device needed for the answer
num_frames: the number of sets of weights desired
Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that
sum to one over the second axis, i.e. ans.sum(dim=1) is all ones.
"""
logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev
logprobs[:,-1] += self.final_log_weight
return logprobs.softmax(dim=1)
def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}")
num_inputs = 3
num_channels = 50
m = RandomCombine(num_inputs=num_inputs, num_channels=num_channels,
final_weight=final_weight, pure_prob=pure_prob, stddev=stddev)
x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ]
y = m(x)
assert y.shape == x[0].shape
assert torch.allclose(y, x[0]) # .. since actually all ones.
if __name__ == '__main__': if __name__ == '__main__':
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.0)
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.3)
_test_random_combine(0.5, 1, 0.3)
_test_random_combine(0.5, 0.5, 0.3)
feature_dim = 50 feature_dim = 50
c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
batch_size = 5 batch_size = 5

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/specaugmod_baseline", default="transducer_stateless/specaugmod_baseline_randcombine1",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved