mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Various bug fixes
This commit is contained in:
parent
7711fba867
commit
9576d6574f
@ -140,7 +140,7 @@ def get_params() -> AttributeDict:
|
||||
"eos_sym": 1,
|
||||
"start_epoch": 0,
|
||||
"num_epochs": 20,
|
||||
"num_valid_batches": 100,
|
||||
"num_valid_batches": 200,
|
||||
"symbols_per_batch": 5000,
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
@ -288,7 +288,8 @@ def compute_loss(
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
|
||||
tgt_nll = model.decoder_nll(memory, pos_emb, src_symbols,
|
||||
decoder_nll_func = model.module.decoder_nll if isinstance(model, DDP) else model.decoder_nll
|
||||
tgt_nll = decoder_nll_func(memory, pos_emb, src_symbols,
|
||||
tgt_symbols, src_key_padding_mask)
|
||||
loss = (tgt_nll * tgt_weights).sum()
|
||||
|
||||
@ -312,6 +313,8 @@ def compute_validation_loss(
|
||||
tot_loss = 0.0
|
||||
tot_frames = 0.0
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
if batch_idx == params.num_valid_batches:
|
||||
break
|
||||
batch = tuple(x.to(device) for x in batch)
|
||||
|
||||
# `batch` is actually a tuple.. we'll unpack it later.
|
||||
@ -319,8 +322,6 @@ def compute_validation_loss(
|
||||
num_frames = batch[4].sum()
|
||||
|
||||
assert loss.requires_grad is False
|
||||
assert ctc_loss.requires_grad is False
|
||||
assert att_loss.requires_grad is False
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
num_frames_cpu = num_frames.cpu().item()
|
||||
|
Loading…
x
Reference in New Issue
Block a user