Bug fixes

This commit is contained in:
Daniel Povey 2023-05-16 23:08:13 +08:00
parent a405106d2f
commit 610b2270aa

View File

@ -27,6 +27,7 @@ class TextEmbedder(nn.Module):
def __init__(self,
vocab_size: int,
embedding_dim: int):
super().__init__()
self.embed = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim)
@ -34,12 +35,23 @@ class TextEmbedder(nn.Module):
embedding_dim,
groups=embedding_dim,
kernel_size=2)
self.balancer = Balancer(embedding_dim,
channel_dim=-1,
min_positive=0.1,
min_abs=1.0,
max_abs=2.0)
self.balancer1 = Balancer(embedding_dim,
channel_dim=1,
min_positive=0.1,
min_abs=1.0,
max_abs=2.0)
self.activation1 = nn.ReLU()
self.conv2 = nn.Conv1d(embedding_dim,
embedding_dim,
kernel_size=2)
self.balancer2 = Balancer(embedding_dim,
channel_dim=1,
min_positive=0.1,
min_abs=1.0,
max_abs=2.0)
self.activation2 = nn.ReLU()
self.out_proj = nn.Linear(embedding_dim,
embedding_dim,
bias=False)
@ -55,12 +67,16 @@ class TextEmbedder(nn.Module):
"""
x = self.embed(text) # (seq_len, batch_size, embedding_dim)
x = torch.cat((torch.zeros_like(x[0:1], x)), dim=0) # pad
x = x.permute(1, 2, 0) # N,C,H, i.e. (batch_size, embedding_dim, seq_len)
x = torch.nn.functional.pad(x, (1, 0))
x = self.conv1(x)
x = x.permute(2, 0, 1) # (seq_len, batch_size, embedding_dim)
x = self.balancer(x) # make sure no channel has all zeros.
x = self.balancer1(x) # make sure no channel has all zeros.
x = self.activation1(x)
x = torch.nn.functional.pad(x, (1, 0))
x = self.conv2(x)
x = self.balancer2(x)
x = self.activation2(x)
x = x.permute(2, 0, 1) # (seq_len, batch_size, embedding_dim)
x = self.out_proj(x)
return x