freeze BERT option

This commit is contained in:
marcoyang 2023-09-21 10:24:14 +08:00
parent 21cc1dfff4
commit ae3149cb7f
2 changed files with 17 additions and 9 deletions

View File

@ -48,6 +48,7 @@ class PromptedTransducer(nn.Module):
use_BERT: bool = True, use_BERT: bool = True,
text_encoder_type: str = "BERT", text_encoder_type: str = "BERT",
text_encoder_adapter: bool = False, text_encoder_adapter: bool = False,
freeze_text_encoder: bool = True,
context_fuser: nn.Module = None, context_fuser: nn.Module = None,
): ):
""" """
@ -112,6 +113,7 @@ class PromptedTransducer(nn.Module):
if text_encoder_type in ("BERT", "BERT-UNCASED") if text_encoder_type in ("BERT", "BERT-UNCASED")
else self.text_encoder.config.dim else self.text_encoder.config.dim
) )
self.freeze_text_encoder = freeze_text_encoder
if text_encoder_adapter: if text_encoder_adapter:
self.text_encoder_adapter = nn.Sequential( self.text_encoder_adapter = nn.Sequential(
@ -180,6 +182,8 @@ class PromptedTransducer(nn.Module):
lm_scale * lm_probs + am_scale * am_probs + lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs (1-lm_scale-am_scale) * combined_probs
""" """
if self.freeze_text_encoder:
self.text_encoder.eval()
assert x.ndim == 3, x.shape assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes assert y.num_axes == 2, y.num_axes

View File

@ -849,9 +849,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_transducer_model(params: AttributeDict) -> nn.Module: def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder_embed = get_encoder_embed(params) encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
text_encoder = get_text_encoder(params) # This should be a cased BERT base model text_encoder = get_text_encoder(params) # This should be a cased BERT base model
num_param = sum([p.numel() for p in text_encoder.parameters()]) num_param = sum([p.numel() for p in text_encoder.parameters()])
logging.info(f"Num params in text encoder: {num_param}") logging.info(f"Num params in text encoder: {num_param}")
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
@ -867,6 +869,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
text_encoder_type=params.text_encoder_type, text_encoder_type=params.text_encoder_type,
text_encoder_adapter=params.text_encoder_adapter, text_encoder_adapter=params.text_encoder_adapter,
freeze_text_encoder=params.freeze_text_encoder,
context_fuser=None, context_fuser=None,
) )
@ -1618,14 +1621,15 @@ def run(rank, world_size, args):
valid_cuts, text_sampling_func=naive_triplet_text_sampling valid_cuts, text_sampling_func=naive_triplet_text_sampling
) )
if not params.print_diagnostics: # if not params.print_diagnostics:
scan_pessimistic_batches_for_oom( # scan_pessimistic_batches_for_oom(
model=model, # model=model,
train_dl=train_dl, # train_dl=train_dl,
optimizer=optimizer, # optimizer=optimizer,
sp=sp, # sp=sp,
params=params, # tokenizer=tokenizer,
) # params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
@ -1717,7 +1721,7 @@ def scan_pessimistic_batches_for_oom(
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
tokenizer, tokenizer: spm.SentencePieceProcessor,
params: AttributeDict, params: AttributeDict,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches