mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Bug fixes
This commit is contained in:
parent
a405106d2f
commit
610b2270aa
@ -27,6 +27,7 @@ class TextEmbedder(nn.Module):
|
||||
def __init__(self,
|
||||
vocab_size: int,
|
||||
embedding_dim: int):
|
||||
super().__init__()
|
||||
self.embed = nn.Embedding(
|
||||
num_embeddings=vocab_size,
|
||||
embedding_dim=embedding_dim)
|
||||
@ -34,12 +35,23 @@ class TextEmbedder(nn.Module):
|
||||
embedding_dim,
|
||||
groups=embedding_dim,
|
||||
kernel_size=2)
|
||||
self.balancer = Balancer(embedding_dim,
|
||||
channel_dim=-1,
|
||||
self.balancer1 = Balancer(embedding_dim,
|
||||
channel_dim=1,
|
||||
min_positive=0.1,
|
||||
min_abs=1.0,
|
||||
max_abs=2.0)
|
||||
self.activation1 = nn.ReLU()
|
||||
self.conv2 = nn.Conv1d(embedding_dim,
|
||||
embedding_dim,
|
||||
kernel_size=2)
|
||||
|
||||
self.balancer2 = Balancer(embedding_dim,
|
||||
channel_dim=1,
|
||||
min_positive=0.1,
|
||||
min_abs=1.0,
|
||||
max_abs=2.0)
|
||||
self.activation2 = nn.ReLU()
|
||||
|
||||
self.out_proj = nn.Linear(embedding_dim,
|
||||
embedding_dim,
|
||||
bias=False)
|
||||
@ -55,12 +67,16 @@ class TextEmbedder(nn.Module):
|
||||
"""
|
||||
x = self.embed(text) # (seq_len, batch_size, embedding_dim)
|
||||
|
||||
x = torch.cat((torch.zeros_like(x[0:1], x)), dim=0) # pad
|
||||
x = x.permute(1, 2, 0) # N,C,H, i.e. (batch_size, embedding_dim, seq_len)
|
||||
x = torch.nn.functional.pad(x, (1, 0))
|
||||
x = self.conv1(x)
|
||||
x = x.permute(2, 0, 1) # (seq_len, batch_size, embedding_dim)
|
||||
x = self.balancer(x) # make sure no channel has all zeros.
|
||||
x = self.balancer1(x) # make sure no channel has all zeros.
|
||||
x = self.activation1(x)
|
||||
x = torch.nn.functional.pad(x, (1, 0))
|
||||
x = self.conv2(x)
|
||||
x = self.balancer2(x)
|
||||
x = self.activation2(x)
|
||||
x = x.permute(2, 0, 1) # (seq_len, batch_size, embedding_dim)
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user