mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
freeze BERT option
This commit is contained in:
parent
21cc1dfff4
commit
ae3149cb7f
@ -48,6 +48,7 @@ class PromptedTransducer(nn.Module):
|
|||||||
use_BERT: bool = True,
|
use_BERT: bool = True,
|
||||||
text_encoder_type: str = "BERT",
|
text_encoder_type: str = "BERT",
|
||||||
text_encoder_adapter: bool = False,
|
text_encoder_adapter: bool = False,
|
||||||
|
freeze_text_encoder: bool = True,
|
||||||
context_fuser: nn.Module = None,
|
context_fuser: nn.Module = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -112,6 +113,7 @@ class PromptedTransducer(nn.Module):
|
|||||||
if text_encoder_type in ("BERT", "BERT-UNCASED")
|
if text_encoder_type in ("BERT", "BERT-UNCASED")
|
||||||
else self.text_encoder.config.dim
|
else self.text_encoder.config.dim
|
||||||
)
|
)
|
||||||
|
self.freeze_text_encoder = freeze_text_encoder
|
||||||
|
|
||||||
if text_encoder_adapter:
|
if text_encoder_adapter:
|
||||||
self.text_encoder_adapter = nn.Sequential(
|
self.text_encoder_adapter = nn.Sequential(
|
||||||
@ -180,6 +182,8 @@ class PromptedTransducer(nn.Module):
|
|||||||
lm_scale * lm_probs + am_scale * am_probs +
|
lm_scale * lm_probs + am_scale * am_probs +
|
||||||
(1-lm_scale-am_scale) * combined_probs
|
(1-lm_scale-am_scale) * combined_probs
|
||||||
"""
|
"""
|
||||||
|
if self.freeze_text_encoder:
|
||||||
|
self.text_encoder.eval()
|
||||||
assert x.ndim == 3, x.shape
|
assert x.ndim == 3, x.shape
|
||||||
assert x_lens.ndim == 1, x_lens.shape
|
assert x_lens.ndim == 1, x_lens.shape
|
||||||
assert y.num_axes == 2, y.num_axes
|
assert y.num_axes == 2, y.num_axes
|
||||||
|
@ -849,9 +849,11 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder_embed = get_encoder_embed(params)
|
encoder_embed = get_encoder_embed(params)
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
|
|
||||||
text_encoder = get_text_encoder(params) # This should be a cased BERT base model
|
text_encoder = get_text_encoder(params) # This should be a cased BERT base model
|
||||||
num_param = sum([p.numel() for p in text_encoder.parameters()])
|
num_param = sum([p.numel() for p in text_encoder.parameters()])
|
||||||
logging.info(f"Num params in text encoder: {num_param}")
|
logging.info(f"Num params in text encoder: {num_param}")
|
||||||
|
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
|
||||||
@ -867,6 +869,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
text_encoder_type=params.text_encoder_type,
|
text_encoder_type=params.text_encoder_type,
|
||||||
text_encoder_adapter=params.text_encoder_adapter,
|
text_encoder_adapter=params.text_encoder_adapter,
|
||||||
|
freeze_text_encoder=params.freeze_text_encoder,
|
||||||
context_fuser=None,
|
context_fuser=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1618,14 +1621,15 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts, text_sampling_func=naive_triplet_text_sampling
|
valid_cuts, text_sampling_func=naive_triplet_text_sampling
|
||||||
)
|
)
|
||||||
|
|
||||||
if not params.print_diagnostics:
|
# if not params.print_diagnostics:
|
||||||
scan_pessimistic_batches_for_oom(
|
# scan_pessimistic_batches_for_oom(
|
||||||
model=model,
|
# model=model,
|
||||||
train_dl=train_dl,
|
# train_dl=train_dl,
|
||||||
optimizer=optimizer,
|
# optimizer=optimizer,
|
||||||
sp=sp,
|
# sp=sp,
|
||||||
params=params,
|
# tokenizer=tokenizer,
|
||||||
)
|
# params=params,
|
||||||
|
# )
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
@ -1717,7 +1721,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
tokenizer,
|
tokenizer: spm.SentencePieceProcessor,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
Loading…
x
Reference in New Issue
Block a user