mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
First draft of model rework
This commit is contained in:
parent
eec597fdd5
commit
8be10d3d6c
@ -32,8 +32,6 @@ class Conformer(EncoderInterface):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
num_features (int): Number of input features
|
num_features (int): Number of input features
|
||||||
output_dim (int): Model output dimension. If not equal to the encoder dimension,
|
|
||||||
we will project to the output.
|
|
||||||
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||||
d_model (int): attention dimension, also the output dimension
|
d_model (int): attention dimension, also the output dimension
|
||||||
nhead (int): number of head
|
nhead (int): number of head
|
||||||
@ -47,7 +45,6 @@ class Conformer(EncoderInterface):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
output_dim: int,
|
|
||||||
subsampling_factor: int = 4,
|
subsampling_factor: int = 4,
|
||||||
d_model: int = 256,
|
d_model: int = 256,
|
||||||
nhead: int = 4,
|
nhead: int = 4,
|
||||||
@ -83,11 +80,6 @@ class Conformer(EncoderInterface):
|
|||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||||
|
|
||||||
if output_dim == d_model:
|
|
||||||
self.encoder_output_layer = nn.Identity()
|
|
||||||
else:
|
|
||||||
self.encoder_output_layer = ScaledLinear(d_model, output_dim,
|
|
||||||
initial_speed=0.5)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||||
@ -123,7 +115,6 @@ class Conformer(EncoderInterface):
|
|||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
|
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
|
||||||
warmup=warmup) # (T, N, C)
|
warmup=warmup) # (T, N, C)
|
||||||
|
|
||||||
x = self.encoder_output_layer(x)
|
|
||||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
return x, lengths
|
return x, lengths
|
||||||
@ -1116,7 +1107,7 @@ class Noam(object):
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
feature_dim = 50
|
feature_dim = 50
|
||||||
c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
|
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
seq_len = 20
|
seq_len = 20
|
||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
|
@ -46,8 +46,8 @@ class Decoder(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
vocab_size:
|
vocab_size:
|
||||||
Number of tokens of the modeling unit including blank.
|
Number of tokens of the modeling unit including blank.
|
||||||
embedding_dim:
|
decoder_dim:
|
||||||
Dimension of the input embedding.
|
Dimension of the input embedding, and of the decoder output.
|
||||||
blank_id:
|
blank_id:
|
||||||
The ID of the blank symbol.
|
The ID of the blank symbol.
|
||||||
context_size:
|
context_size:
|
||||||
@ -63,23 +63,22 @@ class Decoder(nn.Module):
|
|||||||
initial_speed = 0.5
|
initial_speed = 0.5
|
||||||
self.embedding = ScaledEmbedding(
|
self.embedding = ScaledEmbedding(
|
||||||
num_embeddings=vocab_size,
|
num_embeddings=vocab_size,
|
||||||
embedding_dim=embedding_dim,
|
embedding_dim=decoder_dim,
|
||||||
padding_idx=blank_id,
|
padding_idx=blank_id,
|
||||||
initial_speed=initial_speed
|
initial_speed=initial_speed
|
||||||
)
|
)
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
self.output_linear = ScaledLinear(embedding_dim, embedding_dim)
|
|
||||||
|
|
||||||
assert context_size >= 1, context_size
|
assert context_size >= 1, context_size
|
||||||
self.context_size = context_size
|
self.context_size = context_size
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
if context_size > 1:
|
if context_size > 1:
|
||||||
self.conv = ScaledConv1d(
|
self.conv = ScaledConv1d(
|
||||||
in_channels=embedding_dim,
|
in_channels=decoder_dim,
|
||||||
out_channels=embedding_dim,
|
out_channels=decoder_dim,
|
||||||
kernel_size=context_size,
|
kernel_size=context_size,
|
||||||
padding=0,
|
padding=0,
|
||||||
groups=embedding_dim,
|
groups=decoder_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -92,7 +91,7 @@ class Decoder(nn.Module):
|
|||||||
True to left pad the input. Should be True during training.
|
True to left pad the input. Should be True during training.
|
||||||
False to not pad the input. Should be False during inference.
|
False to not pad the input. Should be False during inference.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, U, embedding_dim).
|
Return a tensor of shape (N, U, decoder_dim).
|
||||||
"""
|
"""
|
||||||
y = y.to(torch.int64)
|
y = y.to(torch.int64)
|
||||||
embedding_out = self.embedding(y)
|
embedding_out = self.embedding(y)
|
||||||
@ -108,5 +107,5 @@ class Decoder(nn.Module):
|
|||||||
assert embedding_out.size(-1) == self.context_size
|
assert embedding_out.size(-1) == self.context_size
|
||||||
embedding_out = self.conv(embedding_out)
|
embedding_out = self.conv(embedding_out)
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
embedding_out = self.output_linear(F.relu(embedding_out))
|
embedding_out = F.relu(embedding_out)
|
||||||
return embedding_out
|
return embedding_out
|
||||||
|
@ -20,11 +20,19 @@ import torch.nn.functional as F
|
|||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
def __init__(self, input_dim: int, output_dim: int):
|
def __init__(self,
|
||||||
|
encoder_dim: int,
|
||||||
|
decoder_dim: int,
|
||||||
|
joiner_dim: int,
|
||||||
|
vocab_size: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.output_linear = ScaledLinear(input_dim, output_dim,
|
# We don't bother giving the 'initial_speed' arg to the decoder
|
||||||
initial_speed=0.5)
|
# submodules, because it does not affect the initial convergence of the
|
||||||
|
# system (only the simple joiner is involved in that).
|
||||||
|
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim)
|
||||||
|
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim)
|
||||||
|
self.output_linear = ScaledLinear(joiner_dim, vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||||
@ -41,7 +49,7 @@ class Joiner(nn.Module):
|
|||||||
assert encoder_out.ndim == decoder_out.ndim == 4
|
assert encoder_out.ndim == decoder_out.ndim == 4
|
||||||
assert encoder_out.shape == decoder_out.shape
|
assert encoder_out.shape == decoder_out.shape
|
||||||
|
|
||||||
logit = encoder_out + decoder_out
|
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||||
|
|
||||||
logit = self.output_linear(torch.tanh(logit))
|
logit = self.output_linear(torch.tanh(logit))
|
||||||
|
|
||||||
|
@ -34,23 +34,25 @@ class Transducer(nn.Module):
|
|||||||
encoder: EncoderInterface,
|
encoder: EncoderInterface,
|
||||||
decoder: nn.Module,
|
decoder: nn.Module,
|
||||||
joiner: nn.Module,
|
joiner: nn.Module,
|
||||||
embedding_dim: int,
|
encoder_dim: int,
|
||||||
|
decoder_dim: int,
|
||||||
|
joiner_dim: int,
|
||||||
vocab_size: int
|
vocab_size: int
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
encoder:
|
encoder:
|
||||||
It is the transcription network in the paper. Its accepts
|
It is the transcription network in the paper. Its accepts
|
||||||
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||||
It returns two tensors: `logits` of shape (N, T, C) and
|
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||||
`logit_lens` of shape (N,).
|
`logit_lens` of shape (N,).
|
||||||
decoder:
|
decoder:
|
||||||
It is the prediction network in the paper. Its input shape
|
It is the prediction network in the paper. Its input shape
|
||||||
is (N, U) and its output shape is (N, U, C). It should contain
|
is (N, U) and its output shape is (N, U, decoder_dim). It should contain
|
||||||
one attribute: `blank_id`.
|
one attribute: `blank_id`.
|
||||||
joiner:
|
joiner:
|
||||||
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
|
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). Its
|
||||||
output shape is (N, T, U, C). Note that its output contains
|
output shape is (N, T, U, vocab_size). Note that its output contains
|
||||||
unnormalized probs, i.e., not processed by log-softmax.
|
unnormalized probs, i.e., not processed by log-softmax.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -61,17 +63,10 @@ class Transducer(nn.Module):
|
|||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.joiner = joiner
|
self.joiner = joiner
|
||||||
|
|
||||||
self.simple_am_proj = ScaledLinear(embedding_dim, vocab_size,
|
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size,
|
||||||
initial_speed=0.5)
|
initial_speed=0.5)
|
||||||
self.simple_lm_proj = ScaledLinear(embedding_dim, vocab_size,
|
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size,
|
||||||
initial_speed=0.5)
|
initial_speed=0.5)
|
||||||
with torch.no_grad():
|
|
||||||
# Initialize the two projections to be the same; this will be
|
|
||||||
# convenient for the real joiner, which adds the endcoder
|
|
||||||
# (acoustic-model/am) and decoder (language-model/lm) embeddings
|
|
||||||
self.simple_lm_proj.weight[:] = self.simple_am_proj.weight
|
|
||||||
self.simple_lm_proj.bias[:] = self.simple_am_proj.bias
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -133,7 +128,7 @@ class Transducer(nn.Module):
|
|||||||
# sos_y_padded: [B, S + 1], start with SOS.
|
# sos_y_padded: [B, S + 1], start with SOS.
|
||||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||||
|
|
||||||
# decoder_out: [B, S + 1, C]
|
# decoder_out: [B, S + 1, decoder_dim]
|
||||||
decoder_out = self.decoder(sos_y_padded)
|
decoder_out = self.decoder(sos_y_padded)
|
||||||
|
|
||||||
# Note: y does not start with SOS
|
# Note: y does not start with SOS
|
||||||
@ -167,13 +162,13 @@ class Transducer(nn.Module):
|
|||||||
s_range=prune_range,
|
s_range=prune_range,
|
||||||
)
|
)
|
||||||
|
|
||||||
# am_pruned : [B, T, prune_range, C]
|
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||||
# lm_pruned : [B, T, prune_range, C]
|
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||||
am=encoder_out, lm=decoder_out, ranges=ranges
|
am=encoder_out, lm=decoder_out, ranges=ranges
|
||||||
)
|
)
|
||||||
|
|
||||||
# logits : [B, T, prune_range, C]
|
# logits : [B, T, prune_range, vocab_size]
|
||||||
logits = self.joiner(am_pruned, lm_pruned)
|
logits = self.joiner(am_pruned, lm_pruned)
|
||||||
|
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
|
@ -268,7 +268,7 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
- subsampling_factor: The subsampling factor for the model.
|
- subsampling_factor: The subsampling factor for the model.
|
||||||
|
|
||||||
- attention_dim: Hidden dim for multi-head attention model.
|
- encoder_dim: Hidden dim for multi-head attention model.
|
||||||
|
|
||||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||||
|
|
||||||
@ -287,12 +287,14 @@ def get_params() -> AttributeDict:
|
|||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"attention_dim": 512,
|
"encoder_dim": 512,
|
||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
# parameters for decoder
|
# parameters for decoder
|
||||||
"embedding_dim": 512,
|
"decoder_dim": 512,
|
||||||
|
# parameters for joiner
|
||||||
|
"joiner_dim": 512,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"warm_step": 60000, # For the 100h subset, use 8k
|
"warm_step": 60000, # For the 100h subset, use 8k
|
||||||
"model_warm_step": 4000, # arg given to model, not for lrate
|
"model_warm_step": 4000, # arg given to model, not for lrate
|
||||||
@ -309,7 +311,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
output_dim=params.embedding_dim,
|
output_dim=params.embedding_dim,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
d_model=params.attention_dim,
|
d_model=params.encoder_dim,
|
||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
@ -329,8 +331,10 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.embedding_dim,
|
encoder_dim=params.encoder_dim
|
||||||
output_dim=params.vocab_size,
|
decoder_dim=params.decoder_dim,
|
||||||
|
joiner_dim=params.joiner_dim,
|
||||||
|
vocab_size=params.vocab_size,
|
||||||
)
|
)
|
||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
@ -344,7 +348,9 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
embedding_dim=params.embedding_dim,
|
encoder_dim=params.encoder_dim
|
||||||
|
decoder_dim=params.decoder_dim,
|
||||||
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
@ -748,7 +754,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
optimizer = Noam(
|
optimizer = Noam(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
model_size=params.attention_dim,
|
model_size=params.encoder_dim,
|
||||||
factor=params.lr_factor,
|
factor=params.lr_factor,
|
||||||
warm_step=params.warm_step,
|
warm_step=params.warm_step,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user