Initial refactoring to remove unnecessary vocab_size

This commit is contained in:
Daniel Povey 2022-03-30 21:40:22 +08:00
parent 74121ac478
commit 709c387ce6
5 changed files with 31 additions and 24 deletions

View File

@ -32,9 +32,10 @@ class Conformer(EncoderInterface):
""" """
Args: Args:
num_features (int): Number of input features num_features (int): Number of input features
output_dim (int): Number of output dimension output_dim (int): Model output dimension. If not equal to the encoder dimension,
we will project to the output.
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
d_model (int): attention dimension d_model (int): attention dimension, also the output dimension
nhead (int): number of head nhead (int): number of head
dim_feedforward (int): feedforward dimention dim_feedforward (int): feedforward dimention
num_encoder_layers (int): number of encoder layers num_encoder_layers (int): number of encoder layers
@ -42,7 +43,6 @@ class Conformer(EncoderInterface):
cnn_module_kernel (int): Kernel size of convolution module cnn_module_kernel (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend. vgg_frontend (bool): whether to use vgg frontend.
""" """
def __init__( def __init__(
self, self,
num_features: int, num_features: int,
@ -59,7 +59,6 @@ class Conformer(EncoderInterface):
super(Conformer, self).__init__() super(Conformer, self).__init__()
self.num_features = num_features self.num_features = num_features
self.output_dim = output_dim
self.subsampling_factor = subsampling_factor self.subsampling_factor = subsampling_factor
if subsampling_factor != 4: if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.") raise NotImplementedError("Support only 'subsampling_factor=4'.")
@ -83,7 +82,11 @@ class Conformer(EncoderInterface):
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers, self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period))) aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
self.encoder_output_layer = ScaledLinear(d_model, output_dim) if output_dim == d_model:
self.encoder_output_layer = Identity()
else:
self.encoder_output_layer = ScaledLinear(d_model, output_dim,
initial_speed=0.5)
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
@ -101,9 +104,9 @@ class Conformer(EncoderInterface):
to turn modules on sequentially. to turn modules on sequentially.
Returns: Returns:
Return a tuple containing 2 tensors: Return a tuple containing 2 tensors:
- logits, its shape is (batch_size, output_seq_len, output_dim) - embeddings: its shape is (batch_size, output_seq_len, d_model)
- logit_lens, a tensor of shape (batch_size,) containing the number - lengths, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding. of frames in `embeddings` before padding.
""" """
x = self.encoder_embed(x) x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x) x, pos_emb = self.encoder_pos(x)
@ -117,10 +120,10 @@ class Conformer(EncoderInterface):
x = self.encoder(x, pos_emb, src_key_padding_mask=mask, x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
warmup=warmup) # (T, N, C) warmup=warmup) # (T, N, C)
logits = self.encoder_output_layer(x) x = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return logits, lengths return x, lengths
class ConformerEncoderLayer(nn.Module): class ConformerEncoderLayer(nn.Module):

View File

@ -68,6 +68,7 @@ class Decoder(nn.Module):
initial_speed=initial_speed initial_speed=initial_speed
) )
self.blank_id = blank_id self.blank_id = blank_id
self.output_linear = ScaledLinear(embedding_dim, embedding_dim)
assert context_size >= 1, context_size assert context_size >= 1, context_size
self.context_size = context_size self.context_size = context_size
@ -81,8 +82,6 @@ class Decoder(nn.Module):
groups=embedding_dim, groups=embedding_dim,
bias=False, bias=False,
) )
self.output_linear = ScaledLinear(embedding_dim, vocab_size)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
""" """

View File

@ -20,11 +20,10 @@ import torch.nn.functional as F
from scaling import ScaledLinear from scaling import ScaledLinear
class Joiner(nn.Module): class Joiner(nn.Module):
def __init__(self, input_dim: int, inner_dim: int, output_dim: int): def __init__(self, input_dim: int, output_dim: int):
super().__init__() super().__init__()
self.inner_linear = ScaledLinear(input_dim, inner_dim) self.output_linear = ScaledLinear(input_dim, output_dim)
self.output_linear = ScaledLinear(inner_dim, output_dim)
def forward( def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
@ -43,8 +42,6 @@ class Joiner(nn.Module):
logit = encoder_out + decoder_out logit = encoder_out + decoder_out
logit = self.inner_linear(torch.tanh(logit)) logit = self.output_linear(torch.tanh(logit))
output = self.output_linear(F.relu(logit))
return output return output

View File

@ -19,6 +19,7 @@ import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import add_sos from icefall.utils import add_sos
@ -33,6 +34,8 @@ class Transducer(nn.Module):
encoder: EncoderInterface, encoder: EncoderInterface,
decoder: nn.Module, decoder: nn.Module,
joiner: nn.Module, joiner: nn.Module,
embedding_dim: int,
vocab_size: int
): ):
""" """
Args: Args:
@ -58,6 +61,10 @@ class Transducer(nn.Module):
self.decoder = decoder self.decoder = decoder
self.joiner = joiner self.joiner = joiner
# could perhaps separate this into 2 linear projections, one
# for lm and one for am.
self.simple_joiner = nn.Linear(embedding_dim, vocab_size)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -133,8 +140,8 @@ class Transducer(nn.Module):
boundary[:, 3] = x_lens boundary[:, 3] = x_lens
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=decoder_out, lm=self.simple_joiner(decoder_out),
am=encoder_out, am=self.simple_joiner(encoder_out),
symbols=y_padded, symbols=y_padded,
termination_symbol=blank_id, termination_symbol=blank_id,
lm_only_scale=lm_scale, lm_only_scale=lm_scale,

View File

@ -306,7 +306,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer # TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.vocab_size, output_dim=params.embedding_dim,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim, d_model=params.attention_dim,
nhead=params.nhead, nhead=params.nhead,
@ -328,8 +328,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.vocab_size, input_dim=params.embedding_dim,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
) )
return joiner return joiner
@ -344,6 +343,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
embedding_dim=params.embedding_dim,
vocab_size=params.vocab_size,
) )
return model return model