mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
add freeze param in utils.py
This commit is contained in:
parent
fdff6b3b3a
commit
974c1fff08
@ -625,9 +625,7 @@ def write_error_stats(
|
||||
f"{cut_id}:\t"
|
||||
+ " ".join(
|
||||
(
|
||||
ref_word
|
||||
if ref_word == hyp_word
|
||||
else f"({ref_word}->{hyp_word})"
|
||||
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
||||
for ref_word, hyp_word in ali
|
||||
)
|
||||
),
|
||||
@ -637,9 +635,7 @@ def write_error_stats(
|
||||
print("", file=f)
|
||||
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||
|
||||
for count, (ref, hyp) in sorted(
|
||||
[(v, k) for k, v in subs.items()], reverse=True
|
||||
):
|
||||
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
||||
print(f"{count} {ref} -> {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
@ -653,9 +649,7 @@ def write_error_stats(
|
||||
print(f"{count} {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print(
|
||||
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
|
||||
)
|
||||
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
||||
for _, word, counts in sorted(
|
||||
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||
):
|
||||
@ -1380,7 +1374,10 @@ def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, floa
|
||||
|
||||
|
||||
def get_parameter_groups_with_lrs(
|
||||
model: nn.Module, lr: float, include_names: bool = False
|
||||
model: nn.Module,
|
||||
lr: float,
|
||||
include_names: bool = False,
|
||||
freeze_modules: List[str] = [],
|
||||
) -> List[dict]:
|
||||
"""
|
||||
This is for use with the ScaledAdam optimizers (more recent versions that accept lists of
|
||||
@ -1404,6 +1401,8 @@ def get_parameter_groups_with_lrs(
|
||||
... ]
|
||||
|
||||
"""
|
||||
named_modules = list(model.named_modules())
|
||||
|
||||
# flat_lr_scale just contains the lr_scale explicitly specified
|
||||
# for each prefix of the name, e.g. 'encoder.layers.3', these need
|
||||
# to be multiplied for all prefix of the name of any given parameter.
|
||||
@ -1423,6 +1422,15 @@ def get_parameter_groups_with_lrs(
|
||||
split_name = name.split(".")
|
||||
# caution: as a special case, if the name is '', split_name will be [ '' ].
|
||||
prefix = split_name[0]
|
||||
if prefix == "module": # DDP
|
||||
module_name = split_name[1]
|
||||
if module_name in freeze_modules:
|
||||
logging.info(f"Remove {name} from parameters")
|
||||
continue
|
||||
else:
|
||||
if prefix in freeze_modules:
|
||||
logging.info(f"Remove {name} from parameters")
|
||||
continue
|
||||
cur_lr = lr * flat_lr_scale[prefix]
|
||||
if prefix != "":
|
||||
cur_lr *= flat_lr_scale[""]
|
||||
|
Loading…
x
Reference in New Issue
Block a user