mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
minor fixes to LSTM streaming model (#537)
This commit is contained in:
parent
cdea2d26d4
commit
0598291ff1
@ -505,9 +505,6 @@ def load_checkpoint_if_available(
|
|||||||
if "cur_epoch" in saved_params:
|
if "cur_epoch" in saved_params:
|
||||||
params["start_epoch"] = saved_params["cur_epoch"]
|
params["start_epoch"] = saved_params["cur_epoch"]
|
||||||
|
|
||||||
if "cur_batch_idx" in saved_params:
|
|
||||||
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
|
|
||||||
|
|
||||||
return saved_params
|
return saved_params
|
||||||
|
|
||||||
|
|
||||||
@ -615,8 +612,6 @@ def compute_loss(
|
|||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
reduction="none",
|
reduction="none",
|
||||||
)
|
)
|
||||||
simple_loss[0] = float("inf")
|
|
||||||
pruned_loss[1] = float("nan")
|
|
||||||
simple_loss_is_finite = torch.isfinite(simple_loss)
|
simple_loss_is_finite = torch.isfinite(simple_loss)
|
||||||
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
pruned_loss_is_finite = torch.isfinite(pruned_loss)
|
||||||
is_finite = simple_loss_is_finite & pruned_loss_is_finite
|
is_finite = simple_loss_is_finite & pruned_loss_is_finite
|
||||||
@ -769,13 +764,7 @@ def train_one_epoch(
|
|||||||
|
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
if batch_idx < cur_batch_idx:
|
|
||||||
continue
|
|
||||||
cur_batch_idx = batch_idx
|
|
||||||
|
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
@ -821,7 +810,6 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train > 0
|
params.batch_idx_train > 0
|
||||||
and params.batch_idx_train % params.save_every_n == 0
|
and params.batch_idx_train % params.save_every_n == 0
|
||||||
):
|
):
|
||||||
params.cur_batch_idx = batch_idx
|
|
||||||
save_checkpoint_with_global_batch_idx(
|
save_checkpoint_with_global_batch_idx(
|
||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir,
|
||||||
global_batch_idx=params.batch_idx_train,
|
global_batch_idx=params.batch_idx_train,
|
||||||
@ -834,7 +822,6 @@ def train_one_epoch(
|
|||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
del params.cur_batch_idx
|
|
||||||
remove_checkpoints(
|
remove_checkpoints(
|
||||||
out_dir=params.exp_dir,
|
out_dir=params.exp_dir,
|
||||||
topk=params.keep_last_k,
|
topk=params.keep_last_k,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user