Support mix precision training on the reworked model (#305)

* Add mix precision support

* Minor fixes

* Minor fixes

* Minor fixes
This commit is contained in:
Wei Kang 2022-04-11 16:49:54 +08:00 committed by GitHub
parent 34aad74a2c
commit 7012fd65b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 38 deletions

View File

@ -141,17 +141,21 @@ class Transducer(nn.Module):
boundary[:, 2] = y_lens boundary[:, 2] = y_lens
boundary[:, 3] = x_lens boundary[:, 3] = x_lens
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( lm=self.simple_lm_proj(decoder_out)
lm=self.simple_lm_proj(decoder_out), am=self.simple_am_proj(encoder_out)
am=self.simple_am_proj(encoder_out),
symbols=y_padded, with torch.cuda.amp.autocast(enabled=False):
termination_symbol=blank_id, simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm_only_scale=lm_scale, lm=lm.float(),
am_only_scale=am_scale, am=am.float(),
boundary=boundary, symbols=y_padded,
reduction="sum", termination_symbol=blank_id,
return_grad=True, lm_only_scale=lm_scale,
) am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range] # ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges( ranges = k2.get_rnnt_prune_ranges(
@ -176,13 +180,14 @@ class Transducer(nn.Module):
logits = self.joiner(am_pruned, lm_pruned, logits = self.joiner(am_pruned, lm_pruned,
project_input=False) project_input=False)
pruned_loss = k2.rnnt_loss_pruned( with torch.cuda.amp.autocast(enabled=False):
logits=logits, pruned_loss = k2.rnnt_loss_pruned(
symbols=y_padded, logits=logits.float(),
ranges=ranges, symbols=y_padded,
termination_symbol=blank_id, ranges=ranges,
boundary=boundary, termination_symbol=blank_id,
reduction="sum", boundary=boundary,
) reduction="sum",
)
return (simple_loss, pruned_loss) return (simple_loss, pruned_loss)

View File

@ -29,7 +29,16 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300
# For mix precision training:
./pruned_transducer_stateless2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--use_fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 550
""" """
@ -58,6 +67,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eve, Eden from optim import Eve, Eden
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -249,6 +259,13 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser return parser
@ -447,6 +464,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None, sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -460,6 +478,8 @@ def save_checkpoint(
The optimizer used in the training. The optimizer used in the training.
sampler: sampler:
The sampler for the training dataset. The sampler for the training dataset.
scaler:
The scaler used for mix precision training.
""" """
if rank != 0: if rank != 0:
return return
@ -471,6 +491,7 @@ def save_checkpoint(
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sampler=sampler, sampler=sampler,
scaler=scaler,
rank=rank, rank=rank,
) )
@ -599,6 +620,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0, rank: int = 0,
@ -622,6 +644,8 @@ def train_one_epoch(
Dataloader for the training dataset. Dataloader for the training dataset.
valid_dl: valid_dl:
Dataloader for the validation dataset. Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer: tb_writer:
Writer to write log messages to tensorboard. Writer to write log messages to tensorboard.
world_size: world_size:
@ -644,22 +668,24 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss( with torch.cuda.amp.autocast(enabled=params.use_fp16):
params=params, loss, loss_info = compute_loss(
model=model, params=params,
sp=sp, model=model,
batch=batch, sp=sp,
is_training=True, batch=batch,
warmup=(params.batch_idx_train / params.model_warm_step) is_training=True,
) warmup=(params.batch_idx_train / params.model_warm_step)
)
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
loss.backward() scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train) scheduler.step_batch(params.batch_idx_train)
optimizer.step() scaler.step(optimizer)
scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
@ -676,6 +702,7 @@ def train_one_epoch(
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sampler=train_dl.sampler, sampler=train_dl.sampler,
scaler=scaler,
rank=rank, rank=rank,
) )
del params.cur_batch_idx del params.cur_batch_idx
@ -695,7 +722,9 @@ def train_one_epoch(
) )
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_params.batch_idx_train) tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
@ -850,6 +879,11 @@ def run(rank, world_size, args):
params=params, params=params,
) )
scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
scheduler.step_epoch(epoch) scheduler.step_epoch(epoch)
fix_random_seed(params.seed + epoch) fix_random_seed(params.seed + epoch)
@ -869,6 +903,7 @@ def run(rank, world_size, args):
sp=sp, sp=sp,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
@ -884,6 +919,7 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sampler=train_dl.sampler, sampler=train_dl.sampler,
scaler=scaler,
rank=rank, rank=rank,
) )
@ -913,14 +949,15 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero # warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because # (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam. # we want to avoid these params being subject to shrinkage in adam.
loss, _ = compute_loss( with torch.cuda.amp.autocast(enabled=params.use_fp16):
params=params, loss, _ = compute_loss(
model=model, params=params,
sp=sp, model=model,
batch=batch, sp=sp,
is_training=True, batch=batch,
warmup = 0.0 is_training=True,
) warmup = 0.0
)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()