support fine-tuning mono-lingual whisper model; add ScaledAdam as an option

This commit is contained in:
marcoyang 2024-03-29 15:29:50 +08:00
parent f208431f5c
commit 6b2bd0fb52

View File

@ -80,6 +80,7 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
get_parameter_groups_with_lrs,
filter_uneven_sized_batch, filter_uneven_sized_batch,
setup_logger, setup_logger,
str2bool, str2bool,
@ -145,6 +146,14 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--optimizer",
type=str,
default="adam",
choices=["scaledadam", "adam"],
help="Which optimizer to use."
)
parser.add_argument( parser.add_argument(
"--base-lr", type=float, default=1e-5, help="The base learning rate." "--base-lr", type=float, default=1e-5, help="The base learning rate."
) )
@ -463,23 +472,33 @@ def compute_loss(
torch.LongTensor(text_tokens) for text_tokens in text_tokens_list torch.LongTensor(text_tokens) for text_tokens in text_tokens_list
] ]
# 50256 is the index of <pad> for all whisper models if params.is_multilingual:
# 50256 is the index of <pad> for multi-lingual whisper models
pad_idx = 50256
else:
# choose a symbol that is not used in en-whisper model as padding symbol
pad_idx = 50363
assert tokenizer.eot != pad_idx, "EOT symbol should be different from pad symbol"
prev_outputs_tokens = _batch_tensors( prev_outputs_tokens = _batch_tensors(
[tokens[:-1] for tokens in text_tokens_list], pad_value=50256 [tokens[:-1] for tokens in text_tokens_list], pad_value=pad_idx
) )
target_tokens = _batch_tensors( target_tokens = _batch_tensors(
[tokens[1:] for tokens in text_tokens_list], pad_value=50256 [tokens[1:] for tokens in text_tokens_list], pad_value=pad_idx
) )
target_lengths = torch.LongTensor( target_lengths = torch.LongTensor(
[tokens.shape[0] - 1 for tokens in text_tokens_list] [tokens.shape[0] - 1 for tokens in text_tokens_list]
) )
decoder_criterion = LabelSmoothingLoss( decoder_criterion = LabelSmoothingLoss(
ignore_index=50256, label_smoothing=0.1, reduction="sum" ignore_index=pad_idx, label_smoothing=0.1, reduction="sum"
) )
# ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|> # ignore the prefix tokens, which are:
ignore_prefix_size = 3 # 1. Multi-lingual model: <|startoftranscript|>, <|lang_id|>, <|transcibe|>, <|notimestampes|>
# 2. Mono-lingual model: <|startoftranscript|>, <|notimestampes|>
ignore_prefix_size = len(tokenizer.sot_sequence_including_notimestamps) - 1
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
encoder_out = model.encoder(feature) encoder_out = model.encoder(feature)
text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
@ -581,9 +600,10 @@ def train_one_epoch(
be set to 0. be set to 0.
""" """
model.train() model.train()
for name, module in model.named_modules(): if params.freeze_modules is not None:
if name.startswith(params.freeze_modules): for name, module in model.named_modules():
module.eval() if name.startswith(params.freeze_modules):
module.eval()
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
@ -753,6 +773,7 @@ def run(rank, world_size, args):
num_trainable = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()]) num_trainable = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}. Total trainable parameters: {num_trainable}") logging.info(f"Number of model parameters: {num_param}. Total trainable parameters: {num_trainable}")
params.is_multilingual = model.is_multilingual
tokenizer = whisper.tokenizer.get_tokenizer( tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual, model.is_multilingual,
num_languages=model.num_languages, num_languages=model.num_languages,
@ -777,7 +798,14 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}") logging.info(f"Device: {device}")
model.to(device) model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr) if params.optimizer == "adam":
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
else:
optimizer = ScaledAdam(
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
lr=params.base_lr, # should have no effect
clipping_scale=2.0,
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
if checkpoints and "optimizer" in checkpoints: if checkpoints and "optimizer" in checkpoints: