Minor fixes

This commit is contained in:
pkufool 2022-04-11 15:40:14 +08:00
parent 4ebe821769
commit 22474e9abe

View File

@ -28,9 +28,6 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300
""" """
@ -938,14 +935,15 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
loss, _ = compute_loss( with torch.cuda.amp.autocast(enabled=params.use_fp16):
params=params, loss, _ = compute_loss(
model=model, params=params,
sp=sp, model=model,
batch=batch, sp=sp,
is_training=True, batch=batch,
warmup = 0.0 is_training=True,
) warmup = 0.0
)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()