Add 1-d convolution to text embedding module; reduce batch size

This commit is contained in:
Daniel Povey 2023-05-16 20:05:52 +08:00
parent 399a79ace6
commit a405106d2f
2 changed files with 46 additions and 4 deletions

View File

@ -20,8 +20,50 @@
import torch
from torch import nn, Tensor
from subformer import Subformer
from scaling import Balancer
class TextEmbedder(nn.Module):
def __init__(self,
vocab_size: int,
embedding_dim: int):
self.embed = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim)
self.conv1 = nn.Conv1d(embedding_dim,
embedding_dim,
groups=embedding_dim,
kernel_size=2)
self.balancer = Balancer(embedding_dim,
channel_dim=-1,
min_positive=0.1,
min_abs=1.0,
max_abs=2.0)
self.activation1 = nn.ReLU()
self.out_proj = nn.Linear(embedding_dim,
embedding_dim,
bias=False)
def forward(self,
text: Tensor) -> Tensor:
"""
Args:
text: Tensor of shape (seq_len, batch_size), containing integer indexes
0 <= text < vocab_size.
Returns:
Tensor of shape (seq_len, batch_size, embedding_dim)
"""
x = self.embed(text) # (seq_len, batch_size, embedding_dim)
x = torch.cat((torch.zeros_like(x[0:1], x)), dim=0) # pad
x = x.permute(1, 2, 0) # N,C,H, i.e. (batch_size, embedding_dim, seq_len)
x = self.conv1(x)
x = x.permute(2, 0, 1) # (seq_len, batch_size, embedding_dim)
x = self.balancer(x) # make sure no channel has all zeros.
x = self.activation1(x)
x = self.out_proj(x)
return x
class SubformerLM(nn.Module):
def __init__(self,

View File

@ -64,7 +64,7 @@ from subformer import Subformer
from scaling import ScheduledFloat
from lhotse.utils import fix_random_seed
from decoder import Decoder
from model import SubformerLM
from model import SubformerLM, TextEmbedder
from optim import Eden, ScaledAdam
from torch import Tensor
from torch import nn
@ -395,7 +395,7 @@ def get_params() -> AttributeDict:
"warm_step": 2000,
"env_info": get_env_info(),
"bytes_per_segment": 2048,
"batch_size": 20,
"batch_size": 18,
"train_file_list": "train.txt",
"valid_file_list": "valid.txt",
"num_workers": 4,
@ -411,8 +411,8 @@ def _to_int_tuple(s: str):
def get_encoder_embed(params: AttributeDict) -> nn.Module:
return nn.Embedding(
num_embeddings=256, # we encode the text as UTF-8 bytes
return TextEmbedder(
vocab_size=256, # we encode the text as UTF-8 bytes
embedding_dim=_to_int_tuple(params.encoder_dim)[0],
)