mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 10:16:14 +00:00
Remove batchnorm, weight decay, and SOS.
This commit is contained in:
parent
5b6699a835
commit
35d63de820
@ -396,7 +396,7 @@ def main():
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
|
@ -194,7 +194,7 @@ def main():
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
|
@ -208,7 +208,7 @@ def main():
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
|
@ -564,7 +564,7 @@ def run(rank, world_size, args):
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
|
@ -56,7 +56,6 @@ class Conformer(Transformer):
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
use_feat_batchnorm: bool = False,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__(
|
||||
num_features=num_features,
|
||||
@ -69,7 +68,6 @@ class Conformer(Transformer):
|
||||
dropout=dropout,
|
||||
normalize_before=normalize_before,
|
||||
vgg_frontend=vgg_frontend,
|
||||
use_feat_batchnorm=use_feat_batchnorm,
|
||||
)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
@ -107,11 +105,6 @@ class Conformer(Transformer):
|
||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `logits` before padding.
|
||||
"""
|
||||
if self.use_feat_batchnorm:
|
||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||
x = self.feat_batchnorm(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
@ -873,7 +866,7 @@ class ConvolutionModule(nn.Module):
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.norm = nn.LayerNorm(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
@ -903,7 +896,12 @@ class ConvolutionModule(nn.Module):
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.activation(self.norm(x))
|
||||
# x is (batch, channels, time)
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
|
@ -129,7 +129,6 @@ def get_params() -> AttributeDict:
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"use_feat_batchnorm": True,
|
||||
# parameters for decoder
|
||||
"context_size": 2, # tri-gram
|
||||
"env_info": get_env_info(),
|
||||
@ -149,7 +148,6 @@ def get_encoder_model(params: AttributeDict):
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
@ -20,13 +20,14 @@ import torch.nn.functional as F
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""This class implements the stateless decoder from the following paper:
|
||||
"""This class modifies the stateless decoder from the following paper:
|
||||
|
||||
RNN-transducer with stateless prediction network
|
||||
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
|
||||
|
||||
It removes the recurrent connection from the decoder, i.e., the prediction
|
||||
network.
|
||||
network. Different from the above paper, it adds an extra Conv1d
|
||||
right after the embedding layer.
|
||||
|
||||
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
|
||||
"""
|
||||
|
@ -119,7 +119,6 @@ def get_params() -> AttributeDict:
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"use_feat_batchnorm": True,
|
||||
# parameters for decoder
|
||||
"context_size": 2, # tri-gram
|
||||
"env_info": get_env_info(),
|
||||
@ -138,7 +137,6 @@ def get_encoder_model(params: AttributeDict):
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
@ -16,7 +16,6 @@
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Joiner(nn.Module):
|
||||
@ -48,7 +47,7 @@ class Joiner(nn.Module):
|
||||
# Now decoder_out is (N, 1, U, C)
|
||||
|
||||
logit = encoder_out + decoder_out
|
||||
logit = F.relu(logit)
|
||||
logit = torch.tanh(logit)
|
||||
|
||||
output = self.output_linear(logit)
|
||||
|
||||
|
@ -126,7 +126,6 @@ def get_params() -> AttributeDict:
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"use_feat_batchnorm": True,
|
||||
# parameters for decoder
|
||||
"context_size": 2, # tri-gram
|
||||
"env_info": get_env_info(),
|
||||
@ -145,7 +144,6 @@ def get_encoder_model(params: AttributeDict):
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
@ -171,15 +171,10 @@ def get_params() -> AttributeDict:
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- use_feat_batchnorm: Whether to do batch normalization for the
|
||||
input features.
|
||||
|
||||
- attention_dim: Hidden dim for multi-head attention model.
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
- weight_decay: The weight_decay for the optimizer.
|
||||
|
||||
- warm_step: The warm_step for Noam optimizer.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
@ -201,11 +196,9 @@ def get_params() -> AttributeDict:
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"use_feat_batchnorm": True,
|
||||
# parameters for decoder
|
||||
"context_size": 2, # tri-gram
|
||||
# parameters for Noam
|
||||
"weight_decay": 1e-6,
|
||||
"warm_step": 80000, # For the 100h subset, use 8k
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
@ -225,7 +218,6 @@ def get_encoder_model(params: AttributeDict):
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -568,7 +560,7 @@ def run(rank, world_size, args):
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <sos/eos> are defined in local/train_bpe_model.py
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
@ -593,7 +585,6 @@ def run(rank, world_size, args):
|
||||
model_size=params.attention_dim,
|
||||
factor=params.lr_factor,
|
||||
warm_step=params.warm_step,
|
||||
weight_decay=params.weight_decay,
|
||||
)
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
|
@ -39,7 +39,6 @@ class Transformer(EncoderInterface):
|
||||
dropout: float = 0.1,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
use_feat_batchnorm: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
@ -65,13 +64,8 @@ class Transformer(EncoderInterface):
|
||||
If True, use pre-layer norm; False to use post-layer norm.
|
||||
vgg_frontend:
|
||||
True to use vgg style frontend for subsampling.
|
||||
use_feat_batchnorm:
|
||||
True to use batchnorm for the input layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.use_feat_batchnorm = use_feat_batchnorm
|
||||
if use_feat_batchnorm:
|
||||
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
||||
|
||||
self.num_features = num_features
|
||||
self.output_dim = output_dim
|
||||
@ -131,11 +125,6 @@ class Transformer(EncoderInterface):
|
||||
- logit_lens, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `logits` before padding.
|
||||
"""
|
||||
if self.use_feat_batchnorm:
|
||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||
x = self.feat_batchnorm(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||
|
||||
x = self.encoder_embed(x)
|
||||
x = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
Loading…
x
Reference in New Issue
Block a user