Support mix precision training on the reworked model (#305)

* Add mix precision support

* Minor fixes

* Minor fixes

* Minor fixes
This commit is contained in:
Wei Kang 2022-04-11 16:49:54 +08:00 committed by GitHub
parent 34aad74a2c
commit 7012fd65b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 80 additions and 38 deletions

View File

@ -141,9 +141,13 @@ class Transducer(nn.Module):
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm=self.simple_lm_proj(decoder_out)
am=self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=self.simple_lm_proj(decoder_out),
am=self.simple_am_proj(encoder_out),
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
@ -176,8 +180,9 @@ class Transducer(nn.Module):
logits = self.joiner(am_pruned, lm_pruned,
project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits,
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,

View File

@ -29,7 +29,16 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--full-libri 1 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--use_fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 550
"""
@ -58,6 +67,7 @@ from lhotse.utils import fix_random_seed
from model import Transducer
from optim import Eve, Eden
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
@ -249,6 +259,13 @@ def get_parser():
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser
@ -447,6 +464,7 @@ def save_checkpoint(
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -460,6 +478,8 @@ def save_checkpoint(
The optimizer used in the training.
sampler:
The sampler for the training dataset.
scaler:
The scaler used for mix precision training.
"""
if rank != 0:
return
@ -471,6 +491,7 @@ def save_checkpoint(
optimizer=optimizer,
scheduler=scheduler,
sampler=sampler,
scaler=scaler,
rank=rank,
)
@ -599,6 +620,7 @@ def train_one_epoch(
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
@ -622,6 +644,8 @@ def train_one_epoch(
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
@ -644,6 +668,7 @@ def train_one_epoch(
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
@ -657,9 +682,10 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
loss.backward()
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
optimizer.step()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if params.print_diagnostics and batch_idx == 5:
@ -676,6 +702,7 @@ def train_one_epoch(
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
@ -695,7 +722,9 @@ def train_one_epoch(
)
if tb_writer is not None:
tb_writer.add_scalar("train/learning_rate", cur_params.batch_idx_train)
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
@ -850,6 +879,11 @@ def run(rank, world_size, args):
params=params,
)
scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs):
scheduler.step_epoch(epoch)
fix_random_seed(params.seed + epoch)
@ -869,6 +903,7 @@ def run(rank, world_size, args):
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
@ -884,6 +919,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
@ -913,6 +949,7 @@ def scan_pessimistic_batches_for_oom(
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,