Support mix precision training on the reworked model (#305)

* Add mix precision support

* Minor fixes

* Minor fixes

* Minor fixes
This commit is contained in:
Wei Kang 2022-04-11 16:49:54 +08:00 committed by GitHub
parent 34aad74a2c
commit 7012fd65b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 38 deletions

View File

@ -141,17 +141,21 @@ class Transducer(nn.Module):
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=self.simple_lm_proj(decoder_out),
am=self.simple_am_proj(encoder_out),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
lm=self.simple_lm_proj(decoder_out)
am=self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction="sum",
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
@ -176,13 +180,14 @@ class Transducer(nn.Module):
logits = self.joiner(am_pruned, lm_pruned,
project_input=False)
pruned_loss = k2.rnnt_loss_pruned(
logits=logits,
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction="sum",
)
return (simple_loss, pruned_loss)

View File

@ -29,7 +29,16 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--full-libri 1 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--use_fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 550
"""
@ -58,6 +67,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eve, Eden
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -249,6 +259,13 @@ def get_parser():
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser
@ -447,6 +464,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -460,6 +478,8 @@ def save_checkpoint(
The optimizer used in the training.
sampler:
The sampler for the training dataset.
scaler:
The scaler used for mix precision training.
"""
if rank != 0:
return
@ -471,6 +491,7 @@ def save_checkpoint(
optimizer=optimizer,
scheduler=scheduler,
sampler=sampler,
scaler=scaler,
rank=rank,
)
@ -599,6 +620,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
@ -622,6 +644,8 @@ def train_one_epoch(
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
@ -644,22 +668,24 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step)
)
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step)
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
loss.backward()
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
optimizer.step()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5:
@ -676,6 +702,7 @@ def train_one_epoch(
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
@ -695,7 +722,9 @@ def train_one_epoch(
)
if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_params.batch_idx_train)
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
@ -850,6 +879,11 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs):
scheduler.step_epoch(epoch)
fix_random_seed(params.seed + epoch)
@ -869,6 +903,7 @@ def run(rank, world_size, args):
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
@ -884,6 +919,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
@ -913,14 +949,15 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup = 0.0
)
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup = 0.0
)
loss.backward()
optimizer.step()
optimizer.zero_grad()