diff --git a/icefall/utils.py b/icefall/utils.py index d4a12d68d..26e6936bb 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -542,14 +542,14 @@ def write_error_stats( words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) num_corr = 0 ERR = "*" - + if compute_CER: for i, res in enumerate(results): cut_id, ref, hyp = res ref = list("".join(ref)) hyp = list("".join(hyp)) results[i] = (cut_id, ref, hyp) - + for cut_id, ref, hyp in results: ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode) for ref_word, hyp_word in ali: @@ -625,9 +625,7 @@ def write_error_stats( f"{cut_id}:\t" + " ".join( ( - ref_word - if ref_word == hyp_word - else f"({ref_word}->{hyp_word})" + ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" for ref_word, hyp_word in ali ) ), @@ -637,9 +635,7 @@ def write_error_stats( print("", file=f) print("SUBSTITUTIONS: count ref -> hyp", file=f) - for count, (ref, hyp) in sorted( - [(v, k) for k, v in subs.items()], reverse=True - ): + for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): print(f"{count} {ref} -> {hyp}", file=f) print("", file=f) @@ -653,9 +649,7 @@ def write_error_stats( print(f"{count} {hyp}", file=f) print("", file=f) - print( - "PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f - ) + print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) for _, word, counts in sorted( [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True ): @@ -1380,7 +1374,10 @@ def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, floa def get_parameter_groups_with_lrs( - model: nn.Module, lr: float, include_names: bool = False + model: nn.Module, + lr: float, + include_names: bool = False, + freeze_modules: List[str] = [], ) -> List[dict]: """ This is for use with the ScaledAdam optimizers (more recent versions that accept lists of @@ -1404,6 +1401,8 @@ def get_parameter_groups_with_lrs( ... ] """ + named_modules = list(model.named_modules()) + # flat_lr_scale just contains the lr_scale explicitly specified # for each prefix of the name, e.g. 'encoder.layers.3', these need # to be multiplied for all prefix of the name of any given parameter. @@ -1423,6 +1422,15 @@ def get_parameter_groups_with_lrs( split_name = name.split(".") # caution: as a special case, if the name is '', split_name will be [ '' ]. prefix = split_name[0] + if prefix == "module": # DDP + module_name = split_name[1] + if module_name in freeze_modules: + logging.info(f"Remove {name} from parameters") + continue + else: + if prefix in freeze_modules: + logging.info(f"Remove {name} from parameters") + continue cur_lr = lr * flat_lr_scale[prefix] if prefix != "": cur_lr *= flat_lr_scale[""]