support fine-tuning mono-lingual whisper model; add ScaledAdam as an option

This commit is contained in:
marcoyang 2024-03-29 15:29:50 +08:00
parent f208431f5c
commit 6b2bd0fb52

View File

@ -80,6 +80,7 @@ from icefall.hooks import register_inf_check_hooks
from icefall.utils import (
AttributeDict,
MetricsTracker,
get_parameter_groups_with_lrs,
filter_uneven_sized_batch,
setup_logger,
str2bool,
@ -145,6 +146,14 @@ def get_parser():
""",
)
parser.add_argument(
"--optimizer",
type=str,
default="adam",
choices=["scaledadam", "adam"],
help="Which optimizer to use."
)
parser.add_argument(
"--base-lr", type=float, default=1e-5, help="The base learning rate."
)
@ -463,23 +472,33 @@ def compute_loss(
torch.LongTensor(text_tokens) for text_tokens in text_tokens_list
]
# 50256 is the index of <pad> for all whisper models
if params.is_multilingual:
# 50256 is the index of <pad> for multi-lingual whisper models
pad_idx = 50256
else:
# choose a symbol that is not used in en-whisper model as padding symbol
pad_idx = 50363
assert tokenizer.eot != pad_idx, "EOT symbol should be different from pad symbol"
prev_outputs_tokens = _batch_tensors(
[tokens[:-1] for tokens in text_tokens_list], pad_value=50256
[tokens[:-1] for tokens in text_tokens_list], pad_value=pad_idx
)
target_tokens = _batch_tensors(
[tokens[1:] for tokens in text_tokens_list], pad_value=50256
[tokens[1:] for tokens in text_tokens_list], pad_value=pad_idx
)
target_lengths = torch.LongTensor(
[tokens.shape[0] - 1 for tokens in text_tokens_list]
)
decoder_criterion = LabelSmoothingLoss(
ignore_index=50256, label_smoothing=0.1, reduction="sum"
ignore_index=pad_idx, label_smoothing=0.1, reduction="sum"
)
# ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|>
ignore_prefix_size = 3
# ignore the prefix tokens, which are:
# 1. Multi-lingual model: <|startoftranscript|>, <|lang_id|>, <|transcibe|>, <|notimestampes|>
# 2. Mono-lingual model: <|startoftranscript|>, <|notimestampes|>
ignore_prefix_size = len(tokenizer.sot_sequence_including_notimestamps) - 1
with torch.set_grad_enabled(is_training):
encoder_out = model.encoder(feature)
text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
@ -581,6 +600,7 @@ def train_one_epoch(
be set to 0.
"""
model.train()
if params.freeze_modules is not None:
for name, module in model.named_modules():
if name.startswith(params.freeze_modules):
module.eval()
@ -753,6 +773,7 @@ def run(rank, world_size, args):
num_trainable = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}. Total trainable parameters: {num_trainable}")
params.is_multilingual = model.is_multilingual
tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
@ -777,7 +798,14 @@ def run(rank, world_size, args):
logging.info(f"Device: {device}")
model.to(device)
if params.optimizer == "adam":
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
else:
optimizer = ScaledAdam(
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
lr=params.base_lr, # should have no effect
clipping_scale=2.0,
)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
if checkpoints and "optimizer" in checkpoints: