mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
update train.py
This commit is contained in:
parent
711859c21f
commit
ebc0f3b052
@ -23,7 +23,8 @@ torchrun --nproc_per_node 8 ./whisper/train.py \
|
|||||||
--max-duration 200 \
|
--max-duration 200 \
|
||||||
--exp-dir whisper/exp_large_v2 \
|
--exp-dir whisper/exp_large_v2 \
|
||||||
--model-name large-v2 \
|
--model-name large-v2 \
|
||||||
--manifest-dir data/fbank_whisper \
|
--full-libri True \
|
||||||
|
--manifest-dir data/fbank_whisper_80D \
|
||||||
--deepspeed \
|
--deepspeed \
|
||||||
--deepspeed_config ./whisper/ds_config_zero1.json
|
--deepspeed_config ./whisper/ds_config_zero1.json
|
||||||
|
|
||||||
@ -31,7 +32,8 @@ torchrun --nproc_per_node 8 ./whisper/train.py \
|
|||||||
torchrun --nproc_per_node 8 ./whisper/train.py \
|
torchrun --nproc_per_node 8 ./whisper/train.py \
|
||||||
--max-duration 200 \
|
--max-duration 200 \
|
||||||
--exp-dir whisper/exp_medium \
|
--exp-dir whisper/exp_medium \
|
||||||
--manifest-dir data/fbank_whisper \
|
--full-libri True \
|
||||||
|
--manifest-dir data/fbank_whisper_80D \
|
||||||
--base-lr 1e-5 \
|
--base-lr 1e-5 \
|
||||||
--model-name medium
|
--model-name medium
|
||||||
"""
|
"""
|
||||||
@ -53,7 +55,7 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import whisper
|
import whisper
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||||
from label_smoothing import LabelSmoothingLoss
|
from label_smoothing import LabelSmoothingLoss
|
||||||
from lhotse import CutSet, load_manifest
|
from lhotse import CutSet, load_manifest
|
||||||
@ -147,7 +149,7 @@ def get_parser():
|
|||||||
"--model-name",
|
"--model-name",
|
||||||
type=str,
|
type=str,
|
||||||
default="large-v2",
|
default="large-v2",
|
||||||
choices=["large-v2", "large-v3", "medium", "small", "tiny"],
|
choices=["large-v2", "large-v3", "medium", "medium.en", "small", "small.en", "tiny", "tiny.en"],
|
||||||
help="""The model name to use.
|
help="""The model name to use.
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -450,8 +452,7 @@ def compute_loss(
|
|||||||
batch_idx_train = params.batch_idx_train
|
batch_idx_train = params.batch_idx_train
|
||||||
|
|
||||||
texts = batch["supervisions"]["text"]
|
texts = batch["supervisions"]["text"]
|
||||||
# remove spaces in texts
|
texts = [t[0] + t[1:].lower() for t in texts]
|
||||||
texts = [text.replace(" ", "") for text in texts]
|
|
||||||
|
|
||||||
text_tokens_list = [
|
text_tokens_list = [
|
||||||
list(tokenizer.sot_sequence_including_notimestamps)
|
list(tokenizer.sot_sequence_including_notimestamps)
|
||||||
@ -744,7 +745,7 @@ def run(rank, world_size, args):
|
|||||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||||
model.is_multilingual,
|
model.is_multilingual,
|
||||||
num_languages=model.num_languages,
|
num_languages=model.num_languages,
|
||||||
language="zh",
|
language="en",
|
||||||
task="transcribe",
|
task="transcribe",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -800,7 +801,19 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
aishell = AishellAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
|
if params.full_libri:
|
||||||
|
train_cuts = librispeech.train_all_shuf_cuts()
|
||||||
|
else:
|
||||||
|
train_cuts = librispeech.train_clean_100_cuts()
|
||||||
|
|
||||||
|
def remove_short_and_long_utt(c: Cut):
|
||||||
|
if c.duration < 1.0 or c.duration > 20.0:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
train_cuts = train_cuts.filter(remove_short_and_long_utt)
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
# We only load the sampler's state dict when it loads a checkpoint
|
# We only load the sampler's state dict when it loads a checkpoint
|
||||||
@ -809,8 +822,16 @@ def run(rank, world_size, args):
|
|||||||
else:
|
else:
|
||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
|
|
||||||
train_dl = aishell.train_dataloaders(aishell.train_cuts())
|
train_dl = librispeech.train_dataloaders(
|
||||||
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
train_cuts, sampler_state_dict=sampler_state_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_cuts = librispeech.dev_clean_cuts()
|
||||||
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
|
|
||||||
|
# do this to prevent Whisper throwing the length mismatch error
|
||||||
|
valid_cuts = valid_cuts.filter(remove_short_and_long_utt)
|
||||||
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
@ -911,7 +932,7 @@ def display_and_save_batch(
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
AishellAsrDataModule.add_arguments(parser)
|
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user