Add 1-d convolution to text embedding module; reduce batch size

This commit is contained in:
Daniel Povey 2023-05-16 20:05:52 +08:00
parent 399a79ace6
commit a405106d2f
2 changed files with 46 additions and 4 deletions

View File

@ -20,8 +20,50 @@
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from subformer import Subformer from subformer import Subformer
from scaling import Balancer
class TextEmbedder(nn.Module):
def __init__(self,
vocab_size: int,
embedding_dim: int):
self.embed = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim)
self.conv1 = nn.Conv1d(embedding_dim,
embedding_dim,
groups=embedding_dim,
kernel_size=2)
self.balancer = Balancer(embedding_dim,
channel_dim=-1,
min_positive=0.1,
min_abs=1.0,
max_abs=2.0)
self.activation1 = nn.ReLU()
self.out_proj = nn.Linear(embedding_dim,
embedding_dim,
bias=False)
def forward(self,
text: Tensor) -> Tensor:
"""
Args:
text: Tensor of shape (seq_len, batch_size), containing integer indexes
0 <= text < vocab_size.
Returns:
Tensor of shape (seq_len, batch_size, embedding_dim)
"""
x = self.embed(text) # (seq_len, batch_size, embedding_dim)
x = torch.cat((torch.zeros_like(x[0:1], x)), dim=0) # pad
x = x.permute(1, 2, 0) # N,C,H, i.e. (batch_size, embedding_dim, seq_len)
x = self.conv1(x)
x = x.permute(2, 0, 1) # (seq_len, batch_size, embedding_dim)
x = self.balancer(x) # make sure no channel has all zeros.
x = self.activation1(x)
x = self.out_proj(x)
return x
class SubformerLM(nn.Module): class SubformerLM(nn.Module):
def __init__(self, def __init__(self,

View File

@ -64,7 +64,7 @@ from subformer import Subformer
from scaling import ScheduledFloat from scaling import ScheduledFloat
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from decoder import Decoder from decoder import Decoder
from model import SubformerLM from model import SubformerLM, TextEmbedder
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
@ -395,7 +395,7 @@ def get_params() -> AttributeDict:
"warm_step": 2000, "warm_step": 2000,
"env_info": get_env_info(), "env_info": get_env_info(),
"bytes_per_segment": 2048, "bytes_per_segment": 2048,
"batch_size": 20, "batch_size": 18,
"train_file_list": "train.txt", "train_file_list": "train.txt",
"valid_file_list": "valid.txt", "valid_file_list": "valid.txt",
"num_workers": 4, "num_workers": 4,
@ -411,8 +411,8 @@ def _to_int_tuple(s: str):
def get_encoder_embed(params: AttributeDict) -> nn.Module: def get_encoder_embed(params: AttributeDict) -> nn.Module:
return nn.Embedding( return TextEmbedder(
num_embeddings=256, # we encode the text as UTF-8 bytes vocab_size=256, # we encode the text as UTF-8 bytes
embedding_dim=_to_int_tuple(params.encoder_dim)[0], embedding_dim=_to_int_tuple(params.encoder_dim)[0],
) )