mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add depthwise conv to decoder
This commit is contained in:
parent
610b2270aa
commit
30ace76fbc
@ -21,21 +21,38 @@ import random
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
vocab_size: int):
|
vocab_size: int):
|
||||||
"""
|
"""
|
||||||
A 'decoder' that computes the probability of symbols in a language modeling task.
|
A 'decoder' that computes the probability of symbols in a language modeling task.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.out_proj = nn.Linear(embed_dim,
|
|
||||||
vocab_size)
|
|
||||||
|
|
||||||
|
|
||||||
|
self.to_hidden = nn.Linear(
|
||||||
|
embed_dim, hidden_dim, bias=False,
|
||||||
|
)
|
||||||
|
# no padding, will manually pad on the left so it is causal.
|
||||||
|
self.depthwise_conv = nn.Conv1d(
|
||||||
|
in_channels=hidden_dim,
|
||||||
|
out_channels=hidden_dim,
|
||||||
|
groups=hidden_dim,
|
||||||
|
kernel_size=3
|
||||||
|
)
|
||||||
|
self.activation = nn.Tanh()
|
||||||
|
self.hidden_to_vocab = nn.Linear(
|
||||||
|
hidden_dim, vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bypass = nn.Linear(
|
||||||
|
embed_dim, vocab_size, bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
labels: Tensor,
|
labels: Tensor,
|
||||||
@ -54,12 +71,17 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
assert batch_size == _batch_size
|
assert batch_size == _batch_size
|
||||||
|
|
||||||
x = self.out_proj(encoder_embed)
|
bypass = self.bypass(encoder_embed)
|
||||||
|
|
||||||
x = x.transpose(0, 1)
|
x = self.to_hidden(encoder_embed) # (seq_len, batch_size, hidden_dim)
|
||||||
|
x = x.permute(1, 2, 0) # (N,C,H) = (batch_size, hidden_dim, seq_len)
|
||||||
# x: (batch_size, seq_len, vocab_size)
|
x = torch.nn.functional.pad(x, (2, 0)) # pad left with 2 frames.
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
x = x.permute(0, 2, 1) # (batch_size, seq_len, hidden_dim)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.hidden_to_vocab(x) # (batch_size, seq_len, vocab_size)
|
||||||
|
|
||||||
|
x = x + bypass.transpose(0, 1)
|
||||||
x = x.log_softmax(dim=-1)
|
x = x.log_softmax(dim=-1)
|
||||||
|
|
||||||
logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len)
|
logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len)
|
||||||
|
|||||||
@ -175,6 +175,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Positional-encoding dimension in encoder stacks: a single int or comma-separated list."
|
help="Positional-encoding dimension in encoder stacks: a single int or comma-separated list."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-hidden-dim",
|
||||||
|
type=int,
|
||||||
|
default=1536,
|
||||||
|
help="Hidden dimension in decoder",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -439,6 +446,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
decoder = Decoder(
|
decoder = Decoder(
|
||||||
embed_dim=max(_to_int_tuple(params.encoder_dim)),
|
embed_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||||
|
hidden_dim=params.decoder_hidden_dim,
|
||||||
vocab_size=256, # bytes
|
vocab_size=256, # bytes
|
||||||
)
|
)
|
||||||
return decoder
|
return decoder
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user