Add madam optimizer.

This commit is contained in:
Fangjun Kuang 2021-08-26 15:11:34 +08:00
parent 69a2bd5179
commit d09784fb8b
3 changed files with 1146 additions and 14 deletions

View File

@ -4,8 +4,10 @@ statistics=true
max-line-length = 80
per-file-ignores =
# line too long
egs/librispeech/ASR/conformer_ctc/conformer.py: E501,
egs/librispeech/ASR/conformer_ctc*/conformer.py: E501,
exclude =
.git,
**/data/**
**/data/**,
egs/librispeech/ASR/conformer_ctc_embedding_scale/embedding.py,
egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py

File diff suppressed because it is too large Load Diff

View File

@ -30,10 +30,10 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from madam import Foam
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
@ -111,13 +111,9 @@ def get_params() -> AttributeDict:
- lang_dir: It contains language related input files such as
"lexicon.txt"
- lr: It specifies the initial learning rate
- feature_dim: The model input dim. It has to match the one used
in computing features.
- weight_decay: The weight_decay for the optimizer.
- subsampling_factor: The subsampling factor for the model.
- best_train_loss: Best training loss so far. It is used to select
@ -150,10 +146,9 @@ def get_params() -> AttributeDict:
"""
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"exp_dir": Path("conformer_ctc_embedding_scale/exp"),
"lang_dir": Path("data/lang_bpe"),
"feature_dim": 80,
"weight_decay": 1e-6,
"subsampling_factor": 4,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
@ -175,7 +170,8 @@ def get_params() -> AttributeDict:
"mmi_loss": False,
"use_feat_batchnorm": True,
"lr_factor": 5.0,
"warm_step": 80000,
"max_lrate": 5.0e-04,
"warm_step": 25000,
}
)
@ -657,12 +653,10 @@ def run(rank, world_size, args):
if world_size > 1:
model = DDP(model, device_ids=[rank])
optimizer = Noam(
optimizer = Foam(
model.parameters(),
model_size=params.attention_dim,
factor=params.lr_factor,
max_lrate=params.max_lrate,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
if checkpoints: