mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
update train.py
This commit is contained in:
parent
711859c21f
commit
ebc0f3b052
@ -23,7 +23,8 @@ torchrun --nproc_per_node 8 ./whisper/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir whisper/exp_large_v2 \
|
||||
--model-name large-v2 \
|
||||
--manifest-dir data/fbank_whisper \
|
||||
--full-libri True \
|
||||
--manifest-dir data/fbank_whisper_80D \
|
||||
--deepspeed \
|
||||
--deepspeed_config ./whisper/ds_config_zero1.json
|
||||
|
||||
@ -31,7 +32,8 @@ torchrun --nproc_per_node 8 ./whisper/train.py \
|
||||
torchrun --nproc_per_node 8 ./whisper/train.py \
|
||||
--max-duration 200 \
|
||||
--exp-dir whisper/exp_medium \
|
||||
--manifest-dir data/fbank_whisper \
|
||||
--full-libri True \
|
||||
--manifest-dir data/fbank_whisper_80D \
|
||||
--base-lr 1e-5 \
|
||||
--model-name medium
|
||||
"""
|
||||
@ -53,7 +55,7 @@ import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from lhotse import CutSet, load_manifest
|
||||
@ -147,7 +149,7 @@ def get_parser():
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="large-v2",
|
||||
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
||||
choices=["large-v2", "large-v3", "medium", "medium.en", "small", "small.en", "tiny", "tiny.en"],
|
||||
help="""The model name to use.
|
||||
""",
|
||||
)
|
||||
@ -450,8 +452,7 @@ def compute_loss(
|
||||
batch_idx_train = params.batch_idx_train
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
# remove spaces in texts
|
||||
texts = [text.replace(" ", "") for text in texts]
|
||||
texts = [t[0] + t[1:].lower() for t in texts]
|
||||
|
||||
text_tokens_list = [
|
||||
list(tokenizer.sot_sequence_including_notimestamps)
|
||||
@ -744,7 +745,7 @@ def run(rank, world_size, args):
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||
model.is_multilingual,
|
||||
num_languages=model.num_languages,
|
||||
language="zh",
|
||||
language="en",
|
||||
task="transcribe",
|
||||
)
|
||||
|
||||
@ -800,7 +801,19 @@ def run(rank, world_size, args):
|
||||
if params.inf_check:
|
||||
register_inf_check_hooks(model)
|
||||
|
||||
aishell = AishellAsrDataModule(args)
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
if params.full_libri:
|
||||
train_cuts = librispeech.train_all_shuf_cuts()
|
||||
else:
|
||||
train_cuts = librispeech.train_clean_100_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
if c.duration < 1.0 or c.duration > 20.0:
|
||||
return False
|
||||
return True
|
||||
|
||||
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
@ -809,8 +822,16 @@ def run(rank, world_size, args):
|
||||
else:
|
||||
sampler_state_dict = None
|
||||
|
||||
train_dl = aishell.train_dataloaders(aishell.train_cuts())
|
||||
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||
train_dl = librispeech.train_dataloaders(
|
||||
train_cuts, sampler_state_dict=sampler_state_dict
|
||||
)
|
||||
|
||||
valid_cuts = librispeech.dev_clean_cuts()
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
|
||||
# do this to prevent Whisper throwing the length mismatch error
|
||||
valid_cuts = valid_cuts.filter(remove_short_and_long_utt)
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
@ -911,7 +932,7 @@ def display_and_save_batch(
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AishellAsrDataModule.add_arguments(parser)
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user