mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Closer to working..
This commit is contained in:
parent
e5a0d8929b
commit
d34eafa623
@ -39,7 +39,7 @@ class Conformer(EncoderInterface):
|
||||
Args:
|
||||
num_features (int): Number of input features
|
||||
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||
d_model (int): attention dimension, also the output dimension
|
||||
d_model (int): (attention_dim1, attention_dim2, output_dim)
|
||||
nhead (int): number of head
|
||||
dim_feedforward (int): feedforward dimention
|
||||
num_encoder_layers (int): number of encoder layers
|
||||
@ -53,13 +53,14 @@ class Conformer(EncoderInterface):
|
||||
self,
|
||||
num_features: int,
|
||||
subsampling_factor: int = 4,
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
dim_feedforward: int = 2048,
|
||||
num_encoder_layers: int = 12,
|
||||
conformer_subsampling_factor: int = 4,
|
||||
d_model: Tuple[int] = (256, 384, 512),
|
||||
nhead: Tuple[int] = (8, 8),
|
||||
dim_feedforward: Tuple[int] = (1536, 2048),
|
||||
num_encoder_layers: Tuple[int] = (12, 12),
|
||||
dropout: float = 0.1,
|
||||
layer_dropout: float = 0.075,
|
||||
cnn_module_kernel: int = 31,
|
||||
cnn_module_kernel: Tuple[int] = (31, 31),
|
||||
aux_layer_period: int = 3,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__()
|
||||
@ -74,23 +75,47 @@ class Conformer(EncoderInterface):
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (2) embedding: num_features -> d_model
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model[0],
|
||||
dropout=dropout)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
encoder_layer = ConformerEncoderLayer(
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward,
|
||||
encoder_layer1 = ConformerEncoderLayer(
|
||||
d_model[0],
|
||||
nhead[0],
|
||||
dim_feedforward[0],
|
||||
dropout,
|
||||
layer_dropout,
|
||||
cnn_module_kernel,
|
||||
cnn_module_kernel[0],
|
||||
)
|
||||
self.encoder = ConformerEncoder(
|
||||
encoder_layer,
|
||||
num_encoder_layers,
|
||||
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
|
||||
self.encoder1 = ConformerEncoder(
|
||||
encoder_layer1,
|
||||
num_encoder_layers[0],
|
||||
aux_layers=list(range(0, num_encoder_layers[0] - 1, aux_layer_period)),
|
||||
dropout=dropout
|
||||
)
|
||||
encoder_layer2 = ConformerEncoderLayer(
|
||||
d_model[1],
|
||||
nhead[1],
|
||||
dim_feedforward[1],
|
||||
dropout,
|
||||
layer_dropout,
|
||||
cnn_module_kernel[1],
|
||||
)
|
||||
self.encoder2 = DownsampledConformerEncoder(
|
||||
ConformerEncoder(
|
||||
encoder_layer2,
|
||||
num_encoder_layers[1],
|
||||
aux_layers=list(range(0, num_encoder_layers[1] - 1, aux_layer_period)),
|
||||
dropout=dropout
|
||||
),
|
||||
input_dim=d_model[0],
|
||||
module_dim=d_model[1],
|
||||
output_dim=d_model[1],
|
||||
downsample=conformer_subsampling_factor,
|
||||
)
|
||||
|
||||
self.out_proj = ScaledLinear(
|
||||
d_model[0] + d_model[1], d_model[2],
|
||||
bias=False)
|
||||
|
||||
|
||||
def forward(
|
||||
@ -114,7 +139,7 @@ class Conformer(EncoderInterface):
|
||||
of frames in `embeddings` before padding.
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
@ -124,12 +149,21 @@ class Conformer(EncoderInterface):
|
||||
assert x.size(0) == lengths.max().item()
|
||||
mask = make_pad_mask(lengths)
|
||||
|
||||
x = self.encoder(
|
||||
x, pos_emb, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C)
|
||||
# x1:
|
||||
x1, x_no_combine = self.encoder1(
|
||||
x, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C) where C == d_model[0]
|
||||
|
||||
x2 = self.encoder1(
|
||||
x1, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C) where C == d_model[1]
|
||||
|
||||
x = torch.cat((x1, x2), dim=2)
|
||||
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
return x, lengths
|
||||
|
||||
|
||||
@ -288,8 +322,12 @@ class ConformerEncoder(nn.Module):
|
||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> pos_emb = torch.rand(32, 19, 512)
|
||||
>>> out = conformer_encoder(src, pos_emb)
|
||||
>>> out = conformer_encoder(src)
|
||||
|
||||
|
||||
Returns: (combined_output, output),
|
||||
where `combined_output` has gone through the RandomCombiner module and `output` is just the
|
||||
original output, in case you need to bypass the RandomCombiner module.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -297,8 +335,13 @@ class ConformerEncoder(nn.Module):
|
||||
encoder_layer: nn.Module,
|
||||
num_layers: int,
|
||||
aux_layers: List[int],
|
||||
dropout: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model,
|
||||
dropout)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
)
|
||||
@ -318,16 +361,14 @@ class ConformerEncoder(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tensor:
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
@ -338,7 +379,9 @@ class ConformerEncoder(nn.Module):
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
|
||||
Returns: (x, x_no_combine), both of shape (S, N, E)
|
||||
"""
|
||||
pos_emb = self.encoder_pos(src)
|
||||
output = src
|
||||
|
||||
outputs = []
|
||||
@ -356,11 +399,103 @@ class ConformerEncoder(nn.Module):
|
||||
if i in self.aux_layers:
|
||||
outputs.append(output)
|
||||
|
||||
output = self.combiner(outputs)
|
||||
combined_output = self.combiner(outputs)
|
||||
|
||||
output = output + 0.0 * attn_scores.sum() # just ensure attn_scores is used in backprop
|
||||
combined_output = combined_output + 0.0 * attn_scores.sum() # just ensure attn_scores is used in backprop
|
||||
|
||||
return output
|
||||
return combined_output, output
|
||||
|
||||
|
||||
class DownsampledConformerEncoder(nn.Module):
|
||||
r"""
|
||||
DownsampledConformerEncoder is a conformer encoder evaluated at a reduced frame rate,
|
||||
after convolutional downsampling, and then upsampled again at the output
|
||||
so that the output has the same shape as the input.
|
||||
"""
|
||||
def __init__(self,
|
||||
encoder: nn.Module,
|
||||
input_dim: int,
|
||||
module_dim: int,
|
||||
output_dim: int,
|
||||
downsample: int):
|
||||
super(DownsampledConformerEncoder, self).__init__()
|
||||
|
||||
self.downsample = downsample
|
||||
|
||||
# note: we'll pad manually.
|
||||
self.downsample = nn.Conv1d(
|
||||
input_dim,
|
||||
module_dim,
|
||||
kernel_size=downsample,
|
||||
stride=downsample,
|
||||
padding=0)
|
||||
|
||||
self.encoder = encoder
|
||||
|
||||
self.upsample = nn.ConvTranspose1d(
|
||||
module_dim,
|
||||
output_dim,
|
||||
kernel_size=downsample,
|
||||
stride=downsample,
|
||||
padding=0)
|
||||
|
||||
def forward(self,
|
||||
src: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
warmup: float = 1.0,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
r"""Downsample, go through encoder, upsample.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
mask: the mask for the src sequence (optional). CAUTION: we need to downsample
|
||||
this, if we are to support it. Won't work correctly yet.
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
|
||||
Returns: (x, x_no_combine), both of shape (S, N, E)
|
||||
"""
|
||||
(seq_len, batch_size, embedding_dim) = x.shape
|
||||
ds = self.downsample
|
||||
d_seq_len = (seq_len + ds - 1) // ds
|
||||
x_orig = x
|
||||
if seq_len != d_seq_len * ds:
|
||||
# right-pad x
|
||||
pad = seq_len - d_seq_len * ds
|
||||
x = torch.nn.functional.pad(x,
|
||||
(0, pad, 0, 0, 0, 0),
|
||||
mode='replicate')
|
||||
|
||||
if mask is not None:
|
||||
mask = mask[::ds,::ds]
|
||||
if src_key_padding_mask is not None:
|
||||
src_key_padding_mask = src_key_padding_mask[::ds]
|
||||
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
x = self.downsample(x)
|
||||
x = x.permute(2, 0, 1) # (time, batch, channels)
|
||||
|
||||
|
||||
x, _x_no_combine = self.encoder(
|
||||
x, src_key_padding_mask=mask, warmup=warmup
|
||||
) # (T, N, C)
|
||||
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
x = self.upsample(x)
|
||||
x = x.permute(2, 0, 1) # (time, batch, channels)
|
||||
|
||||
new_seq_len = x.shape[0]
|
||||
assert new_seq_len >= seq_len
|
||||
if new_seq_len > seq_len:
|
||||
x = x[:seq_len]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
@ -379,7 +514,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
def __init__(
|
||||
self, d_model: int, dropout_rate: float, max_len: int = 5000
|
||||
) -> None:
|
||||
"""Construct an PositionalEncoding object."""
|
||||
"""Construct a PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
@ -391,7 +526,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
if self.pe.size(1) >= x.size(0) * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
@ -401,9 +536,9 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||
# position of key vector. We use position relative positions when keys
|
||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
pe_positive = torch.zeros(x.size(0), self.d_model)
|
||||
pe_negative = torch.zeros(x.size(0), self.d_model)
|
||||
position = torch.arange(0, x.size(0), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
@ -425,7 +560,7 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
x (torch.Tensor): Input tensor (time, batch, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
@ -436,11 +571,11 @@ class RelPositionalEncoding(torch.nn.Module):
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- x.size(1)
|
||||
- x.size(0)
|
||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||
+ x.size(1),
|
||||
+ x.size(0),
|
||||
]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
return self.dropout(pos_emb)
|
||||
|
||||
|
||||
class RelPositionMultiheadAttention(nn.Module):
|
||||
@ -472,7 +607,7 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
self.head_dim = embed_dim // (num_heads * 2)
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim // 2
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
), "embed_dim//2 must be divisible by num_heads"
|
||||
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim // 2, bias=True)
|
||||
self.in_balancer = ActivationBalancer(3 * embed_dim // 2,
|
||||
@ -951,6 +1086,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
layer1_channels: int = 8,
|
||||
layer2_channels: int = 32,
|
||||
layer3_channels: int = 128,
|
||||
dropout: float = 0.1,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -998,6 +1134,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
)
|
||||
out_height = (((in_channels - 1) // 2 - 1) // 2)
|
||||
self.out = ScaledLinear(out_height * layer3_channels, out_channels)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -1017,6 +1154,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape (N, ((T-1)//2 - 1))//2, odim)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
class RandomCombine(nn.Module):
|
||||
@ -1251,14 +1389,13 @@ def _test_random_combine_main():
|
||||
|
||||
def _test_conformer_main():
|
||||
feature_dim = 50
|
||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
feature_dim = 50
|
||||
# Just make sure the forward pass runs.
|
||||
|
||||
c = Conformer(
|
||||
num_features=feature_dim, d_model=128, nhead=4
|
||||
num_features=feature_dim, d_model=(64,96,128), nhead=(4,4)
|
||||
)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
@ -1271,8 +1408,6 @@ def _test_conformer_main():
|
||||
f # to remove flake8 warnings
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
torch.set_num_threads(1)
|
||||
|
||||
@ -91,30 +91,38 @@ LRSchedulerType = Union[
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
type=int,
|
||||
default=24,
|
||||
help="Number of conformer encoder layers..",
|
||||
type=str,
|
||||
default="12,12",
|
||||
help="Number of conformer encoder layers, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dim-feedforward",
|
||||
type=int,
|
||||
default=1536,
|
||||
help="Feedforward dimension of the conformer encoder layer.",
|
||||
"--feedforward-dims",
|
||||
type=str,
|
||||
default="1536,1536",
|
||||
help="Feedforward dimension of the conformer encoder layers, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nhead",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of attention heads in the conformer encoder layer.",
|
||||
type=str,
|
||||
default="8,8",
|
||||
help="Number of attention heads in the conformer encoder layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dim",
|
||||
"--encoder-dims",
|
||||
type=str,
|
||||
default="320,512,512",
|
||||
help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, "
|
||||
"and the output dim of the encoder",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--conformer-subsampling-factor",
|
||||
type=int,
|
||||
default=384,
|
||||
help="Attention dimension in the conformer encoder layer.",
|
||||
default=4,
|
||||
help="Subsampling factor for 2nd stack of encoder layers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -401,13 +409,16 @@ def get_params() -> AttributeDict:
|
||||
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
def to_int_list(s: str):
|
||||
return list(map(int, s.split(',')))
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
d_model=params.encoder_dim,
|
||||
nhead=params.nhead,
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
subsampling_factor=params.subsamplng_factor,
|
||||
conformer_subsampling_factor=params.conformer_subsamplng_factor,
|
||||
d_model=to_int_list(params.encoder_dims),
|
||||
nhead=to_int_list(params.nhead),
|
||||
feedforward_dims=to_int_list(params.feedforward_dims),
|
||||
num_encoder_layers=to_int_list(params.num_encoder_layers),
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -424,7 +435,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
encoder_dim=params.encoder_dim,
|
||||
encoder_dim=int(params.encoder_dims.split(',')[-1]),
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
@ -441,7 +452,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
encoder_dim=params.encoder_dim,
|
||||
encoder_dim=int(params.encoder_dims.split(',')[-1]),
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user