Initial refactoring to remove unnecessary vocab_size

This commit is contained in:
Daniel Povey 2022-03-30 21:40:22 +08:00
parent 74121ac478
commit 709c387ce6
5 changed files with 31 additions and 24 deletions

View File

@ -32,9 +32,10 @@ class Conformer(EncoderInterface):
"""
Args:
num_features (int): Number of input features
output_dim (int): Number of output dimension
output_dim (int): Model output dimension. If not equal to the encoder dimension,
we will project to the output.
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension
d_model (int): attention dimension, also the output dimension
nhead (int): number of head
dim_feedforward (int): feedforward dimention
num_encoder_layers (int): number of encoder layers
@ -42,7 +43,6 @@ class Conformer(EncoderInterface):
cnn_module_kernel (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend.
"""
def __init__(
self,
num_features: int,
@ -59,7 +59,6 @@ class Conformer(EncoderInterface):
super(Conformer, self).__init__()
self.num_features = num_features
self.output_dim = output_dim
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
@ -83,7 +82,11 @@ class Conformer(EncoderInterface):
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
self.encoder_output_layer = ScaledLinear(d_model, output_dim)
if output_dim == d_model:
self.encoder_output_layer = Identity()
else:
self.encoder_output_layer = ScaledLinear(d_model, output_dim,
initial_speed=0.5)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
@ -101,9 +104,9 @@ class Conformer(EncoderInterface):
to turn modules on sequentially.
Returns:
Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim)
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
- embeddings: its shape is (batch_size, output_seq_len, d_model)
- lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding.
"""
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
@ -117,10 +120,10 @@ class Conformer(EncoderInterface):
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
warmup=warmup) # (T, N, C)
logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
x = self.encoder_output_layer(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return logits, lengths
return x, lengths
class ConformerEncoderLayer(nn.Module):

View File

@ -68,6 +68,7 @@ class Decoder(nn.Module):
initial_speed=initial_speed
)
self.blank_id = blank_id
self.output_linear = ScaledLinear(embedding_dim, embedding_dim)
assert context_size >= 1, context_size
self.context_size = context_size
@ -81,8 +82,6 @@ class Decoder(nn.Module):
groups=embedding_dim,
bias=False,
)
self.output_linear = ScaledLinear(embedding_dim, vocab_size)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""

View File

@ -20,11 +20,10 @@ import torch.nn.functional as F
from scaling import ScaledLinear
class Joiner(nn.Module):
def __init__(self, input_dim: int, inner_dim: int, output_dim: int):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.inner_linear = ScaledLinear(input_dim, inner_dim)
self.output_linear = ScaledLinear(inner_dim, output_dim)
self.output_linear = ScaledLinear(input_dim, output_dim)
def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
@ -43,8 +42,6 @@ class Joiner(nn.Module):
logit = encoder_out + decoder_out
logit = self.inner_linear(torch.tanh(logit))
output = self.output_linear(F.relu(logit))
logit = self.output_linear(torch.tanh(logit))
return output

View File

@ -19,6 +19,7 @@ import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos
@ -33,6 +34,8 @@ class Transducer(nn.Module):
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
embedding_dim: int,
vocab_size: int
):
"""
Args:
@ -58,6 +61,10 @@ class Transducer(nn.Module):
self.decoder = decoder
self.joiner = joiner
# could perhaps separate this into 2 linear projections, one
# for lm and one for am.
self.simple_joiner = nn.Linear(embedding_dim, vocab_size)
def forward(
self,
x: torch.Tensor,
@ -133,8 +140,8 @@ class Transducer(nn.Module):
boundary[:, 3] = x_lens
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=decoder_out,
am=encoder_out,
lm=self.simple_joiner(decoder_out),
am=self.simple_joiner(encoder_out),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,

View File

@ -306,7 +306,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.vocab_size,
output_dim=params.embedding_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
@ -328,8 +328,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.vocab_size,
inner_dim=params.embedding_dim,
input_dim=params.embedding_dim,
output_dim=params.vocab_size,
)
return joiner
@ -344,6 +343,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
embedding_dim=params.embedding_dim,
vocab_size=params.vocab_size,
)
return model