support half precision training

This commit is contained in:
pkufool 2022-03-31 19:58:57 +08:00
parent 395a3f952b
commit 2ff81b2838

View File

@ -51,8 +51,8 @@ from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from transformer import Noam from transformer import Noam
@ -221,6 +221,13 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser return parser
@ -411,6 +418,7 @@ def save_checkpoint(
model: nn.Module, model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -424,6 +432,8 @@ def save_checkpoint(
The optimizer used in the training. The optimizer used in the training.
sampler: sampler:
The sampler for the training dataset. The sampler for the training dataset.
scaler:
The scaler used for mix precision training.
""" """
if rank != 0: if rank != 0:
return return
@ -434,6 +444,7 @@ def save_checkpoint(
params=params, params=params,
optimizer=optimizer, optimizer=optimizer,
sampler=sampler, sampler=sampler,
scaler=scaler,
rank=rank, rank=rank,
) )
@ -473,6 +484,7 @@ def compute_loss(
feature = batch["inputs"] feature = batch["inputs"]
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
@ -547,6 +559,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
@ -568,6 +581,8 @@ def train_one_epoch(
Dataloader for the training dataset. Dataloader for the training dataset.
valid_dl: valid_dl:
Dataloader for the validation dataset. Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer: tb_writer:
Writer to write log messages to tensorboard. Writer to write log messages to tensorboard.
world_size: world_size:
@ -610,14 +625,17 @@ def train_one_epoch(
and tb_writer is not None and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0 and params.batch_idx_train % (params.log_interval * 5) == 0
): ):
deltas = optim_step_and_measure_param_change(model, optimizer) deltas = optim_step_and_measure_param_change(
model, optimizer, scaler=scaler
)
tb_writer.add_scalars( tb_writer.add_scalars(
"train/relative_param_change_per_minibatch", "train/relative_param_change_per_minibatch",
deltas, deltas,
global_step=params.batch_idx_train, global_step=params.batch_idx_train,
) )
else: else:
optimizer.step() scaler.step(optimizer)
scaler.update()
cur_batch_idx = params.get("cur_batch_idx", 0) cur_batch_idx = params.get("cur_batch_idx", 0)
@ -629,20 +647,23 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss( with torch.autocast(
params=params, device_type=model.device.type, enabled=params.use_fp16
model=model, ):
sp=sp, loss, loss_info = compute_loss(
batch=batch, params=params,
is_training=True, model=model,
) sp=sp,
batch=batch,
is_training=True,
)
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
loss.backward() scaler.scale(loss).backward()
maybe_log_weights("train/param_norms") maybe_log_weights("train/param_norms")
maybe_log_gradients("train/grad_norms") maybe_log_gradients("train/grad_norms")
@ -662,6 +683,7 @@ def train_one_epoch(
params=params, params=params,
optimizer=optimizer, optimizer=optimizer,
sampler=train_dl.sampler, sampler=train_dl.sampler,
scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx del params.cur_batch_idx
@ -831,6 +853,11 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch) fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
@ -857,6 +884,7 @@ def run(rank, world_size, args):
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
scaler=scaler,
) )
save_checkpoint( save_checkpoint(
@ -864,6 +892,7 @@ def run(rank, world_size, args):
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
sampler=train_dl.sampler, sampler=train_dl.sampler,
scaler=scaler,
rank=rank, rank=rank,
) )
@ -891,15 +920,17 @@ def scan_pessimistic_batches_for_oom(
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
optimizer.zero_grad() optimizer.zero_grad()
loss, _ = compute_loss( with torch.autocast(
params=params, device_type=model.device.type, enabled=params.use_fp16
model=model, ):
sp=sp, loss, _ = compute_loss(
batch=batch, params=params,
is_training=True, model=model,
) sp=sp,
batch=batch,
is_training=True,
)
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
except RuntimeError as e: except RuntimeError as e:
if "CUDA out of memory" in str(e): if "CUDA out of memory" in str(e):