Fix warmup (#435)

* fix warmup when scan_pessimistic_batches_for_oom

* delete comments
This commit is contained in:
Zengwei Yao 2022-06-20 13:40:01 +08:00 committed by GitHub
parent ab788980c9
commit a42d96dfe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 20 additions and 26 deletions

View File

@ -1018,6 +1018,7 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
sp=sp, sp=sp,
params=params, params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler(enabled=params.use_fp16)
@ -1078,6 +1079,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
params: AttributeDict, params: AttributeDict,
warmup: float,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -1088,9 +1090,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
@ -1098,7 +1097,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=0.0, warmup=warmup,
) )
loss.backward() loss.backward()
optimizer.step() optimizer.step()

View File

@ -883,6 +883,7 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
sp=sp, sp=sp,
params=params, params=params,
warmup=0.0 if params.start_epoch == 0 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler(enabled=params.use_fp16)
@ -973,6 +974,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
params: AttributeDict, params: AttributeDict,
warmup: float,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -983,9 +985,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
@ -993,7 +992,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=0.0, warmup=warmup,
) )
loss.backward() loss.backward()
optimizer.step() optimizer.step()

View File

@ -1001,6 +1001,7 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
sp=sp, sp=sp,
params=params, params=params,
warmup=0.0 if params.start_epoch == 0 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler(enabled=params.use_fp16)
@ -1061,6 +1062,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
params: AttributeDict, params: AttributeDict,
warmup: float,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -1071,9 +1073,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
@ -1081,7 +1080,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=0.0, warmup=warmup,
) )
loss.backward() loss.backward()
optimizer.step() optimizer.step()

View File

@ -932,6 +932,7 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
sp=sp, sp=sp,
params=params, params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler(enabled=params.use_fp16)
@ -992,6 +993,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
params: AttributeDict, params: AttributeDict,
warmup: float,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -1002,9 +1004,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
@ -1012,7 +1011,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=0.0, warmup=warmup,
) )
loss.backward() loss.backward()
optimizer.step() optimizer.step()

View File

@ -980,6 +980,7 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
sp=sp, sp=sp,
params=params, params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler(enabled=params.use_fp16)
@ -1072,6 +1073,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
params: AttributeDict, params: AttributeDict,
warmup: float,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -1082,9 +1084,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
@ -1092,7 +1091,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=0.0, warmup=warmup,
) )
loss.backward() loss.backward()
optimizer.step() optimizer.step()

View File

@ -74,9 +74,9 @@ from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut, MonoCut from lhotse.cut import Cut, MonoCut
from lhotse.dataset.collation import collate_custom_field
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from lhotse.dataset.collation import collate_custom_field
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
@ -376,7 +376,7 @@ def get_params() -> AttributeDict:
"distillation_layer": 5, # 0-based index "distillation_layer": 5, # 0-based index
# Since output rate of hubert is 50, while that of encoder is 8, # Since output rate of hubert is 50, while that of encoder is 8,
# two successive codebook_index are concatenated together. # two successive codebook_index are concatenated together.
# Detailed in function Transducer::concat_sucessive_codebook_indexes. # Detailed in function Transducer::concat_sucessive_codebook_indexes
"num_codebooks": 16, # used to construct distillation loss "num_codebooks": 16, # used to construct distillation loss
} }
) )
@ -988,6 +988,7 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
sp=sp, sp=sp,
params=params, params=params,
warmup=0.0 if params.start_epoch == 1 else 1.0,
) )
scaler = GradScaler(enabled=params.use_fp16) scaler = GradScaler(enabled=params.use_fp16)
@ -1048,6 +1049,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
params: AttributeDict, params: AttributeDict,
warmup: float,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -1058,9 +1060,6 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
@ -1068,7 +1067,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
warmup=0.0, warmup=warmup,
) )
loss.backward() loss.backward()
optimizer.step() optimizer.step()