minor fixes to LSTM streaming model (#537)

This commit is contained in:
Fangjun Kuang 2022-08-20 09:50:50 +08:00 committed by GitHub
parent cdea2d26d4
commit 0598291ff1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -505,9 +505,6 @@ def load_checkpoint_if_available(
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params
@ -615,8 +612,6 @@ def compute_loss(
warmup=warmup,
reduction="none",
)
simple_loss[0] = float("inf")
pruned_loss[1] = float("nan")
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
is_finite = simple_loss_is_finite & pruned_loss_is_finite
@ -769,13 +764,7 @@ def train_one_epoch(
tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
@ -821,7 +810,6 @@ def train_one_epoch(
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
@ -834,7 +822,6 @@ def train_one_epoch(
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,