mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 16:14:17 +00:00
add freeze param in utils.py
This commit is contained in:
parent
fdff6b3b3a
commit
974c1fff08
@ -625,9 +625,7 @@ def write_error_stats(
|
|||||||
f"{cut_id}:\t"
|
f"{cut_id}:\t"
|
||||||
+ " ".join(
|
+ " ".join(
|
||||||
(
|
(
|
||||||
ref_word
|
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
||||||
if ref_word == hyp_word
|
|
||||||
else f"({ref_word}->{hyp_word})"
|
|
||||||
for ref_word, hyp_word in ali
|
for ref_word, hyp_word in ali
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@ -637,9 +635,7 @@ def write_error_stats(
|
|||||||
print("", file=f)
|
print("", file=f)
|
||||||
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||||
|
|
||||||
for count, (ref, hyp) in sorted(
|
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
||||||
[(v, k) for k, v in subs.items()], reverse=True
|
|
||||||
):
|
|
||||||
print(f"{count} {ref} -> {hyp}", file=f)
|
print(f"{count} {ref} -> {hyp}", file=f)
|
||||||
|
|
||||||
print("", file=f)
|
print("", file=f)
|
||||||
@ -653,9 +649,7 @@ def write_error_stats(
|
|||||||
print(f"{count} {hyp}", file=f)
|
print(f"{count} {hyp}", file=f)
|
||||||
|
|
||||||
print("", file=f)
|
print("", file=f)
|
||||||
print(
|
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
||||||
"PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f
|
|
||||||
)
|
|
||||||
for _, word, counts in sorted(
|
for _, word, counts in sorted(
|
||||||
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||||
):
|
):
|
||||||
@ -1380,7 +1374,10 @@ def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, floa
|
|||||||
|
|
||||||
|
|
||||||
def get_parameter_groups_with_lrs(
|
def get_parameter_groups_with_lrs(
|
||||||
model: nn.Module, lr: float, include_names: bool = False
|
model: nn.Module,
|
||||||
|
lr: float,
|
||||||
|
include_names: bool = False,
|
||||||
|
freeze_modules: List[str] = [],
|
||||||
) -> List[dict]:
|
) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
This is for use with the ScaledAdam optimizers (more recent versions that accept lists of
|
This is for use with the ScaledAdam optimizers (more recent versions that accept lists of
|
||||||
@ -1404,6 +1401,8 @@ def get_parameter_groups_with_lrs(
|
|||||||
... ]
|
... ]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
named_modules = list(model.named_modules())
|
||||||
|
|
||||||
# flat_lr_scale just contains the lr_scale explicitly specified
|
# flat_lr_scale just contains the lr_scale explicitly specified
|
||||||
# for each prefix of the name, e.g. 'encoder.layers.3', these need
|
# for each prefix of the name, e.g. 'encoder.layers.3', these need
|
||||||
# to be multiplied for all prefix of the name of any given parameter.
|
# to be multiplied for all prefix of the name of any given parameter.
|
||||||
@ -1423,6 +1422,15 @@ def get_parameter_groups_with_lrs(
|
|||||||
split_name = name.split(".")
|
split_name = name.split(".")
|
||||||
# caution: as a special case, if the name is '', split_name will be [ '' ].
|
# caution: as a special case, if the name is '', split_name will be [ '' ].
|
||||||
prefix = split_name[0]
|
prefix = split_name[0]
|
||||||
|
if prefix == "module": # DDP
|
||||||
|
module_name = split_name[1]
|
||||||
|
if module_name in freeze_modules:
|
||||||
|
logging.info(f"Remove {name} from parameters")
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if prefix in freeze_modules:
|
||||||
|
logging.info(f"Remove {name} from parameters")
|
||||||
|
continue
|
||||||
cur_lr = lr * flat_lr_scale[prefix]
|
cur_lr = lr * flat_lr_scale[prefix]
|
||||||
if prefix != "":
|
if prefix != "":
|
||||||
cur_lr *= flat_lr_scale[""]
|
cur_lr *= flat_lr_scale[""]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user