mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 07:04:18 +00:00
minor fix
This commit is contained in:
parent
e3a8f14969
commit
86a3a2f73b
@ -1018,6 +1018,12 @@ def train_one_epoch(
|
|||||||
be set to 0.
|
be set to 0.
|
||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
|
# set modules except adapters to eval mode
|
||||||
|
for name, m in model.named_modules():
|
||||||
|
if "adapter" in name:
|
||||||
|
m.training = True
|
||||||
|
else:
|
||||||
|
m.training = False
|
||||||
|
|
||||||
tot_loss = MetricsTracker()
|
tot_loss = MetricsTracker()
|
||||||
|
|
||||||
@ -1159,12 +1165,6 @@ def train_one_epoch(
|
|||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
model.train()
|
|
||||||
for name, m in model.named_modules():
|
|
||||||
if "adapter" in name:
|
|
||||||
m.training = True
|
|
||||||
else:
|
|
||||||
m.training = False
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}"
|
f"Validation on {valid_set}: Epoch {params.cur_epoch}, validation: {valid_info}"
|
||||||
)
|
)
|
||||||
@ -1175,6 +1175,8 @@ def train_one_epoch(
|
|||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train
|
tb_writer, f"train/{valid_set}_valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
model.train()
|
||||||
|
# set modules except adapters to eval mode
|
||||||
for name, m in model.named_modules():
|
for name, m in model.named_modules():
|
||||||
if "adapter" in name:
|
if "adapter" in name:
|
||||||
m.training = True
|
m.training = True
|
||||||
|
Loading…
x
Reference in New Issue
Block a user