Closer to working..

This commit is contained in:
Daniel Povey 2022-09-27 15:47:58 +08:00
parent e5a0d8929b
commit d34eafa623
2 changed files with 210 additions and 64 deletions

View File

@ -39,7 +39,7 @@ class Conformer(EncoderInterface):
Args:
num_features (int): Number of input features
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension, also the output dimension
d_model (int): (attention_dim1, attention_dim2, output_dim)
nhead (int): number of head
dim_feedforward (int): feedforward dimention
num_encoder_layers (int): number of encoder layers
@ -53,13 +53,14 @@ class Conformer(EncoderInterface):
self,
num_features: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
conformer_subsampling_factor: int = 4,
d_model: Tuple[int] = (256, 384, 512),
nhead: Tuple[int] = (8, 8),
dim_feedforward: Tuple[int] = (1536, 2048),
num_encoder_layers: Tuple[int] = (12, 12),
dropout: float = 0.1,
layer_dropout: float = 0.075,
cnn_module_kernel: int = 31,
cnn_module_kernel: Tuple[int] = (31, 31),
aux_layer_period: int = 3,
) -> None:
super(Conformer, self).__init__()
@ -74,23 +75,47 @@ class Conformer(EncoderInterface):
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_features -> d_model
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_embed = Conv2dSubsampling(num_features, d_model[0],
dropout=dropout)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
encoder_layer = ConformerEncoderLayer(
d_model,
nhead,
dim_feedforward,
encoder_layer1 = ConformerEncoderLayer(
d_model[0],
nhead[0],
dim_feedforward[0],
dropout,
layer_dropout,
cnn_module_kernel,
cnn_module_kernel[0],
)
self.encoder = ConformerEncoder(
encoder_layer,
num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
self.encoder1 = ConformerEncoder(
encoder_layer1,
num_encoder_layers[0],
aux_layers=list(range(0, num_encoder_layers[0] - 1, aux_layer_period)),
dropout=dropout
)
encoder_layer2 = ConformerEncoderLayer(
d_model[1],
nhead[1],
dim_feedforward[1],
dropout,
layer_dropout,
cnn_module_kernel[1],
)
self.encoder2 = DownsampledConformerEncoder(
ConformerEncoder(
encoder_layer2,
num_encoder_layers[1],
aux_layers=list(range(0, num_encoder_layers[1] - 1, aux_layer_period)),
dropout=dropout
),
input_dim=d_model[0],
module_dim=d_model[1],
output_dim=d_model[1],
downsample=conformer_subsampling_factor,
)
self.out_proj = ScaledLinear(
d_model[0] + d_model[1], d_model[2],
bias=False)
def forward(
@ -114,7 +139,7 @@ class Conformer(EncoderInterface):
of frames in `embeddings` before padding.
"""
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
with warnings.catch_warnings():
@ -124,12 +149,21 @@ class Conformer(EncoderInterface):
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
x = self.encoder(
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
# x1:
x1, x_no_combine = self.encoder1(
x, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C) where C == d_model[0]
x2 = self.encoder1(
x1, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C) where C == d_model[1]
x = torch.cat((x1, x2), dim=2)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
x = self.out_proj(x)
return x, lengths
@ -288,8 +322,12 @@ class ConformerEncoder(nn.Module):
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512)
>>> out = conformer_encoder(src, pos_emb)
>>> out = conformer_encoder(src)
Returns: (combined_output, output),
where `combined_output` has gone through the RandomCombiner module and `output` is just the
original output, in case you need to bypass the RandomCombiner module.
"""
def __init__(
@ -297,8 +335,13 @@ class ConformerEncoder(nn.Module):
encoder_layer: nn.Module,
num_layers: int,
aux_layers: List[int],
dropout: float,
) -> None:
super().__init__()
self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model,
dropout)
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
@ -318,16 +361,14 @@ class ConformerEncoder(nn.Module):
def forward(
self,
src: Tensor,
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
) -> Tensor:
) -> Tuple[Tensor, Tensor]:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
pos_emb: Positional embedding tensor (required).
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
@ -338,7 +379,9 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
Returns: (x, x_no_combine), both of shape (S, N, E)
"""
pos_emb = self.encoder_pos(src)
output = src
outputs = []
@ -356,11 +399,103 @@ class ConformerEncoder(nn.Module):
if i in self.aux_layers:
outputs.append(output)
output = self.combiner(outputs)
combined_output = self.combiner(outputs)
output = output + 0.0 * attn_scores.sum() # just ensure attn_scores is used in backprop
combined_output = combined_output + 0.0 * attn_scores.sum() # just ensure attn_scores is used in backprop
return output
return combined_output, output
class DownsampledConformerEncoder(nn.Module):
r"""
DownsampledConformerEncoder is a conformer encoder evaluated at a reduced frame rate,
after convolutional downsampling, and then upsampled again at the output
so that the output has the same shape as the input.
"""
def __init__(self,
encoder: nn.Module,
input_dim: int,
module_dim: int,
output_dim: int,
downsample: int):
super(DownsampledConformerEncoder, self).__init__()
self.downsample = downsample
# note: we'll pad manually.
self.downsample = nn.Conv1d(
input_dim,
module_dim,
kernel_size=downsample,
stride=downsample,
padding=0)
self.encoder = encoder
self.upsample = nn.ConvTranspose1d(
module_dim,
output_dim,
kernel_size=downsample,
stride=downsample,
padding=0)
def forward(self,
src: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
) -> Tuple[Tensor, Tensor]:
r"""Downsample, go through encoder, upsample.
Args:
src: the sequence to the encoder (required).
mask: the mask for the src sequence (optional). CAUTION: we need to downsample
this, if we are to support it. Won't work correctly yet.
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
Returns: (x, x_no_combine), both of shape (S, N, E)
"""
(seq_len, batch_size, embedding_dim) = x.shape
ds = self.downsample
d_seq_len = (seq_len + ds - 1) // ds
x_orig = x
if seq_len != d_seq_len * ds:
# right-pad x
pad = seq_len - d_seq_len * ds
x = torch.nn.functional.pad(x,
(0, pad, 0, 0, 0, 0),
mode='replicate')
if mask is not None:
mask = mask[::ds,::ds]
if src_key_padding_mask is not None:
src_key_padding_mask = src_key_padding_mask[::ds]
x = x.permute(1, 2, 0) # (#batch, channels, time).
x = self.downsample(x)
x = x.permute(2, 0, 1) # (time, batch, channels)
x, _x_no_combine = self.encoder(
x, src_key_padding_mask=mask, warmup=warmup
) # (T, N, C)
x = x.permute(1, 2, 0) # (#batch, channels, time).
x = self.upsample(x)
x = x.permute(2, 0, 1) # (time, batch, channels)
new_seq_len = x.shape[0]
assert new_seq_len >= seq_len
if new_seq_len > seq_len:
x = x[:seq_len]
return x
class RelPositionalEncoding(torch.nn.Module):
@ -379,7 +514,7 @@ class RelPositionalEncoding(torch.nn.Module):
def __init__(
self, d_model: int, dropout_rate: float, max_len: int = 5000
) -> None:
"""Construct an PositionalEncoding object."""
"""Construct a PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
self.d_model = d_model
self.dropout = torch.nn.Dropout(dropout_rate)
@ -391,7 +526,7 @@ class RelPositionalEncoding(torch.nn.Module):
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(1) * 2 - 1:
if self.pe.size(1) >= x.size(0) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
x.device
@ -401,9 +536,9 @@ class RelPositionalEncoding(torch.nn.Module):
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
pe_positive = torch.zeros(x.size(0), self.d_model)
pe_negative = torch.zeros(x.size(0), self.d_model)
position = torch.arange(0, x.size(0), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
@ -425,7 +560,7 @@ class RelPositionalEncoding(torch.nn.Module):
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (batch, time, `*`).
x (torch.Tensor): Input tensor (time, batch, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
@ -436,11 +571,11 @@ class RelPositionalEncoding(torch.nn.Module):
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(1)
- x.size(0)
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(1),
+ x.size(0),
]
return self.dropout(x), self.dropout(pos_emb)
return self.dropout(pos_emb)
class RelPositionMultiheadAttention(nn.Module):
@ -472,7 +607,7 @@ class RelPositionMultiheadAttention(nn.Module):
self.head_dim = embed_dim // (num_heads * 2)
assert (
self.head_dim * num_heads == self.embed_dim // 2
), "embed_dim must be divisible by num_heads"
), "embed_dim//2 must be divisible by num_heads"
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
@ -951,6 +1086,7 @@ class Conv2dSubsampling(nn.Module):
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
dropout: float = 0.1,
) -> None:
"""
Args:
@ -998,6 +1134,7 @@ class Conv2dSubsampling(nn.Module):
)
out_height = (((in_channels - 1) // 2 - 1) // 2)
self.out = ScaledLinear(out_height * layer3_channels, out_channels)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -1017,6 +1154,7 @@ class Conv2dSubsampling(nn.Module):
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
x = self.dropout(x)
return x
class RandomCombine(nn.Module):
@ -1251,14 +1389,13 @@ def _test_random_combine_main():
def _test_conformer_main():
feature_dim = 50
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
batch_size = 5
seq_len = 20
feature_dim = 50
# Just make sure the forward pass runs.
c = Conformer(
num_features=feature_dim, d_model=128, nhead=4
num_features=feature_dim, d_model=(64,96,128), nhead=(4,4)
)
batch_size = 5
seq_len = 20
@ -1271,8 +1408,6 @@ def _test_conformer_main():
f # to remove flake8 warnings
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
torch.set_num_threads(1)

View File

@ -91,30 +91,38 @@ LRSchedulerType = Union[
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-encoder-layers",
type=int,
default=24,
help="Number of conformer encoder layers..",
type=str,
default="12,12",
help="Number of conformer encoder layers, comma separated.",
)
parser.add_argument(
"--dim-feedforward",
type=int,
default=1536,
help="Feedforward dimension of the conformer encoder layer.",
"--feedforward-dims",
type=str,
default="1536,1536",
help="Feedforward dimension of the conformer encoder layers, comma separated.",
)
parser.add_argument(
"--nhead",
type=int,
default=8,
help="Number of attention heads in the conformer encoder layer.",
type=str,
default="8,8",
help="Number of attention heads in the conformer encoder layers.",
)
parser.add_argument(
"--encoder-dim",
"--encoder-dims",
type=str,
default="320,512,512",
help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, "
"and the output dim of the encoder",
)
parser.add_argument(
"--conformer-subsampling-factor",
type=int,
default=384,
help="Attention dimension in the conformer encoder layer.",
default=4,
help="Subsampling factor for 2nd stack of encoder layers.",
)
parser.add_argument(
@ -401,13 +409,16 @@ def get_params() -> AttributeDict:
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
def to_int_list(s: str):
return list(map(int, s.split(',')))
encoder = Conformer(
num_features=params.feature_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.encoder_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
subsampling_factor=params.subsamplng_factor,
conformer_subsampling_factor=params.conformer_subsamplng_factor,
d_model=to_int_list(params.encoder_dims),
nhead=to_int_list(params.nhead),
feedforward_dims=to_int_list(params.feedforward_dims),
num_encoder_layers=to_int_list(params.num_encoder_layers),
)
return encoder
@ -424,7 +435,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
encoder_dim=params.encoder_dim,
encoder_dim=int(params.encoder_dims.split(',')[-1]),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
@ -441,7 +452,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=params.encoder_dim,
encoder_dim=int(params.encoder_dims.split(',')[-1]),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,