mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Initial refactoring to remove unnecessary vocab_size
This commit is contained in:
parent
74121ac478
commit
709c387ce6
@ -32,9 +32,10 @@ class Conformer(EncoderInterface):
|
||||
"""
|
||||
Args:
|
||||
num_features (int): Number of input features
|
||||
output_dim (int): Number of output dimension
|
||||
output_dim (int): Model output dimension. If not equal to the encoder dimension,
|
||||
we will project to the output.
|
||||
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||
d_model (int): attention dimension
|
||||
d_model (int): attention dimension, also the output dimension
|
||||
nhead (int): number of head
|
||||
dim_feedforward (int): feedforward dimention
|
||||
num_encoder_layers (int): number of encoder layers
|
||||
@ -42,7 +43,6 @@ class Conformer(EncoderInterface):
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
vgg_frontend (bool): whether to use vgg frontend.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
@ -59,7 +59,6 @@ class Conformer(EncoderInterface):
|
||||
super(Conformer, self).__init__()
|
||||
|
||||
self.num_features = num_features
|
||||
self.output_dim = output_dim
|
||||
self.subsampling_factor = subsampling_factor
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
@ -83,7 +82,11 @@ class Conformer(EncoderInterface):
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
|
||||
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
|
||||
|
||||
self.encoder_output_layer = ScaledLinear(d_model, output_dim)
|
||||
if output_dim == d_model:
|
||||
self.encoder_output_layer = Identity()
|
||||
else:
|
||||
self.encoder_output_layer = ScaledLinear(d_model, output_dim,
|
||||
initial_speed=0.5)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||
@ -101,9 +104,9 @@ class Conformer(EncoderInterface):
|
||||
to turn modules on sequentially.
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `logits` before padding.
|
||||
- embeddings: its shape is (batch_size, output_seq_len, d_model)
|
||||
- lengths, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `embeddings` before padding.
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
@ -117,10 +120,10 @@ class Conformer(EncoderInterface):
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
|
||||
warmup=warmup) # (T, N, C)
|
||||
|
||||
logits = self.encoder_output_layer(x)
|
||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
x = self.encoder_output_layer(x)
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
return logits, lengths
|
||||
return x, lengths
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
|
@ -68,6 +68,7 @@ class Decoder(nn.Module):
|
||||
initial_speed=initial_speed
|
||||
)
|
||||
self.blank_id = blank_id
|
||||
self.output_linear = ScaledLinear(embedding_dim, embedding_dim)
|
||||
|
||||
assert context_size >= 1, context_size
|
||||
self.context_size = context_size
|
||||
@ -81,8 +82,6 @@ class Decoder(nn.Module):
|
||||
groups=embedding_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.output_linear = ScaledLinear(embedding_dim, vocab_size)
|
||||
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
|
@ -20,11 +20,10 @@ import torch.nn.functional as F
|
||||
from scaling import ScaledLinear
|
||||
|
||||
class Joiner(nn.Module):
|
||||
def __init__(self, input_dim: int, inner_dim: int, output_dim: int):
|
||||
def __init__(self, input_dim: int, output_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.inner_linear = ScaledLinear(input_dim, inner_dim)
|
||||
self.output_linear = ScaledLinear(inner_dim, output_dim)
|
||||
self.output_linear = ScaledLinear(input_dim, output_dim)
|
||||
|
||||
def forward(
|
||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||
@ -43,8 +42,6 @@ class Joiner(nn.Module):
|
||||
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
logit = self.inner_linear(torch.tanh(logit))
|
||||
|
||||
output = self.output_linear(F.relu(logit))
|
||||
logit = self.output_linear(torch.tanh(logit))
|
||||
|
||||
return output
|
||||
|
@ -19,6 +19,7 @@ import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos
|
||||
|
||||
@ -33,6 +34,8 @@ class Transducer(nn.Module):
|
||||
encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
embedding_dim: int,
|
||||
vocab_size: int
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@ -58,6 +61,10 @@ class Transducer(nn.Module):
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
# could perhaps separate this into 2 linear projections, one
|
||||
# for lm and one for am.
|
||||
self.simple_joiner = nn.Linear(embedding_dim, vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -133,8 +140,8 @@ class Transducer(nn.Module):
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=decoder_out,
|
||||
am=encoder_out,
|
||||
lm=self.simple_joiner(decoder_out),
|
||||
am=self.simple_joiner(encoder_out),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
|
@ -306,7 +306,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
# TODO: We can add an option to switch between Conformer and Transformer
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.vocab_size,
|
||||
output_dim=params.embedding_dim,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
d_model=params.attention_dim,
|
||||
nhead=params.nhead,
|
||||
@ -328,8 +328,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.vocab_size,
|
||||
inner_dim=params.embedding_dim,
|
||||
input_dim=params.embedding_dim,
|
||||
output_dim=params.vocab_size,
|
||||
)
|
||||
return joiner
|
||||
@ -344,6 +343,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
embedding_dim=params.embedding_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
)
|
||||
return model
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user