mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge branch 'scaled_adam_exp27' into scaled_adam_exp69
# Conflicts: # egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py
This commit is contained in:
commit
e4c9786e4a
@ -41,7 +41,7 @@ class Conformer(EncoderInterface):
|
|||||||
Args:
|
Args:
|
||||||
num_features (int): Number of input features
|
num_features (int): Number of input features
|
||||||
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||||
d_model (int): attention dimension, also the output dimension
|
d_model (int): (attention_dim1, attention_dim2, output_dim)
|
||||||
nhead (int): number of head
|
nhead (int): number of head
|
||||||
dim_feedforward (int): feedforward dimention
|
dim_feedforward (int): feedforward dimention
|
||||||
num_encoder_layers (int): number of encoder layers
|
num_encoder_layers (int): number of encoder layers
|
||||||
@ -55,13 +55,13 @@ class Conformer(EncoderInterface):
|
|||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
subsampling_factor: int = 4,
|
subsampling_factor: int = 4,
|
||||||
d_model: int = 256,
|
conformer_subsampling_factor: int = 4,
|
||||||
nhead: int = 4,
|
d_model: Tuple[int] = (256, 384, 512),
|
||||||
dim_feedforward: int = 2048,
|
nhead: Tuple[int] = (8, 8),
|
||||||
num_encoder_layers: int = 12,
|
feedforward_dim: Tuple[int] = (1536, 2048),
|
||||||
|
num_encoder_layers: Tuple[int] = (12, 12),
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.25,
|
cnn_module_kernel: Tuple[int] = (31, 31),
|
||||||
cnn_module_kernel: int = 31,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__()
|
super(Conformer, self).__init__()
|
||||||
|
|
||||||
@ -75,22 +75,42 @@ class Conformer(EncoderInterface):
|
|||||||
# That is, it does two things simultaneously:
|
# That is, it does two things simultaneously:
|
||||||
# (1) subsampling: T -> T//subsampling_factor
|
# (1) subsampling: T -> T//subsampling_factor
|
||||||
# (2) embedding: num_features -> d_model
|
# (2) embedding: num_features -> d_model
|
||||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
self.encoder_embed = Conv2dSubsampling(num_features, d_model[0],
|
||||||
|
dropout=dropout)
|
||||||
|
|
||||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
encoder_layer1 = ConformerEncoderLayer(
|
||||||
|
d_model[0],
|
||||||
encoder_layer = ConformerEncoderLayer(
|
nhead[0],
|
||||||
d_model,
|
feedforward_dim[0],
|
||||||
nhead,
|
|
||||||
dim_feedforward,
|
|
||||||
dropout,
|
dropout,
|
||||||
cnn_module_kernel,
|
cnn_module_kernel[0],
|
||||||
|
|
||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(
|
self.encoder1 = ConformerEncoder(
|
||||||
encoder_layer,
|
encoder_layer1,
|
||||||
num_encoder_layers,
|
num_encoder_layers[0],
|
||||||
layer_dropout=layer_dropout,
|
dropout,
|
||||||
)
|
)
|
||||||
|
encoder_layer2 = ConformerEncoderLayer(
|
||||||
|
d_model[1],
|
||||||
|
nhead[1],
|
||||||
|
feedforward_dim[1],
|
||||||
|
dropout,
|
||||||
|
cnn_module_kernel[1],
|
||||||
|
)
|
||||||
|
self.encoder2 = DownsampledConformerEncoder(
|
||||||
|
ConformerEncoder(
|
||||||
|
encoder_layer2,
|
||||||
|
num_encoder_layers[1],
|
||||||
|
dropout,
|
||||||
|
),
|
||||||
|
input_dim=d_model[0],
|
||||||
|
output_dim=d_model[1],
|
||||||
|
downsample=conformer_subsampling_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_combiner = SimpleCombiner(d_model[0],
|
||||||
|
d_model[1])
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -110,7 +130,7 @@ class Conformer(EncoderInterface):
|
|||||||
of frames in `embeddings` before padding.
|
of frames in `embeddings` before padding.
|
||||||
"""
|
"""
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x, pos_emb = self.encoder_pos(x)
|
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
@ -120,9 +140,17 @@ class Conformer(EncoderInterface):
|
|||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
mask = make_pad_mask(lengths)
|
mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
x = self.encoder(
|
# x1:
|
||||||
x, pos_emb, src_key_padding_mask=mask,
|
x1 = self.encoder1(
|
||||||
) # (T, N, C)
|
x, src_key_padding_mask=mask,
|
||||||
|
) # (T, N, C) where C == d_model[0]
|
||||||
|
|
||||||
|
x2 = self.encoder2(
|
||||||
|
x1, src_key_padding_mask=mask,
|
||||||
|
) # (T, N, C) where C == d_model[1]
|
||||||
|
|
||||||
|
x = self.out_combiner(x1, x2)
|
||||||
|
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
@ -137,7 +165,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
d_model: the number of expected features in the input (required).
|
d_model: the number of expected features in the input (required).
|
||||||
nhead: the number of heads in the multiheadattention models (required).
|
nhead: the number of heads in the multiheadattention models (required).
|
||||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
feedforward_dim: the dimension of the feedforward network model (default=2048).
|
||||||
dropout: the dropout value (default=0.1).
|
dropout: the dropout value (default=0.1).
|
||||||
cnn_module_kernel (int): Kernel size of convolution module.
|
cnn_module_kernel (int): Kernel size of convolution module.
|
||||||
|
|
||||||
@ -151,7 +179,7 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
d_model: int,
|
d_model: int,
|
||||||
nhead: int,
|
nhead: int,
|
||||||
dim_feedforward: int = 2048,
|
feedforward_dim: int = 2048,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -164,22 +192,22 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward = nn.Sequential(
|
self.feed_forward = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, feedforward_dim),
|
||||||
ActivationBalancer(dim_feedforward,
|
ActivationBalancer(feedforward_dim,
|
||||||
channel_dim=-1, max_abs=10.0),
|
channel_dim=-1, max_abs=10.0),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model,
|
ScaledLinear(feedforward_dim, d_model,
|
||||||
initial_scale=0.01),
|
initial_scale=0.01),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.feed_forward_macaron = nn.Sequential(
|
self.feed_forward_macaron = nn.Sequential(
|
||||||
nn.Linear(d_model, dim_feedforward),
|
nn.Linear(d_model, feedforward_dim),
|
||||||
ActivationBalancer(dim_feedforward,
|
ActivationBalancer(feedforward_dim,
|
||||||
channel_dim=-1, max_abs=10.0),
|
channel_dim=-1, max_abs=10.0),
|
||||||
DoubleSwish(),
|
DoubleSwish(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
ScaledLinear(dim_feedforward, d_model,
|
ScaledLinear(feedforward_dim, d_model,
|
||||||
initial_scale=0.01),
|
initial_scale=0.01),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -261,22 +289,23 @@ class ConformerEncoder(nn.Module):
|
|||||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||||
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
|
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
|
||||||
>>> src = torch.rand(10, 32, 512)
|
>>> src = torch.rand(10, 32, 512)
|
||||||
>>> pos_emb = torch.rand(32, 19, 512)
|
>>> out = conformer_encoder(src)
|
||||||
>>> out = conformer_encoder(src, pos_emb)
|
|
||||||
|
|
||||||
|
Returns: (combined_output, output),
|
||||||
|
where `combined_output` has gone through the RandomCombiner module and `output` is just the
|
||||||
|
original output, in case you need to bypass the RandomCombiner module.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder_layer: nn.Module,
|
encoder_layer: nn.Module,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
layer_dropout: float = 0.25
|
dropout: float
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert 0 < layer_dropout < 0.5
|
|
||||||
# `count` tracks how many times the forward function has been called
|
self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model,
|
||||||
# since we initialized the model (it is not written to disk or read when
|
dropout)
|
||||||
# we resume training). It is used for random seeding for layer dropping.
|
|
||||||
self.count = 0
|
|
||||||
self.layer_dropout = layer_dropout
|
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||||
@ -287,19 +316,16 @@ class ConformerEncoder(nn.Module):
|
|||||||
num_channels = encoder_layer.norm_final.num_channels
|
num_channels = encoder_layer.norm_final.num_channels
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
pos_emb: Tensor,
|
|
||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src: the sequence to the encoder (required).
|
src: the sequence to the encoder (required).
|
||||||
pos_emb: Positional embedding tensor (required).
|
|
||||||
mask: the mask for the src sequence (optional).
|
mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
|
||||||
@ -310,7 +336,9 @@ class ConformerEncoder(nn.Module):
|
|||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||||
|
|
||||||
|
Returns: (x, x_no_combine), both of shape (S, N, E)
|
||||||
"""
|
"""
|
||||||
|
pos_emb = self.encoder_pos(src)
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
@ -337,7 +365,6 @@ class ConformerEncoder(nn.Module):
|
|||||||
frame_mask = torch.logical_or(frame_mask,
|
frame_mask = torch.logical_or(frame_mask,
|
||||||
torch.rand_like(src[:,:1,:1]) < 0.1)
|
torch.rand_like(src[:,:1,:1]) < 0.1)
|
||||||
|
|
||||||
|
|
||||||
feature_mask[..., feature_unmasked_dim:] *= frame_mask
|
feature_mask[..., feature_unmasked_dim:] *= frame_mask
|
||||||
|
|
||||||
|
|
||||||
@ -364,11 +391,190 @@ class ConformerEncoder(nn.Module):
|
|||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output * feature_mask
|
output = output * feature_mask
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class DownsampledConformerEncoder(nn.Module):
|
||||||
|
r"""
|
||||||
|
DownsampledConformerEncoder is a conformer encoder evaluated at a reduced frame rate,
|
||||||
|
after convolutional downsampling, and then upsampled again at the output
|
||||||
|
so that the output has the same shape as the input.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
encoder: nn.Module,
|
||||||
|
input_dim: int,
|
||||||
|
output_dim: int,
|
||||||
|
downsample: int):
|
||||||
|
super(DownsampledConformerEncoder, self).__init__()
|
||||||
|
self.downsample_factor = downsample
|
||||||
|
self.downsample = AttentionDownsample(input_dim, output_dim, downsample)
|
||||||
|
self.encoder = encoder
|
||||||
|
self.upsample = SimpleUpsample(output_dim, downsample)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
src: Tensor,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
r"""Downsample, go through encoder, upsample.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src: the sequence to the encoder (required).
|
||||||
|
mask: the mask for the src sequence (optional). CAUTION: we need to downsample
|
||||||
|
this, if we are to support it. Won't work correctly yet.
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
src: (S, N, E).
|
||||||
|
mask: (S, S).
|
||||||
|
src_key_padding_mask: (N, S).
|
||||||
|
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||||
|
|
||||||
|
Returns: output of shape (S, N, F) where F is the number of output features
|
||||||
|
(output_dim to constructor)
|
||||||
|
"""
|
||||||
|
src_orig = src
|
||||||
|
src = self.downsample(src)
|
||||||
|
ds = self.downsample_factor
|
||||||
|
if mask is not None:
|
||||||
|
mask = mask[::ds,::ds]
|
||||||
|
if src_key_padding_mask is not None:
|
||||||
|
src_key_padding_mask = src_key_padding_mask[::ds]
|
||||||
|
|
||||||
|
src = self.encoder(
|
||||||
|
src, src_key_padding_mask=mask,
|
||||||
|
)
|
||||||
|
src = self.upsample(src)
|
||||||
|
# remove any extra frames that are not a multiple of downsample_factor
|
||||||
|
src = src[:src_orig.shape[0]]
|
||||||
|
|
||||||
|
return src
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionDownsample(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Does downsampling with attention, by weighted sum, and a projection..
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
downsample: int):
|
||||||
|
"""
|
||||||
|
Require out_channels > in_channels.
|
||||||
|
"""
|
||||||
|
super(AttentionDownsample, self).__init__()
|
||||||
|
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
|
||||||
|
|
||||||
|
# fill in the extra dimensions with a projection of the input
|
||||||
|
if out_channels > in_channels:
|
||||||
|
self.extra_proj = nn.Linear(in_channels * downsample,
|
||||||
|
out_channels - in_channels,
|
||||||
|
bias=False)
|
||||||
|
else:
|
||||||
|
self.extra_proj = None
|
||||||
|
self.downsample = downsample
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
src: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
x: (seq_len, batch_size, in_channels)
|
||||||
|
Returns a tensor of shape
|
||||||
|
( (seq_len+downsample-1)//downsample, batch_size, out_channels)
|
||||||
|
"""
|
||||||
|
(seq_len, batch_size, in_channels) = src.shape
|
||||||
|
ds = self.downsample
|
||||||
|
d_seq_len = (seq_len + ds - 1) // ds
|
||||||
|
src_orig = src
|
||||||
|
# Pad to an exact multiple of self.downsample
|
||||||
|
if seq_len != d_seq_len * ds:
|
||||||
|
# right-pad src, repeating the last element.
|
||||||
|
pad = d_seq_len * ds - seq_len
|
||||||
|
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
|
||||||
|
src = torch.cat((src, src_extra), dim=0)
|
||||||
|
assert src.shape[0] == d_seq_len * ds
|
||||||
|
|
||||||
|
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
|
||||||
|
scores = (src * self.query).sum(dim=-1, keepdim=True)
|
||||||
|
weights = scores.softmax(dim=1)
|
||||||
|
|
||||||
|
# ans1 is the first `in_channels` channels of the output
|
||||||
|
ans = (src * weights).sum(dim=1)
|
||||||
|
src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels)
|
||||||
|
|
||||||
|
if self.extra_proj is not None:
|
||||||
|
ans2 = self.extra_proj(src)
|
||||||
|
ans = torch.cat((ans, ans2), dim=2)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleUpsample(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A very simple form of upsampling that mostly just repeats the input, but
|
||||||
|
also adds a position-specific bias.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
num_channels: int,
|
||||||
|
upsample: int):
|
||||||
|
super(SimpleUpsample, self).__init__()
|
||||||
|
self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
src: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
x: (seq_len, batch_size, num_channels)
|
||||||
|
Returns a tensor of shape
|
||||||
|
( (seq_len*upsample), batch_size, num_channels)
|
||||||
|
"""
|
||||||
|
upsample = self.bias.shape[0]
|
||||||
|
(seq_len, batch_size, num_channels) = src.shape
|
||||||
|
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
|
||||||
|
src = src + self.bias.unsqueeze(1)
|
||||||
|
src = src.reshape(seq_len * upsample, batch_size, num_channels)
|
||||||
|
return src
|
||||||
|
|
||||||
|
class SimpleCombiner(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A very simple way of combining 2 vectors of 2 different dims, via a
|
||||||
|
learned weighted combination in the shared part of the dim.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
dim1: int,
|
||||||
|
dim2: int):
|
||||||
|
super(SimpleCombiner, self).__init__()
|
||||||
|
assert dim2 >= dim1
|
||||||
|
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
|
||||||
|
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
src1: Tensor,
|
||||||
|
src2: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
src1: (*, dim1)
|
||||||
|
src2: (*, dim2)
|
||||||
|
|
||||||
|
Returns: a tensor of shape (*, dim2)
|
||||||
|
"""
|
||||||
|
assert src1.shape[:-1] == src2.shape[:-1]
|
||||||
|
dim1 = src1.shape[-1]
|
||||||
|
dim2 = src2.shape[-1]
|
||||||
|
|
||||||
|
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
|
||||||
|
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
|
||||||
|
weight = (weight1 + weight2).sigmoid()
|
||||||
|
|
||||||
|
src2_part1 = src2[...,:dim1]
|
||||||
|
part1 = src1 * weight + src2_part1 * (1.0 - weight)
|
||||||
|
part2 = src2[...,dim1:]
|
||||||
|
return torch.cat((part1, part2), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RelPositionalEncoding(torch.nn.Module):
|
class RelPositionalEncoding(torch.nn.Module):
|
||||||
"""Relative positional encoding module.
|
"""Relative positional encoding module.
|
||||||
|
|
||||||
@ -385,7 +591,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, d_model: int, dropout_rate: float, max_len: int = 5000
|
self, d_model: int, dropout_rate: float, max_len: int = 5000
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct an PositionalEncoding object."""
|
"""Construct a PositionalEncoding object."""
|
||||||
super(RelPositionalEncoding, self).__init__()
|
super(RelPositionalEncoding, self).__init__()
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||||
@ -397,7 +603,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
if self.pe is not None:
|
if self.pe is not None:
|
||||||
# self.pe contains both positive and negative parts
|
# self.pe contains both positive and negative parts
|
||||||
# the length of self.pe is 2 * input_len - 1
|
# the length of self.pe is 2 * input_len - 1
|
||||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
if self.pe.size(1) >= x.size(0) * 2 - 1:
|
||||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||||
x.device
|
x.device
|
||||||
@ -407,9 +613,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||||
# position of key vector. We use position relative positions when keys
|
# position of key vector. We use position relative positions when keys
|
||||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
pe_positive = torch.zeros(x.size(0), self.d_model)
|
||||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
pe_negative = torch.zeros(x.size(0), self.d_model)
|
||||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
position = torch.arange(0, x.size(0), dtype=torch.float32).unsqueeze(1)
|
||||||
div_term = torch.exp(
|
div_term = torch.exp(
|
||||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||||
* -(math.log(10000.0) / self.d_model)
|
* -(math.log(10000.0) / self.d_model)
|
||||||
@ -431,7 +637,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
"""Add positional encoding.
|
"""Add positional encoding.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
x (torch.Tensor): Input tensor (time, batch, `*`).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||||
@ -442,11 +648,11 @@ class RelPositionalEncoding(torch.nn.Module):
|
|||||||
pos_emb = self.pe[
|
pos_emb = self.pe[
|
||||||
:,
|
:,
|
||||||
self.pe.size(1) // 2
|
self.pe.size(1) // 2
|
||||||
- x.size(1)
|
- x.size(0)
|
||||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||||
+ x.size(1),
|
+ x.size(0),
|
||||||
]
|
]
|
||||||
return self.dropout(x), self.dropout(pos_emb)
|
return self.dropout(pos_emb)
|
||||||
|
|
||||||
|
|
||||||
class RelPositionMultiheadAttention(nn.Module):
|
class RelPositionMultiheadAttention(nn.Module):
|
||||||
@ -478,7 +684,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
self.head_dim = embed_dim // (num_heads * 2)
|
self.head_dim = embed_dim // (num_heads * 2)
|
||||||
assert (
|
assert (
|
||||||
self.head_dim * num_heads == self.embed_dim // 2
|
self.head_dim * num_heads == self.embed_dim // 2
|
||||||
), "embed_dim must be divisible by num_heads"
|
), "embed_dim//2 must be divisible by num_heads"
|
||||||
|
|
||||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
|
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
|
||||||
self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
|
self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
|
||||||
@ -965,6 +1171,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
layer1_channels: int = 8,
|
layer1_channels: int = 8,
|
||||||
layer2_channels: int = 32,
|
layer2_channels: int = 32,
|
||||||
layer3_channels: int = 128,
|
layer3_channels: int = 128,
|
||||||
|
dropout: float = 0.1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -1012,6 +1219,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
)
|
)
|
||||||
out_height = (((in_channels - 1) // 2 - 1) // 2)
|
out_height = (((in_channels - 1) // 2 - 1) // 2)
|
||||||
self.out = ScaledLinear(out_height * layer3_channels, out_channels)
|
self.out = ScaledLinear(out_height * layer3_channels, out_channels)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -1031,6 +1239,7 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
b, c, t, f = x.size()
|
b, c, t, f = x.size()
|
||||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||||
|
x = self.dropout(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class AttentionCombine(nn.Module):
|
class AttentionCombine(nn.Module):
|
||||||
@ -1166,14 +1375,13 @@ def _test_random_combine():
|
|||||||
|
|
||||||
def _test_conformer_main():
|
def _test_conformer_main():
|
||||||
feature_dim = 50
|
feature_dim = 50
|
||||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
seq_len = 20
|
seq_len = 20
|
||||||
feature_dim = 50
|
feature_dim = 50
|
||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
|
|
||||||
c = Conformer(
|
c = Conformer(
|
||||||
num_features=feature_dim, d_model=128, nhead=4
|
num_features=feature_dim, d_model=(64,96,128), nhead=(4,4)
|
||||||
)
|
)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
seq_len = 20
|
seq_len = 20
|
||||||
@ -1191,8 +1399,6 @@ def _test_conformer_main():
|
|||||||
f # to remove flake8 warnings
|
f # to remove flake8 warnings
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logging.getLogger().setLevel(logging.INFO)
|
logging.getLogger().setLevel(logging.INFO)
|
||||||
torch.set_num_threads(1)
|
torch.set_num_threads(1)
|
||||||
|
|||||||
@ -91,30 +91,38 @@ LRSchedulerType = Union[
|
|||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-encoder-layers",
|
"--num-encoder-layers",
|
||||||
type=int,
|
type=str,
|
||||||
default=24,
|
default="12,12",
|
||||||
help="Number of conformer encoder layers..",
|
help="Number of conformer encoder layers, comma separated.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dim-feedforward",
|
"--feedforward-dims",
|
||||||
type=int,
|
type=str,
|
||||||
default=1536,
|
default="1536,1536",
|
||||||
help="Feedforward dimension of the conformer encoder layer.",
|
help="Feedforward dimension of the conformer encoder layers, comma separated.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--nhead",
|
"--nhead",
|
||||||
type=int,
|
type=str,
|
||||||
default=8,
|
default="8,8",
|
||||||
help="Number of attention heads in the conformer encoder layer.",
|
help="Number of attention heads in the conformer encoder layers.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-dim",
|
"--encoder-dims",
|
||||||
|
type=str,
|
||||||
|
default="384,384",
|
||||||
|
help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, "
|
||||||
|
"and the output dim of the encoder",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--conformer-subsampling-factor",
|
||||||
type=int,
|
type=int,
|
||||||
default=384,
|
default=2,
|
||||||
help="Attention dimension in the conformer encoder layer.",
|
help="Subsampling factor for 2nd stack of encoder layers.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -401,13 +409,16 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
# TODO: We can add an option to switch between Conformer and Transformer
|
# TODO: We can add an option to switch between Conformer and Transformer
|
||||||
|
def to_int_list(s: str):
|
||||||
|
return list(map(int, s.split(',')))
|
||||||
encoder = Conformer(
|
encoder = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
d_model=params.encoder_dim,
|
conformer_subsampling_factor=params.conformer_subsampling_factor,
|
||||||
nhead=params.nhead,
|
d_model=to_int_list(params.encoder_dims),
|
||||||
dim_feedforward=params.dim_feedforward,
|
nhead=to_int_list(params.nhead),
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
feedforward_dim=to_int_list(params.feedforward_dims),
|
||||||
|
num_encoder_layers=to_int_list(params.num_encoder_layers),
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -424,7 +435,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
encoder_dim=params.encoder_dim,
|
encoder_dim=int(params.encoder_dims.split(',')[-1]),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
@ -441,7 +452,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
encoder_dim=params.encoder_dim,
|
encoder_dim=int(params.encoder_dims.split(',')[-1]),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user