mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
support fine-tuning mono-lingual whisper model; add ScaledAdam as an option
This commit is contained in:
parent
f208431f5c
commit
6b2bd0fb52
@ -80,6 +80,7 @@ from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
get_parameter_groups_with_lrs,
|
||||
filter_uneven_sized_batch,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
@ -145,6 +146,14 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--optimizer",
|
||||
type=str,
|
||||
default="adam",
|
||||
choices=["scaledadam", "adam"],
|
||||
help="Which optimizer to use."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--base-lr", type=float, default=1e-5, help="The base learning rate."
|
||||
)
|
||||
@ -463,23 +472,33 @@ def compute_loss(
|
||||
torch.LongTensor(text_tokens) for text_tokens in text_tokens_list
|
||||
]
|
||||
|
||||
# 50256 is the index of <pad> for all whisper models
|
||||
if params.is_multilingual:
|
||||
# 50256 is the index of <pad> for multi-lingual whisper models
|
||||
pad_idx = 50256
|
||||
else:
|
||||
# choose a symbol that is not used in en-whisper model as padding symbol
|
||||
pad_idx = 50363
|
||||
|
||||
assert tokenizer.eot != pad_idx, "EOT symbol should be different from pad symbol"
|
||||
|
||||
prev_outputs_tokens = _batch_tensors(
|
||||
[tokens[:-1] for tokens in text_tokens_list], pad_value=50256
|
||||
[tokens[:-1] for tokens in text_tokens_list], pad_value=pad_idx
|
||||
)
|
||||
target_tokens = _batch_tensors(
|
||||
[tokens[1:] for tokens in text_tokens_list], pad_value=50256
|
||||
[tokens[1:] for tokens in text_tokens_list], pad_value=pad_idx
|
||||
)
|
||||
target_lengths = torch.LongTensor(
|
||||
[tokens.shape[0] - 1 for tokens in text_tokens_list]
|
||||
)
|
||||
|
||||
decoder_criterion = LabelSmoothingLoss(
|
||||
ignore_index=50256, label_smoothing=0.1, reduction="sum"
|
||||
ignore_index=pad_idx, label_smoothing=0.1, reduction="sum"
|
||||
)
|
||||
|
||||
# ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|>
|
||||
ignore_prefix_size = 3
|
||||
# ignore the prefix tokens, which are:
|
||||
# 1. Multi-lingual model: <|startoftranscript|>, <|lang_id|>, <|transcibe|>, <|notimestampes|>
|
||||
# 2. Mono-lingual model: <|startoftranscript|>, <|notimestampes|>
|
||||
ignore_prefix_size = len(tokenizer.sot_sequence_including_notimestamps) - 1
|
||||
with torch.set_grad_enabled(is_training):
|
||||
encoder_out = model.encoder(feature)
|
||||
text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
|
||||
@ -581,6 +600,7 @@ def train_one_epoch(
|
||||
be set to 0.
|
||||
"""
|
||||
model.train()
|
||||
if params.freeze_modules is not None:
|
||||
for name, module in model.named_modules():
|
||||
if name.startswith(params.freeze_modules):
|
||||
module.eval()
|
||||
@ -753,6 +773,7 @@ def run(rank, world_size, args):
|
||||
num_trainable = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}. Total trainable parameters: {num_trainable}")
|
||||
|
||||
params.is_multilingual = model.is_multilingual
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||
model.is_multilingual,
|
||||
num_languages=model.num_languages,
|
||||
@ -777,7 +798,14 @@ def run(rank, world_size, args):
|
||||
logging.info(f"Device: {device}")
|
||||
model.to(device)
|
||||
|
||||
if params.optimizer == "adam":
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
|
||||
else:
|
||||
optimizer = ScaledAdam(
|
||||
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
|
||||
lr=params.base_lr, # should have no effect
|
||||
clipping_scale=2.0,
|
||||
)
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
|
Loading…
x
Reference in New Issue
Block a user