diff --git a/egs/libriheavy/LM/zipformer1/decoder.py b/egs/libriheavy/LM/zipformer1/decoder.py index 12c42070b..d705d0413 100644 --- a/egs/libriheavy/LM/zipformer1/decoder.py +++ b/egs/libriheavy/LM/zipformer1/decoder.py @@ -21,21 +21,38 @@ import random import torch from torch import nn, Tensor - class Decoder(nn.Module): """ """ def __init__(self, embed_dim: int, + hidden_dim: int, vocab_size: int): """ A 'decoder' that computes the probability of symbols in a language modeling task. """ super().__init__() - self.out_proj = nn.Linear(embed_dim, - vocab_size) + self.to_hidden = nn.Linear( + embed_dim, hidden_dim, bias=False, + ) + # no padding, will manually pad on the left so it is causal. + self.depthwise_conv = nn.Conv1d( + in_channels=hidden_dim, + out_channels=hidden_dim, + groups=hidden_dim, + kernel_size=3 + ) + self.activation = nn.Tanh() + self.hidden_to_vocab = nn.Linear( + hidden_dim, vocab_size, + ) + + self.bypass = nn.Linear( + embed_dim, vocab_size, bias=False, + ) + def forward(self, labels: Tensor, @@ -54,12 +71,17 @@ class Decoder(nn.Module): assert batch_size == _batch_size - x = self.out_proj(encoder_embed) + bypass = self.bypass(encoder_embed) - x = x.transpose(0, 1) - - # x: (batch_size, seq_len, vocab_size) + x = self.to_hidden(encoder_embed) # (seq_len, batch_size, hidden_dim) + x = x.permute(1, 2, 0) # (N,C,H) = (batch_size, hidden_dim, seq_len) + x = torch.nn.functional.pad(x, (2, 0)) # pad left with 2 frames. + x = self.depthwise_conv(x) + x = x.permute(0, 2, 1) # (batch_size, seq_len, hidden_dim) + x = self.activation(x) + x = self.hidden_to_vocab(x) # (batch_size, seq_len, vocab_size) + x = x + bypass.transpose(0, 1) x = x.log_softmax(dim=-1) logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len) diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index e9daa1d8f..a1f75ff2d 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -175,6 +175,13 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Positional-encoding dimension in encoder stacks: a single int or comma-separated list." ) + parser.add_argument( + "--decoder-hidden-dim", + type=int, + default=1536, + help="Hidden dimension in decoder", + ) + @@ -439,6 +446,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( embed_dim=max(_to_int_tuple(params.encoder_dim)), + hidden_dim=params.decoder_hidden_dim, vocab_size=256, # bytes ) return decoder