From 6b2bd0fb5234d57edd949359e1326cbe3fda4973 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Fri, 29 Mar 2024 15:29:50 +0800 Subject: [PATCH] support fine-tuning mono-lingual whisper model; add ScaledAdam as an option --- egs/librispeech/ASR/whisper/train.py | 48 ++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 10 deletions(-) diff --git a/egs/librispeech/ASR/whisper/train.py b/egs/librispeech/ASR/whisper/train.py index db6f2e182..40fa921a0 100755 --- a/egs/librispeech/ASR/whisper/train.py +++ b/egs/librispeech/ASR/whisper/train.py @@ -80,6 +80,7 @@ from icefall.hooks import register_inf_check_hooks from icefall.utils import ( AttributeDict, MetricsTracker, + get_parameter_groups_with_lrs, filter_uneven_sized_batch, setup_logger, str2bool, @@ -145,6 +146,14 @@ def get_parser(): """, ) + parser.add_argument( + "--optimizer", + type=str, + default="adam", + choices=["scaledadam", "adam"], + help="Which optimizer to use." + ) + parser.add_argument( "--base-lr", type=float, default=1e-5, help="The base learning rate." ) @@ -463,23 +472,33 @@ def compute_loss( torch.LongTensor(text_tokens) for text_tokens in text_tokens_list ] - # 50256 is the index of for all whisper models + if params.is_multilingual: + # 50256 is the index of for multi-lingual whisper models + pad_idx = 50256 + else: + # choose a symbol that is not used in en-whisper model as padding symbol + pad_idx = 50363 + + assert tokenizer.eot != pad_idx, "EOT symbol should be different from pad symbol" + prev_outputs_tokens = _batch_tensors( - [tokens[:-1] for tokens in text_tokens_list], pad_value=50256 + [tokens[:-1] for tokens in text_tokens_list], pad_value=pad_idx ) target_tokens = _batch_tensors( - [tokens[1:] for tokens in text_tokens_list], pad_value=50256 + [tokens[1:] for tokens in text_tokens_list], pad_value=pad_idx ) target_lengths = torch.LongTensor( [tokens.shape[0] - 1 for tokens in text_tokens_list] ) decoder_criterion = LabelSmoothingLoss( - ignore_index=50256, label_smoothing=0.1, reduction="sum" + ignore_index=pad_idx, label_smoothing=0.1, reduction="sum" ) - # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|> - ignore_prefix_size = 3 + # ignore the prefix tokens, which are: + # 1. Multi-lingual model: <|startoftranscript|>, <|lang_id|>, <|transcibe|>, <|notimestampes|> + # 2. Mono-lingual model: <|startoftranscript|>, <|notimestampes|> + ignore_prefix_size = len(tokenizer.sot_sequence_including_notimestamps) - 1 with torch.set_grad_enabled(is_training): encoder_out = model.encoder(feature) text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) @@ -581,9 +600,10 @@ def train_one_epoch( be set to 0. """ model.train() - for name, module in model.named_modules(): - if name.startswith(params.freeze_modules): - module.eval() + if params.freeze_modules is not None: + for name, module in model.named_modules(): + if name.startswith(params.freeze_modules): + module.eval() tot_loss = MetricsTracker() @@ -753,6 +773,7 @@ def run(rank, world_size, args): num_trainable = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}. Total trainable parameters: {num_trainable}") + params.is_multilingual = model.is_multilingual tokenizer = whisper.tokenizer.get_tokenizer( model.is_multilingual, num_languages=model.num_languages, @@ -777,7 +798,14 @@ def run(rank, world_size, args): logging.info(f"Device: {device}") model.to(device) - optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr) + if params.optimizer == "adam": + optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr) + else: + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) if checkpoints and "optimizer" in checkpoints: