minor fix

This commit is contained in:
marcoyang 2024-02-22 20:54:20 +08:00
parent e3a8f14969
commit 86a3a2f73b

View File

@ -1018,6 +1018,12 @@ def train_one_epoch(
be set to 0. be set to 0.
""" """
model.train() model.train()
# set modules except adapters to eval mode
for name, m in model.named_modules():
if "adapter" in name:
m.training = True
else:
m.training = False
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
@ -1159,12 +1165,6 @@ def train_one_epoch(
valid_dl=valid_dl, valid_dl=valid_dl,
world_size=world_size, world_size=world_size,
) )
model.train()
for name, m in model.named_modules():
if "adapter" in name:
m.training = True
else:
m.training = False
logging.info( logging.info(
f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}" f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}"
) )
@ -1175,6 +1175,8 @@ def train_one_epoch(
valid_info.write_summary( valid_info.write_summary(
tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train
) )
model.train()
# set modules except adapters to eval mode
for name, m in model.named_modules(): for name, m in model.named_modules():
if "adapter" in name: if "adapter" in name:
m.training = True m.training = True