mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Fix warmup (#435)
* fix warmup when scan_pessimistic_batches_for_oom * delete comments
This commit is contained in:
parent
ab788980c9
commit
a42d96dfe0
@ -1018,6 +1018,7 @@ def run(rank, world_size, args):
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
params=params,
|
params=params,
|
||||||
|
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
@ -1078,6 +1079,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
|
warmup: float,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
@ -1088,9 +1090,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -1098,7 +1097,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup=0.0,
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
@ -883,6 +883,7 @@ def run(rank, world_size, args):
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
params=params,
|
params=params,
|
||||||
|
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
@ -973,6 +974,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
|
warmup: float,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
@ -983,9 +985,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -993,7 +992,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup=0.0,
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
@ -1001,6 +1001,7 @@ def run(rank, world_size, args):
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
params=params,
|
params=params,
|
||||||
|
warmup=0.0 if params.start_epoch == 0 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
@ -1061,6 +1062,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
|
warmup: float,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
@ -1071,9 +1073,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -1081,7 +1080,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup=0.0,
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
@ -932,6 +932,7 @@ def run(rank, world_size, args):
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
params=params,
|
params=params,
|
||||||
|
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
@ -992,6 +993,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
|
warmup: float,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
@ -1002,9 +1004,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -1012,7 +1011,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup=0.0,
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
@ -980,6 +980,7 @@ def run(rank, world_size, args):
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
params=params,
|
params=params,
|
||||||
|
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
@ -1072,6 +1073,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
|
warmup: float,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
@ -1082,9 +1084,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -1092,7 +1091,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup=0.0,
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
@ -74,9 +74,9 @@ from conformer import Conformer
|
|||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut, MonoCut
|
from lhotse.cut import Cut, MonoCut
|
||||||
|
from lhotse.dataset.collation import collate_custom_field
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from lhotse.dataset.collation import collate_custom_field
|
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -376,7 +376,7 @@ def get_params() -> AttributeDict:
|
|||||||
"distillation_layer": 5, # 0-based index
|
"distillation_layer": 5, # 0-based index
|
||||||
# Since output rate of hubert is 50, while that of encoder is 8,
|
# Since output rate of hubert is 50, while that of encoder is 8,
|
||||||
# two successive codebook_index are concatenated together.
|
# two successive codebook_index are concatenated together.
|
||||||
# Detailed in function Transducer::concat_sucessive_codebook_indexes.
|
# Detailed in function Transducer::concat_sucessive_codebook_indexes
|
||||||
"num_codebooks": 16, # used to construct distillation loss
|
"num_codebooks": 16, # used to construct distillation loss
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -988,6 +988,7 @@ def run(rank, world_size, args):
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
params=params,
|
params=params,
|
||||||
|
warmup=0.0 if params.start_epoch == 1 else 1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16)
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
@ -1048,6 +1049,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
|
warmup: float,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
@ -1058,9 +1060,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -1068,7 +1067,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup=0.0,
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user