Add mix precision support

This commit is contained in:
pkufool 2022-04-11 15:27:24 +08:00
parent 34aad74a2c
commit 4ebe821769
2 changed files with 60 additions and 30 deletions

View File

@ -141,9 +141,13 @@ class Transducer(nn.Module):
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = x_lens boundary[:, 3] = x_lens
lm=self.simple_lm_proj(decoder_out)
am=self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=self.simple_lm_proj(decoder_out), lm=lm.float(),
am=self.simple_am_proj(encoder_out), am=am.float(),
symbols=y_padded, symbols=y_padded,
termination_symbol=blank_id, termination_symbol=blank_id,
lm_only_scale=lm_scale, lm_only_scale=lm_scale,
@ -176,8 +180,9 @@ class Transducer(nn.Module):
logits = self.joiner(am_pruned, lm_pruned, logits = self.joiner(am_pruned, lm_pruned,
project_input=False) project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
logits=logits, logits=logits.float(),
symbols=y_padded, symbols=y_padded,
ranges=ranges, ranges=ranges,
termination_symbol=blank_id, termination_symbol=blank_id,

View File

@ -58,6 +58,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eve, Eden from optim import Eve, Eden
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -249,6 +250,13 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser return parser
@ -447,6 +455,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -460,6 +469,8 @@ def save_checkpoint(
The optimizer used in the training. The optimizer used in the training.
sampler: sampler:
The sampler for the training dataset. The sampler for the training dataset.
scaler:
The scaler used for mix precision training.
""" """
if rank != 0: if rank != 0:
return return
@ -471,6 +482,7 @@ def save_checkpoint(
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sampler=sampler, sampler=sampler,
scaler=scaler,
rank=rank, rank=rank,
) )
@ -599,6 +611,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
@ -622,6 +635,8 @@ def train_one_epoch(
Dataloader for the training dataset. Dataloader for the training dataset.
valid_dl: valid_dl:
Dataloader for the validation dataset. Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer: tb_writer:
Writer to write log messages to tensorboard. Writer to write log messages to tensorboard.
world_size: world_size:
@ -644,6 +659,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -657,9 +673,10 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
loss.backward() scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train) scheduler.step_batch(params.batch_idx_train)
optimizer.step() scaler.step(optimizer)
scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -676,6 +693,7 @@ def train_one_epoch(
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sampler=train_dl.sampler, sampler=train_dl.sampler,
scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx del params.cur_batch_idx
@ -841,7 +859,7 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
if not params.print_diagnostics: if not params.print_diagnostics and not params.use_fp16:
scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
model=model, model=model,
train_dl=train_dl, train_dl=train_dl,
@ -850,6 +868,11 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
scheduler.step_epoch(epoch) scheduler.step_epoch(epoch)
fix_random_seed(params.seed + epoch) fix_random_seed(params.seed + epoch)
@ -869,6 +892,7 @@ def run(rank, world_size, args):
sp=sp, sp=sp,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
@ -884,6 +908,7 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sampler=train_dl.sampler, sampler=train_dl.sampler,
scaler=scaler,
rank=rank, rank=rank,
) )