mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix warmup (#435)
* fix warmup when scan_pessimistic_batches_for_oom * delete comments
This commit is contained in:
parent
ab788980c9
commit
a42d96dfe0
@ -1018,6 +1018,7 @@ def run(rank, world_size, args):
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
@ -1078,6 +1079,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
params: AttributeDict,
|
||||
warmup: float,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
@ -1088,9 +1090,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
@ -1098,7 +1097,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=0.0,
|
||||
warmup=warmup,
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
@ -883,6 +883,7 @@ def run(rank, world_size, args):
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
@ -973,6 +974,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
params: AttributeDict,
|
||||
warmup: float,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
@ -983,9 +985,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
@ -993,7 +992,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=0.0,
|
||||
warmup=warmup,
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
@ -1001,6 +1001,7 @@ def run(rank, world_size, args):
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
@ -1061,6 +1062,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
params: AttributeDict,
|
||||
warmup: float,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
@ -1071,9 +1073,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
@ -1081,7 +1080,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=0.0,
|
||||
warmup=warmup,
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
@ -932,6 +932,7 @@ def run(rank, world_size, args):
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
@ -992,6 +993,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
params: AttributeDict,
|
||||
warmup: float,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
@ -1002,9 +1004,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
@ -1012,7 +1011,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=0.0,
|
||||
warmup=warmup,
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
@ -980,6 +980,7 @@ def run(rank, world_size, args):
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
@ -1072,6 +1073,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
params: AttributeDict,
|
||||
warmup: float,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
@ -1082,9 +1084,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
@ -1092,7 +1091,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=0.0,
|
||||
warmup=warmup,
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
@ -74,9 +74,9 @@ from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut, MonoCut
|
||||
from lhotse.dataset.collation import collate_custom_field
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from lhotse.dataset.collation import collate_custom_field
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
@ -376,7 +376,7 @@ def get_params() -> AttributeDict:
|
||||
"distillation_layer": 5, # 0-based index
|
||||
# Since output rate of hubert is 50, while that of encoder is 8,
|
||||
# two successive codebook_index are concatenated together.
|
||||
# Detailed in function Transducer::concat_sucessive_codebook_indexes.
|
||||
# Detailed in function Transducer::concat_sucessive_codebook_indexes
|
||||
"num_codebooks": 16, # used to construct distillation loss
|
||||
}
|
||||
)
|
||||
@ -988,6 +988,7 @@ def run(rank, world_size, args):
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
@ -1048,6 +1049,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
params: AttributeDict,
|
||||
warmup: float,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
@ -1058,9 +1060,6 @@ def scan_pessimistic_batches_for_oom(
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||
# (i.e. are not remembered by the decaying-average in adam), because
|
||||
# we want to avoid these params being subject to shrinkage in adam.
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
@ -1068,7 +1067,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=0.0,
|
||||
warmup=warmup,
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
Loading…
x
Reference in New Issue
Block a user