mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Add madam optimizer.
This commit is contained in:
parent
69a2bd5179
commit
d09784fb8b
6
.flake8
6
.flake8
@ -4,8 +4,10 @@ statistics=true
|
||||
max-line-length = 80
|
||||
per-file-ignores =
|
||||
# line too long
|
||||
egs/librispeech/ASR/conformer_ctc/conformer.py: E501,
|
||||
egs/librispeech/ASR/conformer_ctc*/conformer.py: E501,
|
||||
|
||||
exclude =
|
||||
.git,
|
||||
**/data/**
|
||||
**/data/**,
|
||||
egs/librispeech/ASR/conformer_ctc_embedding_scale/embedding.py,
|
||||
egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py
|
||||
|
1136
egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py
Normal file
1136
egs/librispeech/ASR/conformer_ctc_embedding_scale/madam.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -30,10 +30,10 @@ import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from madam import Foam
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
@ -111,13 +111,9 @@ def get_params() -> AttributeDict:
|
||||
- lang_dir: It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
|
||||
- lr: It specifies the initial learning rate
|
||||
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
in computing features.
|
||||
|
||||
- weight_decay: The weight_decay for the optimizer.
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
@ -150,10 +146,9 @@ def get_params() -> AttributeDict:
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_ctc/exp"),
|
||||
"exp_dir": Path("conformer_ctc_embedding_scale/exp"),
|
||||
"lang_dir": Path("data/lang_bpe"),
|
||||
"feature_dim": 80,
|
||||
"weight_decay": 1e-6,
|
||||
"subsampling_factor": 4,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
@ -175,7 +170,8 @@ def get_params() -> AttributeDict:
|
||||
"mmi_loss": False,
|
||||
"use_feat_batchnorm": True,
|
||||
"lr_factor": 5.0,
|
||||
"warm_step": 80000,
|
||||
"max_lrate": 5.0e-04,
|
||||
"warm_step": 25000,
|
||||
}
|
||||
)
|
||||
|
||||
@ -657,12 +653,10 @@ def run(rank, world_size, args):
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
|
||||
optimizer = Noam(
|
||||
optimizer = Foam(
|
||||
model.parameters(),
|
||||
model_size=params.attention_dim,
|
||||
factor=params.lr_factor,
|
||||
max_lrate=params.max_lrate,
|
||||
warm_step=params.warm_step,
|
||||
weight_decay=params.weight_decay,
|
||||
)
|
||||
|
||||
if checkpoints:
|
||||
|
Loading…
x
Reference in New Issue
Block a user