From 0191e8f3e44345f7c828cdb48e9a9e99331ed704 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 22 Feb 2023 11:40:33 +0800 Subject: [PATCH] Simplify how dim changes are dealt with; see also scaled_adam_exp977 --- .../pruned_transducer_stateless7/scaling.py | 10 +++ .../pruned_transducer_stateless7/zipformer.py | 74 +++++-------------- 2 files changed, 28 insertions(+), 56 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index c2d1ab804..f8ba1dc54 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -2165,6 +2165,16 @@ class SwooshR(torch.nn.Module): return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687 return SwooshRFunction.apply(x) +def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: + if num_channels <= x.shape[-1]: + return x[..., :num_channels] + else: + shape = list(x.shape) + shape[-1] = num_channels - shape[-1] + zeros = torch.zeros(*shape, dtype=x.dtype, device=x.device) + return torch.cat((x, zeros), dim=-1) + + def _test_max_eig(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 68077af81..c49c012ba 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -49,6 +49,7 @@ from scaling import ( ScheduledFloat, FloatLike, limit_param_value, + convert_num_channels, ScaleGrad, ) from torch import Tensor, nn @@ -225,8 +226,7 @@ class Zipformer(EncoderInterface): if downsampling_factor[i] != 1: encoder = DownsampledZipformerEncoder( encoder, - input_dim=encoder_dim[i-1] if i > 0 else encoder_dim[0], - output_dim=encoder_dim[i], + dim=encoder_dim[i], downsample=downsampling_factor[i], dropout=dropout, ) @@ -242,7 +242,6 @@ class Zipformer(EncoderInterface): self._init_skip_modules() self.downsample_output = SimpleDownsample(max(encoder_dim), - max(encoder_dim), downsample=output_downsampling_factor, dropout=dropout) @@ -269,8 +268,7 @@ class Zipformer(EncoderInterface): logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.") skip_layers.append(j) - skip_modules.append(SimpleCombiner(self.encoder_dim[j], - self.encoder_dim[i-1], + skip_modules.append(SimpleCombiner(self.encoder_dim[i-1], min_weight=(0.0, 0.25))) break self.skip_layers = skip_layers @@ -412,6 +410,7 @@ class Zipformer(EncoderInterface): x = torch.where(mask, skip_x, x) else: x = skip_x + x = convert_num_channels(x, self.encoder_dim[i]) x = module(x, chunk_size=chunk_size, feature_mask=feature_masks[i], @@ -871,19 +870,16 @@ class DownsampledZipformerEncoder(nn.Module): """ def __init__(self, encoder: nn.Module, - input_dim: int, - output_dim: int, + dim: int, downsample: int, dropout: FloatLike): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample - self.downsample = SimpleDownsample(input_dim, output_dim, + self.downsample = SimpleDownsample(dim, downsample, dropout) self.encoder = encoder - self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner(input_dim, - output_dim, - min_weight=(0.0, 0.25)) + self.upsample = SimpleUpsample(dim, downsample) + self.out_combiner = SimpleCombiner(dim, min_weight=(0.0, 0.25)) def forward(self, @@ -933,13 +929,9 @@ class SimpleDownsample(torch.nn.Module): Does downsampling with attention, by weighted sum, and a projection.. """ def __init__(self, - in_channels: int, - out_channels: int, + channels: int, downsample: int, dropout: FloatLike): - """ - Require out_channels > in_channels. - """ super(SimpleDownsample, self).__init__() self.bias = nn.Parameter(torch.zeros(downsample)) @@ -947,22 +939,14 @@ class SimpleDownsample(torch.nn.Module): self.name = None # will be set from training code self.dropout = copy.deepcopy(dropout) - # fill in the extra dimensions with a projection of the input - if out_channels > in_channels: - self.extra_proj = nn.Linear(in_channels * downsample, - out_channels - in_channels, - bias=False) - else: - self.extra_proj = None self.downsample = downsample - self.out_channels = out_channels def forward(self, src: Tensor) -> Tensor: """ x: (seq_len, batch_size, in_channels) Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, out_channels) + ( (seq_len+downsample-1)//downsample, batch_size, channels) """ (seq_len, batch_size, in_channels) = src.shape ds = self.downsample @@ -984,13 +968,7 @@ class SimpleDownsample(torch.nn.Module): # ans1 is the first `in_channels` channels of the output ans = (src * weights).sum(dim=1) - src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) - if self.extra_proj is not None: - ans2 = self.extra_proj(src) - ans = torch.cat((ans, ans2), dim=2) - - ans = ans[..., :self.out_channels] return ans @@ -1028,26 +1006,25 @@ class SimpleCombiner(torch.nn.Module): The output will have the same dimension as dim2. """ def __init__(self, - dim1: int, - dim2: int, + dim: int, min_weight: Tuple[float, float] = (0., 0.)): super(SimpleCombiner, self).__init__() initial_weight1 = 0.1 - self.weight1 = nn.Parameter(torch.full((dim2,), initial_weight1)) + self.weight1 = nn.Parameter(torch.full((dim,), initial_weight1)) self.min_weight = min_weight def forward(self, src1: Tensor, src2: Tensor) -> Tensor: """ - src1: (*, dim1) - src2: (*, dim2) + src1: (*, other_dim) + src2: (*, dim) - Returns: a tensor of shape (*, dim2) + Returns: a tensor of shape (*, dim) """ assert src1.shape[:-1] == src2.shape[:-1] - dim1 = src1.shape[-1] - dim2 = src2.shape[-1] + num_channels = src2.shape[-1] + src1 = convert_num_channels(src1, num_channels) weight1 = limit_param_value(self.weight1, @@ -1055,22 +1032,7 @@ class SimpleCombiner(torch.nn.Module): max=1.0-self.min_weight[1], training=self.training) - src1_dim = src1.shape[-1] - src2_dim = src2.shape[-1] - if src1_dim != src2_dim: - if src1_dim < src2_dim: - zeros_shape = list(src1.shape[:-1]) + [src2_dim - src1_dim] - src1 = torch.cat((src1, torch.zeros(*zeros_shape, - device=src1.device, - dtype=src1.dtype)), - dim=-1) - else: - src1 = src1[...,:src2_dim] - - src1 = src1 * weight1 - src2 = src2 * (1.0 - weight1) - - return src1 + src2 + return src1 * weight1 + src2 * (1.0 - weight1)