Add some debugging code to train.py:

This commit is contained in:
Daniel Povey 2021-09-09 14:03:04 +08:00
parent abadc71415
commit c810e67342
2 changed files with 1361 additions and 71 deletions

File diff suppressed because it is too large Load Diff

View File

@ -33,7 +33,7 @@ from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from madam import Gloam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint from icefall.checkpoint import load_checkpoint
@ -150,10 +150,9 @@ def get_params() -> AttributeDict:
""" """
params = AttributeDict( params = AttributeDict(
{ {
"exp_dir": Path("conformer_ctc/exp"), "exp_dir": Path("conformer_ctc/exp_gloam_2e-4_0.85"),
"lang_dir": Path("data/lang_bpe"), "lang_dir": Path("data/lang_bpe"),
"feature_dim": 80, "feature_dim": 80,
"weight_decay": 1e-6,
"subsampling_factor": 4, "subsampling_factor": 4,
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
@ -174,8 +173,10 @@ def get_params() -> AttributeDict:
"is_espnet_structure": True, "is_espnet_structure": True,
"mmi_loss": False, "mmi_loss": False,
"use_feat_batchnorm": True, "use_feat_batchnorm": True,
"lr_factor": 5.0, "max_lrate": 2.0e-04,
"warm_step": 80000, "first_decay_epoch": 1,
"decay_per_epoch": 0.85,
"warm_step": 40000,
} }
) )
@ -296,6 +297,7 @@ def compute_loss(
function enables autograd during computation; when it is False, it function enables autograd during computation; when it is False, it
disables autograd. disables autograd.
""" """
try:
device = graph_compiler.device device = graph_compiler.device
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is [N, T, C] # at entry, feature is [N, T, C]
@ -303,6 +305,7 @@ def compute_loss(
feature = feature.to(device) feature = feature.to(device)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(feature, supervisions) nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
# nnet_output is [N, T, C] # nnet_output is [N, T, C]
@ -364,6 +367,12 @@ def compute_loss(
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
return loss, ctc_loss.detach(), att_loss.detach() return loss, ctc_loss.detach(), att_loss.detach()
except RuntimeError as e:
print(f"Runtime error. feature.shape = {feature.shape}, supervisions = {supervisions}")
raise e
def compute_validation_loss( def compute_validation_loss(
@ -657,12 +666,13 @@ def run(rank, world_size, args):
if world_size > 1: if world_size > 1:
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
optimizer = Noam( # Remember: with Gloam, you need to cal set_epoch() on every epoch.
optimizer = Gloam(
model.parameters(), model.parameters(),
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step, warm_step=params.warm_step,
weight_decay=params.weight_decay, max_lrate=params.max_lrate,
first_decay_epoch=params.first_decay_epoch,
decay_per_epoch=params.decay_per_epoch,
) )
if checkpoints: if checkpoints:
@ -673,6 +683,7 @@ def run(rank, world_size, args):
valid_dl = librispeech.valid_dataloaders() valid_dl = librispeech.valid_dataloaders()
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
optimizer.set_epoch(epoch) # specific to Gloam
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate cur_lr = optimizer._rate