mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 20:12:24 +00:00
Add some debugging code to train.py:
This commit is contained in:
parent
abadc71415
commit
c810e67342
1279
egs/librispeech/ASR/conformer_ctc/madam.py
Normal file
1279
egs/librispeech/ASR/conformer_ctc/madam.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -33,7 +33,7 @@ from lhotse.utils import fix_random_seed
|
|||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformer import Noam
|
from madam import Gloam
|
||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import load_checkpoint
|
from icefall.checkpoint import load_checkpoint
|
||||||
@ -150,10 +150,9 @@ def get_params() -> AttributeDict:
|
|||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
"exp_dir": Path("conformer_ctc/exp"),
|
"exp_dir": Path("conformer_ctc/exp_gloam_2e-4_0.85"),
|
||||||
"lang_dir": Path("data/lang_bpe"),
|
"lang_dir": Path("data/lang_bpe"),
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"weight_decay": 1e-6,
|
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
@ -174,8 +173,10 @@ def get_params() -> AttributeDict:
|
|||||||
"is_espnet_structure": True,
|
"is_espnet_structure": True,
|
||||||
"mmi_loss": False,
|
"mmi_loss": False,
|
||||||
"use_feat_batchnorm": True,
|
"use_feat_batchnorm": True,
|
||||||
"lr_factor": 5.0,
|
"max_lrate": 2.0e-04,
|
||||||
"warm_step": 80000,
|
"first_decay_epoch": 1,
|
||||||
|
"decay_per_epoch": 0.85,
|
||||||
|
"warm_step": 40000,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -296,6 +297,7 @@ def compute_loss(
|
|||||||
function enables autograd during computation; when it is False, it
|
function enables autograd during computation; when it is False, it
|
||||||
disables autograd.
|
disables autograd.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
device = graph_compiler.device
|
device = graph_compiler.device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is [N, T, C]
|
# at entry, feature is [N, T, C]
|
||||||
@ -303,6 +305,7 @@ def compute_loss(
|
|||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||||
# nnet_output is [N, T, C]
|
# nnet_output is [N, T, C]
|
||||||
@ -364,6 +367,12 @@ def compute_loss(
|
|||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
return loss, ctc_loss.detach(), att_loss.detach()
|
return loss, ctc_loss.detach(), att_loss.detach()
|
||||||
|
except RuntimeError as e:
|
||||||
|
print(f"Runtime error. feature.shape = {feature.shape}, supervisions = {supervisions}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def compute_validation_loss(
|
def compute_validation_loss(
|
||||||
@ -657,12 +666,13 @@ def run(rank, world_size, args):
|
|||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model = DDP(model, device_ids=[rank])
|
model = DDP(model, device_ids=[rank])
|
||||||
|
|
||||||
optimizer = Noam(
|
# Remember: with Gloam, you need to cal set_epoch() on every epoch.
|
||||||
|
optimizer = Gloam(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
model_size=params.attention_dim,
|
|
||||||
factor=params.lr_factor,
|
|
||||||
warm_step=params.warm_step,
|
warm_step=params.warm_step,
|
||||||
weight_decay=params.weight_decay,
|
max_lrate=params.max_lrate,
|
||||||
|
first_decay_epoch=params.first_decay_epoch,
|
||||||
|
decay_per_epoch=params.decay_per_epoch,
|
||||||
)
|
)
|
||||||
|
|
||||||
if checkpoints:
|
if checkpoints:
|
||||||
@ -673,6 +683,7 @@ def run(rank, world_size, args):
|
|||||||
valid_dl = librispeech.valid_dataloaders()
|
valid_dl = librispeech.valid_dataloaders()
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
|
optimizer.set_epoch(epoch) # specific to Gloam
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
|
|
||||||
cur_lr = optimizer._rate
|
cur_lr = optimizer._rate
|
||||||
|
Loading…
x
Reference in New Issue
Block a user