Various fixes

This commit is contained in:
Daniel Povey 2022-09-27 16:09:30 +08:00
parent d34eafa623
commit 01af88c2f6
2 changed files with 39 additions and 38 deletions

View File

@ -56,7 +56,7 @@ class Conformer(EncoderInterface):
conformer_subsampling_factor: int = 4,
d_model: Tuple[int] = (256, 384, 512),
nhead: Tuple[int] = (8, 8),
dim_feedforward: Tuple[int] = (1536, 2048),
feedforward_dim: Tuple[int] = (1536, 2048),
num_encoder_layers: Tuple[int] = (12, 12),
dropout: float = 0.1,
layer_dropout: float = 0.075,
@ -81,7 +81,7 @@ class Conformer(EncoderInterface):
encoder_layer1 = ConformerEncoderLayer(
d_model[0],
nhead[0],
dim_feedforward[0],
feedforward_dim[0],
dropout,
layer_dropout,
cnn_module_kernel[0],
@ -95,7 +95,7 @@ class Conformer(EncoderInterface):
encoder_layer2 = ConformerEncoderLayer(
d_model[1],
nhead[1],
dim_feedforward[1],
feedforward_dim[1],
dropout,
layer_dropout,
cnn_module_kernel[1],
@ -150,12 +150,12 @@ class Conformer(EncoderInterface):
mask = make_pad_mask(lengths)
# x1:
x1, x_no_combine = self.encoder1(
x1, x1_no_combine = self.encoder1(
x, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C) where C == d_model[0]
x2 = self.encoder1(
x1, src_key_padding_mask=mask, warmup=warmup
x2 = self.encoder2(
x1_no_combine, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C) where C == d_model[1]
x = torch.cat((x1, x2), dim=2)
@ -175,7 +175,7 @@ class ConformerEncoderLayer(nn.Module):
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
feedforward_dim: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module.
@ -190,7 +190,7 @@ class ConformerEncoderLayer(nn.Module):
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
feedforward_dim: int = 2048,
dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31,
@ -206,22 +206,22 @@ class ConformerEncoderLayer(nn.Module):
)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
ActivationBalancer(dim_feedforward,
nn.Linear(d_model, feedforward_dim),
ActivationBalancer(feedforward_dim,
channel_dim=-1, max_abs=10.0),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model,
ScaledLinear(feedforward_dim, d_model,
initial_scale=0.1),
)
self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
ActivationBalancer(dim_feedforward,
nn.Linear(d_model, feedforward_dim),
ActivationBalancer(feedforward_dim,
channel_dim=-1, max_abs=10.0),
DoubleSwish(),
nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model,
ScaledLinear(feedforward_dim, d_model,
initial_scale=0.1),
)
@ -420,7 +420,7 @@ class DownsampledConformerEncoder(nn.Module):
downsample: int):
super(DownsampledConformerEncoder, self).__init__()
self.downsample = downsample
self.downsample_factor = downsample
# note: we'll pad manually.
self.downsample = nn.Conv1d(
@ -459,43 +459,44 @@ class DownsampledConformerEncoder(nn.Module):
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
Returns: (x, x_no_combine), both of shape (S, N, E)
Returns: output of shape (S, N, F) where F is the number of output features
(output_dim to constructor)
"""
(seq_len, batch_size, embedding_dim) = x.shape
ds = self.downsample
(seq_len, batch_size, embedding_dim) = src.shape
ds = self.downsample_factor
d_seq_len = (seq_len + ds - 1) // ds
x_orig = x
src_orig = src
if seq_len != d_seq_len * ds:
# right-pad x
pad = seq_len - d_seq_len * ds
x = torch.nn.functional.pad(x,
(0, pad, 0, 0, 0, 0),
mode='replicate')
# right-pad src, repeating the last element.
pad = d_seq_len * ds - seq_len
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
src = torch.cat((src, src_extra), dim=0)
assert src.shape[0] == d_seq_len * ds
if mask is not None:
mask = mask[::ds,::ds]
if src_key_padding_mask is not None:
src_key_padding_mask = src_key_padding_mask[::ds]
x = x.permute(1, 2, 0) # (#batch, channels, time).
x = self.downsample(x)
x = x.permute(2, 0, 1) # (time, batch, channels)
src = src.permute(1, 2, 0) # (#batch, channels, time).
src = self.downsample(src)
src = src.permute(2, 0, 1) # (time, batch, channels)
x, _x_no_combine = self.encoder(
x, src_key_padding_mask=mask, warmup=warmup
src, _src_no_combine = self.encoder(
src, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
x = x.permute(1, 2, 0) # (#batch, channels, time).
x = self.upsample(x)
x = x.permute(2, 0, 1) # (time, batch, channels)
src = src.permute(1, 2, 0) # (#batch, channels, time).
src = self.upsample(src)
src = src.permute(2, 0, 1) # (time, batch, channels)
new_seq_len = x.shape[0]
new_seq_len = src.shape[0]
assert new_seq_len >= seq_len
if new_seq_len > seq_len:
x = x[:seq_len]
src = src[:seq_len]
return x
return src
class RelPositionalEncoding(torch.nn.Module):

View File

@ -413,11 +413,11 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
return list(map(int, s.split(',')))
encoder = Conformer(
num_features=params.feature_dim,
subsampling_factor=params.subsamplng_factor,
conformer_subsampling_factor=params.conformer_subsamplng_factor,
subsampling_factor=params.subsampling_factor,
conformer_subsampling_factor=params.conformer_subsampling_factor,
d_model=to_int_list(params.encoder_dims),
nhead=to_int_list(params.nhead),
feedforward_dims=to_int_list(params.feedforward_dims),
feedforward_dim=to_int_list(params.feedforward_dims),
num_encoder_layers=to_int_list(params.num_encoder_layers),
)
return encoder