update train.py

This commit is contained in:
marcoyang 2024-03-28 16:16:18 +08:00
parent 711859c21f
commit ebc0f3b052

View File

@ -23,7 +23,8 @@ torchrun --nproc_per_node 8 ./whisper/train.py \
--max-duration 200 \ --max-duration 200 \
--exp-dir whisper/exp_large_v2 \ --exp-dir whisper/exp_large_v2 \
--model-name large-v2 \ --model-name large-v2 \
--manifest-dir data/fbank_whisper \ --full-libri True \
--manifest-dir data/fbank_whisper_80D \
--deepspeed \ --deepspeed \
--deepspeed_config ./whisper/ds_config_zero1.json --deepspeed_config ./whisper/ds_config_zero1.json
@ -31,7 +32,8 @@ torchrun --nproc_per_node 8 ./whisper/train.py \
torchrun --nproc_per_node 8 ./whisper/train.py \ torchrun --nproc_per_node 8 ./whisper/train.py \
--max-duration 200 \ --max-duration 200 \
--exp-dir whisper/exp_medium \ --exp-dir whisper/exp_medium \
--manifest-dir data/fbank_whisper \ --full-libri True \
--manifest-dir data/fbank_whisper_80D \
--base-lr 1e-5 \ --base-lr 1e-5 \
--model-name medium --model-name medium
""" """
@ -53,7 +55,7 @@ import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import whisper import whisper
from asr_datamodule import AishellAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest from lhotse import CutSet, load_manifest
@ -147,7 +149,7 @@ def get_parser():
"--model-name", "--model-name",
type=str, type=str,
default="large-v2", default="large-v2",
choices=["large-v2", "large-v3", "medium", "small", "tiny"], choices=["large-v2", "large-v3", "medium", "medium.en", "small", "small.en", "tiny", "tiny.en"],
help="""The model name to use. help="""The model name to use.
""", """,
) )
@ -450,8 +452,7 @@ def compute_loss(
batch_idx_train = params.batch_idx_train batch_idx_train = params.batch_idx_train
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
# remove spaces in texts texts = [t[0] + t[1:].lower() for t in texts]
texts = [text.replace(" ", "") for text in texts]
text_tokens_list = [ text_tokens_list = [
list(tokenizer.sot_sequence_including_notimestamps) list(tokenizer.sot_sequence_including_notimestamps)
@ -744,7 +745,7 @@ def run(rank, world_size, args):
tokenizer = whisper.tokenizer.get_tokenizer( tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual, model.is_multilingual,
num_languages=model.num_languages, num_languages=model.num_languages,
language="zh", language="en",
task="transcribe", task="transcribe",
) )
@ -800,7 +801,19 @@ def run(rank, world_size, args):
if params.inf_check: if params.inf_check:
register_inf_check_hooks(model) register_inf_check_hooks(model)
aishell = AishellAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
if params.full_libri:
train_cuts = librispeech.train_all_shuf_cuts()
else:
train_cuts = librispeech.train_clean_100_cuts()
def remove_short_and_long_utt(c: Cut):
if c.duration < 1.0 or c.duration > 20.0:
return False
return True
train_cuts = train_cuts.filter(remove_short_and_long_utt)
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint # We only load the sampler's state dict when it loads a checkpoint
@ -809,8 +822,16 @@ def run(rank, world_size, args):
else: else:
sampler_state_dict = None sampler_state_dict = None
train_dl = aishell.train_dataloaders(aishell.train_cuts()) train_dl = librispeech.train_dataloaders(
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
# do this to prevent Whisper throwing the length mismatch error
valid_cuts = valid_cuts.filter(remove_short_and_long_utt)
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
@ -911,7 +932,7 @@ def display_and_save_batch(
def main(): def main():
parser = get_parser() parser = get_parser()
AishellAsrDataModule.add_arguments(parser) LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)