Remove batchnorm, weight decay, and SOS.

This commit is contained in:
Fangjun Kuang 2021-12-23 14:19:49 +08:00
parent 5b6699a835
commit 35d63de820
12 changed files with 16 additions and 44 deletions

View File

@ -396,7 +396,7 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()

View File

@ -194,7 +194,7 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()

View File

@ -208,7 +208,7 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()

View File

@ -564,7 +564,7 @@ def run(rank, world_size, args):
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()

View File

@ -56,7 +56,6 @@ class Conformer(Transformer):
cnn_module_kernel: int = 31,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None:
super(Conformer, self).__init__(
num_features=num_features,
@ -69,7 +68,6 @@ class Conformer(Transformer):
dropout=dropout,
normalize_before=normalize_before,
vgg_frontend=vgg_frontend,
use_feat_batchnorm=use_feat_batchnorm,
)
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
@ -107,11 +105,6 @@ class Conformer(Transformer):
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x)
x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module):
groups=channels,
bias=bias,
)
self.norm = nn.BatchNorm1d(channels)
self.norm = nn.LayerNorm(channels)
self.pointwise_conv2 = nn.Conv1d(
channels,
channels,
@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module):
# 1D Depthwise Conv
x = self.depthwise_conv(x)
x = self.activation(self.norm(x))
# x is (batch, channels, time)
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
x = self.activation(x)
x = self.pointwise_conv2(x) # (batch, channel, time)

View File

@ -129,7 +129,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(),
@ -149,7 +148,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder

View File

@ -20,13 +20,14 @@ import torch.nn.functional as F
class Decoder(nn.Module):
"""This class implements the stateless decoder from the following paper:
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network.
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""

View File

@ -119,7 +119,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(),
@ -138,7 +137,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder

View File

@ -16,7 +16,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module):
@ -48,7 +47,7 @@ class Joiner(nn.Module):
# Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out
logit = F.relu(logit)
logit = torch.tanh(logit)
output = self.output_linear(logit)

View File

@ -126,7 +126,6 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
"env_info": get_env_info(),
@ -145,7 +144,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder

View File

@ -171,15 +171,10 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
@ -201,11 +196,9 @@ def get_params() -> AttributeDict:
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"use_feat_batchnorm": True,
# parameters for decoder
"context_size": 2, # tri-gram
# parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(),
}
@ -225,7 +218,6 @@ def get_encoder_model(params: AttributeDict):
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
)
return encoder
@ -568,7 +560,7 @@ def run(rank, world_size, args):
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
@ -593,7 +585,6 @@ def run(rank, world_size, args):
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
if checkpoints and "optimizer" in checkpoints:

View File

@ -39,7 +39,6 @@ class Transformer(EncoderInterface):
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None:
"""
Args:
@ -65,13 +64,8 @@ class Transformer(EncoderInterface):
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
use_feat_batchnorm:
True to use batchnorm for the input layer.
"""
super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features
self.output_dim = output_dim
@ -131,11 +125,6 @@ class Transformer(EncoderInterface):
- logit_lens, a tensor of shape (batch_size,) containing the number
of frames in `logits` before padding.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
x = self.encoder_embed(x)
x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)