mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 07:04:18 +00:00
support deepspeed to finetune large model
This commit is contained in:
parent
92895f774f
commit
b6418acda2
@ -182,7 +182,7 @@ class AishellAsrDataModule:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def train_dataloaders(
|
def train_dataloaders(
|
||||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
|
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, rank = None, world_size = None
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -276,6 +276,8 @@ class AishellAsrDataModule:
|
|||||||
shuffle=self.args.shuffle,
|
shuffle=self.args.shuffle,
|
||||||
num_buckets=self.args.num_buckets,
|
num_buckets=self.args.num_buckets,
|
||||||
drop_last=self.args.drop_last,
|
drop_last=self.args.drop_last,
|
||||||
|
world_size=world_size,
|
||||||
|
rank=rank,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using SimpleCutSampler.")
|
logging.info("Using SimpleCutSampler.")
|
||||||
@ -300,7 +302,7 @@ class AishellAsrDataModule:
|
|||||||
|
|
||||||
return train_dl
|
return train_dl
|
||||||
|
|
||||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
def valid_dataloaders(self, cuts_valid: CutSet, rank = None, world_size = None) -> DataLoader:
|
||||||
transforms = []
|
transforms = []
|
||||||
if self.args.concatenate_cuts:
|
if self.args.concatenate_cuts:
|
||||||
transforms = [
|
transforms = [
|
||||||
@ -325,6 +327,8 @@ class AishellAsrDataModule:
|
|||||||
cuts_valid,
|
cuts_valid,
|
||||||
max_duration=self.args.max_duration,
|
max_duration=self.args.max_duration,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
logging.info("About to create dev dataloader")
|
logging.info("About to create dev dataloader")
|
||||||
valid_dl = DataLoader(
|
valid_dl = DataLoader(
|
||||||
|
@ -109,20 +109,17 @@ def get_parser():
|
|||||||
default="beam-search",
|
default="beam-search",
|
||||||
help="""Decoding method.
|
help="""Decoding method.
|
||||||
Supported values are:
|
Supported values are:
|
||||||
- (0) ctc-decoding. Use CTC decoding. It maps the tokens ids to
|
- beam-search
|
||||||
tokens using token symbol tabel directly.
|
|
||||||
- (1) 1best. Extract the best path from the decoding lattice as the
|
|
||||||
decoding result.
|
|
||||||
- (2) nbest. Extract n paths from the decoding lattice; the path
|
|
||||||
with the highest score is the decoding result.
|
|
||||||
- (3) attention-decoder. Extract n paths from the lattice,
|
|
||||||
the path with the highest score is the decoding result.
|
|
||||||
- (4) nbest-oracle. Its WER is the lower bound of any n-best
|
|
||||||
rescoring method can achieve. Useful for debugging n-best
|
|
||||||
rescoring method.
|
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--beam-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="beam size for beam search decoding",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@ -357,10 +354,9 @@ def main():
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
|
setup_logger(f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}")
|
||||||
|
|
||||||
#options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=10)
|
options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=params.beam_size)
|
||||||
options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=None)
|
|
||||||
params.decoding_options = options
|
params.decoding_options = options
|
||||||
params.cleaner = BasicTextNormalizer()
|
params.cleaner = BasicTextNormalizer()
|
||||||
params.normalizer = Normalizer()
|
params.normalizer = Normalizer()
|
||||||
|
32
egs/aishell/ASR/whisper/ds_config_zero1.json
Normal file
32
egs/aishell/ASR/whisper/ds_config_zero1.json
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
{
|
||||||
|
"fp16": {
|
||||||
|
"enabled": true,
|
||||||
|
"loss_scale": 0,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"initial_scale_power": 16,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 1,
|
||||||
|
"allgather_partitions": true,
|
||||||
|
"allgather_bucket_size": 2e8,
|
||||||
|
"overlap_comm": true,
|
||||||
|
"reduce_scatter": true,
|
||||||
|
"reduce_bucket_size": 2e8,
|
||||||
|
"contiguous_gradients": true
|
||||||
|
},
|
||||||
|
"scheduler": {
|
||||||
|
"type": "WarmupLR",
|
||||||
|
"params": {
|
||||||
|
"warmup_min_lr": 5e-6,
|
||||||
|
"warmup_max_lr": 1e-5,
|
||||||
|
"warmup_num_steps": 100
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"gradient_clipping": 5,
|
||||||
|
"steps_per_print": 50,
|
||||||
|
"train_micro_batch_size_per_gpu": 1,
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
@ -372,7 +372,7 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
|||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
name: str,
|
name: str,
|
||||||
device: Optional[Union[str, torch.device]] = None,
|
device: Optional[Union[str, torch.device]] = 'cpu',
|
||||||
download_root: str = None,
|
download_root: str = None,
|
||||||
in_memory: bool = False,
|
in_memory: bool = False,
|
||||||
) -> Whisper:
|
) -> Whisper:
|
||||||
@ -397,8 +397,8 @@ def load_model(
|
|||||||
The Whisper ASR model instance
|
The Whisper ASR model instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if device is None:
|
# if device is None:
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
if download_root is None:
|
if download_root is None:
|
||||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||||
|
@ -8,3 +8,4 @@ librosa
|
|||||||
openai-whisper
|
openai-whisper
|
||||||
zhconv
|
zhconv
|
||||||
WeTextProcessing
|
WeTextProcessing
|
||||||
|
deepspeed
|
||||||
|
@ -42,6 +42,8 @@ import warnings
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
import deepspeed
|
||||||
|
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import optim
|
import optim
|
||||||
@ -102,15 +104,6 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
|||||||
if hasattr(module, "batch_count"):
|
if hasattr(module, "batch_count"):
|
||||||
module.batch_count = batch_count
|
module.batch_count = batch_count
|
||||||
|
|
||||||
|
|
||||||
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
|
|
||||||
parser.add_argument(
|
|
||||||
"--deepspeed-config",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to deepspeed json config file.",
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -251,7 +244,7 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_deepspeed_arguments(parser)
|
parser = deepspeed.add_config_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@ -495,7 +488,6 @@ def compute_loss(
|
|||||||
feature = feature.transpose(1, 2) # (N, C, T)
|
feature = feature.transpose(1, 2) # (N, C, T)
|
||||||
# pad feature from B,80,T to B,80,3000
|
# pad feature from B,80,T to B,80,3000
|
||||||
#feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1]))
|
#feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1]))
|
||||||
#print(feature.shape, 23333333)
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
@ -629,24 +621,25 @@ def train_one_epoch(
|
|||||||
for batch_idx, batch in enumerate(train_dl):
|
for batch_idx, batch in enumerate(train_dl):
|
||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||||
logging.info("Computing validation loss")
|
# logging.info("Computing validation loss")
|
||||||
valid_info = compute_validation_loss(
|
# valid_info = compute_validation_loss(
|
||||||
params=params,
|
# params=params,
|
||||||
tokenizer=tokenizer,
|
# tokenizer=tokenizer,
|
||||||
model=model,
|
# model=model,
|
||||||
valid_dl=valid_dl,
|
# valid_dl=valid_dl,
|
||||||
world_size=world_size,
|
# world_size=world_size,
|
||||||
)
|
# )
|
||||||
model.train()
|
# model.train()
|
||||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
# logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||||
logging.info(
|
# logging.info(
|
||||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||||
)
|
# )
|
||||||
if tb_writer is not None:
|
# if tb_writer is not None:
|
||||||
valid_info.write_summary(
|
# valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
# tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
# )
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
@ -661,13 +654,20 @@ def train_one_epoch(
|
|||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
scaler.scale(loss).backward()
|
if params.deepspeed:
|
||||||
set_batch_count(model, params.batch_idx_train)
|
# deepspeed's backward() is different from torch's backward()
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
# in that it does not accept a loss tensor as input.
|
||||||
|
# It computes the loss internally.
|
||||||
|
model.backward(loss)
|
||||||
|
model.step()
|
||||||
|
else:
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
set_batch_count(model, params.batch_idx_train)
|
||||||
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except: # noqa
|
except: # noqa
|
||||||
display_and_save_batch(batch, params=params)
|
display_and_save_batch(batch, params=params)
|
||||||
raise
|
raise
|
||||||
@ -679,6 +679,7 @@ def train_one_epoch(
|
|||||||
rank == 0
|
rank == 0
|
||||||
and params.batch_idx_train > 0
|
and params.batch_idx_train > 0
|
||||||
and params.batch_idx_train % params.average_period == 0
|
and params.batch_idx_train % params.average_period == 0
|
||||||
|
and not params.deepspeed
|
||||||
):
|
):
|
||||||
update_averaged_model(
|
update_averaged_model(
|
||||||
params=params,
|
params=params,
|
||||||
@ -686,29 +687,28 @@ def train_one_epoch(
|
|||||||
model_avg=model_avg,
|
model_avg=model_avg,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
# if (
|
||||||
params.batch_idx_train > 0
|
# params.batch_idx_train > 0
|
||||||
and params.batch_idx_train % params.save_every_n == 0
|
# and params.batch_idx_train % params.save_every_n == 0
|
||||||
):
|
# ):
|
||||||
save_checkpoint_with_global_batch_idx(
|
# save_checkpoint_with_global_batch_idx(
|
||||||
out_dir=params.exp_dir,
|
# out_dir=params.exp_dir,
|
||||||
global_batch_idx=params.batch_idx_train,
|
# global_batch_idx=params.batch_idx_train,
|
||||||
model=model,
|
# model=model,
|
||||||
model_avg=model_avg,
|
# model_avg=model_avg,
|
||||||
params=params,
|
# params=params,
|
||||||
optimizer=optimizer,
|
# optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
# scheduler=scheduler,
|
||||||
sampler=train_dl.sampler,
|
# sampler=train_dl.sampler,
|
||||||
scaler=scaler,
|
# scaler=scaler,
|
||||||
rank=rank,
|
# rank=rank,
|
||||||
)
|
# )
|
||||||
remove_checkpoints(
|
# remove_checkpoints(
|
||||||
out_dir=params.exp_dir,
|
# out_dir=params.exp_dir,
|
||||||
topk=params.keep_last_k,
|
# topk=params.keep_last_k,
|
||||||
rank=rank,
|
# rank=rank,
|
||||||
)
|
# )
|
||||||
|
if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed:
|
||||||
if batch_idx % 100 == 0 and params.use_fp16:
|
|
||||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
# behavior depending on the current grad scale.
|
# behavior depending on the current grad scale.
|
||||||
@ -723,14 +723,14 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {batch_idx}, loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
f"lr: {cur_lr:.2e}, "
|
f"lr: {cur_lr:.2e}, "
|
||||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
+ (f"grad_scale: {scaler._scale.item()}" if (params.use_fp16 and not params.deepspeed) else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -774,37 +774,21 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
fix_random_seed(params.seed)
|
fix_random_seed(params.seed)
|
||||||
|
|
||||||
setup_dist(use_ddp_launch=True)
|
|
||||||
|
|
||||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||||
logging.info("Training started")
|
logging.info(params)
|
||||||
|
|
||||||
if args.tensorboard and rank == 0:
|
|
||||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
|
||||||
else:
|
|
||||||
tb_writer = None
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda", rank)
|
|
||||||
logging.info(f"Device: {device}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
#model = whisper.load_model("medium")
|
|
||||||
# TODO download model only on rank 0
|
# TODO download model only on rank 0
|
||||||
# TODO may change compute validation loss using multiple cards
|
# TODO may change compute validation loss using multiple cards
|
||||||
model = load_model("medium")
|
# model = load_model("medium")
|
||||||
|
model = load_model("large-v2")
|
||||||
del model.alignment_heads
|
del model.alignment_heads
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||||
model.is_multilingual, language="zh", task="transcribe"
|
model.is_multilingual, language="zh", task="transcribe"
|
||||||
)
|
)
|
||||||
logging.info(params)
|
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
|
||||||
|
|
||||||
assert params.save_every_n >= params.average_period
|
assert params.save_every_n >= params.average_period
|
||||||
model_avg: Optional[nn.Module] = None
|
model_avg: Optional[nn.Module] = None
|
||||||
@ -817,10 +801,12 @@ def run(rank, world_size, args):
|
|||||||
params=params, model=model, model_avg=model_avg
|
params=params, model=model, model_avg=model_avg
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", rank)
|
||||||
|
else:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
|
||||||
logging.info("Using DDP")
|
|
||||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
|
||||||
|
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
|
||||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||||
@ -837,6 +823,17 @@ def run(rank, world_size, args):
|
|||||||
logging.info("Loading scheduler state dict")
|
logging.info("Loading scheduler state dict")
|
||||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
if params.deepspeed:
|
||||||
|
logging.info("Using DeepSpeed")
|
||||||
|
model, optimizer, _, _ = deepspeed.initialize(
|
||||||
|
args=params, model=model, optimizer=optimizer,
|
||||||
|
model_parameters=model.parameters())
|
||||||
|
else:
|
||||||
|
logging.info("Using DDP")
|
||||||
|
setup_dist(use_ddp_launch=True)
|
||||||
|
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2**22
|
2**22
|
||||||
@ -846,51 +843,8 @@ def run(rank, world_size, args):
|
|||||||
if params.inf_check:
|
if params.inf_check:
|
||||||
register_inf_check_hooks(model)
|
register_inf_check_hooks(model)
|
||||||
|
|
||||||
def remove_short_and_long_utt(c: Cut):
|
|
||||||
# Keep only utterances with duration between 1 second and 20 seconds
|
|
||||||
#
|
|
||||||
# Caution: There is a reason to select 20.0 here. Please see
|
|
||||||
# ../local/display_manifest_statistics.py
|
|
||||||
#
|
|
||||||
# You should use ../local/display_manifest_statistics.py to get
|
|
||||||
# an utterance duration distribution for your dataset to select
|
|
||||||
# the threshold
|
|
||||||
if c.duration < 1.0 or c.duration > 12.0:
|
|
||||||
logging.warning(
|
|
||||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
|
|
||||||
# In pruned RNN-T, we require that T >= S
|
|
||||||
# where T is the number of feature frames after subsampling
|
|
||||||
# and S is the number of tokens in the utterance
|
|
||||||
|
|
||||||
# In ./zipformer.py, the conv module uses the following expression
|
|
||||||
# for subsampling
|
|
||||||
# T = ((c.num_frames - 7) // 2 + 1) // 2
|
|
||||||
# tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
|
||||||
|
|
||||||
# if T < len(tokens):
|
|
||||||
# logging.warning(
|
|
||||||
# f"Exclude cut with ID {c.id} from training. "
|
|
||||||
# f"Number of frames (before subsampling): {c.num_frames}. "
|
|
||||||
# f"Number of frames (after subsampling): {T}. "
|
|
||||||
# f"Text: {c.supervisions[0].text}. "
|
|
||||||
# f"Tokens: {tokens}. "
|
|
||||||
# f"Number of tokens: {len(tokens)}"
|
|
||||||
# )
|
|
||||||
# return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
aishell = AishellAsrDataModule(args)
|
aishell = AishellAsrDataModule(args)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||||
# We only load the sampler's state dict when it loads a checkpoint
|
# We only load the sampler's state dict when it loads a checkpoint
|
||||||
# saved in the middle of an epoch
|
# saved in the middle of an epoch
|
||||||
@ -899,22 +853,19 @@ def run(rank, world_size, args):
|
|||||||
sampler_state_dict = None
|
sampler_state_dict = None
|
||||||
|
|
||||||
|
|
||||||
train_dl = aishell.train_dataloaders(aishell.train_cuts())
|
train_dl = aishell.train_dataloaders(aishell.train_cuts(), rank=rank, world_size=world_size)
|
||||||
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts(), rank=rank, world_size=world_size)
|
||||||
# if not params.print_diagnostics:
|
|
||||||
# scan_pessimistic_batches_for_oom(
|
|
||||||
# model=model,
|
|
||||||
# train_dl=train_dl,
|
|
||||||
# optimizer=optimizer,
|
|
||||||
# graph_compiler=graph_compiler,
|
|
||||||
# params=params,
|
|
||||||
# )
|
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
|
|
||||||
|
if args.tensorboard and rank == 0:
|
||||||
|
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||||
|
else:
|
||||||
|
tb_writer = None
|
||||||
|
|
||||||
logging.info(f"start training from epoch {params.start_epoch}")
|
logging.info(f"start training from epoch {params.start_epoch}")
|
||||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||||
scheduler.step_epoch(epoch - 1)
|
scheduler.step_epoch(epoch - 1)
|
||||||
@ -945,20 +896,28 @@ def run(rank, world_size, args):
|
|||||||
diagnostic.print_diagnostics()
|
diagnostic.print_diagnostics()
|
||||||
break
|
break
|
||||||
|
|
||||||
save_checkpoint(
|
if params.deepspeed:
|
||||||
params=params,
|
model.save_checkpoint(save_dir=params.exp_dir,
|
||||||
model=model,
|
tag=f"epoch-{params.cur_epoch}",
|
||||||
model_avg=model_avg,
|
client_state={})
|
||||||
optimizer=optimizer,
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
scheduler=scheduler,
|
params.exp_dir, f"epoch-{params.cur_epoch}.pt",
|
||||||
sampler=train_dl.sampler,
|
tag=f"epoch-{params.cur_epoch}")
|
||||||
scaler=scaler,
|
else:
|
||||||
rank=rank,
|
save_checkpoint(
|
||||||
)
|
params=params,
|
||||||
|
model=model,
|
||||||
|
model_avg=model_avg,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
sampler=train_dl.sampler,
|
||||||
|
scaler=scaler,
|
||||||
|
rank=rank,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("Done!")
|
logging.info("Done!")
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1 and not params.deepspeed:
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
cleanup_dist()
|
cleanup_dist()
|
||||||
|
|
||||||
@ -988,48 +947,6 @@ def display_and_save_batch(
|
|||||||
logging.info(f"features shape: {features.shape}")
|
logging.info(f"features shape: {features.shape}")
|
||||||
|
|
||||||
|
|
||||||
# def scan_pessimistic_batches_for_oom(
|
|
||||||
# model: Union[nn.Module, DDP],
|
|
||||||
# tokenizer: whisper.tokenizer.Tokenizer,
|
|
||||||
# train_dl: torch.utils.data.DataLoader,
|
|
||||||
# optimizer: torch.optim.Optimizer,
|
|
||||||
# params: AttributeDict,
|
|
||||||
# ):
|
|
||||||
# from lhotse.dataset import find_pessimistic_batches
|
|
||||||
|
|
||||||
# logging.info(
|
|
||||||
# "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
|
||||||
# )
|
|
||||||
# batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
|
||||||
# for criterion, cuts in batches.items():
|
|
||||||
# batch = train_dl.dataset[cuts]
|
|
||||||
# try:
|
|
||||||
# with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
|
||||||
# loss, _ = compute_loss(
|
|
||||||
# params=params,
|
|
||||||
# tokenizer=tokenizer,
|
|
||||||
# model=model,
|
|
||||||
# batch=batch,
|
|
||||||
# is_training=True,
|
|
||||||
# )
|
|
||||||
# loss.backward()
|
|
||||||
# optimizer.zero_grad()
|
|
||||||
# except Exception as e:
|
|
||||||
# if "CUDA out of memory" in str(e):
|
|
||||||
# logging.error(
|
|
||||||
# "Your GPU ran out of memory with the current "
|
|
||||||
# "max_duration setting. We recommend decreasing "
|
|
||||||
# "max_duration and trying again.\n"
|
|
||||||
# f"Failing criterion: {criterion} "
|
|
||||||
# f"(={crit_values[criterion]}) ..."
|
|
||||||
# )
|
|
||||||
# display_and_save_batch(batch, params=params)
|
|
||||||
# raise
|
|
||||||
# logging.info(
|
|
||||||
# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
AishellAsrDataModule.add_arguments(parser)
|
AishellAsrDataModule.add_arguments(parser)
|
||||||
@ -1038,13 +955,10 @@ def main():
|
|||||||
|
|
||||||
world_size = get_world_size()
|
world_size = get_world_size()
|
||||||
rank = get_rank()
|
rank = get_rank()
|
||||||
assert world_size >= 1
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
run(rank=rank, world_size=world_size, args=args)
|
run(rank=rank, world_size=world_size, args=args)
|
||||||
|
|
||||||
|
|
||||||
torch.set_num_threads(1)
|
|
||||||
torch.set_num_interop_threads(1)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user