diff --git a/egs/libriheavy/LM/zipformer1/decoder.py b/egs/libriheavy/LM/zipformer1/decoder.py index d705d0413..12c42070b 100644 --- a/egs/libriheavy/LM/zipformer1/decoder.py +++ b/egs/libriheavy/LM/zipformer1/decoder.py @@ -21,38 +21,21 @@ import random import torch from torch import nn, Tensor + class Decoder(nn.Module): """ """ def __init__(self, embed_dim: int, - hidden_dim: int, vocab_size: int): """ A 'decoder' that computes the probability of symbols in a language modeling task. """ super().__init__() + self.out_proj = nn.Linear(embed_dim, + vocab_size) - self.to_hidden = nn.Linear( - embed_dim, hidden_dim, bias=False, - ) - # no padding, will manually pad on the left so it is causal. - self.depthwise_conv = nn.Conv1d( - in_channels=hidden_dim, - out_channels=hidden_dim, - groups=hidden_dim, - kernel_size=3 - ) - self.activation = nn.Tanh() - self.hidden_to_vocab = nn.Linear( - hidden_dim, vocab_size, - ) - - self.bypass = nn.Linear( - embed_dim, vocab_size, bias=False, - ) - def forward(self, labels: Tensor, @@ -71,17 +54,12 @@ class Decoder(nn.Module): assert batch_size == _batch_size - bypass = self.bypass(encoder_embed) + x = self.out_proj(encoder_embed) - x = self.to_hidden(encoder_embed) # (seq_len, batch_size, hidden_dim) - x = x.permute(1, 2, 0) # (N,C,H) = (batch_size, hidden_dim, seq_len) - x = torch.nn.functional.pad(x, (2, 0)) # pad left with 2 frames. - x = self.depthwise_conv(x) - x = x.permute(0, 2, 1) # (batch_size, seq_len, hidden_dim) - x = self.activation(x) - x = self.hidden_to_vocab(x) # (batch_size, seq_len, vocab_size) + x = x.transpose(0, 1) + + # x: (batch_size, seq_len, vocab_size) - x = x + bypass.transpose(0, 1) x = x.log_softmax(dim=-1) logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len) diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index 173058096..b825d6b85 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -175,13 +175,6 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Positional-encoding dimension in encoder stacks: a single int or comma-separated list." ) - parser.add_argument( - "--decoder-hidden-dim", - type=int, - default=1536, - help="Hidden dimension in decoder", - ) - @@ -446,7 +439,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( embed_dim=max(_to_int_tuple(params.encoder_dim)), - hidden_dim=params.decoder_hidden_dim, vocab_size=256, # bytes ) return decoder