Fix conflicts

This commit is contained in:
pkufool 2022-01-27 19:29:23 +08:00
commit 5d4fd85715
6 changed files with 33 additions and 29 deletions

View File

@ -83,18 +83,18 @@ class Decoder(nn.Module):
Returns: Returns:
Return a tensor of shape (N, U, vocab_size). Return a tensor of shape (N, U, vocab_size).
""" """
embeding_out = self.embedding(y) embedding_out = self.embedding(y)
if self.context_size > 1: if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:
embeding_out = F.pad( embedding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0) embedding_out, pad=(self.context_size - 1, 0)
) )
else: else:
# During inference time, there is no need to do extra padding # During inference time, there is no need to do extra padding
# as we only need one output # as we only need one output
assert embeding_out.size(-1) == self.context_size assert embedding_out.size(-1) == self.context_size
embeding_out = self.conv(embeding_out) embedding_out = self.conv(embedding_out)
embeding_out = embeding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
embeding_out = self.output_linear(F.relu(embeding_out)) embedding_out = self.output_linear(F.relu(embedding_out))
return embeding_out return embedding_out

View File

@ -38,6 +38,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam
@ -507,6 +508,7 @@ def train_one_epoch(
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:

View File

@ -82,17 +82,17 @@ class Decoder(nn.Module):
Returns: Returns:
Return a tensor of shape (N, U, embedding_dim). Return a tensor of shape (N, U, embedding_dim).
""" """
embeding_out = self.embedding(y) embedding_out = self.embedding(y)
if self.context_size > 1: if self.context_size > 1:
embeding_out = embeding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True: if need_pad is True:
embeding_out = F.pad( embedding_out = F.pad(
embeding_out, pad=(self.context_size - 1, 0) embedding_out, pad=(self.context_size - 1, 0)
) )
else: else:
# During inference time, there is no need to do extra padding # During inference time, there is no need to do extra padding
# as we only need one output # as we only need one output
assert embeding_out.size(-1) == self.context_size assert embedding_out.size(-1) == self.context_size
embeding_out = self.conv(embeding_out) embedding_out = self.conv(embedding_out)
embeding_out = embeding_out.permute(0, 2, 1) embedding_out = embedding_out.permute(0, 2, 1)
return embeding_out return embedding_out

View File

@ -48,6 +48,7 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
@ -133,7 +134,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)

View File

@ -49,6 +49,7 @@ from typing import List
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search
from conformer import Conformer from conformer import Conformer
@ -148,7 +149,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
output_dim=params.encoder_out_dim, output_dim=params.encoder_out_dim,
@ -162,7 +163,7 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -172,7 +173,7 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -180,7 +181,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)

View File

@ -213,7 +213,7 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
# TODO: We can add an option to switch between Conformer and Transformer # TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer( encoder = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
@ -228,7 +228,7 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, embedding_dim=params.encoder_out_dim,
@ -238,7 +238,7 @@ def get_decoder_model(params: AttributeDict):
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, input_dim=params.encoder_out_dim,
output_dim=params.vocab_size, output_dim=params.vocab_size,
@ -246,7 +246,7 @@ def get_joiner_model(params: AttributeDict):
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)