Fix warmup (#435)

* fix warmup when scan_pessimistic_batches_for_oom

* delete comments
This commit is contained in:
Zengwei Yao 2022-06-20 13:40:01 +08:00 committed by GitHub
parent ab788980c9
commit a42d96dfe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 20 additions and 26 deletions

View File

@ -1018,6 +1018,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -1078,6 +1079,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float,
):
from lhotse.dataset import find_pessimistic_batches
@ -1088,9 +1090,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
@ -1098,7 +1097,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
warmup=warmup,
)
loss.backward()
optimizer.step()

View File

@ -883,6 +883,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 0 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -973,6 +974,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float,
):
from lhotse.dataset import find_pessimistic_batches
@ -983,9 +985,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
@ -993,7 +992,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
warmup=warmup,
)
loss.backward()
optimizer.step()

View File

@ -1001,6 +1001,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 0 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -1061,6 +1062,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float,
):
from lhotse.dataset import find_pessimistic_batches
@ -1071,9 +1073,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
@ -1081,7 +1080,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
warmup=warmup,
)
loss.backward()
optimizer.step()

View File

@ -932,6 +932,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -992,6 +993,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float,
):
from lhotse.dataset import find_pessimistic_batches
@ -1002,9 +1004,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
@ -1012,7 +1011,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
warmup=warmup,
)
loss.backward()
optimizer.step()

View File

@ -980,6 +980,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -1072,6 +1073,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float,
):
from lhotse.dataset import find_pessimistic_batches
@ -1082,9 +1084,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
@ -1092,7 +1091,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
warmup=warmup,
)
loss.backward()
optimizer.step()

View File

@ -74,9 +74,9 @@ from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut, MonoCut
from lhotse.dataset.collation import collate_custom_field
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from lhotse.dataset.collation import collate_custom_field
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
@ -376,7 +376,7 @@ def get_params() -> AttributeDict:
"distillation_layer": 5, # 0-based index
# Since output rate of hubert is 50, while that of encoder is 8,
# two successive codebook_index are concatenated together.
# Detailed in function Transducer::concat_sucessive_codebook_indexes.
# Detailed in function Transducer::concat_sucessive_codebook_indexes
"num_codebooks": 16, # used to construct distillation loss
}
)
@ -988,6 +988,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
sp=sp,
params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
)
scaler = GradScaler(enabled=params.use_fp16)
@ -1048,6 +1049,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
params: AttributeDict,
warmup: float,
):
from lhotse.dataset import find_pessimistic_batches
@ -1058,9 +1060,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
@ -1068,7 +1067,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
warmup=warmup,
)
loss.backward()
optimizer.step()