First draft of model rework

This commit is contained in:
Daniel Povey 2022-04-02 20:03:21 +08:00
parent eec597fdd5
commit 8be10d3d6c
5 changed files with 49 additions and 50 deletions

View File

@ -32,8 +32,6 @@ class Conformer(EncoderInterface):
"""
Args:
num_features (int): Number of input features
output_dim (int): Model output dimension. If not equal to the encoder dimension,
we will project to the output.
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension, also the output dimension
nhead (int): number of head
@ -47,7 +45,6 @@ class Conformer(EncoderInterface):
def __init__(
self,
num_features: int,
output_dim: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
@ -83,11 +80,6 @@ class Conformer(EncoderInterface):
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
if output_dim == d_model:
self.encoder_output_layer = nn.Identity()
else:
self.encoder_output_layer = ScaledLinear(d_model, output_dim,
initial_speed=0.5)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
@ -123,7 +115,6 @@ class Conformer(EncoderInterface):
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
warmup=warmup) # (T, N, C)
x = self.encoder_output_layer(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths
@ -1116,7 +1107,7 @@ class Noam(object):
if __name__ == '__main__':
feature_dim = 50
c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs.

View File

@ -46,8 +46,8 @@ class Decoder(nn.Module):
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
embedding_dim:
Dimension of the input embedding.
decoder_dim:
Dimension of the input embedding, and of the decoder output.
blank_id:
The ID of the blank symbol.
context_size:
@ -63,23 +63,22 @@ class Decoder(nn.Module):
initial_speed = 0.5
self.embedding = ScaledEmbedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
embedding_dim=decoder_dim,
padding_idx=blank_id,
initial_speed=initial_speed
)
self.blank_id = blank_id
self.output_linear = ScaledLinear(embedding_dim, embedding_dim)
assert context_size >= 1, context_size
self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1:
self.conv = ScaledConv1d(
in_channels=embedding_dim,
out_channels=embedding_dim,
in_channels=decoder_dim,
out_channels=decoder_dim,
kernel_size=context_size,
padding=0,
groups=embedding_dim,
groups=decoder_dim,
bias=False,
)
@ -92,7 +91,7 @@ class Decoder(nn.Module):
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, embedding_dim).
Return a tensor of shape (N, U, decoder_dim).
"""
y = y.to(torch.int64)
embedding_out = self.embedding(y)
@ -108,5 +107,5 @@ class Decoder(nn.Module):
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = self.output_linear(F.relu(embedding_out))
embedding_out = F.relu(embedding_out)
return embedding_out

View File

@ -20,11 +20,19 @@ import torch.nn.functional as F
from scaling import ScaledLinear
class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
def __init__(self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int):
super().__init__()
self.output_linear = ScaledLinear(input_dim, output_dim,
initial_speed=0.5)
# We don't bother giving the 'initial_speed' arg to the decoder
# submodules, because it does not affect the initial convergence of the
# system (only the simple joiner is involved in that).
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim)
self.output_linear = ScaledLinear(joiner_dim, vocab_size)
def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
@ -41,7 +49,7 @@ class Joiner(nn.Module):
assert encoder_out.ndim == decoder_out.ndim == 4
assert encoder_out.shape == decoder_out.shape
logit = encoder_out + decoder_out
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
logit = self.output_linear(torch.tanh(logit))

View File

@ -34,23 +34,25 @@ class Transducer(nn.Module):
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
embedding_dim: int,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, C) and
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
is (N, U) and its output shape is (N, U, decoder_dim). It should contain
one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). Its
output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
@ -61,17 +63,10 @@ class Transducer(nn.Module):
self.decoder = decoder
self.joiner = joiner
self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size,
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size,
initial_speed=0.5)
self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size,
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size,
initial_speed=0.5)
with torch.no_grad():
# Initialize the two projections to be the same; this will be
# convenient for the real joiner, which adds the endcoder
# (acoustic-model/am) and decoder (language-model/lm) embeddings
self.simple_lm_proj.weight[:] = self.simple_am_proj.weight
self.simple_lm_proj.bias[:] = self.simple_am_proj.bias
def forward(
self,
@ -133,7 +128,7 @@ class Transducer(nn.Module):
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, C]
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
@ -167,13 +162,13 @@ class Transducer(nn.Module):
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, C]
# lm_pruned : [B, T, prune_range, C]
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=encoder_out, lm=decoder_out, ranges=ranges
)
# logits : [B, T, prune_range, C]
# logits : [B, T, prune_range, vocab_size]
logits = self.joiner(am_pruned, lm_pruned)
pruned_loss = k2.rnnt_loss_pruned(

View File

@ -268,7 +268,7 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model.
- attention_dim: Hidden dim for multi-head attention model.
- encoder_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
@ -287,12 +287,14 @@ def get_params() -> AttributeDict:
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"attention_dim": 512,
"encoder_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
# parameters for decoder
"embedding_dim": 512,
"decoder_dim": 512,
# parameters for joiner
"joiner_dim": 512,
# parameters for Noam
"warm_step": 60000, # For the 100h subset, use 8k
"model_warm_step": 4000, # arg given to model, not for lrate
@ -309,7 +311,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
num_features=params.feature_dim,
output_dim=params.embedding_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
d_model=params.encoder_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
@ -329,8 +331,10 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.embedding_dim,
output_dim=params.vocab_size,
encoder_dim=params.encoder_dim
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
)
return joiner
@ -344,7 +348,9 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
embedding_dim=params.embedding_dim,
encoder_dim=params.encoder_dim
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
)
return model
@ -748,7 +754,7 @@ def run(rank, world_size, args):
optimizer = Noam(
model.parameters(),
model_size=params.attention_dim,
model_size=params.encoder_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
)