From 86a3a2f73b2f067689d9dbf54d47a55a9b9ad62f Mon Sep 17 00:00:00 2001 From: marcoyang Date: Thu, 22 Feb 2024 20:54:20 +0800 Subject: [PATCH] minor fix --- egs/librispeech/ASR/zipformer_adapter/train.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/egs/librispeech/ASR/zipformer_adapter/train.py b/egs/librispeech/ASR/zipformer_adapter/train.py index 5c45c9a1f..7f81ddd96 100755 --- a/egs/librispeech/ASR/zipformer_adapter/train.py +++ b/egs/librispeech/ASR/zipformer_adapter/train.py @@ -1018,6 +1018,12 @@ def train_one_epoch( be set to 0. """ model.train() + # set modules except adapters to eval mode + for name, m in model.named_modules(): + if "adapter" in name: + m.training = True + else: + m.training = False tot_loss = MetricsTracker() @@ -1159,12 +1165,6 @@ def train_one_epoch( valid_dl=valid_dl, world_size=world_size, ) - model.train() - for name, m in model.named_modules(): - if "adapter" in name: - m.training = True - else: - m.training = False logging.info( f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}" ) @@ -1175,6 +1175,8 @@ def train_one_epoch( valid_info.write_summary( tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train ) + model.train() + # set modules except adapters to eval mode for name, m in model.named_modules(): if "adapter" in name: m.training = True