diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py b/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py index 8c121255b..21c7b4fac 100644 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/model_with_BERT.py @@ -48,6 +48,7 @@ class PromptedTransducer(nn.Module): use_BERT: bool = True, text_encoder_type: str = "BERT", text_encoder_adapter: bool = False, + freeze_text_encoder: bool = True, context_fuser: nn.Module = None, ): """ @@ -112,6 +113,7 @@ class PromptedTransducer(nn.Module): if text_encoder_type in ("BERT", "BERT-UNCASED") else self.text_encoder.config.dim ) + self.freeze_text_encoder = freeze_text_encoder if text_encoder_adapter: self.text_encoder_adapter = nn.Sequential( @@ -180,6 +182,8 @@ class PromptedTransducer(nn.Module): lm_scale * lm_probs + am_scale * am_probs + (1-lm_scale-am_scale) * combined_probs """ + if self.freeze_text_encoder: + self.text_encoder.eval() assert x.ndim == 3, x.shape assert x_lens.ndim == 1, x_lens.shape assert y.num_axes == 2, y.num_axes diff --git a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py index 56ed27a6a..412df50b6 100755 --- a/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py +++ b/egs/libriheavy/ASR/zipformer_prompt_asr/train_bert_encoder.py @@ -849,9 +849,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_transducer_model(params: AttributeDict) -> nn.Module: encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) + text_encoder = get_text_encoder(params) # This should be a cased BERT base model num_param = sum([p.numel() for p in text_encoder.parameters()]) logging.info(f"Num params in text encoder: {num_param}") + decoder = get_decoder_model(params) joiner = get_joiner_model(params) @@ -867,6 +869,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: vocab_size=params.vocab_size, text_encoder_type=params.text_encoder_type, text_encoder_adapter=params.text_encoder_adapter, + freeze_text_encoder=params.freeze_text_encoder, context_fuser=None, ) @@ -1618,14 +1621,15 @@ def run(rank, world_size, args): valid_cuts, text_sampling_func=naive_triplet_text_sampling ) - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # tokenizer=tokenizer, + # params=params, + # ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: @@ -1717,7 +1721,7 @@ def scan_pessimistic_batches_for_oom( train_dl: torch.utils.data.DataLoader, optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, - tokenizer, + tokenizer: spm.SentencePieceProcessor, params: AttributeDict, ): from lhotse.dataset import find_pessimistic_batches