Simplify how dim changes are dealt with; see also scaled_adam_exp977

This commit is contained in:
Daniel Povey 2023-02-22 11:40:33 +08:00
parent 90180ce5e7
commit 0191e8f3e4
2 changed files with 28 additions and 56 deletions

View File

@ -2165,6 +2165,16 @@ class SwooshR(torch.nn.Module):
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
return SwooshRFunction.apply(x)
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
if num_channels <= x.shape[-1]:
return x[..., :num_channels]
else:
shape = list(x.shape)
shape[-1] = num_channels - shape[-1]
zeros = torch.zeros(*shape, dtype=x.dtype, device=x.device)
return torch.cat((x, zeros), dim=-1)
def _test_max_eig():

View File

@ -49,6 +49,7 @@ from scaling import (
ScheduledFloat,
FloatLike,
limit_param_value,
convert_num_channels,
ScaleGrad,
)
from torch import Tensor, nn
@ -225,8 +226,7 @@ class Zipformer(EncoderInterface):
if downsampling_factor[i] != 1:
encoder = DownsampledZipformerEncoder(
encoder,
input_dim=encoder_dim[i-1] if i > 0 else encoder_dim[0],
output_dim=encoder_dim[i],
dim=encoder_dim[i],
downsample=downsampling_factor[i],
dropout=dropout,
)
@ -242,7 +242,6 @@ class Zipformer(EncoderInterface):
self._init_skip_modules()
self.downsample_output = SimpleDownsample(max(encoder_dim),
max(encoder_dim),
downsample=output_downsampling_factor,
dropout=dropout)
@ -269,8 +268,7 @@ class Zipformer(EncoderInterface):
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.")
skip_layers.append(j)
skip_modules.append(SimpleCombiner(self.encoder_dim[j],
self.encoder_dim[i-1],
skip_modules.append(SimpleCombiner(self.encoder_dim[i-1],
min_weight=(0.0, 0.25)))
break
self.skip_layers = skip_layers
@ -412,6 +410,7 @@ class Zipformer(EncoderInterface):
x = torch.where(mask, skip_x, x)
else:
x = skip_x
x = convert_num_channels(x, self.encoder_dim[i])
x = module(x,
chunk_size=chunk_size,
feature_mask=feature_masks[i],
@ -871,19 +870,16 @@ class DownsampledZipformerEncoder(nn.Module):
"""
def __init__(self,
encoder: nn.Module,
input_dim: int,
output_dim: int,
dim: int,
downsample: int,
dropout: FloatLike):
super(DownsampledZipformerEncoder, self).__init__()
self.downsample_factor = downsample
self.downsample = SimpleDownsample(input_dim, output_dim,
self.downsample = SimpleDownsample(dim,
downsample, dropout)
self.encoder = encoder
self.upsample = SimpleUpsample(output_dim, downsample)
self.out_combiner = SimpleCombiner(input_dim,
output_dim,
min_weight=(0.0, 0.25))
self.upsample = SimpleUpsample(dim, downsample)
self.out_combiner = SimpleCombiner(dim, min_weight=(0.0, 0.25))
def forward(self,
@ -933,13 +929,9 @@ class SimpleDownsample(torch.nn.Module):
Does downsampling with attention, by weighted sum, and a projection..
"""
def __init__(self,
in_channels: int,
out_channels: int,
channels: int,
downsample: int,
dropout: FloatLike):
"""
Require out_channels > in_channels.
"""
super(SimpleDownsample, self).__init__()
self.bias = nn.Parameter(torch.zeros(downsample))
@ -947,22 +939,14 @@ class SimpleDownsample(torch.nn.Module):
self.name = None # will be set from training code
self.dropout = copy.deepcopy(dropout)
# fill in the extra dimensions with a projection of the input
if out_channels > in_channels:
self.extra_proj = nn.Linear(in_channels * downsample,
out_channels - in_channels,
bias=False)
else:
self.extra_proj = None
self.downsample = downsample
self.out_channels = out_channels
def forward(self,
src: Tensor) -> Tensor:
"""
x: (seq_len, batch_size, in_channels)
Returns a tensor of shape
( (seq_len+downsample-1)//downsample, batch_size, out_channels)
( (seq_len+downsample-1)//downsample, batch_size, channels)
"""
(seq_len, batch_size, in_channels) = src.shape
ds = self.downsample
@ -984,13 +968,7 @@ class SimpleDownsample(torch.nn.Module):
# ans1 is the first `in_channels` channels of the output
ans = (src * weights).sum(dim=1)
src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels)
if self.extra_proj is not None:
ans2 = self.extra_proj(src)
ans = torch.cat((ans, ans2), dim=2)
ans = ans[..., :self.out_channels]
return ans
@ -1028,26 +1006,25 @@ class SimpleCombiner(torch.nn.Module):
The output will have the same dimension as dim2.
"""
def __init__(self,
dim1: int,
dim2: int,
dim: int,
min_weight: Tuple[float, float] = (0., 0.)):
super(SimpleCombiner, self).__init__()
initial_weight1 = 0.1
self.weight1 = nn.Parameter(torch.full((dim2,), initial_weight1))
self.weight1 = nn.Parameter(torch.full((dim,), initial_weight1))
self.min_weight = min_weight
def forward(self,
src1: Tensor,
src2: Tensor) -> Tensor:
"""
src1: (*, dim1)
src2: (*, dim2)
src1: (*, other_dim)
src2: (*, dim)
Returns: a tensor of shape (*, dim2)
Returns: a tensor of shape (*, dim)
"""
assert src1.shape[:-1] == src2.shape[:-1]
dim1 = src1.shape[-1]
dim2 = src2.shape[-1]
num_channels = src2.shape[-1]
src1 = convert_num_channels(src1, num_channels)
weight1 = limit_param_value(self.weight1,
@ -1055,22 +1032,7 @@ class SimpleCombiner(torch.nn.Module):
max=1.0-self.min_weight[1],
training=self.training)
src1_dim = src1.shape[-1]
src2_dim = src2.shape[-1]
if src1_dim != src2_dim:
if src1_dim < src2_dim:
zeros_shape = list(src1.shape[:-1]) + [src2_dim - src1_dim]
src1 = torch.cat((src1, torch.zeros(*zeros_shape,
device=src1.device,
dtype=src1.dtype)),
dim=-1)
else:
src1 = src1[...,:src2_dim]
src1 = src1 * weight1
src2 = src2 * (1.0 - weight1)
return src1 + src2
return src1 * weight1 + src2 * (1.0 - weight1)