Merge branch 'scaled_adam_exp27' into scaled_adam_exp69

# Conflicts:
#	egs/librispeech/ASR/pruned_transducer_stateless7/conformer.py
This commit is contained in:
Daniel Povey 2022-10-06 18:04:48 +08:00
commit e4c9786e4a
2 changed files with 295 additions and 78 deletions

View File

@ -41,7 +41,7 @@ class Conformer(EncoderInterface):
Args: Args:
num_features (int): Number of input features num_features (int): Number of input features
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension, also the output dimension d_model (int): (attention_dim1, attention_dim2, output_dim)
nhead (int): number of head nhead (int): number of head
dim_feedforward (int): feedforward dimention dim_feedforward (int): feedforward dimention
num_encoder_layers (int): number of encoder layers num_encoder_layers (int): number of encoder layers
@ -55,13 +55,13 @@ class Conformer(EncoderInterface):
self, self,
num_features: int, num_features: int,
subsampling_factor: int = 4, subsampling_factor: int = 4,
d_model: int = 256, conformer_subsampling_factor: int = 4,
nhead: int = 4, d_model: Tuple[int] = (256, 384, 512),
dim_feedforward: int = 2048, nhead: Tuple[int] = (8, 8),
num_encoder_layers: int = 12, feedforward_dim: Tuple[int] = (1536, 2048),
num_encoder_layers: Tuple[int] = (12, 12),
dropout: float = 0.1, dropout: float = 0.1,
layer_dropout: float = 0.25, cnn_module_kernel: Tuple[int] = (31, 31),
cnn_module_kernel: int = 31,
) -> None: ) -> None:
super(Conformer, self).__init__() super(Conformer, self).__init__()
@ -75,22 +75,42 @@ class Conformer(EncoderInterface):
# That is, it does two things simultaneously: # That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor # (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model # (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model) self.encoder_embed = Conv2dSubsampling(num_features, d_model[0],
dropout=dropout)
self.encoder_pos = RelPositionalEncoding(d_model, dropout) encoder_layer1 = ConformerEncoderLayer(
d_model[0],
encoder_layer = ConformerEncoderLayer( nhead[0],
d_model, feedforward_dim[0],
nhead,
dim_feedforward,
dropout, dropout,
cnn_module_kernel, cnn_module_kernel[0],
) )
self.encoder = ConformerEncoder( self.encoder1 = ConformerEncoder(
encoder_layer, encoder_layer1,
num_encoder_layers, num_encoder_layers[0],
layer_dropout=layer_dropout, dropout,
) )
encoder_layer2 = ConformerEncoderLayer(
d_model[1],
nhead[1],
feedforward_dim[1],
dropout,
cnn_module_kernel[1],
)
self.encoder2 = DownsampledConformerEncoder(
ConformerEncoder(
encoder_layer2,
num_encoder_layers[1],
dropout,
),
input_dim=d_model[0],
output_dim=d_model[1],
downsample=conformer_subsampling_factor,
)
self.out_combiner = SimpleCombiner(d_model[0],
d_model[1])
def forward( def forward(
@ -110,7 +130,7 @@ class Conformer(EncoderInterface):
of frames in `embeddings` before padding. of frames in `embeddings` before padding.
""" """
x = self.encoder_embed(x) x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
with warnings.catch_warnings(): with warnings.catch_warnings():
@ -120,9 +140,17 @@ class Conformer(EncoderInterface):
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths) mask = make_pad_mask(lengths)
x = self.encoder( # x1:
x, pos_emb, src_key_padding_mask=mask, x1 = self.encoder1(
) # (T, N, C) x, src_key_padding_mask=mask,
) # (T, N, C) where C == d_model[0]
x2 = self.encoder2(
x1, src_key_padding_mask=mask,
) # (T, N, C) where C == d_model[1]
x = self.out_combiner(x1, x2)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -137,7 +165,7 @@ class ConformerEncoderLayer(nn.Module):
Args: Args:
d_model: the number of expected features in the input (required). d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required). nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048). feedforward_dim: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1). dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module. cnn_module_kernel (int): Kernel size of convolution module.
@ -151,7 +179,7 @@ class ConformerEncoderLayer(nn.Module):
self, self,
d_model: int, d_model: int,
nhead: int, nhead: int,
dim_feedforward: int = 2048, feedforward_dim: int = 2048,
dropout: float = 0.1, dropout: float = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
) -> None: ) -> None:
@ -164,22 +192,22 @@ class ConformerEncoderLayer(nn.Module):
) )
self.feed_forward = nn.Sequential( self.feed_forward = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, feedforward_dim),
ActivationBalancer(dim_feedforward, ActivationBalancer(feedforward_dim,
channel_dim=-1, max_abs=10.0), channel_dim=-1, max_abs=10.0),
DoubleSwish(), DoubleSwish(),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, ScaledLinear(feedforward_dim, d_model,
initial_scale=0.01), initial_scale=0.01),
) )
self.feed_forward_macaron = nn.Sequential( self.feed_forward_macaron = nn.Sequential(
nn.Linear(d_model, dim_feedforward), nn.Linear(d_model, feedforward_dim),
ActivationBalancer(dim_feedforward, ActivationBalancer(feedforward_dim,
channel_dim=-1, max_abs=10.0), channel_dim=-1, max_abs=10.0),
DoubleSwish(), DoubleSwish(),
nn.Dropout(dropout), nn.Dropout(dropout),
ScaledLinear(dim_feedforward, d_model, ScaledLinear(feedforward_dim, d_model,
initial_scale=0.01), initial_scale=0.01),
) )
@ -261,22 +289,23 @@ class ConformerEncoder(nn.Module):
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512) >>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512) >>> out = conformer_encoder(src)
>>> out = conformer_encoder(src, pos_emb)
Returns: (combined_output, output),
where `combined_output` has gone through the RandomCombiner module and `output` is just the
original output, in case you need to bypass the RandomCombiner module.
""" """
def __init__( def __init__(
self, self,
encoder_layer: nn.Module, encoder_layer: nn.Module,
num_layers: int, num_layers: int,
layer_dropout: float = 0.25 dropout: float
) -> None: ) -> None:
super().__init__() super().__init__()
assert 0 < layer_dropout < 0.5
# `count` tracks how many times the forward function has been called self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model,
# since we initialized the model (it is not written to disk or read when dropout)
# we resume training). It is used for random seeding for layer dropping.
self.count = 0
self.layer_dropout = layer_dropout
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)] [copy.deepcopy(encoder_layer) for i in range(num_layers)]
@ -287,19 +316,16 @@ class ConformerEncoder(nn.Module):
num_channels = encoder_layer.norm_final.num_channels num_channels = encoder_layer.norm_final.num_channels
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
pos_emb: Tensor,
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tuple[Tensor, Tensor]:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
Args: Args:
src: the sequence to the encoder (required). src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional). mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional). src_key_padding_mask: the mask for the src keys per batch (optional).
@ -310,7 +336,9 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask: (N, S). src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
Returns: (x, x_no_combine), both of shape (S, N, E)
""" """
pos_emb = self.encoder_pos(src)
output = src output = src
outputs = [] outputs = []
@ -337,7 +365,6 @@ class ConformerEncoder(nn.Module):
frame_mask = torch.logical_or(frame_mask, frame_mask = torch.logical_or(frame_mask,
torch.rand_like(src[:,:1,:1]) < 0.1) torch.rand_like(src[:,:1,:1]) < 0.1)
feature_mask[..., feature_unmasked_dim:] *= frame_mask feature_mask[..., feature_unmasked_dim:] *= frame_mask
@ -364,11 +391,190 @@ class ConformerEncoder(nn.Module):
src_mask=mask, src_mask=mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
) )
output = output * feature_mask output = output * feature_mask
return output return output
class DownsampledConformerEncoder(nn.Module):
r"""
DownsampledConformerEncoder is a conformer encoder evaluated at a reduced frame rate,
after convolutional downsampling, and then upsampled again at the output
so that the output has the same shape as the input.
"""
def __init__(self,
encoder: nn.Module,
input_dim: int,
output_dim: int,
downsample: int):
super(DownsampledConformerEncoder, self).__init__()
self.downsample_factor = downsample
self.downsample = AttentionDownsample(input_dim, output_dim, downsample)
self.encoder = encoder
self.upsample = SimpleUpsample(output_dim, downsample)
def forward(self,
src: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
r"""Downsample, go through encoder, upsample.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional). CAUTION: we need to downsample
this, if we are to support it. Won't work correctly yet.
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
Returns: output of shape (S, N, F) where F is the number of output features
(output_dim to constructor)
"""
src_orig = src
src = self.downsample(src)
ds = self.downsample_factor
if mask is not None:
mask = mask[::ds,::ds]
if src_key_padding_mask is not None:
src_key_padding_mask = src_key_padding_mask[::ds]
src = self.encoder(
src, src_key_padding_mask=mask,
)
src = self.upsample(src)
# remove any extra frames that are not a multiple of downsample_factor
src = src[:src_orig.shape[0]]
return src
class AttentionDownsample(torch.nn.Module):
"""
Does downsampling with attention, by weighted sum, and a projection..
"""
def __init__(self,
in_channels: int,
out_channels: int,
downsample: int):
"""
Require out_channels > in_channels.
"""
super(AttentionDownsample, self).__init__()
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
# fill in the extra dimensions with a projection of the input
if out_channels > in_channels:
self.extra_proj = nn.Linear(in_channels * downsample,
out_channels - in_channels,
bias=False)
else:
self.extra_proj = None
self.downsample = downsample
def forward(self,
src: Tensor) -> Tensor:
"""
x: (seq_len, batch_size, in_channels)
Returns a tensor of shape
( (seq_len+downsample-1)//downsample, batch_size, out_channels)
"""
(seq_len, batch_size, in_channels) = src.shape
ds = self.downsample
d_seq_len = (seq_len + ds - 1) // ds
src_orig = src
# Pad to an exact multiple of self.downsample
if seq_len != d_seq_len * ds:
# right-pad src, repeating the last element.
pad = d_seq_len * ds - seq_len
src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2])
src = torch.cat((src, src_extra), dim=0)
assert src.shape[0] == d_seq_len * ds
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
scores = (src * self.query).sum(dim=-1, keepdim=True)
weights = scores.softmax(dim=1)
# ans1 is the first `in_channels` channels of the output
ans = (src * weights).sum(dim=1)
src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels)
if self.extra_proj is not None:
ans2 = self.extra_proj(src)
ans = torch.cat((ans, ans2), dim=2)
return ans
class SimpleUpsample(torch.nn.Module):
"""
A very simple form of upsampling that mostly just repeats the input, but
also adds a position-specific bias.
"""
def __init__(self,
num_channels: int,
upsample: int):
super(SimpleUpsample, self).__init__()
self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01)
def forward(self,
src: Tensor) -> Tensor:
"""
x: (seq_len, batch_size, num_channels)
Returns a tensor of shape
( (seq_len*upsample), batch_size, num_channels)
"""
upsample = self.bias.shape[0]
(seq_len, batch_size, num_channels) = src.shape
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
src = src + self.bias.unsqueeze(1)
src = src.reshape(seq_len * upsample, batch_size, num_channels)
return src
class SimpleCombiner(torch.nn.Module):
"""
A very simple way of combining 2 vectors of 2 different dims, via a
learned weighted combination in the shared part of the dim.
"""
def __init__(self,
dim1: int,
dim2: int):
super(SimpleCombiner, self).__init__()
assert dim2 >= dim1
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
def forward(self,
src1: Tensor,
src2: Tensor) -> Tensor:
"""
src1: (*, dim1)
src2: (*, dim2)
Returns: a tensor of shape (*, dim2)
"""
assert src1.shape[:-1] == src2.shape[:-1]
dim1 = src1.shape[-1]
dim2 = src2.shape[-1]
weight1 = (src1 * self.to_weight1).sum(dim=-1, keepdim=True)
weight2 = (src2 * self.to_weight2).sum(dim=-1, keepdim=True)
weight = (weight1 + weight2).sigmoid()
src2_part1 = src2[...,:dim1]
part1 = src1 * weight + src2_part1 * (1.0 - weight)
part2 = src2[...,dim1:]
return torch.cat((part1, part2), dim=-1)
class RelPositionalEncoding(torch.nn.Module): class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module. """Relative positional encoding module.
@ -385,7 +591,7 @@ class RelPositionalEncoding(torch.nn.Module):
def __init__( def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000 self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None: ) -> None:
"""Construct an PositionalEncoding object.""" """Construct a PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__() super(RelPositionalEncoding, self).__init__()
self.d_model = d_model self.d_model = d_model
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout = torch.nn.Dropout(dropout_rate)
@ -397,7 +603,7 @@ class RelPositionalEncoding(torch.nn.Module):
if self.pe is not None: if self.pe is not None:
# self.pe contains both positive and negative parts # self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1 # the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1: if self.pe.size(1) >= x.size(0) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device # Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str( if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device x.device
@ -407,9 +613,9 @@ class RelPositionalEncoding(torch.nn.Module):
# Suppose `i` means to the position of query vecotr and `j` means the # Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys # position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j). # are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model) pe_positive = torch.zeros(x.size(0), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model) pe_negative = torch.zeros(x.size(0), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) position = torch.arange(0, x.size(0), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp( div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32) torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model) * -(math.log(10000.0) / self.d_model)
@ -431,7 +637,7 @@ class RelPositionalEncoding(torch.nn.Module):
"""Add positional encoding. """Add positional encoding.
Args: Args:
x (torch.Tensor): Input tensor (batch, time, `*`). x (torch.Tensor): Input tensor (time, batch, `*`).
Returns: Returns:
torch.Tensor: Encoded tensor (batch, time, `*`). torch.Tensor: Encoded tensor (batch, time, `*`).
@ -442,11 +648,11 @@ class RelPositionalEncoding(torch.nn.Module):
pos_emb = self.pe[ pos_emb = self.pe[
:, :,
self.pe.size(1) // 2 self.pe.size(1) // 2
- x.size(1) - x.size(0)
+ 1 : self.pe.size(1) // 2 # noqa E203 + 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1), + x.size(0),
] ]
return self.dropout(x), self.dropout(pos_emb) return self.dropout(pos_emb)
class RelPositionMultiheadAttention(nn.Module): class RelPositionMultiheadAttention(nn.Module):
@ -478,7 +684,7 @@ class RelPositionMultiheadAttention(nn.Module):
self.head_dim = embed_dim // (num_heads * 2) self.head_dim = embed_dim // (num_heads * 2)
assert ( assert (
self.head_dim * num_heads == self.embed_dim // 2 self.head_dim * num_heads == self.embed_dim // 2
), "embed_dim must be divisible by num_heads" ), "embed_dim//2 must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True) self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
self.in_balancer = ActivationBalancer(3 * embed_dim // 2, self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
@ -965,6 +1171,7 @@ class Conv2dSubsampling(nn.Module):
layer1_channels: int = 8, layer1_channels: int = 8,
layer2_channels: int = 32, layer2_channels: int = 32,
layer3_channels: int = 128, layer3_channels: int = 128,
dropout: float = 0.1,
) -> None: ) -> None:
""" """
Args: Args:
@ -1012,6 +1219,7 @@ class Conv2dSubsampling(nn.Module):
) )
out_height = (((in_channels - 1) // 2 - 1) // 2) out_height = (((in_channels - 1) // 2 - 1) // 2)
self.out = ScaledLinear(out_height * layer3_channels, out_channels) self.out = ScaledLinear(out_height * layer3_channels, out_channels)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -1031,6 +1239,7 @@ class Conv2dSubsampling(nn.Module):
b, c, t, f = x.size() b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim) # Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.dropout(x)
return x return x
class AttentionCombine(nn.Module): class AttentionCombine(nn.Module):
@ -1166,14 +1375,13 @@ def _test_random_combine():
def _test_conformer_main(): def _test_conformer_main():
feature_dim = 50 feature_dim = 50
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
batch_size = 5 batch_size = 5
seq_len = 20 seq_len = 20
feature_dim = 50 feature_dim = 50
# Just make sure the forward pass runs. # Just make sure the forward pass runs.
c = Conformer( c = Conformer(
num_features=feature_dim, d_model=128, nhead=4 num_features=feature_dim, d_model=(64,96,128), nhead=(4,4)
) )
batch_size = 5 batch_size = 5
seq_len = 20 seq_len = 20
@ -1191,8 +1399,6 @@ def _test_conformer_main():
f # to remove flake8 warnings f # to remove flake8 warnings
if __name__ == "__main__": if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1) torch.set_num_threads(1)

View File

@ -91,30 +91,38 @@ LRSchedulerType = Union[
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--num-encoder-layers", "--num-encoder-layers",
type=int, type=str,
default=24, default="12,12",
help="Number of conformer encoder layers..", help="Number of conformer encoder layers, comma separated.",
) )
parser.add_argument( parser.add_argument(
"--dim-feedforward", "--feedforward-dims",
type=int, type=str,
default=1536, default="1536,1536",
help="Feedforward dimension of the conformer encoder layer.", help="Feedforward dimension of the conformer encoder layers, comma separated.",
) )
parser.add_argument( parser.add_argument(
"--nhead", "--nhead",
type=int, type=str,
default=8, default="8,8",
help="Number of attention heads in the conformer encoder layer.", help="Number of attention heads in the conformer encoder layers.",
) )
parser.add_argument( parser.add_argument(
"--encoder-dim", "--encoder-dims",
type=str,
default="384,384",
help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, "
"and the output dim of the encoder",
)
parser.add_argument(
"--conformer-subsampling-factor",
type=int, type=int,
default=384, default=2,
help="Attention dimension in the conformer encoder layer.", help="Subsampling factor for 2nd stack of encoder layers.",
) )
parser.add_argument( parser.add_argument(
@ -401,13 +409,16 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer # TODO: We can add an option to switch between Conformer and Transformer
def to_int_list(s: str):
return list(map(int, s.split(',')))
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
d_model=params.encoder_dim, conformer_subsampling_factor=params.conformer_subsampling_factor,
nhead=params.nhead, d_model=to_int_list(params.encoder_dims),
dim_feedforward=params.dim_feedforward, nhead=to_int_list(params.nhead),
num_encoder_layers=params.num_encoder_layers, feedforward_dim=to_int_list(params.feedforward_dims),
num_encoder_layers=to_int_list(params.num_encoder_layers),
) )
return encoder return encoder
@ -424,7 +435,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
encoder_dim=params.encoder_dim, encoder_dim=int(params.encoder_dims.split(',')[-1]),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
@ -441,7 +452,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
encoder_dim=params.encoder_dim, encoder_dim=int(params.encoder_dims.split(',')[-1]),
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,