add freeze param in utils.py

This commit is contained in:
marcoyang 2023-09-20 19:05:12 +08:00
parent fdff6b3b3a
commit 974c1fff08

View File

@ -542,14 +542,14 @@ def write_error_stats(
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
if compute_CER:
for i, res in enumerate(results):
cut_id, ref, hyp = res
ref = list("".join(ref))
hyp = list("".join(hyp))
results[i] = (cut_id, ref, hyp)
for cut_id, ref, hyp in results:
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
for ref_word, hyp_word in ali:
@ -625,9 +625,7 @@ def write_error_stats(
f"{cut_id}:\t"
+ " ".join(
(
ref_word
if ref_word == hyp_word
else f"({ref_word}->{hyp_word})"
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali
)
),
@ -637,9 +635,7 @@ def write_error_stats(
print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted(
[(v, k) for k, v in subs.items()], reverse=True
):
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f)
@ -653,9 +649,7 @@ def write_error_stats(
print(f"{count} {hyp}", file=f)
print("", file=f)
print(
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
)
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
):
@ -1380,7 +1374,10 @@ def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, floa
def get_parameter_groups_with_lrs(
model: nn.Module, lr: float, include_names: bool = False
model: nn.Module,
lr: float,
include_names: bool = False,
freeze_modules: List[str] = [],
) -> List[dict]:
"""
This is for use with the ScaledAdam optimizers (more recent versions that accept lists of
@ -1404,6 +1401,8 @@ def get_parameter_groups_with_lrs(
... ]
"""
named_modules = list(model.named_modules())
# flat_lr_scale just contains the lr_scale explicitly specified
# for each prefix of the name, e.g. 'encoder.layers.3', these need
# to be multiplied for all prefix of the name of any given parameter.
@ -1423,6 +1422,15 @@ def get_parameter_groups_with_lrs(
split_name = name.split(".")
# caution: as a special case, if the name is '', split_name will be [ '' ].
prefix = split_name[0]
if prefix == "module": # DDP
module_name = split_name[1]
if module_name in freeze_modules:
logging.info(f"Remove {name} from parameters")
continue
else:
if prefix in freeze_modules:
logging.info(f"Remove {name} from parameters")
continue
cur_lr = lr * flat_lr_scale[prefix]
if prefix != "":
cur_lr *= flat_lr_scale[""]