mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Various fixes
This commit is contained in:
parent
d34eafa623
commit
01af88c2f6
@ -56,7 +56,7 @@ class Conformer(EncoderInterface):
|
||||
conformer_subsampling_factor: int = 4,
|
||||
d_model: Tuple[int] = (256, 384, 512),
|
||||
nhead: Tuple[int] = (8, 8),
|
||||
dim_feedforward: Tuple[int] = (1536, 2048),
|
||||
feedforward_dim: Tuple[int] = (1536, 2048),
|
||||
num_encoder_layers: Tuple[int] = (12, 12),
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
@ -81,7 +81,7 @@ class Conformer(EncoderInterface):
|
||||
encoder_layer1 = ConformerEncoderLayer(
|
||||
d_model[0],
|
||||
nhead[0],
|
||||
dim_feedforward[0],
|
||||
feedforward_dim[0],
|
||||
dropout,
|
||||
layer_dropout,
|
||||
cnn_module_kernel[0],
|
||||
@ -95,7 +95,7 @@ class Conformer(EncoderInterface):
|
||||
encoder_layer2 = ConformerEncoderLayer(
|
||||
d_model[1],
|
||||
nhead[1],
|
||||
dim_feedforward[1],
|
||||
feedforward_dim[1],
|
||||
dropout,
|
||||
layer_dropout,
|
||||
cnn_module_kernel[1],
|
||||
@ -150,12 +150,12 @@ class Conformer(EncoderInterface):
|
||||
mask = make_pad_mask(lengths)
|
||||
|
||||
# x1:
|
||||
x1, x_no_combine = self.encoder1(
|
||||
x1, x1_no_combine = self.encoder1(
|
||||
x, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C) where C == d_model[0]
|
||||
|
||||
x2 = self.encoder1(
|
||||
x1, src_key_padding_mask=mask, warmup=warmup
|
||||
x2 = self.encoder2(
|
||||
x1_no_combine, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C) where C == d_model[1]
|
||||
|
||||
x = torch.cat((x1, x2), dim=2)
|
||||
@ -175,7 +175,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
Args:
|
||||
d_model: the number of expected features in the input (required).
|
||||
nhead: the number of heads in the multiheadattention models (required).
|
||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||
feedforward_dim: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
cnn_module_kernel (int): Kernel size of convolution module.
|
||||
|
||||
@ -190,7 +190,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
dim_feedforward: int = 2048,
|
||||
feedforward_dim: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
cnn_module_kernel: int = 31,
|
||||
@ -206,22 +206,22 @@ class ConformerEncoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
ActivationBalancer(dim_feedforward,
|
||||
nn.Linear(d_model, feedforward_dim),
|
||||
ActivationBalancer(feedforward_dim,
|
||||
channel_dim=-1, max_abs=10.0),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model,
|
||||
ScaledLinear(feedforward_dim, d_model,
|
||||
initial_scale=0.1),
|
||||
)
|
||||
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
ActivationBalancer(dim_feedforward,
|
||||
nn.Linear(d_model, feedforward_dim),
|
||||
ActivationBalancer(feedforward_dim,
|
||||
channel_dim=-1, max_abs=10.0),
|
||||
DoubleSwish(),
|
||||
nn.Dropout(dropout),
|
||||
ScaledLinear(dim_feedforward, d_model,
|
||||
ScaledLinear(feedforward_dim, d_model,
|
||||
initial_scale=0.1),
|
||||
)
|
||||
|
||||
@ -420,7 +420,7 @@ class DownsampledConformerEncoder(nn.Module):
|
||||
downsample: int):
|
||||
super(DownsampledConformerEncoder, self).__init__()
|
||||
|
||||
self.downsample = downsample
|
||||
self.downsample_factor = downsample
|
||||
|
||||
# note: we'll pad manually.
|
||||
self.downsample = nn.Conv1d(
|
||||
@ -459,43 +459,44 @@ class DownsampledConformerEncoder(nn.Module):
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
|
||||
Returns: (x, x_no_combine), both of shape (S, N, E)
|
||||
Returns: output of shape (S, N, F) where F is the number of output features
|
||||
(output_dim to constructor)
|
||||
"""
|
||||
(seq_len, batch_size, embedding_dim) = x.shape
|
||||
ds = self.downsample
|
||||
(seq_len, batch_size, embedding_dim) = src.shape
|
||||
ds = self.downsample_factor
|
||||
d_seq_len = (seq_len + ds - 1) // ds
|
||||
x_orig = x
|
||||
src_orig = src
|
||||
if seq_len != d_seq_len * ds:
|
||||
# right-pad x
|
||||
pad = seq_len - d_seq_len * ds
|
||||
x = torch.nn.functional.pad(x,
|
||||
(0, pad, 0, 0, 0, 0),
|
||||
mode='replicate')
|
||||
# right-pad src, repeating the last element.
|
||||
pad = d_seq_len * ds - seq_len
|
||||
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
|
||||
src = torch.cat((src, src_extra), dim=0)
|
||||
assert src.shape[0] == d_seq_len * ds
|
||||
|
||||
if mask is not None:
|
||||
mask = mask[::ds,::ds]
|
||||
if src_key_padding_mask is not None:
|
||||
src_key_padding_mask = src_key_padding_mask[::ds]
|
||||
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
x = self.downsample(x)
|
||||
x = x.permute(2, 0, 1) # (time, batch, channels)
|
||||
src = src.permute(1, 2, 0) # (#batch, channels, time).
|
||||
src = self.downsample(src)
|
||||
src = src.permute(2, 0, 1) # (time, batch, channels)
|
||||
|
||||
|
||||
x, _x_no_combine = self.encoder(
|
||||
x, src_key_padding_mask=mask, warmup=warmup
|
||||
src, _src_no_combine = self.encoder(
|
||||
src, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C)
|
||||
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
x = self.upsample(x)
|
||||
x = x.permute(2, 0, 1) # (time, batch, channels)
|
||||
src = src.permute(1, 2, 0) # (#batch, channels, time).
|
||||
src = self.upsample(src)
|
||||
src = src.permute(2, 0, 1) # (time, batch, channels)
|
||||
|
||||
new_seq_len = x.shape[0]
|
||||
new_seq_len = src.shape[0]
|
||||
assert new_seq_len >= seq_len
|
||||
if new_seq_len > seq_len:
|
||||
x = x[:seq_len]
|
||||
src = src[:seq_len]
|
||||
|
||||
return x
|
||||
return src
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
|
||||
@ -413,11 +413,11 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
return list(map(int, s.split(',')))
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
subsampling_factor=params.subsamplng_factor,
|
||||
conformer_subsampling_factor=params.conformer_subsamplng_factor,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
conformer_subsampling_factor=params.conformer_subsampling_factor,
|
||||
d_model=to_int_list(params.encoder_dims),
|
||||
nhead=to_int_list(params.nhead),
|
||||
feedforward_dims=to_int_list(params.feedforward_dims),
|
||||
feedforward_dim=to_int_list(params.feedforward_dims),
|
||||
num_encoder_layers=to_int_list(params.num_encoder_layers),
|
||||
)
|
||||
return encoder
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user