mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Various fixes
This commit is contained in:
parent
d34eafa623
commit
01af88c2f6
@ -56,7 +56,7 @@ class Conformer(EncoderInterface):
|
|||||||
conformer_subsampling_factor: int = 4,
|
conformer_subsampling_factor: int = 4,
|
||||||
d_model: Tuple[int] = (256, 384, 512),
|
d_model: Tuple[int] = (256, 384, 512),
|
||||||
nhead: Tuple[int] = (8, 8),
|
nhead: Tuple[int] = (8, 8),
|
||||||
dim_feedforward: Tuple[int] = (1536, 2048),
|
feedforward_dim: Tuple[int] = (1536, 2048),
|
||||||
num_encoder_layers: Tuple[int] = (12, 12),
|
num_encoder_layers: Tuple[int] = (12, 12),
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.075,
|
layer_dropout: float = 0.075,
|
||||||
@ -81,7 +81,7 @@ class Conformer(EncoderInterface):
|
|||||||
encoder_layer1 = ConformerEncoderLayer(
|
encoder_layer1 = ConformerEncoderLayer(
|
||||||
d_model[0],
|
d_model[0],
|
||||||
nhead[0],
|
nhead[0],
|
||||||
dim_feedforward[0],
|
feedforward_dim[0],
|
||||||
dropout,
|
dropout,
|
||||||
layer_dropout,
|
layer_dropout,
|
||||||
cnn_module_kernel[0],
|
cnn_module_kernel[0],
|
||||||
@ -95,7 +95,7 @@ class Conformer(EncoderInterface):
|
|||||||
encoder_layer2 = ConformerEncoderLayer(
|
encoder_layer2 = ConformerEncoderLayer(
|
||||||
d_model[1],
|
d_model[1],
|
||||||
nhead[1],
|
nhead[1],
|
||||||
dim_feedforward[1],
|
feedforward_dim[1],
|
||||||
dropout,
|
dropout,
|
||||||
layer_dropout,
|
layer_dropout,
|
||||||
cnn_module_kernel[1],
|
cnn_module_kernel[1],
|
||||||
@ -150,12 +150,12 @@ class Conformer(EncoderInterface):
|
|||||||
mask = make_pad_mask(lengths)
|
mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
# x1:
|
# x1:
|
||||||
x1, x_no_combine = self.encoder1(
|
x1, x1_no_combine = self.encoder1(
|
||||||
x, src_key_padding_mask=mask, warmup=warmup
|
x, src_key_padding_mask=mask, warmup=warmup
|
||||||
) # (T, N, C) where C == d_model[0]
|
) # (T, N, C) where C == d_model[0]
|
||||||
|
|
||||||
x2 = self.encoder1(
|
x2 = self.encoder2(
|
||||||
x1, src_key_padding_mask=mask, warmup=warmup
|
x1_no_combine, src_key_padding_mask=mask, warmup=warmup
|
||||||
) # (T, N, C) where C == d_model[1]
|
) # (T, N, C) where C == d_model[1]
|
||||||
|
|
||||||
x = torch.cat((x1, x2), dim=2)
|
x = torch.cat((x1, x2), dim=2)
|
||||||
@ -175,7 +175,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
d_model: the number of expected features in the input (required).
|
d_model: the number of expected features in the input (required).
|
||||||
nhead: the number of heads in the multiheadattention models (required).
|
nhead: the number of heads in the multiheadattention models (required).
|
||||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
feedforward_dim: the dimension of the feedforward network model (default=2048).
|
||||||
dropout: the dropout value (default=0.1).
|
dropout: the dropout value (default=0.1).
|
||||||
cnn_module_kernel (int): Kernel size of convolution module.
|
cnn_module_kernel (int): Kernel size of convolution module.
|
||||||
|
|
||||||
@ -190,7 +190,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
d_model: int,
|
d_model: int,
|
||||||
nhead: int,
|
nhead: int,
|
||||||
dim_feedforward: int = 2048,
|
feedforward_dim: int = 2048,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.075,
|
layer_dropout: float = 0.075,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
@ -206,22 +206,22 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, feedforward_dim),
|
||||||
ActivationBalancer(dim_feedforward,
|
ActivationBalancer(feedforward_dim,
|
||||||
channel_dim=-1, max_abs=10.0),
|
channel_dim=-1, max_abs=10.0),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model,
|
ScaledLinear(feedforward_dim, d_model,
|
||||||
initial_scale=0.1),
|
initial_scale=0.1),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, feedforward_dim),
|
||||||
ActivationBalancer(dim_feedforward,
|
ActivationBalancer(feedforward_dim,
|
||||||
channel_dim=-1, max_abs=10.0),
|
channel_dim=-1, max_abs=10.0),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model,
|
ScaledLinear(feedforward_dim, d_model,
|
||||||
initial_scale=0.1),
|
initial_scale=0.1),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -420,7 +420,7 @@ class DownsampledConformerEncoder(nn.Module):
|
|||||||
downsample: int):
|
downsample: int):
|
||||||
super(DownsampledConformerEncoder, self).__init__()
|
super(DownsampledConformerEncoder, self).__init__()
|
||||||
|
|
||||||
self.downsample = downsample
|
self.downsample_factor = downsample
|
||||||
|
|
||||||
# note: we'll pad manually.
|
# note: we'll pad manually.
|
||||||
self.downsample = nn.Conv1d(
|
self.downsample = nn.Conv1d(
|
||||||
@ -459,43 +459,44 @@ class DownsampledConformerEncoder(nn.Module):
|
|||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||||
|
|
||||||
Returns: (x, x_no_combine), both of shape (S, N, E)
|
Returns: output of shape (S, N, F) where F is the number of output features
|
||||||
|
(output_dim to constructor)
|
||||||
"""
|
"""
|
||||||
(seq_len, batch_size, embedding_dim) = x.shape
|
(seq_len, batch_size, embedding_dim) = src.shape
|
||||||
ds = self.downsample
|
ds = self.downsample_factor
|
||||||
d_seq_len = (seq_len + ds - 1) // ds
|
d_seq_len = (seq_len + ds - 1) // ds
|
||||||
x_orig = x
|
src_orig = src
|
||||||
if seq_len != d_seq_len * ds:
|
if seq_len != d_seq_len * ds:
|
||||||
# right-pad x
|
# right-pad src, repeating the last element.
|
||||||
pad = seq_len - d_seq_len * ds
|
pad = d_seq_len * ds - seq_len
|
||||||
x = torch.nn.functional.pad(x,
|
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
|
||||||
(0, pad, 0, 0, 0, 0),
|
src = torch.cat((src, src_extra), dim=0)
|
||||||
mode='replicate')
|
assert src.shape[0] == d_seq_len * ds
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask = mask[::ds,::ds]
|
mask = mask[::ds,::ds]
|
||||||
if src_key_padding_mask is not None:
|
if src_key_padding_mask is not None:
|
||||||
src_key_padding_mask = src_key_padding_mask[::ds]
|
src_key_padding_mask = src_key_padding_mask[::ds]
|
||||||
|
|
||||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
src = src.permute(1, 2, 0) # (#batch, channels, time).
|
||||||
x = self.downsample(x)
|
src = self.downsample(src)
|
||||||
x = x.permute(2, 0, 1) # (time, batch, channels)
|
src = src.permute(2, 0, 1) # (time, batch, channels)
|
||||||
|
|
||||||
|
|
||||||
x, _x_no_combine = self.encoder(
|
src, _src_no_combine = self.encoder(
|
||||||
x, src_key_padding_mask=mask, warmup=warmup
|
src, src_key_padding_mask=mask, warmup=warmup
|
||||||
) # (T, N, C)
|
) # (T, N, C)
|
||||||
|
|
||||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
src = src.permute(1, 2, 0) # (#batch, channels, time).
|
||||||
x = self.upsample(x)
|
src = self.upsample(src)
|
||||||
x = x.permute(2, 0, 1) # (time, batch, channels)
|
src = src.permute(2, 0, 1) # (time, batch, channels)
|
||||||
|
|
||||||
new_seq_len = x.shape[0]
|
new_seq_len = src.shape[0]
|
||||||
assert new_seq_len >= seq_len
|
assert new_seq_len >= seq_len
|
||||||
if new_seq_len > seq_len:
|
if new_seq_len > seq_len:
|
||||||
x = x[:seq_len]
|
src = src[:seq_len]
|
||||||
|
|
||||||
return x
|
return src
|
||||||
|
|
||||||
|
|
||||||
class RelPositionalEncoding(torch.nn.Module):
|
class RelPositionalEncoding(torch.nn.Module):
|
||||||
|
|||||||
@ -413,11 +413,11 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
return list(map(int, s.split(',')))
|
return list(map(int, s.split(',')))
|
||||||
encoder = Conformer(
|
encoder = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
subsampling_factor=params.subsamplng_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
conformer_subsampling_factor=params.conformer_subsamplng_factor,
|
conformer_subsampling_factor=params.conformer_subsampling_factor,
|
||||||
d_model=to_int_list(params.encoder_dims),
|
d_model=to_int_list(params.encoder_dims),
|
||||||
nhead=to_int_list(params.nhead),
|
nhead=to_int_list(params.nhead),
|
||||||
feedforward_dims=to_int_list(params.feedforward_dims),
|
feedforward_dim=to_int_list(params.feedforward_dims),
|
||||||
num_encoder_layers=to_int_list(params.num_encoder_layers),
|
num_encoder_layers=to_int_list(params.num_encoder_layers),
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user