From a42d96dfe047e28a9cd5463b33246f4456e92cdb Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 20 Jun 2022 13:40:01 +0800 Subject: [PATCH] Fix warmup (#435) * fix warmup when scan_pessimistic_batches_for_oom * delete comments --- .../ASR/conv_emformer_transducer_stateless/train.py | 7 +++---- .../ASR/pruned_transducer_stateless2/train.py | 7 +++---- .../ASR/pruned_transducer_stateless3/train.py | 7 +++---- .../ASR/pruned_transducer_stateless4/train.py | 7 +++---- .../ASR/pruned_transducer_stateless5/train.py | 7 +++---- .../ASR/pruned_transducer_stateless6/train.py | 11 +++++------ 6 files changed, 20 insertions(+), 26 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index 106f3e511..acaf1397f 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -1018,6 +1018,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, ) scaler = GradScaler(enabled=params.use_fp16) @@ -1078,6 +1079,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + warmup: float, ): from lhotse.dataset import find_pessimistic_batches @@ -1088,9 +1090,6 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, @@ -1098,7 +1097,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup=0.0, + warmup=warmup, ) loss.backward() optimizer.step() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 36ee7ca74..55f32e119 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -883,6 +883,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + warmup=0.0 if params.start_epoch == 0 else 1.0, ) scaler = GradScaler(enabled=params.use_fp16) @@ -973,6 +974,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + warmup: float, ): from lhotse.dataset import find_pessimistic_batches @@ -983,9 +985,6 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, @@ -993,7 +992,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup=0.0, + warmup=warmup, ) loss.backward() optimizer.step() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 92eae78d1..be9fa8f8b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -1001,6 +1001,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + warmup=0.0 if params.start_epoch == 0 else 1.0, ) scaler = GradScaler(enabled=params.use_fp16) @@ -1061,6 +1062,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + warmup: float, ): from lhotse.dataset import find_pessimistic_batches @@ -1071,9 +1073,6 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, @@ -1081,7 +1080,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup=0.0, + warmup=warmup, ) loss.backward() optimizer.step() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 48c0e683d..0fece2464 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -932,6 +932,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, ) scaler = GradScaler(enabled=params.use_fp16) @@ -992,6 +993,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + warmup: float, ): from lhotse.dataset import find_pessimistic_batches @@ -1002,9 +1004,6 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, @@ -1012,7 +1011,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup=0.0, + warmup=warmup, ) loss.backward() optimizer.step() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index e77eb19ff..eaf893997 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -980,6 +980,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, ) scaler = GradScaler(enabled=params.use_fp16) @@ -1072,6 +1073,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + warmup: float, ): from lhotse.dataset import find_pessimistic_batches @@ -1082,9 +1084,6 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, @@ -1092,7 +1091,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup=0.0, + warmup=warmup, ) loss.backward() optimizer.step() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index 315c01c8e..9e9fc1440 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -74,9 +74,9 @@ from conformer import Conformer from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut, MonoCut +from lhotse.dataset.collation import collate_custom_field from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed -from lhotse.dataset.collation import collate_custom_field from model import Transducer from optim import Eden, Eve from torch import Tensor @@ -376,7 +376,7 @@ def get_params() -> AttributeDict: "distillation_layer": 5, # 0-based index # Since output rate of hubert is 50, while that of encoder is 8, # two successive codebook_index are concatenated together. - # Detailed in function Transducer::concat_sucessive_codebook_indexes. + # Detailed in function Transducer::concat_sucessive_codebook_indexes "num_codebooks": 16, # used to construct distillation loss } ) @@ -988,6 +988,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, ) scaler = GradScaler(enabled=params.use_fp16) @@ -1048,6 +1049,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + warmup: float, ): from lhotse.dataset import find_pessimistic_batches @@ -1058,9 +1060,6 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, @@ -1068,7 +1067,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup=0.0, + warmup=warmup, ) loss.backward() optimizer.step()