mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Simplify how dim changes are dealt with; see also scaled_adam_exp977
This commit is contained in:
parent
90180ce5e7
commit
0191e8f3e4
@ -2165,6 +2165,16 @@ class SwooshR(torch.nn.Module):
|
|||||||
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
return torch.logaddexp(zero, x - 1.) - 0.08 * x - 0.313261687
|
||||||
return SwooshRFunction.apply(x)
|
return SwooshRFunction.apply(x)
|
||||||
|
|
||||||
|
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
|
||||||
|
if num_channels <= x.shape[-1]:
|
||||||
|
return x[..., :num_channels]
|
||||||
|
else:
|
||||||
|
shape = list(x.shape)
|
||||||
|
shape[-1] = num_channels - shape[-1]
|
||||||
|
zeros = torch.zeros(*shape, dtype=x.dtype, device=x.device)
|
||||||
|
return torch.cat((x, zeros), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _test_max_eig():
|
def _test_max_eig():
|
||||||
|
|||||||
@ -49,6 +49,7 @@ from scaling import (
|
|||||||
ScheduledFloat,
|
ScheduledFloat,
|
||||||
FloatLike,
|
FloatLike,
|
||||||
limit_param_value,
|
limit_param_value,
|
||||||
|
convert_num_channels,
|
||||||
ScaleGrad,
|
ScaleGrad,
|
||||||
)
|
)
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@ -225,8 +226,7 @@ class Zipformer(EncoderInterface):
|
|||||||
if downsampling_factor[i] != 1:
|
if downsampling_factor[i] != 1:
|
||||||
encoder = DownsampledZipformerEncoder(
|
encoder = DownsampledZipformerEncoder(
|
||||||
encoder,
|
encoder,
|
||||||
input_dim=encoder_dim[i-1] if i > 0 else encoder_dim[0],
|
dim=encoder_dim[i],
|
||||||
output_dim=encoder_dim[i],
|
|
||||||
downsample=downsampling_factor[i],
|
downsample=downsampling_factor[i],
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
@ -242,7 +242,6 @@ class Zipformer(EncoderInterface):
|
|||||||
self._init_skip_modules()
|
self._init_skip_modules()
|
||||||
|
|
||||||
self.downsample_output = SimpleDownsample(max(encoder_dim),
|
self.downsample_output = SimpleDownsample(max(encoder_dim),
|
||||||
max(encoder_dim),
|
|
||||||
downsample=output_downsampling_factor,
|
downsample=output_downsampling_factor,
|
||||||
dropout=dropout)
|
dropout=dropout)
|
||||||
|
|
||||||
@ -269,8 +268,7 @@ class Zipformer(EncoderInterface):
|
|||||||
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
|
logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
|
||||||
f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.")
|
f"combine the outputs of layers {j} and {i-1}, with downsampling_factor={z[j]} and {z[i-1]}.")
|
||||||
skip_layers.append(j)
|
skip_layers.append(j)
|
||||||
skip_modules.append(SimpleCombiner(self.encoder_dim[j],
|
skip_modules.append(SimpleCombiner(self.encoder_dim[i-1],
|
||||||
self.encoder_dim[i-1],
|
|
||||||
min_weight=(0.0, 0.25)))
|
min_weight=(0.0, 0.25)))
|
||||||
break
|
break
|
||||||
self.skip_layers = skip_layers
|
self.skip_layers = skip_layers
|
||||||
@ -412,6 +410,7 @@ class Zipformer(EncoderInterface):
|
|||||||
x = torch.where(mask, skip_x, x)
|
x = torch.where(mask, skip_x, x)
|
||||||
else:
|
else:
|
||||||
x = skip_x
|
x = skip_x
|
||||||
|
x = convert_num_channels(x, self.encoder_dim[i])
|
||||||
x = module(x,
|
x = module(x,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
feature_mask=feature_masks[i],
|
feature_mask=feature_masks[i],
|
||||||
@ -871,19 +870,16 @@ class DownsampledZipformerEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
encoder: nn.Module,
|
encoder: nn.Module,
|
||||||
input_dim: int,
|
dim: int,
|
||||||
output_dim: int,
|
|
||||||
downsample: int,
|
downsample: int,
|
||||||
dropout: FloatLike):
|
dropout: FloatLike):
|
||||||
super(DownsampledZipformerEncoder, self).__init__()
|
super(DownsampledZipformerEncoder, self).__init__()
|
||||||
self.downsample_factor = downsample
|
self.downsample_factor = downsample
|
||||||
self.downsample = SimpleDownsample(input_dim, output_dim,
|
self.downsample = SimpleDownsample(dim,
|
||||||
downsample, dropout)
|
downsample, dropout)
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.upsample = SimpleUpsample(output_dim, downsample)
|
self.upsample = SimpleUpsample(dim, downsample)
|
||||||
self.out_combiner = SimpleCombiner(input_dim,
|
self.out_combiner = SimpleCombiner(dim, min_weight=(0.0, 0.25))
|
||||||
output_dim,
|
|
||||||
min_weight=(0.0, 0.25))
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
@ -933,13 +929,9 @@ class SimpleDownsample(torch.nn.Module):
|
|||||||
Does downsampling with attention, by weighted sum, and a projection..
|
Does downsampling with attention, by weighted sum, and a projection..
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
in_channels: int,
|
channels: int,
|
||||||
out_channels: int,
|
|
||||||
downsample: int,
|
downsample: int,
|
||||||
dropout: FloatLike):
|
dropout: FloatLike):
|
||||||
"""
|
|
||||||
Require out_channels > in_channels.
|
|
||||||
"""
|
|
||||||
super(SimpleDownsample, self).__init__()
|
super(SimpleDownsample, self).__init__()
|
||||||
|
|
||||||
self.bias = nn.Parameter(torch.zeros(downsample))
|
self.bias = nn.Parameter(torch.zeros(downsample))
|
||||||
@ -947,22 +939,14 @@ class SimpleDownsample(torch.nn.Module):
|
|||||||
self.name = None # will be set from training code
|
self.name = None # will be set from training code
|
||||||
self.dropout = copy.deepcopy(dropout)
|
self.dropout = copy.deepcopy(dropout)
|
||||||
|
|
||||||
# fill in the extra dimensions with a projection of the input
|
|
||||||
if out_channels > in_channels:
|
|
||||||
self.extra_proj = nn.Linear(in_channels * downsample,
|
|
||||||
out_channels - in_channels,
|
|
||||||
bias=False)
|
|
||||||
else:
|
|
||||||
self.extra_proj = None
|
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.out_channels = out_channels
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
src: Tensor) -> Tensor:
|
src: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
x: (seq_len, batch_size, in_channels)
|
x: (seq_len, batch_size, in_channels)
|
||||||
Returns a tensor of shape
|
Returns a tensor of shape
|
||||||
( (seq_len+downsample-1)//downsample, batch_size, out_channels)
|
( (seq_len+downsample-1)//downsample, batch_size, channels)
|
||||||
"""
|
"""
|
||||||
(seq_len, batch_size, in_channels) = src.shape
|
(seq_len, batch_size, in_channels) = src.shape
|
||||||
ds = self.downsample
|
ds = self.downsample
|
||||||
@ -984,13 +968,7 @@ class SimpleDownsample(torch.nn.Module):
|
|||||||
|
|
||||||
# ans1 is the first `in_channels` channels of the output
|
# ans1 is the first `in_channels` channels of the output
|
||||||
ans = (src * weights).sum(dim=1)
|
ans = (src * weights).sum(dim=1)
|
||||||
src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels)
|
|
||||||
|
|
||||||
if self.extra_proj is not None:
|
|
||||||
ans2 = self.extra_proj(src)
|
|
||||||
ans = torch.cat((ans, ans2), dim=2)
|
|
||||||
|
|
||||||
ans = ans[..., :self.out_channels]
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@ -1028,26 +1006,25 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
The output will have the same dimension as dim2.
|
The output will have the same dimension as dim2.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
dim1: int,
|
dim: int,
|
||||||
dim2: int,
|
|
||||||
min_weight: Tuple[float, float] = (0., 0.)):
|
min_weight: Tuple[float, float] = (0., 0.)):
|
||||||
super(SimpleCombiner, self).__init__()
|
super(SimpleCombiner, self).__init__()
|
||||||
initial_weight1 = 0.1
|
initial_weight1 = 0.1
|
||||||
self.weight1 = nn.Parameter(torch.full((dim2,), initial_weight1))
|
self.weight1 = nn.Parameter(torch.full((dim,), initial_weight1))
|
||||||
self.min_weight = min_weight
|
self.min_weight = min_weight
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
src1: Tensor,
|
src1: Tensor,
|
||||||
src2: Tensor) -> Tensor:
|
src2: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
src1: (*, dim1)
|
src1: (*, other_dim)
|
||||||
src2: (*, dim2)
|
src2: (*, dim)
|
||||||
|
|
||||||
Returns: a tensor of shape (*, dim2)
|
Returns: a tensor of shape (*, dim)
|
||||||
"""
|
"""
|
||||||
assert src1.shape[:-1] == src2.shape[:-1]
|
assert src1.shape[:-1] == src2.shape[:-1]
|
||||||
dim1 = src1.shape[-1]
|
num_channels = src2.shape[-1]
|
||||||
dim2 = src2.shape[-1]
|
src1 = convert_num_channels(src1, num_channels)
|
||||||
|
|
||||||
|
|
||||||
weight1 = limit_param_value(self.weight1,
|
weight1 = limit_param_value(self.weight1,
|
||||||
@ -1055,22 +1032,7 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
max=1.0-self.min_weight[1],
|
max=1.0-self.min_weight[1],
|
||||||
training=self.training)
|
training=self.training)
|
||||||
|
|
||||||
src1_dim = src1.shape[-1]
|
return src1 * weight1 + src2 * (1.0 - weight1)
|
||||||
src2_dim = src2.shape[-1]
|
|
||||||
if src1_dim != src2_dim:
|
|
||||||
if src1_dim < src2_dim:
|
|
||||||
zeros_shape = list(src1.shape[:-1]) + [src2_dim - src1_dim]
|
|
||||||
src1 = torch.cat((src1, torch.zeros(*zeros_shape,
|
|
||||||
device=src1.device,
|
|
||||||
dtype=src1.dtype)),
|
|
||||||
dim=-1)
|
|
||||||
else:
|
|
||||||
src1 = src1[...,:src2_dim]
|
|
||||||
|
|
||||||
src1 = src1 * weight1
|
|
||||||
src2 = src2 * (1.0 - weight1)
|
|
||||||
|
|
||||||
return src1 + src2
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user