mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 06:34:20 +00:00
Add mix precision support
This commit is contained in:
parent
34aad74a2c
commit
4ebe821769
@ -141,17 +141,21 @@ class Transducer(nn.Module):
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = x_lens
|
||||
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=self.simple_lm_proj(decoder_out),
|
||||
am=self.simple_am_proj(encoder_out),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
lm=self.simple_lm_proj(decoder_out)
|
||||
am=self.simple_am_proj(encoder_out)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
@ -176,13 +180,14 @@ class Transducer(nn.Module):
|
||||
logits = self.joiner(am_pruned, lm_pruned,
|
||||
project_input=False)
|
||||
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits,
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
|
@ -58,6 +58,7 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eve, Eden
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -249,6 +250,13 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-fp16",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -447,6 +455,7 @@ def save_checkpoint(
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
sampler: Optional[CutSampler] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
@ -460,6 +469,8 @@ def save_checkpoint(
|
||||
The optimizer used in the training.
|
||||
sampler:
|
||||
The sampler for the training dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
@ -471,6 +482,7 @@ def save_checkpoint(
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
@ -599,6 +611,7 @@ def train_one_epoch(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
rank: int = 0,
|
||||
@ -622,6 +635,8 @@ def train_one_epoch(
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
@ -644,22 +659,24 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=(params.batch_idx_train / params.model_warm_step)
|
||||
)
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
warmup=(params.batch_idx_train / params.model_warm_step)
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
loss.backward()
|
||||
scaler.scale(loss).backward()
|
||||
scheduler.step_batch(params.batch_idx_train)
|
||||
optimizer.step()
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if params.print_diagnostics and batch_idx == 5:
|
||||
@ -676,6 +693,7 @@ def train_one_epoch(
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
del params.cur_batch_idx
|
||||
@ -841,7 +859,7 @@ def run(rank, world_size, args):
|
||||
valid_cuts += librispeech.dev_other_cuts()
|
||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||
|
||||
if not params.print_diagnostics:
|
||||
if not params.print_diagnostics and not params.use_fp16:
|
||||
scan_pessimistic_batches_for_oom(
|
||||
model=model,
|
||||
train_dl=train_dl,
|
||||
@ -850,6 +868,11 @@ def run(rank, world_size, args):
|
||||
params=params,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_fp16)
|
||||
if checkpoints and "grad_scaler" in checkpoints:
|
||||
logging.info("Loading grad scaler state dict")
|
||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
scheduler.step_epoch(epoch)
|
||||
fix_random_seed(params.seed + epoch)
|
||||
@ -869,6 +892,7 @@ def run(rank, world_size, args):
|
||||
sp=sp,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
@ -884,6 +908,7 @@ def run(rank, world_size, args):
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
sampler=train_dl.sampler,
|
||||
scaler=scaler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user