add freeze param in utils.py

This commit is contained in:
marcoyang 2023-09-20 19:05:12 +08:00
parent fdff6b3b3a
commit 974c1fff08

View File

@ -625,9 +625,7 @@ def write_error_stats(
f"{cut_id}:\t" f"{cut_id}:\t"
+ " ".join( + " ".join(
( (
ref_word ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
if ref_word == hyp_word
else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali for ref_word, hyp_word in ali
) )
), ),
@ -637,9 +635,7 @@ def write_error_stats(
print("", file=f) print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f) print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted( for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
[(v, k) for k, v in subs.items()], reverse=True
):
print(f"{count} {ref} -> {hyp}", file=f) print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f) print("", file=f)
@ -653,9 +649,7 @@ def write_error_stats(
print(f"{count} {hyp}", file=f) print(f"{count} {hyp}", file=f)
print("", file=f) print("", file=f)
print( print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
)
for _, word, counts in sorted( for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
): ):
@ -1380,7 +1374,10 @@ def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, floa
def get_parameter_groups_with_lrs( def get_parameter_groups_with_lrs(
model: nn.Module, lr: float, include_names: bool = False model: nn.Module,
lr: float,
include_names: bool = False,
freeze_modules: List[str] = [],
) -> List[dict]: ) -> List[dict]:
""" """
This is for use with the ScaledAdam optimizers (more recent versions that accept lists of This is for use with the ScaledAdam optimizers (more recent versions that accept lists of
@ -1404,6 +1401,8 @@ def get_parameter_groups_with_lrs(
... ] ... ]
""" """
named_modules = list(model.named_modules())
# flat_lr_scale just contains the lr_scale explicitly specified # flat_lr_scale just contains the lr_scale explicitly specified
# for each prefix of the name, e.g. 'encoder.layers.3', these need # for each prefix of the name, e.g. 'encoder.layers.3', these need
# to be multiplied for all prefix of the name of any given parameter. # to be multiplied for all prefix of the name of any given parameter.
@ -1423,6 +1422,15 @@ def get_parameter_groups_with_lrs(
split_name = name.split(".") split_name = name.split(".")
# caution: as a special case, if the name is '', split_name will be [ '' ]. # caution: as a special case, if the name is '', split_name will be [ '' ].
prefix = split_name[0] prefix = split_name[0]
if prefix == "module": # DDP
module_name = split_name[1]
if module_name in freeze_modules:
logging.info(f"Remove {name} from parameters")
continue
else:
if prefix in freeze_modules:
logging.info(f"Remove {name} from parameters")
continue
cur_lr = lr * flat_lr_scale[prefix] cur_lr = lr * flat_lr_scale[prefix]
if prefix != "": if prefix != "":
cur_lr *= flat_lr_scale[""] cur_lr *= flat_lr_scale[""]