Fix conflicts

This commit is contained in:
pkufool 2022-01-27 19:29:23 +08:00
commit 5d4fd85715
6 changed files with 33 additions and 29 deletions

View File

@ -83,18 +83,18 @@ class Decoder(nn.Module):
Returns:
Return a tensor of shape (N, U, vocab_size).
"""
embeding_out = self.embedding(y)
embedding_out = self.embedding(y)
if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1)
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embeding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0)
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embeding_out.size(-1) == self.context_size
embeding_out = self.conv(embeding_out)
embeding_out = embeding_out.permute(0, 2, 1)
embeding_out = self.output_linear(F.relu(embeding_out))
return embeding_out
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = self.output_linear(F.relu(embedding_out))
return embedding_out

View File

@ -38,6 +38,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
@ -507,6 +508,7 @@ def train_one_epoch(
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:

View File

@ -82,17 +82,17 @@ class Decoder(nn.Module):
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embeding_out = self.embedding(y)
embedding_out = self.embedding(y)
if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1)
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embeding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0)
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embeding_out.size(-1) == self.context_size
embeding_out = self.conv(embeding_out)
embeding_out = embeding_out.permute(0, 2, 1)
return embeding_out
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
return embedding_out

View File

@ -48,6 +48,7 @@ from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
@ -133,7 +134,7 @@ def get_params() -> AttributeDict:
return params
def get_encoder_model(params: AttributeDict):
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict):
return encoder
def get_decoder_model(params: AttributeDict):
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict):
return decoder
def get_joiner_model(params: AttributeDict):
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict):
return joiner
def get_transducer_model(params: AttributeDict):
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)

View File

@ -49,6 +49,7 @@ from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torch.nn as nn
import torchaudio
from beam_search import beam_search, greedy_search
from conformer import Conformer
@ -148,7 +149,7 @@ def get_params() -> AttributeDict:
return params
def get_encoder_model(params: AttributeDict):
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
@ -162,7 +163,7 @@ def get_encoder_model(params: AttributeDict):
return encoder
def get_decoder_model(params: AttributeDict):
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
@ -172,7 +173,7 @@ def get_decoder_model(params: AttributeDict):
return decoder
def get_joiner_model(params: AttributeDict):
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
@ -180,7 +181,7 @@ def get_joiner_model(params: AttributeDict):
return joiner
def get_transducer_model(params: AttributeDict):
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)

View File

@ -213,7 +213,7 @@ def get_params() -> AttributeDict:
return params
def get_encoder_model(params: AttributeDict):
def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
@ -228,7 +228,7 @@ def get_encoder_model(params: AttributeDict):
return encoder
def get_decoder_model(params: AttributeDict):
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
@ -238,7 +238,7 @@ def get_decoder_model(params: AttributeDict):
return decoder
def get_joiner_model(params: AttributeDict):
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
@ -246,7 +246,7 @@ def get_joiner_model(params: AttributeDict):
return joiner
def get_transducer_model(params: AttributeDict):
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)