mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Simplify how dim changes are dealt with; see also scaled_adam_exp977
This commit is contained in:
parent
90180ce5e7
commit
0191e8f3e4
@ -2165,6 +2165,16 @@ class SwooshR(torch.nn.Module):
|
||||
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||
return SwooshRFunction.apply(x)
|
||||
|
||||
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
|
||||
if num_channels <= x.shape[-1]:
|
||||
return x[..., :num_channels]
|
||||
else:
|
||||
shape = list(x.shape)
|
||||
shape[-1] = num_channels - shape[-1]
|
||||
zeros = torch.zeros(*shape, dtype=x.dtype, device=x.device)
|
||||
return torch.cat((x, zeros), dim=-1)
|
||||
|
||||
|
||||
|
||||
|
||||
def _test_max_eig():
|
||||
|
||||
@ -49,6 +49,7 @@ from scaling import (
|
||||
ScheduledFloat,
|
||||
FloatLike,
|
||||
limit_param_value,
|
||||
convert_num_channels,
|
||||
ScaleGrad,
|
||||
)
|
||||
from torch import Tensor, nn
|
||||
@ -225,8 +226,7 @@ class Zipformer(EncoderInterface):
|
||||
if downsampling_factor[i] != 1:
|
||||
encoder = DownsampledZipformerEncoder(
|
||||
encoder,
|
||||
input_dim=encoder_dim[i-1] if i > 0 else encoder_dim[0],
|
||||
output_dim=encoder_dim[i],
|
||||
dim=encoder_dim[i],
|
||||
downsample=downsampling_factor[i],
|
||||
dropout=dropout,
|
||||
)
|
||||
@ -242,7 +242,6 @@ class Zipformer(EncoderInterface):
|
||||
self._init_skip_modules()
|
||||
|
||||
self.downsample_output = SimpleDownsample(max(encoder_dim),
|
||||
max(encoder_dim),
|
||||
downsample=output_downsampling_factor,
|
||||
dropout=dropout)
|
||||
|
||||
@ -269,8 +268,7 @@ class Zipformer(EncoderInterface):
|
||||
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
|
||||
f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.")
|
||||
skip_layers.append(j)
|
||||
skip_modules.append(SimpleCombiner(self.encoder_dim[j],
|
||||
self.encoder_dim[i-1],
|
||||
skip_modules.append(SimpleCombiner(self.encoder_dim[i-1],
|
||||
min_weight=(0.0, 0.25)))
|
||||
break
|
||||
self.skip_layers = skip_layers
|
||||
@ -412,6 +410,7 @@ class Zipformer(EncoderInterface):
|
||||
x = torch.where(mask, skip_x, x)
|
||||
else:
|
||||
x = skip_x
|
||||
x = convert_num_channels(x, self.encoder_dim[i])
|
||||
x = module(x,
|
||||
chunk_size=chunk_size,
|
||||
feature_mask=feature_masks[i],
|
||||
@ -871,19 +870,16 @@ class DownsampledZipformerEncoder(nn.Module):
|
||||
"""
|
||||
def __init__(self,
|
||||
encoder: nn.Module,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
dim: int,
|
||||
downsample: int,
|
||||
dropout: FloatLike):
|
||||
super(DownsampledZipformerEncoder, self).__init__()
|
||||
self.downsample_factor = downsample
|
||||
self.downsample = SimpleDownsample(input_dim, output_dim,
|
||||
self.downsample = SimpleDownsample(dim,
|
||||
downsample, dropout)
|
||||
self.encoder = encoder
|
||||
self.upsample = SimpleUpsample(output_dim, downsample)
|
||||
self.out_combiner = SimpleCombiner(input_dim,
|
||||
output_dim,
|
||||
min_weight=(0.0, 0.25))
|
||||
self.upsample = SimpleUpsample(dim, downsample)
|
||||
self.out_combiner = SimpleCombiner(dim, min_weight=(0.0, 0.25))
|
||||
|
||||
|
||||
def forward(self,
|
||||
@ -933,13 +929,9 @@ class SimpleDownsample(torch.nn.Module):
|
||||
Does downsampling with attention, by weighted sum, and a projection..
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
channels: int,
|
||||
downsample: int,
|
||||
dropout: FloatLike):
|
||||
"""
|
||||
Require out_channels > in_channels.
|
||||
"""
|
||||
super(SimpleDownsample, self).__init__()
|
||||
|
||||
self.bias = nn.Parameter(torch.zeros(downsample))
|
||||
@ -947,22 +939,14 @@ class SimpleDownsample(torch.nn.Module):
|
||||
self.name = None # will be set from training code
|
||||
self.dropout = copy.deepcopy(dropout)
|
||||
|
||||
# fill in the extra dimensions with a projection of the input
|
||||
if out_channels > in_channels:
|
||||
self.extra_proj = nn.Linear(in_channels * downsample,
|
||||
out_channels - in_channels,
|
||||
bias=False)
|
||||
else:
|
||||
self.extra_proj = None
|
||||
self.downsample = downsample
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self,
|
||||
src: Tensor) -> Tensor:
|
||||
"""
|
||||
x: (seq_len, batch_size, in_channels)
|
||||
Returns a tensor of shape
|
||||
( (seq_len+downsample-1)//downsample, batch_size, out_channels)
|
||||
( (seq_len+downsample-1)//downsample, batch_size, channels)
|
||||
"""
|
||||
(seq_len, batch_size, in_channels) = src.shape
|
||||
ds = self.downsample
|
||||
@ -984,13 +968,7 @@ class SimpleDownsample(torch.nn.Module):
|
||||
|
||||
# ans1 is the first `in_channels` channels of the output
|
||||
ans = (src * weights).sum(dim=1)
|
||||
src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels)
|
||||
|
||||
if self.extra_proj is not None:
|
||||
ans2 = self.extra_proj(src)
|
||||
ans = torch.cat((ans, ans2), dim=2)
|
||||
|
||||
ans = ans[..., :self.out_channels]
|
||||
return ans
|
||||
|
||||
|
||||
@ -1028,26 +1006,25 @@ class SimpleCombiner(torch.nn.Module):
|
||||
The output will have the same dimension as dim2.
|
||||
"""
|
||||
def __init__(self,
|
||||
dim1: int,
|
||||
dim2: int,
|
||||
dim: int,
|
||||
min_weight: Tuple[float, float] = (0., 0.)):
|
||||
super(SimpleCombiner, self).__init__()
|
||||
initial_weight1 = 0.1
|
||||
self.weight1 = nn.Parameter(torch.full((dim2,), initial_weight1))
|
||||
self.weight1 = nn.Parameter(torch.full((dim,), initial_weight1))
|
||||
self.min_weight = min_weight
|
||||
|
||||
def forward(self,
|
||||
src1: Tensor,
|
||||
src2: Tensor) -> Tensor:
|
||||
"""
|
||||
src1: (*, dim1)
|
||||
src2: (*, dim2)
|
||||
src1: (*, other_dim)
|
||||
src2: (*, dim)
|
||||
|
||||
Returns: a tensor of shape (*, dim2)
|
||||
Returns: a tensor of shape (*, dim)
|
||||
"""
|
||||
assert src1.shape[:-1] == src2.shape[:-1]
|
||||
dim1 = src1.shape[-1]
|
||||
dim2 = src2.shape[-1]
|
||||
num_channels = src2.shape[-1]
|
||||
src1 = convert_num_channels(src1, num_channels)
|
||||
|
||||
|
||||
weight1 = limit_param_value(self.weight1,
|
||||
@ -1055,22 +1032,7 @@ class SimpleCombiner(torch.nn.Module):
|
||||
max=1.0-self.min_weight[1],
|
||||
training=self.training)
|
||||
|
||||
src1_dim = src1.shape[-1]
|
||||
src2_dim = src2.shape[-1]
|
||||
if src1_dim != src2_dim:
|
||||
if src1_dim < src2_dim:
|
||||
zeros_shape = list(src1.shape[:-1]) + [src2_dim - src1_dim]
|
||||
src1 = torch.cat((src1, torch.zeros(*zeros_shape,
|
||||
device=src1.device,
|
||||
dtype=src1.dtype)),
|
||||
dim=-1)
|
||||
else:
|
||||
src1 = src1[...,:src2_dim]
|
||||
|
||||
src1 = src1 * weight1
|
||||
src2 = src2 * (1.0 - weight1)
|
||||
|
||||
return src1 + src2
|
||||
return src1 * weight1 + src2 * (1.0 - weight1)
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user