fix warmup when scan_pessimistic_batches_for_oom

This commit is contained in:
yaozengwei 2022-06-18 19:58:05 +08:00
parent 74c14f5f5d
commit d7f4920206
6 changed files with 18 additions and 12 deletions

View File

@ -1018,6 +1018,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -1078,6 +1079,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float,
):
from lhotse.dataset import find_pessimistic_batches
@ -1098,7 +1100,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
warmup=warmup,
)
loss.backward()
optimizer.step()

View File

@ -883,6 +883,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 0 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -973,6 +974,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float,
):
from lhotse.dataset import find_pessimistic_batches
@ -983,9 +985,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
@ -993,7 +992,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
warmup=warmup,
)
loss.backward()
optimizer.step()

View File

@ -1001,6 +1001,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 0 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -1061,6 +1062,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float,
):
from lhotse.dataset import find_pessimistic_batches
@ -1071,9 +1073,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
@ -1081,7 +1080,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
warmup=warmup,
)
loss.backward()
optimizer.step()

View File

@ -932,6 +932,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -992,6 +993,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float,
):
from lhotse.dataset import find_pessimistic_batches
@ -1012,7 +1014,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
warmup=warmup,
)
loss.backward()
optimizer.step()

View File

@ -980,6 +980,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -1072,6 +1073,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float,
):
from lhotse.dataset import find_pessimistic_batches
@ -1092,7 +1094,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
warmup=warmup,
)
loss.backward()
optimizer.step()

View File

@ -988,6 +988,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -1048,6 +1049,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float,
):
from lhotse.dataset import find_pessimistic_batches
@ -1068,7 +1070,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
warmup=warmup,
)
loss.backward()
optimizer.step()