mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
freeze BERT option
This commit is contained in:
parent
21cc1dfff4
commit
ae3149cb7f
@ -48,6 +48,7 @@ class PromptedTransducer(nn.Module):
|
||||
use_BERT: bool = True,
|
||||
text_encoder_type: str = "BERT",
|
||||
text_encoder_adapter: bool = False,
|
||||
freeze_text_encoder: bool = True,
|
||||
context_fuser: nn.Module = None,
|
||||
):
|
||||
"""
|
||||
@ -112,6 +113,7 @@ class PromptedTransducer(nn.Module):
|
||||
if text_encoder_type in ("BERT", "BERT-UNCASED")
|
||||
else self.text_encoder.config.dim
|
||||
)
|
||||
self.freeze_text_encoder = freeze_text_encoder
|
||||
|
||||
if text_encoder_adapter:
|
||||
self.text_encoder_adapter = nn.Sequential(
|
||||
@ -180,6 +182,8 @@ class PromptedTransducer(nn.Module):
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
if self.freeze_text_encoder:
|
||||
self.text_encoder.eval()
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
@ -849,9 +849,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder_embed = get_encoder_embed(params)
|
||||
encoder = get_encoder_model(params)
|
||||
|
||||
text_encoder = get_text_encoder(params) # This should be a cased BERT base model
|
||||
num_param = sum([p.numel() for p in text_encoder.parameters()])
|
||||
logging.info(f"Num params in text encoder: {num_param}")
|
||||
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
@ -867,6 +869,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
vocab_size=params.vocab_size,
|
||||
text_encoder_type=params.text_encoder_type,
|
||||
text_encoder_adapter=params.text_encoder_adapter,
|
||||
freeze_text_encoder=params.freeze_text_encoder,
|
||||
context_fuser=None,
|
||||
)
|
||||
|
||||
@ -1618,14 +1621,15 @@ def run(rank, world_size, args):
|
||||
valid_cuts, text_sampling_func=naive_triplet_text_sampling
|
||||
)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
)
|
||||
# if not params.print_diagnostics:
|
||||
# scan_pessimistic_batches_for_oom(
|
||||
# model=model,
|
||||
# train_dl=train_dl,
|
||||
# optimizer=optimizer,
|
||||
# sp=sp,
|
||||
# tokenizer=tokenizer,
|
||||
# params=params,
|
||||
# )
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
@ -1717,7 +1721,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
tokenizer,
|
||||
tokenizer: spm.SentencePieceProcessor,
|
||||
params: AttributeDict,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
Loading…
x
Reference in New Issue
Block a user