mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add 1-d convolution to text embedding module; reduce batch size
This commit is contained in:
parent
399a79ace6
commit
a405106d2f
@ -20,8 +20,50 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
from subformer import Subformer
|
from subformer import Subformer
|
||||||
|
from scaling import Balancer
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbedder(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
vocab_size: int,
|
||||||
|
embedding_dim: int):
|
||||||
|
self.embed = nn.Embedding(
|
||||||
|
num_embeddings=vocab_size,
|
||||||
|
embedding_dim=embedding_dim)
|
||||||
|
self.conv1 = nn.Conv1d(embedding_dim,
|
||||||
|
embedding_dim,
|
||||||
|
groups=embedding_dim,
|
||||||
|
kernel_size=2)
|
||||||
|
self.balancer = Balancer(embedding_dim,
|
||||||
|
channel_dim=-1,
|
||||||
|
min_positive=0.1,
|
||||||
|
min_abs=1.0,
|
||||||
|
max_abs=2.0)
|
||||||
|
self.activation1 = nn.ReLU()
|
||||||
|
self.out_proj = nn.Linear(embedding_dim,
|
||||||
|
embedding_dim,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
text: Tensor) -> Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
text: Tensor of shape (seq_len, batch_size), containing integer indexes
|
||||||
|
0 <= text < vocab_size.
|
||||||
|
Returns:
|
||||||
|
Tensor of shape (seq_len, batch_size, embedding_dim)
|
||||||
|
"""
|
||||||
|
x = self.embed(text) # (seq_len, batch_size, embedding_dim)
|
||||||
|
|
||||||
|
x = torch.cat((torch.zeros_like(x[0:1], x)), dim=0) # pad
|
||||||
|
x = x.permute(1, 2, 0) # N,C,H, i.e. (batch_size, embedding_dim, seq_len)
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = x.permute(2, 0, 1) # (seq_len, batch_size, embedding_dim)
|
||||||
|
x = self.balancer(x) # make sure no channel has all zeros.
|
||||||
|
x = self.activation1(x)
|
||||||
|
x = self.out_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
class SubformerLM(nn.Module):
|
class SubformerLM(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|||||||
@ -64,7 +64,7 @@ from subformer import Subformer
|
|||||||
from scaling import ScheduledFloat
|
from scaling import ScheduledFloat
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from model import SubformerLM
|
from model import SubformerLM, TextEmbedder
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -395,7 +395,7 @@ def get_params() -> AttributeDict:
|
|||||||
"warm_step": 2000,
|
"warm_step": 2000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
"bytes_per_segment": 2048,
|
"bytes_per_segment": 2048,
|
||||||
"batch_size": 20,
|
"batch_size": 18,
|
||||||
"train_file_list": "train.txt",
|
"train_file_list": "train.txt",
|
||||||
"valid_file_list": "valid.txt",
|
"valid_file_list": "valid.txt",
|
||||||
"num_workers": 4,
|
"num_workers": 4,
|
||||||
@ -411,8 +411,8 @@ def _to_int_tuple(s: str):
|
|||||||
|
|
||||||
|
|
||||||
def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||||
return nn.Embedding(
|
return TextEmbedder(
|
||||||
num_embeddings=256, # we encode the text as UTF-8 bytes
|
vocab_size=256, # we encode the text as UTF-8 bytes
|
||||||
embedding_dim=_to_int_tuple(params.encoder_dim)[0],
|
embedding_dim=_to_int_tuple(params.encoder_dim)[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user