Various bug fixes

This commit is contained in:
Daniel Povey 2021-08-23 23:45:03 +08:00
parent 7711fba867
commit 9576d6574f

View File

@ -140,7 +140,7 @@ def get_params() -> AttributeDict:
"eos_sym": 1, "eos_sym": 1,
"start_epoch": 0, "start_epoch": 0,
"num_epochs": 20, "num_epochs": 20,
"num_valid_batches": 100, "num_valid_batches": 200,
"symbols_per_batch": 5000, "symbols_per_batch": 5000,
"best_train_loss": float("inf"), "best_train_loss": float("inf"),
"best_valid_loss": float("inf"), "best_valid_loss": float("inf"),
@ -288,8 +288,9 @@ def compute_loss(
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask) memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
tgt_nll = model.decoder_nll(memory, pos_emb, src_symbols, decoder_nll_func = model.module.decoder_nll if isinstance(model, DDP) else model.decoder_nll
tgt_symbols, src_key_padding_mask) tgt_nll = decoder_nll_func(memory, pos_emb, src_symbols,
tgt_symbols, src_key_padding_mask)
loss = (tgt_nll * tgt_weights).sum() loss = (tgt_nll * tgt_weights).sum()
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
@ -312,6 +313,8 @@ def compute_validation_loss(
tot_loss = 0.0 tot_loss = 0.0
tot_frames = 0.0 tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl): for batch_idx, batch in enumerate(valid_dl):
if batch_idx == params.num_valid_batches:
break
batch = tuple(x.to(device) for x in batch) batch = tuple(x.to(device) for x in batch)
# `batch` is actually a tuple.. we'll unpack it later. # `batch` is actually a tuple.. we'll unpack it later.
@ -319,8 +322,6 @@ def compute_validation_loss(
num_frames = batch[4].sum() num_frames = batch[4].sum()
assert loss.requires_grad is False assert loss.requires_grad is False
assert ctc_loss.requires_grad is False
assert att_loss.requires_grad is False
loss_cpu = loss.detach().cpu().item() loss_cpu = loss.detach().cpu().item()
num_frames_cpu = num_frames.cpu().item() num_frames_cpu = num_frames.cpu().item()