freeze BERT option

This commit is contained in:
marcoyang 2023-09-21 10:24:14 +08:00
parent 21cc1dfff4
commit ae3149cb7f
2 changed files with 17 additions and 9 deletions

View File

@ -48,6 +48,7 @@ class PromptedTransducer(nn.Module):
use_BERT: bool = True,
text_encoder_type: str = "BERT",
text_encoder_adapter: bool = False,
freeze_text_encoder: bool = True,
context_fuser: nn.Module = None,
):
"""
@ -112,6 +113,7 @@ class PromptedTransducer(nn.Module):
if text_encoder_type in ("BERT", "BERT-UNCASED")
else self.text_encoder.config.dim
)
self.freeze_text_encoder = freeze_text_encoder
if text_encoder_adapter:
self.text_encoder_adapter = nn.Sequential(
@ -180,6 +182,8 @@ class PromptedTransducer(nn.Module):
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
if self.freeze_text_encoder:
self.text_encoder.eval()
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes

View File

@ -849,9 +849,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params)
text_encoder = get_text_encoder(params) # This should be a cased BERT base model
num_param = sum([p.numel() for p in text_encoder.parameters()])
logging.info(f"Num params in text encoder: {num_param}")
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
@ -867,6 +869,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size,
text_encoder_type=params.text_encoder_type,
text_encoder_adapter=params.text_encoder_adapter,
freeze_text_encoder=params.freeze_text_encoder,
context_fuser=None,
)
@ -1618,14 +1621,15 @@ def run(rank, world_size, args):
valid_cuts, text_sampling_func=naive_triplet_text_sampling
)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
# if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom(
# model=model,
# train_dl=train_dl,
# optimizer=optimizer,
# sp=sp,
# tokenizer=tokenizer,
# params=params,
# )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
@ -1717,7 +1721,7 @@ def scan_pessimistic_batches_for_oom(
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor,
tokenizer,
tokenizer: spm.SentencePieceProcessor,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches