diff --git a/egs/librispeech/ASR/conformer_lm/train.py b/egs/librispeech/ASR/conformer_lm/train.py index 72bc654ea..66602ea1d 100755 --- a/egs/librispeech/ASR/conformer_lm/train.py +++ b/egs/librispeech/ASR/conformer_lm/train.py @@ -140,7 +140,7 @@ def get_params() -> AttributeDict: "eos_sym": 1, "start_epoch": 0, "num_epochs": 20, - "num_valid_batches": 100, + "num_valid_batches": 200, "symbols_per_batch": 5000, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), @@ -288,8 +288,9 @@ def compute_loss( with torch.set_grad_enabled(is_training): memory, pos_emb = model(masked_src_symbols, src_key_padding_mask) - tgt_nll = model.decoder_nll(memory, pos_emb, src_symbols, - tgt_symbols, src_key_padding_mask) + decoder_nll_func = model.module.decoder_nll if isinstance(model, DDP) else model.decoder_nll + tgt_nll = decoder_nll_func(memory, pos_emb, src_symbols, + tgt_symbols, src_key_padding_mask) loss = (tgt_nll * tgt_weights).sum() assert loss.requires_grad == is_training @@ -312,6 +313,8 @@ def compute_validation_loss( tot_loss = 0.0 tot_frames = 0.0 for batch_idx, batch in enumerate(valid_dl): + if batch_idx == params.num_valid_batches: + break batch = tuple(x.to(device) for x in batch) # `batch` is actually a tuple.. we'll unpack it later. @@ -319,8 +322,6 @@ def compute_validation_loss( num_frames = batch[4].sum() assert loss.requires_grad is False - assert ctc_loss.requires_grad is False - assert att_loss.requires_grad is False loss_cpu = loss.detach().cpu().item() num_frames_cpu = num_frames.cpu().item()