add freeze param in utils.py

This commit is contained in:
marcoyang 2023-09-20 19:05:12 +08:00
parent fdff6b3b3a
commit 974c1fff08

View File

@ -542,14 +542,14 @@ def write_error_stats(
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0 num_corr = 0
ERR = "*" ERR = "*"
if compute_CER: if compute_CER:
for i, res in enumerate(results): for i, res in enumerate(results):
cut_id, ref, hyp = res cut_id, ref, hyp = res
ref = list("".join(ref)) ref = list("".join(ref))
hyp = list("".join(hyp)) hyp = list("".join(hyp))
results[i] = (cut_id, ref, hyp) results[i] = (cut_id, ref, hyp)
for cut_id, ref, hyp in results: for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
for ref_word, hyp_word in ali: for ref_word, hyp_word in ali:
@ -625,9 +625,7 @@ def write_error_stats(
f"{cut_id}:\t" f"{cut_id}:\t"
+ " ".join( + " ".join(
( (
ref_word ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
if ref_word == hyp_word
else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali for ref_word, hyp_word in ali
) )
), ),
@ -637,9 +635,7 @@ def write_error_stats(
print("", file=f) print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f) print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted( for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
[(v, k) for k, v in subs.items()], reverse=True
):
print(f"{count} {ref} -> {hyp}", file=f) print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f) print("", file=f)
@ -653,9 +649,7 @@ def write_error_stats(
print(f"{count} {hyp}", file=f) print(f"{count} {hyp}", file=f)
print("", file=f) print("", file=f)
print( print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
)
for _, word, counts in sorted( for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
): ):
@ -1380,7 +1374,10 @@ def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, floa
def get_parameter_groups_with_lrs( def get_parameter_groups_with_lrs(
model: nn.Module, lr: float, include_names: bool = False model: nn.Module,
lr: float,
include_names: bool = False,
freeze_modules: List[str] = [],
) -> List[dict]: ) -> List[dict]:
""" """
This is for use with the ScaledAdam optimizers (more recent versions that accept lists of This is for use with the ScaledAdam optimizers (more recent versions that accept lists of
@ -1404,6 +1401,8 @@ def get_parameter_groups_with_lrs(
... ] ... ]
""" """
named_modules = list(model.named_modules())
# flat_lr_scale just contains the lr_scale explicitly specified # flat_lr_scale just contains the lr_scale explicitly specified
# for each prefix of the name, e.g. 'encoder.layers.3', these need # for each prefix of the name, e.g. 'encoder.layers.3', these need
# to be multiplied for all prefix of the name of any given parameter. # to be multiplied for all prefix of the name of any given parameter.
@ -1423,6 +1422,15 @@ def get_parameter_groups_with_lrs(
split_name = name.split(".") split_name = name.split(".")
# caution: as a special case, if the name is '', split_name will be [ '' ]. # caution: as a special case, if the name is '', split_name will be [ '' ].
prefix = split_name[0] prefix = split_name[0]
if prefix == "module": # DDP
module_name = split_name[1]
if module_name in freeze_modules:
logging.info(f"Remove {name} from parameters")
continue
else:
if prefix in freeze_modules:
logging.info(f"Remove {name} from parameters")
continue
cur_lr = lr * flat_lr_scale[prefix] cur_lr = lr * flat_lr_scale[prefix]
if prefix != "": if prefix != "":
cur_lr *= flat_lr_scale[""] cur_lr *= flat_lr_scale[""]