Add depthwise conv to decoder

This commit is contained in:
Daniel Povey 2023-05-17 11:26:41 +08:00
parent 610b2270aa
commit 30ace76fbc
2 changed files with 37 additions and 7 deletions

View File

@ -21,21 +21,38 @@ import random
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
class Decoder(nn.Module): class Decoder(nn.Module):
""" """
""" """
def __init__(self, def __init__(self,
embed_dim: int, embed_dim: int,
hidden_dim: int,
vocab_size: int): vocab_size: int):
""" """
A 'decoder' that computes the probability of symbols in a language modeling task. A 'decoder' that computes the probability of symbols in a language modeling task.
""" """
super().__init__() super().__init__()
self.out_proj = nn.Linear(embed_dim,
vocab_size)
self.to_hidden = nn.Linear(
embed_dim, hidden_dim, bias=False,
)
# no padding, will manually pad on the left so it is causal.
self.depthwise_conv = nn.Conv1d(
in_channels=hidden_dim,
out_channels=hidden_dim,
groups=hidden_dim,
kernel_size=3
)
self.activation = nn.Tanh()
self.hidden_to_vocab = nn.Linear(
hidden_dim, vocab_size,
)
self.bypass = nn.Linear(
embed_dim, vocab_size, bias=False,
)
def forward(self, def forward(self,
labels: Tensor, labels: Tensor,
@ -54,12 +71,17 @@ class Decoder(nn.Module):
assert batch_size == _batch_size assert batch_size == _batch_size
x = self.out_proj(encoder_embed) bypass = self.bypass(encoder_embed)
x = x.transpose(0, 1) x = self.to_hidden(encoder_embed) # (seq_len, batch_size, hidden_dim)
x = x.permute(1, 2, 0) # (N,C,H) = (batch_size, hidden_dim, seq_len)
# x: (batch_size, seq_len, vocab_size) x = torch.nn.functional.pad(x, (2, 0)) # pad left with 2 frames.
x = self.depthwise_conv(x)
x = x.permute(0, 2, 1) # (batch_size, seq_len, hidden_dim)
x = self.activation(x)
x = self.hidden_to_vocab(x) # (batch_size, seq_len, vocab_size)
x = x + bypass.transpose(0, 1)
x = x.log_softmax(dim=-1) x = x.log_softmax(dim=-1)
logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len) logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len)

View File

@ -175,6 +175,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Positional-encoding dimension in encoder stacks: a single int or comma-separated list." help="Positional-encoding dimension in encoder stacks: a single int or comma-separated list."
) )
parser.add_argument(
"--decoder-hidden-dim",
type=int,
default=1536,
help="Hidden dimension in decoder",
)
@ -439,6 +446,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
def get_decoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
embed_dim=max(_to_int_tuple(params.encoder_dim)), embed_dim=max(_to_int_tuple(params.encoder_dim)),
hidden_dim=params.decoder_hidden_dim,
vocab_size=256, # bytes vocab_size=256, # bytes
) )
return decoder return decoder