mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
Various bug fixes
This commit is contained in:
parent
7711fba867
commit
9576d6574f
@ -140,7 +140,7 @@ def get_params() -> AttributeDict:
|
|||||||
"eos_sym": 1,
|
"eos_sym": 1,
|
||||||
"start_epoch": 0,
|
"start_epoch": 0,
|
||||||
"num_epochs": 20,
|
"num_epochs": 20,
|
||||||
"num_valid_batches": 100,
|
"num_valid_batches": 200,
|
||||||
"symbols_per_batch": 5000,
|
"symbols_per_batch": 5000,
|
||||||
"best_train_loss": float("inf"),
|
"best_train_loss": float("inf"),
|
||||||
"best_valid_loss": float("inf"),
|
"best_valid_loss": float("inf"),
|
||||||
@ -288,8 +288,9 @@ def compute_loss(
|
|||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
|
memory, pos_emb = model(masked_src_symbols, src_key_padding_mask)
|
||||||
tgt_nll = model.decoder_nll(memory, pos_emb, src_symbols,
|
decoder_nll_func = model.module.decoder_nll if isinstance(model, DDP) else model.decoder_nll
|
||||||
tgt_symbols, src_key_padding_mask)
|
tgt_nll = decoder_nll_func(memory, pos_emb, src_symbols,
|
||||||
|
tgt_symbols, src_key_padding_mask)
|
||||||
loss = (tgt_nll * tgt_weights).sum()
|
loss = (tgt_nll * tgt_weights).sum()
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
@ -312,6 +313,8 @@ def compute_validation_loss(
|
|||||||
tot_loss = 0.0
|
tot_loss = 0.0
|
||||||
tot_frames = 0.0
|
tot_frames = 0.0
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
for batch_idx, batch in enumerate(valid_dl):
|
||||||
|
if batch_idx == params.num_valid_batches:
|
||||||
|
break
|
||||||
batch = tuple(x.to(device) for x in batch)
|
batch = tuple(x.to(device) for x in batch)
|
||||||
|
|
||||||
# `batch` is actually a tuple.. we'll unpack it later.
|
# `batch` is actually a tuple.. we'll unpack it later.
|
||||||
@ -319,8 +322,6 @@ def compute_validation_loss(
|
|||||||
num_frames = batch[4].sum()
|
num_frames = batch[4].sum()
|
||||||
|
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
assert ctc_loss.requires_grad is False
|
|
||||||
assert att_loss.requires_grad is False
|
|
||||||
|
|
||||||
loss_cpu = loss.detach().cpu().item()
|
loss_cpu = loss.detach().cpu().item()
|
||||||
num_frames_cpu = num_frames.cpu().item()
|
num_frames_cpu = num_frames.cpu().item()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user