mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Initial refactoring to remove unnecessary vocab_size
This commit is contained in:
parent
74121ac478
commit
709c387ce6
@ -32,9 +32,10 @@ class Conformer(EncoderInterface):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
num_features (int): Number of input features
|
num_features (int): Number of input features
|
||||||
output_dim (int): Number of output dimension
|
output_dim (int): Model output dimension. If not equal to the encoder dimension,
|
||||||
|
we will project to the output.
|
||||||
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||||
d_model (int): attention dimension
|
d_model (int): attention dimension, also the output dimension
|
||||||
nhead (int): number of head
|
nhead (int): number of head
|
||||||
dim_feedforward (int): feedforward dimention
|
dim_feedforward (int): feedforward dimention
|
||||||
num_encoder_layers (int): number of encoder layers
|
num_encoder_layers (int): number of encoder layers
|
||||||
@ -42,7 +43,6 @@ class Conformer(EncoderInterface):
|
|||||||
cnn_module_kernel (int): Kernel size of convolution module
|
cnn_module_kernel (int): Kernel size of convolution module
|
||||||
vgg_frontend (bool): whether to use vgg frontend.
|
vgg_frontend (bool): whether to use vgg frontend.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_features: int,
|
num_features: int,
|
||||||
@ -59,7 +59,6 @@ class Conformer(EncoderInterface):
|
|||||||
super(Conformer, self).__init__()
|
super(Conformer, self).__init__()
|
||||||
|
|
||||||
self.num_features = num_features
|
self.num_features = num_features
|
||||||
self.output_dim = output_dim
|
|
||||||
self.subsampling_factor = subsampling_factor
|
self.subsampling_factor = subsampling_factor
|
||||||
if subsampling_factor != 4:
|
if subsampling_factor != 4:
|
||||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||||
@ -83,7 +82,11 @@ class Conformer(EncoderInterface):
|
|||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
|
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers,
|
||||||
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
|
aux_layers=list(range(0, num_encoder_layers-1, aux_layer_period)))
|
||||||
|
|
||||||
self.encoder_output_layer = ScaledLinear(d_model, output_dim)
|
if output_dim == d_model:
|
||||||
|
self.encoder_output_layer = Identity()
|
||||||
|
else:
|
||||||
|
self.encoder_output_layer = ScaledLinear(d_model, output_dim,
|
||||||
|
initial_speed=0.5)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||||
@ -101,9 +104,9 @@ class Conformer(EncoderInterface):
|
|||||||
to turn modules on sequentially.
|
to turn modules on sequentially.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing 2 tensors:
|
Return a tuple containing 2 tensors:
|
||||||
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
- embeddings: its shape is (batch_size, output_seq_len, d_model)
|
||||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
- lengths, a tensor of shape (batch_size,) containing the number
|
||||||
of frames in `logits` before padding.
|
of frames in `embeddings` before padding.
|
||||||
"""
|
"""
|
||||||
x = self.encoder_embed(x)
|
x = self.encoder_embed(x)
|
||||||
x, pos_emb = self.encoder_pos(x)
|
x, pos_emb = self.encoder_pos(x)
|
||||||
@ -117,10 +120,10 @@ class Conformer(EncoderInterface):
|
|||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
|
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
|
||||||
warmup=warmup) # (T, N, C)
|
warmup=warmup) # (T, N, C)
|
||||||
|
|
||||||
logits = self.encoder_output_layer(x)
|
x = self.encoder_output_layer(x)
|
||||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
return logits, lengths
|
return x, lengths
|
||||||
|
|
||||||
|
|
||||||
class ConformerEncoderLayer(nn.Module):
|
class ConformerEncoderLayer(nn.Module):
|
||||||
|
@ -68,6 +68,7 @@ class Decoder(nn.Module):
|
|||||||
initial_speed=initial_speed
|
initial_speed=initial_speed
|
||||||
)
|
)
|
||||||
self.blank_id = blank_id
|
self.blank_id = blank_id
|
||||||
|
self.output_linear = ScaledLinear(embedding_dim, embedding_dim)
|
||||||
|
|
||||||
assert context_size >= 1, context_size
|
assert context_size >= 1, context_size
|
||||||
self.context_size = context_size
|
self.context_size = context_size
|
||||||
@ -81,8 +82,6 @@ class Decoder(nn.Module):
|
|||||||
groups=embedding_dim,
|
groups=embedding_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
self.output_linear = ScaledLinear(embedding_dim, vocab_size)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
@ -20,11 +20,10 @@ import torch.nn.functional as F
|
|||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
class Joiner(nn.Module):
|
class Joiner(nn.Module):
|
||||||
def __init__(self, input_dim: int, inner_dim: int, output_dim: int):
|
def __init__(self, input_dim: int, output_dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.inner_linear = ScaledLinear(input_dim, inner_dim)
|
self.output_linear = ScaledLinear(input_dim, output_dim)
|
||||||
self.output_linear = ScaledLinear(inner_dim, output_dim)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
||||||
@ -43,8 +42,6 @@ class Joiner(nn.Module):
|
|||||||
|
|
||||||
logit = encoder_out + decoder_out
|
logit = encoder_out + decoder_out
|
||||||
|
|
||||||
logit = self.inner_linear(torch.tanh(logit))
|
logit = self.output_linear(torch.tanh(logit))
|
||||||
|
|
||||||
output = self.output_linear(F.relu(logit))
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
@ -19,6 +19,7 @@ import k2
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
@ -33,6 +34,8 @@ class Transducer(nn.Module):
|
|||||||
encoder: EncoderInterface,
|
encoder: EncoderInterface,
|
||||||
decoder: nn.Module,
|
decoder: nn.Module,
|
||||||
joiner: nn.Module,
|
joiner: nn.Module,
|
||||||
|
embedding_dim: int,
|
||||||
|
vocab_size: int
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -58,6 +61,10 @@ class Transducer(nn.Module):
|
|||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.joiner = joiner
|
self.joiner = joiner
|
||||||
|
|
||||||
|
# could perhaps separate this into 2 linear projections, one
|
||||||
|
# for lm and one for am.
|
||||||
|
self.simple_joiner = nn.Linear(embedding_dim, vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -133,8 +140,8 @@ class Transducer(nn.Module):
|
|||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = x_lens
|
||||||
|
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm=decoder_out,
|
lm=self.simple_joiner(decoder_out),
|
||||||
am=encoder_out,
|
am=self.simple_joiner(encoder_out),
|
||||||
symbols=y_padded,
|
symbols=y_padded,
|
||||||
termination_symbol=blank_id,
|
termination_symbol=blank_id,
|
||||||
lm_only_scale=lm_scale,
|
lm_only_scale=lm_scale,
|
||||||
|
@ -306,7 +306,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
# TODO: We can add an option to switch between Conformer and Transformer
|
# TODO: We can add an option to switch between Conformer and Transformer
|
||||||
encoder = Conformer(
|
encoder = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
output_dim=params.vocab_size,
|
output_dim=params.embedding_dim,
|
||||||
subsampling_factor=params.subsampling_factor,
|
subsampling_factor=params.subsampling_factor,
|
||||||
d_model=params.attention_dim,
|
d_model=params.attention_dim,
|
||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
@ -328,8 +328,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.vocab_size,
|
input_dim=params.embedding_dim,
|
||||||
inner_dim=params.embedding_dim,
|
|
||||||
output_dim=params.vocab_size,
|
output_dim=params.vocab_size,
|
||||||
)
|
)
|
||||||
return joiner
|
return joiner
|
||||||
@ -344,6 +343,8 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
|
embedding_dim=params.embedding_dim,
|
||||||
|
vocab_size=params.vocab_size,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user