mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
support fine-tuning mono-lingual whisper model; add ScaledAdam as an option
This commit is contained in:
parent
f208431f5c
commit
6b2bd0fb52
@ -80,6 +80,7 @@ from icefall.hooks import register_inf_check_hooks
|
|||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
|
get_parameter_groups_with_lrs,
|
||||||
filter_uneven_sized_batch,
|
filter_uneven_sized_batch,
|
||||||
setup_logger,
|
setup_logger,
|
||||||
str2bool,
|
str2bool,
|
||||||
@ -145,6 +146,14 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--optimizer",
|
||||||
|
type=str,
|
||||||
|
default="adam",
|
||||||
|
choices=["scaledadam", "adam"],
|
||||||
|
help="Which optimizer to use."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--base-lr", type=float, default=1e-5, help="The base learning rate."
|
"--base-lr", type=float, default=1e-5, help="The base learning rate."
|
||||||
)
|
)
|
||||||
@ -463,23 +472,33 @@ def compute_loss(
|
|||||||
torch.LongTensor(text_tokens) for text_tokens in text_tokens_list
|
torch.LongTensor(text_tokens) for text_tokens in text_tokens_list
|
||||||
]
|
]
|
||||||
|
|
||||||
# 50256 is the index of <pad> for all whisper models
|
if params.is_multilingual:
|
||||||
|
# 50256 is the index of <pad> for multi-lingual whisper models
|
||||||
|
pad_idx = 50256
|
||||||
|
else:
|
||||||
|
# choose a symbol that is not used in en-whisper model as padding symbol
|
||||||
|
pad_idx = 50363
|
||||||
|
|
||||||
|
assert tokenizer.eot != pad_idx, "EOT symbol should be different from pad symbol"
|
||||||
|
|
||||||
prev_outputs_tokens = _batch_tensors(
|
prev_outputs_tokens = _batch_tensors(
|
||||||
[tokens[:-1] for tokens in text_tokens_list], pad_value=50256
|
[tokens[:-1] for tokens in text_tokens_list], pad_value=pad_idx
|
||||||
)
|
)
|
||||||
target_tokens = _batch_tensors(
|
target_tokens = _batch_tensors(
|
||||||
[tokens[1:] for tokens in text_tokens_list], pad_value=50256
|
[tokens[1:] for tokens in text_tokens_list], pad_value=pad_idx
|
||||||
)
|
)
|
||||||
target_lengths = torch.LongTensor(
|
target_lengths = torch.LongTensor(
|
||||||
[tokens.shape[0] - 1 for tokens in text_tokens_list]
|
[tokens.shape[0] - 1 for tokens in text_tokens_list]
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder_criterion = LabelSmoothingLoss(
|
decoder_criterion = LabelSmoothingLoss(
|
||||||
ignore_index=50256, label_smoothing=0.1, reduction="sum"
|
ignore_index=pad_idx, label_smoothing=0.1, reduction="sum"
|
||||||
)
|
)
|
||||||
|
|
||||||
# ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|>
|
# ignore the prefix tokens, which are:
|
||||||
ignore_prefix_size = 3
|
# 1. Multi-lingual model: <|startoftranscript|>, <|lang_id|>, <|transcibe|>, <|notimestampes|>
|
||||||
|
# 2. Mono-lingual model: <|startoftranscript|>, <|notimestampes|>
|
||||||
|
ignore_prefix_size = len(tokenizer.sot_sequence_including_notimestamps) - 1
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
encoder_out = model.encoder(feature)
|
encoder_out = model.encoder(feature)
|
||||||
text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
|
text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
|
||||||
@ -581,6 +600,7 @@ def train_one_epoch(
|
|||||||
be set to 0.
|
be set to 0.
|
||||||
"""
|
"""
|
||||||
model.train()
|
model.train()
|
||||||
|
if params.freeze_modules is not None:
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if name.startswith(params.freeze_modules):
|
if name.startswith(params.freeze_modules):
|
||||||
module.eval()
|
module.eval()
|
||||||
@ -753,6 +773,7 @@ def run(rank, world_size, args):
|
|||||||
num_trainable = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()])
|
num_trainable = sum([p.numel() if p.requires_grad else 0 for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}. Total trainable parameters: {num_trainable}")
|
logging.info(f"Number of model parameters: {num_param}. Total trainable parameters: {num_trainable}")
|
||||||
|
|
||||||
|
params.is_multilingual = model.is_multilingual
|
||||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||||
model.is_multilingual,
|
model.is_multilingual,
|
||||||
num_languages=model.num_languages,
|
num_languages=model.num_languages,
|
||||||
@ -777,7 +798,14 @@ def run(rank, world_size, args):
|
|||||||
logging.info(f"Device: {device}")
|
logging.info(f"Device: {device}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
|
if params.optimizer == "adam":
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
|
||||||
|
else:
|
||||||
|
optimizer = ScaledAdam(
|
||||||
|
get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True),
|
||||||
|
lr=params.base_lr, # should have no effect
|
||||||
|
clipping_scale=2.0,
|
||||||
|
)
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
|
|
||||||
if checkpoints and "optimizer" in checkpoints:
|
if checkpoints and "optimizer" in checkpoints:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user