mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add 1-d convolution to text embedding module; reduce batch size
This commit is contained in:
parent
399a79ace6
commit
a405106d2f
@ -20,8 +20,50 @@
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
from subformer import Subformer
|
||||
from scaling import Balancer
|
||||
|
||||
|
||||
class TextEmbedder(nn.Module):
|
||||
def __init__(self,
|
||||
vocab_size: int,
|
||||
embedding_dim: int):
|
||||
self.embed = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=embedding_dim)
|
||||
self.conv1 = nn.Conv1d(embedding_dim,
|
||||
embedding_dim,
|
||||
groups=embedding_dim,
|
||||
kernel_size=2)
|
||||
self.balancer = Balancer(embedding_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=0.1,
|
||||
min_abs=1.0,
|
||||
max_abs=2.0)
|
||||
self.activation1 = nn.ReLU()
|
||||
self.out_proj = nn.Linear(embedding_dim,
|
||||
embedding_dim,
|
||||
bias=False)
|
||||
|
||||
def forward(self,
|
||||
text: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
text: Tensor of shape (seq_len, batch_size), containing integer indexes
|
||||
0 <= text < vocab_size.
|
||||
Returns:
|
||||
Tensor of shape (seq_len, batch_size, embedding_dim)
|
||||
"""
|
||||
x = self.embed(text) # (seq_len, batch_size, embedding_dim)
|
||||
|
||||
x = torch.cat((torch.zeros_like(x[0:1], x)), dim=0) # pad
|
||||
x = x.permute(1, 2, 0) # N,C,H, i.e. (batch_size, embedding_dim, seq_len)
|
||||
x = self.conv1(x)
|
||||
x = x.permute(2, 0, 1) # (seq_len, batch_size, embedding_dim)
|
||||
x = self.balancer(x) # make sure no channel has all zeros.
|
||||
x = self.activation1(x)
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
class SubformerLM(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@ -64,7 +64,7 @@ from subformer import Subformer
|
||||
from scaling import ScheduledFloat
|
||||
from lhotse.utils import fix_random_seed
|
||||
from decoder import Decoder
|
||||
from model import SubformerLM
|
||||
from model import SubformerLM, TextEmbedder
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
@ -395,7 +395,7 @@ def get_params() -> AttributeDict:
|
||||
"warm_step": 2000,
|
||||
"env_info": get_env_info(),
|
||||
"bytes_per_segment": 2048,
|
||||
"batch_size": 20,
|
||||
"batch_size": 18,
|
||||
"train_file_list": "train.txt",
|
||||
"valid_file_list": "valid.txt",
|
||||
"num_workers": 4,
|
||||
@ -411,8 +411,8 @@ def _to_int_tuple(s: str):
|
||||
|
||||
|
||||
def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||
return nn.Embedding(
|
||||
num_embeddings=256, # we encode the text as UTF-8 bytes
|
||||
return TextEmbedder(
|
||||
vocab_size=256, # we encode the text as UTF-8 bytes
|
||||
embedding_dim=_to_int_tuple(params.encoder_dim)[0],
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user