From a405106d2f40a4969543ccb17c8f232fb5563ed8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Tue, 16 May 2023 20:05:52 +0800 Subject: [PATCH] Add 1-d convolution to text embedding module; reduce batch size --- egs/libriheavy/LM/zipformer1/model.py | 42 +++++++++++++++++++++++++++ egs/libriheavy/LM/zipformer1/train.py | 8 ++--- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/egs/libriheavy/LM/zipformer1/model.py b/egs/libriheavy/LM/zipformer1/model.py index 1c0bdb9f9..60ee6d6bb 100644 --- a/egs/libriheavy/LM/zipformer1/model.py +++ b/egs/libriheavy/LM/zipformer1/model.py @@ -20,8 +20,50 @@ import torch from torch import nn, Tensor from subformer import Subformer +from scaling import Balancer +class TextEmbedder(nn.Module): + def __init__(self, + vocab_size: int, + embedding_dim: int): + self.embed = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim) + self.conv1 = nn.Conv1d(embedding_dim, + embedding_dim, + groups=embedding_dim, + kernel_size=2) + self.balancer = Balancer(embedding_dim, + channel_dim=-1, + min_positive=0.1, + min_abs=1.0, + max_abs=2.0) + self.activation1 = nn.ReLU() + self.out_proj = nn.Linear(embedding_dim, + embedding_dim, + bias=False) + + def forward(self, + text: Tensor) -> Tensor: + """ + Args: + text: Tensor of shape (seq_len, batch_size), containing integer indexes + 0 <= text < vocab_size. + Returns: + Tensor of shape (seq_len, batch_size, embedding_dim) + """ + x = self.embed(text) # (seq_len, batch_size, embedding_dim) + + x = torch.cat((torch.zeros_like(x[0:1], x)), dim=0) # pad + x = x.permute(1, 2, 0) # N,C,H, i.e. (batch_size, embedding_dim, seq_len) + x = self.conv1(x) + x = x.permute(2, 0, 1) # (seq_len, batch_size, embedding_dim) + x = self.balancer(x) # make sure no channel has all zeros. + x = self.activation1(x) + x = self.out_proj(x) + return x + class SubformerLM(nn.Module): def __init__(self, diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index 278c5bdd6..e9daa1d8f 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -64,7 +64,7 @@ from subformer import Subformer from scaling import ScheduledFloat from lhotse.utils import fix_random_seed from decoder import Decoder -from model import SubformerLM +from model import SubformerLM, TextEmbedder from optim import Eden, ScaledAdam from torch import Tensor from torch import nn @@ -395,7 +395,7 @@ def get_params() -> AttributeDict: "warm_step": 2000, "env_info": get_env_info(), "bytes_per_segment": 2048, - "batch_size": 20, + "batch_size": 18, "train_file_list": "train.txt", "valid_file_list": "valid.txt", "num_workers": 4, @@ -411,8 +411,8 @@ def _to_int_tuple(s: str): def get_encoder_embed(params: AttributeDict) -> nn.Module: - return nn.Embedding( - num_embeddings=256, # we encode the text as UTF-8 bytes + return TextEmbedder( + vocab_size=256, # we encode the text as UTF-8 bytes embedding_dim=_to_int_tuple(params.encoder_dim)[0], )