diff --git a/egs/libriheavy/LM/zipformer1/model.py b/egs/libriheavy/LM/zipformer1/model.py index 60ee6d6bb..7ee6987aa 100644 --- a/egs/libriheavy/LM/zipformer1/model.py +++ b/egs/libriheavy/LM/zipformer1/model.py @@ -27,6 +27,7 @@ class TextEmbedder(nn.Module): def __init__(self, vocab_size: int, embedding_dim: int): + super().__init__() self.embed = nn.Embedding( num_embeddings=vocab_size, embedding_dim=embedding_dim) @@ -34,12 +35,23 @@ class TextEmbedder(nn.Module): embedding_dim, groups=embedding_dim, kernel_size=2) - self.balancer = Balancer(embedding_dim, - channel_dim=-1, - min_positive=0.1, - min_abs=1.0, - max_abs=2.0) + self.balancer1 = Balancer(embedding_dim, + channel_dim=1, + min_positive=0.1, + min_abs=1.0, + max_abs=2.0) self.activation1 = nn.ReLU() + self.conv2 = nn.Conv1d(embedding_dim, + embedding_dim, + kernel_size=2) + + self.balancer2 = Balancer(embedding_dim, + channel_dim=1, + min_positive=0.1, + min_abs=1.0, + max_abs=2.0) + self.activation2 = nn.ReLU() + self.out_proj = nn.Linear(embedding_dim, embedding_dim, bias=False) @@ -55,12 +67,16 @@ class TextEmbedder(nn.Module): """ x = self.embed(text) # (seq_len, batch_size, embedding_dim) - x = torch.cat((torch.zeros_like(x[0:1], x)), dim=0) # pad x = x.permute(1, 2, 0) # N,C,H, i.e. (batch_size, embedding_dim, seq_len) + x = torch.nn.functional.pad(x, (1, 0)) x = self.conv1(x) - x = x.permute(2, 0, 1) # (seq_len, batch_size, embedding_dim) - x = self.balancer(x) # make sure no channel has all zeros. + x = self.balancer1(x) # make sure no channel has all zeros. x = self.activation1(x) + x = torch.nn.functional.pad(x, (1, 0)) + x = self.conv2(x) + x = self.balancer2(x) + x = self.activation2(x) + x = x.permute(2, 0, 1) # (seq_len, batch_size, embedding_dim) x = self.out_proj(x) return x