support half precision training

This commit is contained in:
pkufool 2022-03-31 19:58:57 +08:00
parent 395a3f952b
commit 2ff81b2838

View File

@ -51,8 +51,8 @@ from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
@ -221,6 +221,13 @@ def get_parser():
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser
@ -411,6 +418,7 @@ def save_checkpoint(
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -424,6 +432,8 @@ def save_checkpoint(
The optimizer used in the training.
sampler:
The sampler for the training dataset.
scaler:
The scaler used for mix precision training.
"""
if rank != 0:
return
@ -434,6 +444,7 @@ def save_checkpoint(
params=params,
optimizer=optimizer,
sampler=sampler,
scaler=scaler,
rank=rank,
)
@ -473,6 +484,7 @@ def compute_loss(
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
@ -547,6 +559,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
@ -568,6 +581,8 @@ def train_one_epoch(
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
@ -610,14 +625,17 @@ def train_one_epoch(
and tb_writer is not None
and params.batch_idx_train % (params.log_interval * 5) == 0
):
deltas = optim_step_and_measure_param_change(model, optimizer)
deltas = optim_step_and_measure_param_change(
model, optimizer, scaler=scaler
)
tb_writer.add_scalars(
"train/relative_param_change_per_minibatch",
deltas,
global_step=params.batch_idx_train,
)
else:
optimizer.step()
scaler.step(optimizer)
scaler.update()
cur_batch_idx = params.get("cur_batch_idx", 0)
@ -629,20 +647,23 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
with torch.autocast(
device_type=model.device.type, enabled=params.use_fp16
):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
loss.backward()
scaler.scale(loss).backward()
maybe_log_weights("train/param_norms")
maybe_log_gradients("train/grad_norms")
@ -662,6 +683,7 @@ def train_one_epoch(
params=params,
optimizer=optimizer,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
@ -831,6 +853,11 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs):
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
@ -857,6 +884,7 @@ def run(rank, world_size, args):
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
scaler=scaler,
)
save_checkpoint(
@ -864,6 +892,7 @@ def run(rank, world_size, args):
model=model,
optimizer=optimizer,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
@ -891,15 +920,17 @@ def scan_pessimistic_batches_for_oom(
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
with torch.autocast(
device_type=model.device.type, enabled=params.use_fp16
):
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
if "CUDA out of memory" in str(e):