mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
support deepspeed to finetune large model
This commit is contained in:
parent
92895f774f
commit
b6418acda2
@ -182,7 +182,7 @@ class AishellAsrDataModule:
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
|
||||
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, rank = None, world_size = None
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
@ -276,6 +276,8 @@ class AishellAsrDataModule:
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
@ -300,7 +302,7 @@ class AishellAsrDataModule:
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
def valid_dataloaders(self, cuts_valid: CutSet, rank = None, world_size = None) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
@ -325,6 +327,8 @@ class AishellAsrDataModule:
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
|
@ -109,20 +109,17 @@ def get_parser():
|
||||
default="beam-search",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (0) ctc-decoding. Use CTC decoding. It maps the tokens ids to
|
||||
tokens using token symbol tabel directly.
|
||||
- (1) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
- (2) nbest. Extract n paths from the decoding lattice; the path
|
||||
with the highest score is the decoding result.
|
||||
- (3) attention-decoder. Extract n paths from the lattice,
|
||||
the path with the highest score is the decoding result.
|
||||
- (4) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
- beam-search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="beam size for beam search decoding",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
@ -357,10 +354,9 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}")
|
||||
setup_logger(f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}")
|
||||
|
||||
#options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=10)
|
||||
options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=None)
|
||||
options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=params.beam_size)
|
||||
params.decoding_options = options
|
||||
params.cleaner = BasicTextNormalizer()
|
||||
params.normalizer = Normalizer()
|
||||
|
32
egs/aishell/ASR/whisper/ds_config_zero1.json
Normal file
32
egs/aishell/ASR/whisper/ds_config_zero1.json
Normal file
@ -0,0 +1,32 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": true,
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 1,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 2e8,
|
||||
"overlap_comm": true,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 2e8,
|
||||
"contiguous_gradients": true
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": 5e-6,
|
||||
"warmup_max_lr": 1e-5,
|
||||
"warmup_num_steps": 100
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": 1,
|
||||
"gradient_clipping": 5,
|
||||
"steps_per_print": 50,
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
@ -372,7 +372,7 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
|
||||
def load_model(
|
||||
name: str,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
device: Optional[Union[str, torch.device]] = 'cpu',
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
) -> Whisper:
|
||||
@ -397,8 +397,8 @@ def load_model(
|
||||
The Whisper ASR model instance
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
# if device is None:
|
||||
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if download_root is None:
|
||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||
|
@ -8,3 +8,4 @@ librosa
|
||||
openai-whisper
|
||||
zhconv
|
||||
WeTextProcessing
|
||||
deepspeed
|
||||
|
@ -42,6 +42,8 @@ import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
import deepspeed
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
|
||||
import k2
|
||||
import optim
|
||||
@ -102,15 +104,6 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
if hasattr(module, "batch_count"):
|
||||
module.batch_count = batch_count
|
||||
|
||||
|
||||
def add_deepspeed_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--deepspeed-config",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to deepspeed json config file.",
|
||||
)
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -251,7 +244,7 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
add_deepspeed_arguments(parser)
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
@ -495,7 +488,6 @@ def compute_loss(
|
||||
feature = feature.transpose(1, 2) # (N, C, T)
|
||||
# pad feature from B,80,T to B,80,3000
|
||||
#feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1]))
|
||||
#print(feature.shape, 23333333)
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
@ -629,24 +621,25 @@ def train_one_epoch(
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
# logging.info("Computing validation loss")
|
||||
# valid_info = compute_validation_loss(
|
||||
# params=params,
|
||||
# tokenizer=tokenizer,
|
||||
# model=model,
|
||||
# valid_dl=valid_dl,
|
||||
# world_size=world_size,
|
||||
# )
|
||||
# model.train()
|
||||
# logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
# logging.info(
|
||||
# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
# )
|
||||
# if tb_writer is not None:
|
||||
# valid_info.write_summary(
|
||||
# tb_writer, "train/valid_", params.batch_idx_train
|
||||
# )
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
@ -661,13 +654,20 @@ def train_one_epoch(
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
scaler.scale(loss).backward()
|
||||
set_batch_count(model, params.batch_idx_train)
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
if params.deepspeed:
|
||||
# deepspeed's backward() is different from torch's backward()
|
||||
# in that it does not accept a loss tensor as input.
|
||||
# It computes the loss internally.
|
||||
model.backward(loss)
|
||||
model.step()
|
||||
else:
|
||||
scaler.scale(loss).backward()
|
||||
set_batch_count(model, params.batch_idx_train)
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
except: # noqa
|
||||
display_and_save_batch(batch, params=params)
|
||||
raise
|
||||
@ -679,6 +679,7 @@ def train_one_epoch(
|
||||
rank == 0
|
||||
and params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.average_period == 0
|
||||
and not params.deepspeed
|
||||
):
|
||||
update_averaged_model(
|
||||
params=params,
|
||||
@ -686,29 +687,28 @@ def train_one_epoch(
|
||||
model_avg=model_avg,
|
||||
)
|
||||
|
||||
if (
|
||||
params.batch_idx_train > 0
|
||||
and params.batch_idx_train % params.save_every_n == 0
|
||||
):
|
||||
save_checkpoint_with_global_batch_idx(
|
||||
out_dir=params.exp_dir,
|
||||
global_batch_idx=params.batch_idx_train,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
remove_checkpoints(
|
||||
out_dir=params.exp_dir,
|
||||
topk=params.keep_last_k,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if batch_idx % 100 == 0 and params.use_fp16:
|
||||
# if (
|
||||
# params.batch_idx_train > 0
|
||||
# and params.batch_idx_train % params.save_every_n == 0
|
||||
# ):
|
||||
# save_checkpoint_with_global_batch_idx(
|
||||
# out_dir=params.exp_dir,
|
||||
# global_batch_idx=params.batch_idx_train,
|
||||
# model=model,
|
||||
# model_avg=model_avg,
|
||||
# params=params,
|
||||
# optimizer=optimizer,
|
||||
# scheduler=scheduler,
|
||||
# sampler=train_dl.sampler,
|
||||
# scaler=scaler,
|
||||
# rank=rank,
|
||||
# )
|
||||
# remove_checkpoints(
|
||||
# out_dir=params.exp_dir,
|
||||
# topk=params.keep_last_k,
|
||||
# rank=rank,
|
||||
# )
|
||||
if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed:
|
||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||
# of the grad scaler is configurable, but we can't configure it to have different
|
||||
# behavior depending on the current grad scale.
|
||||
@ -723,14 +723,14 @@ def train_one_epoch(
|
||||
)
|
||||
if batch_idx % params.log_interval == 0:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
||||
cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}, "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if (params.use_fp16 and not params.deepspeed) else "")
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
@ -774,37 +774,21 @@ def run(rank, world_size, args):
|
||||
|
||||
fix_random_seed(params.seed)
|
||||
|
||||
setup_dist(use_ddp_launch=True)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
#model = whisper.load_model("medium")
|
||||
# TODO download model only on rank 0
|
||||
# TODO may change compute validation loss using multiple cards
|
||||
model = load_model("medium")
|
||||
# model = load_model("medium")
|
||||
model = load_model("large-v2")
|
||||
del model.alignment_heads
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||
model.is_multilingual, language="zh", task="transcribe"
|
||||
)
|
||||
logging.info(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
assert params.save_every_n >= params.average_period
|
||||
model_avg: Optional[nn.Module] = None
|
||||
@ -817,10 +801,12 @@ def run(rank, world_size, args):
|
||||
params=params, model=model, model_avg=model_avg
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logging.info(f"Device: {device}")
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
logging.info("Using DDP")
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
@ -837,6 +823,17 @@ def run(rank, world_size, args):
|
||||
logging.info("Loading scheduler state dict")
|
||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||
|
||||
if world_size > 1:
|
||||
if params.deepspeed:
|
||||
logging.info("Using DeepSpeed")
|
||||
model, optimizer, _, _ = deepspeed.initialize(
|
||||
args=params, model=model, optimizer=optimizer,
|
||||
model_parameters=model.parameters())
|
||||
else:
|
||||
logging.info("Using DDP")
|
||||
setup_dist(use_ddp_launch=True)
|
||||
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2**22
|
||||
@ -846,51 +843,8 @@ def run(rank, world_size, args):
|
||||
if params.inf_check:
|
||||
register_inf_check_hooks(model)
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
#
|
||||
# Caution: There is a reason to select 20.0 here. Please see
|
||||
# ../local/display_manifest_statistics.py
|
||||
#
|
||||
# You should use ../local/display_manifest_statistics.py to get
|
||||
# an utterance duration distribution for your dataset to select
|
||||
# the threshold
|
||||
if c.duration < 1.0 or c.duration > 12.0:
|
||||
logging.warning(
|
||||
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
|
||||
)
|
||||
return False
|
||||
|
||||
# In pruned RNN-T, we require that T >= S
|
||||
# where T is the number of feature frames after subsampling
|
||||
# and S is the number of tokens in the utterance
|
||||
|
||||
# In ./zipformer.py, the conv module uses the following expression
|
||||
# for subsampling
|
||||
# T = ((c.num_frames - 7) // 2 + 1) // 2
|
||||
# tokens = sp.encode(c.supervisions[0].text, out_type=str)
|
||||
|
||||
# if T < len(tokens):
|
||||
# logging.warning(
|
||||
# f"Exclude cut with ID {c.id} from training. "
|
||||
# f"Number of frames (before subsampling): {c.num_frames}. "
|
||||
# f"Number of frames (after subsampling): {T}. "
|
||||
# f"Text: {c.supervisions[0].text}. "
|
||||
# f"Tokens: {tokens}. "
|
||||
# f"Number of tokens: {len(tokens)}"
|
||||
# )
|
||||
# return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
aishell = AishellAsrDataModule(args)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
|
||||
# We only load the sampler's state dict when it loads a checkpoint
|
||||
# saved in the middle of an epoch
|
||||
@ -899,22 +853,19 @@ def run(rank, world_size, args):
|
||||
sampler_state_dict = None
|
||||
|
||||
|
||||
train_dl = aishell.train_dataloaders(aishell.train_cuts())
|
||||
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts())
|
||||
# if not params.print_diagnostics:
|
||||
# scan_pessimistic_batches_for_oom(
|
||||
# model=model,
|
||||
# train_dl=train_dl,
|
||||
# optimizer=optimizer,
|
||||
# graph_compiler=graph_compiler,
|
||||
# params=params,
|
||||
# )
|
||||
train_dl = aishell.train_dataloaders(aishell.train_cuts(), rank=rank, world_size=world_size)
|
||||
valid_dl = aishell.valid_dataloaders(aishell.valid_cuts(), rank=rank, world_size=world_size)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
logging.info(f"start training from epoch {params.start_epoch}")
|
||||
for epoch in range(params.start_epoch, params.num_epochs + 1):
|
||||
scheduler.step_epoch(epoch - 1)
|
||||
@ -945,20 +896,28 @@ def run(rank, world_size, args):
|
||||
diagnostic.print_diagnostics()
|
||||
break
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
if params.deepspeed:
|
||||
model.save_checkpoint(save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
client_state={})
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir, f"epoch-{params.cur_epoch}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}")
|
||||
else:
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
if world_size > 1 and not params.deepspeed:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
@ -988,48 +947,6 @@ def display_and_save_batch(
|
||||
logging.info(f"features shape: {features.shape}")
|
||||
|
||||
|
||||
# def scan_pessimistic_batches_for_oom(
|
||||
# model: Union[nn.Module, DDP],
|
||||
# tokenizer: whisper.tokenizer.Tokenizer,
|
||||
# train_dl: torch.utils.data.DataLoader,
|
||||
# optimizer: torch.optim.Optimizer,
|
||||
# params: AttributeDict,
|
||||
# ):
|
||||
# from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
# logging.info(
|
||||
# "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
||||
# )
|
||||
# batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||
# for criterion, cuts in batches.items():
|
||||
# batch = train_dl.dataset[cuts]
|
||||
# try:
|
||||
# with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
# loss, _ = compute_loss(
|
||||
# params=params,
|
||||
# tokenizer=tokenizer,
|
||||
# model=model,
|
||||
# batch=batch,
|
||||
# is_training=True,
|
||||
# )
|
||||
# loss.backward()
|
||||
# optimizer.zero_grad()
|
||||
# except Exception as e:
|
||||
# if "CUDA out of memory" in str(e):
|
||||
# logging.error(
|
||||
# "Your GPU ran out of memory with the current "
|
||||
# "max_duration setting. We recommend decreasing "
|
||||
# "max_duration and trying again.\n"
|
||||
# f"Failing criterion: {criterion} "
|
||||
# f"(={crit_values[criterion]}) ..."
|
||||
# )
|
||||
# display_and_save_batch(batch, params=params)
|
||||
# raise
|
||||
# logging.info(
|
||||
# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
# )
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
AishellAsrDataModule.add_arguments(parser)
|
||||
@ -1038,13 +955,10 @@ def main():
|
||||
|
||||
world_size = get_world_size()
|
||||
rank = get_rank()
|
||||
assert world_size >= 1
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
run(rank=rank, world_size=world_size, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user