mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 17:42:21 +00:00
First version of rand-combine iterated-training-like idea.
This commit is contained in:
parent
63d8d935d4
commit
c1063def95
@ -18,7 +18,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -56,6 +56,7 @@ class Conformer(Transformer):
|
|||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
normalize_before: bool = True,
|
normalize_before: bool = True,
|
||||||
vgg_frontend: bool = False,
|
vgg_frontend: bool = False,
|
||||||
|
aux_layer_period: int = 3
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__(
|
super(Conformer, self).__init__(
|
||||||
num_features=num_features,
|
num_features=num_features,
|
||||||
@ -80,10 +81,11 @@ class Conformer(Transformer):
|
|||||||
cnn_module_kernel,
|
cnn_module_kernel,
|
||||||
normalize_before,
|
normalize_before,
|
||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
|
||||||
|
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
|
||||||
self.normalize_before = normalize_before
|
self.normalize_before = normalize_before
|
||||||
if self.normalize_before:
|
if self.normalize_before:
|
||||||
self.after_norm = nn.LayerNorm(d_model)
|
self.after_norm = nn.LayerNorm(d_model) # TODO: remove.
|
||||||
else:
|
else:
|
||||||
# Note: TorchScript detects that self.after_norm could be used inside forward()
|
# Note: TorchScript detects that self.after_norm could be used inside forward()
|
||||||
# and throws an error without this change.
|
# and throws an error without this change.
|
||||||
@ -280,12 +282,21 @@ class ConformerEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, encoder_layer: nn.Module, num_layers: int
|
self, encoder_layer: nn.Module,
|
||||||
|
num_layers: int,
|
||||||
|
aux_layers: Sequence[int],
|
||||||
) -> None:
|
) -> None:
|
||||||
super(ConformerEncoder, self).__init__()
|
super(ConformerEncoder, self).__init__()
|
||||||
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_layers)])
|
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for i in range(num_layers)])
|
||||||
|
self.aux_layers = set(aux_layers + [num_layers - 1])
|
||||||
|
assert num_layers - 1 not in aux_layers
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
num_channels = encoder_layer.norm_final.weight.numel()
|
||||||
|
self.combiner = RandomCombine(num_inputs=len(self.aux_layers),
|
||||||
|
num_channels=num_channels,
|
||||||
|
final_weight=0.5,
|
||||||
|
pure_prob=0.333,
|
||||||
|
stddev=2.0)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -312,14 +323,19 @@ class ConformerEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
for mod in self.layers:
|
outputs = []
|
||||||
|
|
||||||
|
for i, mod in enumerate(self.layers):
|
||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
|
if i in self.aux_layers:
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
output = self.combiner(outputs)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -918,7 +934,203 @@ def identity(x):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RandomCombine(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
This module combines a list of Tensors, all with the same shape, to
|
||||||
|
produce a single output of that same shape which, in training time,
|
||||||
|
is a random combination of all the inputs; but which in test time
|
||||||
|
will be just the last input.
|
||||||
|
|
||||||
|
All but the last input will have a linear transform before we
|
||||||
|
randomly combine them; these linear transforms will be initialzed
|
||||||
|
to the identity transform.
|
||||||
|
|
||||||
|
The idea is that the list of Tensors will be a list of outputs of multiple
|
||||||
|
conformer layers. This has a similar effect as iterated loss. (See:
|
||||||
|
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
|
||||||
|
NETWORKS).
|
||||||
|
"""
|
||||||
|
def __init__(self, num_inputs: int,
|
||||||
|
num_channels: int,
|
||||||
|
final_weight: float = 0.5,
|
||||||
|
pure_prob: float = 0.5,
|
||||||
|
stddev: float = 2.0) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
num_inputs: The number of tensor inputs, which equals the number of layers'
|
||||||
|
outputs that are fed into this module. E.g. in an 18-layer neural
|
||||||
|
net if we output layers 16, 12, 18, num_inputs would be 3.
|
||||||
|
num_channels: The number of channels on the input, e.g. 512.
|
||||||
|
final_weight: The amount of weight or probability we assign to the
|
||||||
|
final layer when randomly choosing layers or when choosing
|
||||||
|
continuous layer weights.
|
||||||
|
pure_prob: The probability, on each frame, with which we choose
|
||||||
|
only a single layer to output (rather than an interpolation)
|
||||||
|
stddev: A standard deviation that we add to log-probs for computing
|
||||||
|
randomized weights.
|
||||||
|
|
||||||
|
The method of choosing which layers,
|
||||||
|
or combinations of layers, to use, is conceptually as follows.
|
||||||
|
With probability `pure_prob`:
|
||||||
|
With probability `final_weight`: choose final layer,
|
||||||
|
Else: choose random non-final layer.
|
||||||
|
Else:
|
||||||
|
Choose initial log-weights that correspond to assigning
|
||||||
|
weight `final_weight` to the final layer and equal
|
||||||
|
weights to other layers; then add Gaussian noise
|
||||||
|
with variance `stddev` to these log-weights, and normalize
|
||||||
|
to weights (note: the average weight assigned to the
|
||||||
|
final layer here will not be `final_weight` if stddev>0).
|
||||||
|
"""
|
||||||
|
super(RandomCombine, self).__init__()
|
||||||
|
assert pure_prob >= 0 and pure_prob <= 1
|
||||||
|
assert final_weight > 0 and final_weight < 1
|
||||||
|
assert num_inputs >= 1
|
||||||
|
self.linear = nn.ModuleList([nn.Linear(num_channels, num_channels, bias=True)
|
||||||
|
for _ in range(num_inputs - 1)])
|
||||||
|
|
||||||
|
self.num_inputs = num_inputs
|
||||||
|
self.final_weight = final_weight
|
||||||
|
self.pure_prob = pure_prob
|
||||||
|
self.stddev= stddev
|
||||||
|
|
||||||
|
self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item()
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
for i in range(len(self.linear)):
|
||||||
|
nn.init.eye_(self.linear[i].weight)
|
||||||
|
nn.init.constant_(self.linear[i].bias, 0.0)
|
||||||
|
|
||||||
|
def forward(self, inputs: Sequence[Tensor]) -> Tensor:
|
||||||
|
"""
|
||||||
|
Forward function.
|
||||||
|
Args:
|
||||||
|
inputs: a list of Tensor, e.g. from various layers of a transformer.
|
||||||
|
All must be the same shape, of (*, num_channels)
|
||||||
|
Returns:
|
||||||
|
a Tensor of shape (*, num_channels). In test mode
|
||||||
|
this is just the final input.
|
||||||
|
"""
|
||||||
|
num_inputs = self.num_inputs
|
||||||
|
assert len(inputs) == num_inputs
|
||||||
|
if not self.training:
|
||||||
|
return inputs[-1]
|
||||||
|
|
||||||
|
# Shape of weights: (*, num_inputs)
|
||||||
|
num_channels = inputs[0].shape[-1]
|
||||||
|
num_frames = inputs[0].numel() // num_channels
|
||||||
|
|
||||||
|
mod_inputs = []
|
||||||
|
for i in range(num_inputs - 1):
|
||||||
|
mod_inputs.append(self.linear[i](inputs[i]))
|
||||||
|
mod_inputs.append(inputs[num_inputs - 1])
|
||||||
|
|
||||||
|
|
||||||
|
ndim = inputs[0].ndim
|
||||||
|
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
||||||
|
stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape((num_frames,
|
||||||
|
num_channels,
|
||||||
|
num_inputs))
|
||||||
|
|
||||||
|
# weights: (num_frames, num_inputs)
|
||||||
|
weights = self._get_random_weights(inputs[0].dtype, inputs[0].device,
|
||||||
|
num_frames)
|
||||||
|
|
||||||
|
weights = weights.reshape(num_frames, num_inputs, 1)
|
||||||
|
# ans: (num_frames, num_channels, 1)
|
||||||
|
ans = torch.matmul(stacked_inputs, weights)
|
||||||
|
# ans: (*, num_channels)
|
||||||
|
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# for testing only...
|
||||||
|
print("Weights = ", weights.reshape(num_frames, num_inputs))
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
Return a tensor of random weights, of shape (num_frames, self.num_inputs),
|
||||||
|
Args:
|
||||||
|
dtype: the data-type desired for the answer, e.g. float, double
|
||||||
|
device: the device needed for the answer
|
||||||
|
num_frames: the number of sets of weights desired
|
||||||
|
Returns: a tensor of shape (num_frames, self.num_inputs), such that
|
||||||
|
ans.sum(dim=1) is all ones.
|
||||||
|
|
||||||
|
"""
|
||||||
|
pure_prob = self.pure_prob
|
||||||
|
if pure_prob == 0.0:
|
||||||
|
return self._get_random_mixed_weights(dtype, device, num_frames)
|
||||||
|
elif pure_prob == 1.0:
|
||||||
|
return self._get_random_pure_weights(dtype, device, num_frames)
|
||||||
|
else:
|
||||||
|
p = self._get_random_pure_weights(dtype, device, num_frames)
|
||||||
|
m = self._get_random_mixed_weights(dtype, device, num_frames)
|
||||||
|
return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m)
|
||||||
|
|
||||||
|
def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int):
|
||||||
|
"""
|
||||||
|
Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs),
|
||||||
|
Args:
|
||||||
|
dtype: the data-type desired for the answer, e.g. float, double
|
||||||
|
device: the device needed for the answer
|
||||||
|
num_frames: the number of sets of weights desired
|
||||||
|
Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with
|
||||||
|
exactly one weight equal to 1.0 on each frame.
|
||||||
|
"""
|
||||||
|
|
||||||
|
final_prob = self.final_weight
|
||||||
|
|
||||||
|
# final contains self.num_inputs - 1 in all elements
|
||||||
|
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
|
||||||
|
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
|
||||||
|
nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device)
|
||||||
|
|
||||||
|
indexes = torch.where(torch.rand(num_frames, device=device) < final_prob,
|
||||||
|
final, nonfinal)
|
||||||
|
ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int):
|
||||||
|
"""
|
||||||
|
Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs),
|
||||||
|
Args:
|
||||||
|
dtype: the data-type desired for the answer, e.g. float, double
|
||||||
|
device: the device needed for the answer
|
||||||
|
num_frames: the number of sets of weights desired
|
||||||
|
Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that
|
||||||
|
sum to one over the second axis, i.e. ans.sum(dim=1) is all ones.
|
||||||
|
"""
|
||||||
|
logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev
|
||||||
|
logprobs[:,-1] += self.final_log_weight
|
||||||
|
return logprobs.softmax(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
|
||||||
|
print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}")
|
||||||
|
num_inputs = 3
|
||||||
|
num_channels = 50
|
||||||
|
m = RandomCombine(num_inputs=num_inputs, num_channels=num_channels,
|
||||||
|
final_weight=final_weight, pure_prob=pure_prob, stddev=stddev)
|
||||||
|
|
||||||
|
x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ]
|
||||||
|
|
||||||
|
y = m(x)
|
||||||
|
assert y.shape == x[0].shape
|
||||||
|
assert torch.allclose(y, x[0]) # .. since actually all ones.
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
_test_random_combine(0.999, 0, 0.0)
|
||||||
|
_test_random_combine(0.5, 0, 0.0)
|
||||||
|
_test_random_combine(0.999, 0, 0.0)
|
||||||
|
_test_random_combine(0.5, 0, 0.3)
|
||||||
|
_test_random_combine(0.5, 1, 0.3)
|
||||||
|
_test_random_combine(0.5, 0.5, 0.3)
|
||||||
|
|
||||||
feature_dim = 50
|
feature_dim = 50
|
||||||
c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
|
c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/specaugmod_baseline",
|
default="transducer_stateless/specaugmod_baseline_randcombine1",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
|
Loading…
x
Reference in New Issue
Block a user