mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Reverse zlm9..zlm12
This commit is contained in:
parent
5d7517e382
commit
eb64130787
@ -21,38 +21,21 @@ import random
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""
|
||||
"""
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
hidden_dim: int,
|
||||
vocab_size: int):
|
||||
"""
|
||||
A 'decoder' that computes the probability of symbols in a language modeling task.
|
||||
"""
|
||||
super().__init__()
|
||||
self.out_proj = nn.Linear(embed_dim,
|
||||
vocab_size)
|
||||
|
||||
|
||||
self.to_hidden = nn.Linear(
|
||||
embed_dim, hidden_dim, bias=False,
|
||||
)
|
||||
# no padding, will manually pad on the left so it is causal.
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
in_channels=hidden_dim,
|
||||
out_channels=hidden_dim,
|
||||
groups=hidden_dim,
|
||||
kernel_size=3
|
||||
)
|
||||
self.activation = nn.Tanh()
|
||||
self.hidden_to_vocab = nn.Linear(
|
||||
hidden_dim, vocab_size,
|
||||
)
|
||||
|
||||
self.bypass = nn.Linear(
|
||||
embed_dim, vocab_size, bias=False,
|
||||
)
|
||||
|
||||
|
||||
def forward(self,
|
||||
labels: Tensor,
|
||||
@ -71,17 +54,12 @@ class Decoder(nn.Module):
|
||||
|
||||
assert batch_size == _batch_size
|
||||
|
||||
bypass = self.bypass(encoder_embed)
|
||||
x = self.out_proj(encoder_embed)
|
||||
|
||||
x = self.to_hidden(encoder_embed) # (seq_len, batch_size, hidden_dim)
|
||||
x = x.permute(1, 2, 0) # (N,C,H) = (batch_size, hidden_dim, seq_len)
|
||||
x = torch.nn.functional.pad(x, (2, 0)) # pad left with 2 frames.
|
||||
x = self.depthwise_conv(x)
|
||||
x = x.permute(0, 2, 1) # (batch_size, seq_len, hidden_dim)
|
||||
x = self.activation(x)
|
||||
x = self.hidden_to_vocab(x) # (batch_size, seq_len, vocab_size)
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
# x: (batch_size, seq_len, vocab_size)
|
||||
|
||||
x = x + bypass.transpose(0, 1)
|
||||
x = x.log_softmax(dim=-1)
|
||||
|
||||
logprobs = torch.gather(x, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) # (batch_size, seq_len)
|
||||
|
||||
@ -175,13 +175,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Positional-encoding dimension in encoder stacks: a single int or comma-separated list."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder-hidden-dim",
|
||||
type=int,
|
||||
default=1536,
|
||||
help="Hidden dimension in decoder",
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@ -446,7 +439,6 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
embed_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||
hidden_dim=params.decoder_hidden_dim,
|
||||
vocab_size=256, # bytes
|
||||
)
|
||||
return decoder
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user