Simplify how dim changes are dealt with; see also scaled_adam_exp977

This commit is contained in:
Daniel Povey 2023-02-22 11:40:33 +08:00
parent 90180ce5e7
commit 0191e8f3e4
2 changed files with 28 additions and 56 deletions

View File

@ -2165,6 +2165,16 @@ class SwooshR(torch.nn.Module):
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
return SwooshRFunction.apply(x) return SwooshRFunction.apply(x)
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
if num_channels <= x.shape[-1]:
return x[..., :num_channels]
else:
shape = list(x.shape)
shape[-1] = num_channels - shape[-1]
zeros = torch.zeros(*shape, dtype=x.dtype, device=x.device)
return torch.cat((x, zeros), dim=-1)
def _test_max_eig(): def _test_max_eig():

View File

@ -49,6 +49,7 @@ from scaling import (
ScheduledFloat, ScheduledFloat,
FloatLike, FloatLike,
limit_param_value, limit_param_value,
convert_num_channels,
ScaleGrad, ScaleGrad,
) )
from torch import Tensor, nn from torch import Tensor, nn
@ -225,8 +226,7 @@ class Zipformer(EncoderInterface):
if downsampling_factor[i] != 1: if downsampling_factor[i] != 1:
encoder = DownsampledZipformerEncoder( encoder = DownsampledZipformerEncoder(
encoder, encoder,
input_dim=encoder_dim[i-1] if i > 0 else encoder_dim[0], dim=encoder_dim[i],
output_dim=encoder_dim[i],
downsample=downsampling_factor[i], downsample=downsampling_factor[i],
dropout=dropout, dropout=dropout,
) )
@ -242,7 +242,6 @@ class Zipformer(EncoderInterface):
self._init_skip_modules() self._init_skip_modules()
self.downsample_output = SimpleDownsample(max(encoder_dim), self.downsample_output = SimpleDownsample(max(encoder_dim),
max(encoder_dim),
downsample=output_downsampling_factor, downsample=output_downsampling_factor,
dropout=dropout) dropout=dropout)
@ -269,8 +268,7 @@ class Zipformer(EncoderInterface):
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.") f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.")
skip_layers.append(j) skip_layers.append(j)
skip_modules.append(SimpleCombiner(self.encoder_dim[j], skip_modules.append(SimpleCombiner(self.encoder_dim[i-1],
self.encoder_dim[i-1],
min_weight=(0.0, 0.25))) min_weight=(0.0, 0.25)))
break break
self.skip_layers = skip_layers self.skip_layers = skip_layers
@ -412,6 +410,7 @@ class Zipformer(EncoderInterface):
x = torch.where(mask, skip_x, x) x = torch.where(mask, skip_x, x)
else: else:
x = skip_x x = skip_x
x = convert_num_channels(x, self.encoder_dim[i])
x = module(x, x = module(x,
chunk_size=chunk_size, chunk_size=chunk_size,
feature_mask=feature_masks[i], feature_mask=feature_masks[i],
@ -871,19 +870,16 @@ class DownsampledZipformerEncoder(nn.Module):
""" """
def __init__(self, def __init__(self,
encoder: nn.Module, encoder: nn.Module,
input_dim: int, dim: int,
output_dim: int,
downsample: int, downsample: int,
dropout: FloatLike): dropout: FloatLike):
super(DownsampledZipformerEncoder, self).__init__() super(DownsampledZipformerEncoder, self).__init__()
self.downsample_factor = downsample self.downsample_factor = downsample
self.downsample = SimpleDownsample(input_dim, output_dim, self.downsample = SimpleDownsample(dim,
downsample, dropout) downsample, dropout)
self.encoder = encoder self.encoder = encoder
self.upsample = SimpleUpsample(output_dim, downsample) self.upsample = SimpleUpsample(dim, downsample)
self.out_combiner = SimpleCombiner(input_dim, self.out_combiner = SimpleCombiner(dim, min_weight=(0.0, 0.25))
output_dim,
min_weight=(0.0, 0.25))
def forward(self, def forward(self,
@ -933,13 +929,9 @@ class SimpleDownsample(torch.nn.Module):
Does downsampling with attention, by weighted sum, and a projection.. Does downsampling with attention, by weighted sum, and a projection..
""" """
def __init__(self, def __init__(self,
in_channels: int, channels: int,
out_channels: int,
downsample: int, downsample: int,
dropout: FloatLike): dropout: FloatLike):
"""
Require out_channels > in_channels.
"""
super(SimpleDownsample, self).__init__() super(SimpleDownsample, self).__init__()
self.bias = nn.Parameter(torch.zeros(downsample)) self.bias = nn.Parameter(torch.zeros(downsample))
@ -947,22 +939,14 @@ class SimpleDownsample(torch.nn.Module):
self.name = None # will be set from training code self.name = None # will be set from training code
self.dropout = copy.deepcopy(dropout) self.dropout = copy.deepcopy(dropout)
# fill in the extra dimensions with a projection of the input
if out_channels > in_channels:
self.extra_proj = nn.Linear(in_channels * downsample,
out_channels - in_channels,
bias=False)
else:
self.extra_proj = None
self.downsample = downsample self.downsample = downsample
self.out_channels = out_channels
def forward(self, def forward(self,
src: Tensor) -> Tensor: src: Tensor) -> Tensor:
""" """
x: (seq_len, batch_size, in_channels) x: (seq_len, batch_size, in_channels)
Returns a tensor of shape Returns a tensor of shape
( (seq_len+downsample-1)//downsample, batch_size, out_channels) ( (seq_len+downsample-1)//downsample, batch_size, channels)
""" """
(seq_len, batch_size, in_channels) = src.shape (seq_len, batch_size, in_channels) = src.shape
ds = self.downsample ds = self.downsample
@ -984,13 +968,7 @@ class SimpleDownsample(torch.nn.Module):
# ans1 is the first `in_channels` channels of the output # ans1 is the first `in_channels` channels of the output
ans = (src * weights).sum(dim=1) ans = (src * weights).sum(dim=1)
src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels)
if self.extra_proj is not None:
ans2 = self.extra_proj(src)
ans = torch.cat((ans, ans2), dim=2)
ans = ans[..., :self.out_channels]
return ans return ans
@ -1028,26 +1006,25 @@ class SimpleCombiner(torch.nn.Module):
The output will have the same dimension as dim2. The output will have the same dimension as dim2.
""" """
def __init__(self, def __init__(self,
dim1: int, dim: int,
dim2: int,
min_weight: Tuple[float, float] = (0., 0.)): min_weight: Tuple[float, float] = (0., 0.)):
super(SimpleCombiner, self).__init__() super(SimpleCombiner, self).__init__()
initial_weight1 = 0.1 initial_weight1 = 0.1
self.weight1 = nn.Parameter(torch.full((dim2,), initial_weight1)) self.weight1 = nn.Parameter(torch.full((dim,), initial_weight1))
self.min_weight = min_weight self.min_weight = min_weight
def forward(self, def forward(self,
src1: Tensor, src1: Tensor,
src2: Tensor) -> Tensor: src2: Tensor) -> Tensor:
""" """
src1: (*, dim1) src1: (*, other_dim)
src2: (*, dim2) src2: (*, dim)
Returns: a tensor of shape (*, dim2) Returns: a tensor of shape (*, dim)
""" """
assert src1.shape[:-1] == src2.shape[:-1] assert src1.shape[:-1] == src2.shape[:-1]
dim1 = src1.shape[-1] num_channels = src2.shape[-1]
dim2 = src2.shape[-1] src1 = convert_num_channels(src1, num_channels)
weight1 = limit_param_value(self.weight1, weight1 = limit_param_value(self.weight1,
@ -1055,22 +1032,7 @@ class SimpleCombiner(torch.nn.Module):
max=1.0-self.min_weight[1], max=1.0-self.min_weight[1],
training=self.training) training=self.training)
src1_dim = src1.shape[-1] return src1 * weight1 + src2 * (1.0 - weight1)
src2_dim = src2.shape[-1]
if src1_dim != src2_dim:
if src1_dim < src2_dim:
zeros_shape = list(src1.shape[:-1]) + [src2_dim - src1_dim]
src1 = torch.cat((src1, torch.zeros(*zeros_shape,
device=src1.device,
dtype=src1.dtype)),
dim=-1)
else:
src1 = src1[...,:src2_dim]
src1 = src1 * weight1
src2 = src2 * (1.0 - weight1)
return src1 + src2