mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
First draft of model rework
This commit is contained in:
parent
eec597fdd5
commit
8be10d3d6c
@ -32,8 +32,6 @@ class Conformer(EncoderInterface):
|
||||
"""
|
||||
Args:
|
||||
num_features (int): Number of input features
|
||||
output_dim (int): Model output dimension. If not equal to the encoder dimension,
|
||||
we will project to the output.
|
||||
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||
d_model (int): attention dimension, also the output dimension
|
||||
nhead (int): number of head
|
||||
@ -47,7 +45,6 @@ class Conformer(EncoderInterface):
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
output_dim: int,
|
||||
subsampling_factor: int = 4,
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
@ -83,11 +80,6 @@ class Conformer(EncoderInterface):
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
|
||||
if output_dim == d_model:
|
||||
self.encoder_output_layer = nn.Identity()
|
||||
else:
|
||||
self.encoder_output_layer = ScaledLinear(d_model, output_dim,
|
||||
initial_speed=0.5)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
@ -123,7 +115,6 @@ class Conformer(EncoderInterface):
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
|
||||
warmup=warmup) # (T, N, C)
|
||||
|
||||
x = self.encoder_output_layer(x)
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return x, lengths
|
||||
@ -1116,7 +1107,7 @@ class Noam(object):
|
||||
|
||||
if __name__ == '__main__':
|
||||
feature_dim = 50
|
||||
c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
|
||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
# Just make sure the forward pass runs.
|
||||
|
@ -46,8 +46,8 @@ class Decoder(nn.Module):
|
||||
Args:
|
||||
vocab_size:
|
||||
Number of tokens of the modeling unit including blank.
|
||||
embedding_dim:
|
||||
Dimension of the input embedding.
|
||||
decoder_dim:
|
||||
Dimension of the input embedding, and of the decoder output.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
context_size:
|
||||
@ -63,23 +63,22 @@ class Decoder(nn.Module):
|
||||
initial_speed = 0.5
|
||||
self.embedding = ScaledEmbedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=embedding_dim,
|
||||
embedding_dim=decoder_dim,
|
||||
padding_idx=blank_id,
|
||||
initial_speed=initial_speed
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
self.output_linear = ScaledLinear(embedding_dim, embedding_dim)
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
self.vocab_size = vocab_size
|
||||
if context_size > 1:
|
||||
self.conv = ScaledConv1d(
|
||||
in_channels=embedding_dim,
|
||||
out_channels=embedding_dim,
|
||||
in_channels=decoder_dim,
|
||||
out_channels=decoder_dim,
|
||||
kernel_size=context_size,
|
||||
padding=0,
|
||||
groups=embedding_dim,
|
||||
groups=decoder_dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
@ -92,7 +91,7 @@ class Decoder(nn.Module):
|
||||
True to left pad the input. Should be True during training.
|
||||
False to not pad the input. Should be False during inference.
|
||||
Returns:
|
||||
Return a tensor of shape (N, U, embedding_dim).
|
||||
Return a tensor of shape (N, U, decoder_dim).
|
||||
"""
|
||||
y = y.to(torch.int64)
|
||||
embedding_out = self.embedding(y)
|
||||
@ -108,5 +107,5 @@ class Decoder(nn.Module):
|
||||
assert embedding_out.size(-1) == self.context_size
|
||||
embedding_out = self.conv(embedding_out)
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
embedding_out = self.output_linear(F.relu(embedding_out))
|
||||
embedding_out = F.relu(embedding_out)
|
||||
return embedding_out
|
||||
|
@ -20,11 +20,19 @@ import torch.nn.functional as F
|
||||
from scaling import ScaledLinear
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(self, input_dim: int, output_dim: int):
|
||||
def __init__(self,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int):
|
||||
super().__init__()
|
||||
|
||||
self.output_linear = ScaledLinear(input_dim, output_dim,
|
||||
initial_speed=0.5)
|
||||
# We don't bother giving the 'initial_speed' arg to the decoder
|
||||
# submodules, because it does not affect the initial convergence of the
|
||||
# system (only the simple joiner is involved in that).
|
||||
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim)
|
||||
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim)
|
||||
self.output_linear = ScaledLinear(joiner_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
@ -41,7 +49,7 @@ class Joiner(nn.Module):
|
||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||
assert encoder_out.shape == decoder_out.shape
|
||||
|
||||
logit = encoder_out + decoder_out
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
|
||||
|
@ -34,23 +34,25 @@ class Transducer(nn.Module):
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
embedding_dim: int,
|
||||
encoder_dim: int,
|
||||
decoder_dim: int,
|
||||
joiner_dim: int,
|
||||
vocab_size: int
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, C) and
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, C). It should contain
|
||||
is (N, U) and its output shape is (N, U, decoder_dim). It should contain
|
||||
one attribute: `blank_id`.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
||||
output shape is (N, T, U, C). Note that its output contains
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). Its
|
||||
output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
"""
|
||||
super().__init__()
|
||||
@ -61,17 +63,10 @@ class Transducer(nn.Module):
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size,
|
||||
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size,
|
||||
initial_speed=0.5)
|
||||
self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size,
|
||||
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size,
|
||||
initial_speed=0.5)
|
||||
with torch.no_grad():
|
||||
# Initialize the two projections to be the same; this will be
|
||||
# convenient for the real joiner, which adds the endcoder
|
||||
# (acoustic-model/am) and decoder (language-model/lm) embeddings
|
||||
self.simple_lm_proj.weight[:] = self.simple_am_proj.weight
|
||||
self.simple_lm_proj.bias[:] = self.simple_am_proj.bias
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -133,7 +128,7 @@ class Transducer(nn.Module):
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, C]
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
@ -167,13 +162,13 @@ class Transducer(nn.Module):
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, C]
|
||||
# lm_pruned : [B, T, prune_range, C]
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=encoder_out, lm=decoder_out, ranges=ranges
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, C]
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
logits = self.joiner(am_pruned, lm_pruned)
|
||||
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
|
@ -268,7 +268,7 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- attention_dim: Hidden dim for multi-head attention model.
|
||||
- encoder_dim: Hidden dim for multi-head attention model.
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
@ -287,12 +287,14 @@ def get_params() -> AttributeDict:
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"attention_dim": 512,
|
||||
"encoder_dim": 512,
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
# parameters for decoder
|
||||
"embedding_dim": 512,
|
||||
"decoder_dim": 512,
|
||||
# parameters for joiner
|
||||
"joiner_dim": 512,
|
||||
# parameters for Noam
|
||||
"warm_step": 60000, # For the 100h subset, use 8k
|
||||
"model_warm_step": 4000, # arg given to model, not for lrate
|
||||
@ -309,7 +311,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.embedding_dim,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
d_model=params.attention_dim,
|
||||
d_model=params.encoder_dim,
|
||||
nhead=params.nhead,
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
@ -329,8 +331,10 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.embedding_dim,
|
||||
output_dim=params.vocab_size,
|
||||
encoder_dim=params.encoder_dim
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
)
|
||||
return joiner
|
||||
|
||||
@ -344,7 +348,9 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
embedding_dim=params.embedding_dim,
|
||||
encoder_dim=params.encoder_dim
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
)
|
||||
return model
|
||||
@ -748,7 +754,7 @@ def run(rank, world_size, args):
|
||||
|
||||
optimizer = Noam(
|
||||
model.parameters(),
|
||||
model_size=params.attention_dim,
|
||||
model_size=params.encoder_dim,
|
||||
factor=params.lr_factor,
|
||||
warm_step=params.warm_step,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user