mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Support mix precision training on the reworked model (#305)
* Add mix precision support * Minor fixes * Minor fixes * Minor fixes
This commit is contained in:
parent
34aad74a2c
commit
7012fd65b5
@ -141,17 +141,21 @@ class Transducer(nn.Module):
|
|||||||
boundary[:, 2] = y_lens
|
boundary[:, 2] = y_lens
|
||||||
boundary[:, 3] = x_lens
|
boundary[:, 3] = x_lens
|
||||||
|
|
||||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
lm=self.simple_lm_proj(decoder_out)
|
||||||
lm=self.simple_lm_proj(decoder_out),
|
am=self.simple_am_proj(encoder_out)
|
||||||
am=self.simple_am_proj(encoder_out),
|
|
||||||
symbols=y_padded,
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
termination_symbol=blank_id,
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||||
lm_only_scale=lm_scale,
|
lm=lm.float(),
|
||||||
am_only_scale=am_scale,
|
am=am.float(),
|
||||||
boundary=boundary,
|
symbols=y_padded,
|
||||||
reduction="sum",
|
termination_symbol=blank_id,
|
||||||
return_grad=True,
|
lm_only_scale=lm_scale,
|
||||||
)
|
am_only_scale=am_scale,
|
||||||
|
boundary=boundary,
|
||||||
|
reduction="sum",
|
||||||
|
return_grad=True,
|
||||||
|
)
|
||||||
|
|
||||||
# ranges : [B, T, prune_range]
|
# ranges : [B, T, prune_range]
|
||||||
ranges = k2.get_rnnt_prune_ranges(
|
ranges = k2.get_rnnt_prune_ranges(
|
||||||
@ -176,13 +180,14 @@ class Transducer(nn.Module):
|
|||||||
logits = self.joiner(am_pruned, lm_pruned,
|
logits = self.joiner(am_pruned, lm_pruned,
|
||||||
project_input=False)
|
project_input=False)
|
||||||
|
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
logits=logits,
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
symbols=y_padded,
|
logits=logits.float(),
|
||||||
ranges=ranges,
|
symbols=y_padded,
|
||||||
termination_symbol=blank_id,
|
ranges=ranges,
|
||||||
boundary=boundary,
|
termination_symbol=blank_id,
|
||||||
reduction="sum",
|
boundary=boundary,
|
||||||
)
|
reduction="sum",
|
||||||
|
)
|
||||||
|
|
||||||
return (simple_loss, pruned_loss)
|
return (simple_loss, pruned_loss)
|
||||||
|
@ -29,7 +29,16 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
|||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 300
|
--max-duration 300
|
||||||
|
|
||||||
|
# For mix precision training:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless2/train.py \
|
||||||
|
--world-size 4 \
|
||||||
|
--num-epochs 30 \
|
||||||
|
--start-epoch 0 \
|
||||||
|
--use_fp16 1 \
|
||||||
|
--exp-dir pruned_transducer_stateless2/exp \
|
||||||
|
--full-libri 1 \
|
||||||
|
--max-duration 550
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -58,6 +67,7 @@ from lhotse.utils import fix_random_seed
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eve, Eden
|
from optim import Eve, Eden
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -249,6 +259,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-fp16",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use half precision training.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -447,6 +464,7 @@ def save_checkpoint(
|
|||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
scheduler: Optional[LRSchedulerType] = None,
|
scheduler: Optional[LRSchedulerType] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
|
scaler: Optional[GradScaler] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -460,6 +478,8 @@ def save_checkpoint(
|
|||||||
The optimizer used in the training.
|
The optimizer used in the training.
|
||||||
sampler:
|
sampler:
|
||||||
The sampler for the training dataset.
|
The sampler for the training dataset.
|
||||||
|
scaler:
|
||||||
|
The scaler used for mix precision training.
|
||||||
"""
|
"""
|
||||||
if rank != 0:
|
if rank != 0:
|
||||||
return
|
return
|
||||||
@ -471,6 +491,7 @@ def save_checkpoint(
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
|
scaler=scaler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -599,6 +620,7 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
scaler: GradScaler,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
@ -622,6 +644,8 @@ def train_one_epoch(
|
|||||||
Dataloader for the training dataset.
|
Dataloader for the training dataset.
|
||||||
valid_dl:
|
valid_dl:
|
||||||
Dataloader for the validation dataset.
|
Dataloader for the validation dataset.
|
||||||
|
scaler:
|
||||||
|
The scaler used for mix precision training.
|
||||||
tb_writer:
|
tb_writer:
|
||||||
Writer to write log messages to tensorboard.
|
Writer to write log messages to tensorboard.
|
||||||
world_size:
|
world_size:
|
||||||
@ -644,22 +668,24 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
loss, loss_info = compute_loss(
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
params=params,
|
loss, loss_info = compute_loss(
|
||||||
model=model,
|
params=params,
|
||||||
sp=sp,
|
model=model,
|
||||||
batch=batch,
|
sp=sp,
|
||||||
is_training=True,
|
batch=batch,
|
||||||
warmup=(params.batch_idx_train / params.model_warm_step)
|
is_training=True,
|
||||||
)
|
warmup=(params.batch_idx_train / params.model_warm_step)
|
||||||
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
loss.backward()
|
scaler.scale(loss).backward()
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
optimizer.step()
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
@ -676,6 +702,7 @@ def train_one_epoch(
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sampler=train_dl.sampler,
|
sampler=train_dl.sampler,
|
||||||
|
scaler=scaler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
del params.cur_batch_idx
|
del params.cur_batch_idx
|
||||||
@ -695,7 +722,9 @@ def train_one_epoch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
tb_writer.add_scalar("train/learning_rate", cur_params.batch_idx_train)
|
tb_writer.add_scalar(
|
||||||
|
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||||
|
)
|
||||||
|
|
||||||
loss_info.write_summary(
|
loss_info.write_summary(
|
||||||
tb_writer, "train/current_", params.batch_idx_train
|
tb_writer, "train/current_", params.batch_idx_train
|
||||||
@ -850,6 +879,11 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
|
logging.info("Loading grad scaler state dict")
|
||||||
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
scheduler.step_epoch(epoch)
|
scheduler.step_epoch(epoch)
|
||||||
fix_random_seed(params.seed + epoch)
|
fix_random_seed(params.seed + epoch)
|
||||||
@ -869,6 +903,7 @@ def run(rank, world_size, args):
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
|
scaler=scaler,
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
@ -884,6 +919,7 @@ def run(rank, world_size, args):
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
sampler=train_dl.sampler,
|
sampler=train_dl.sampler,
|
||||||
|
scaler=scaler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -913,14 +949,15 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
|
||||||
# (i.e. are not remembered by the decaying-average in adam), because
|
# (i.e. are not remembered by the decaying-average in adam), because
|
||||||
# we want to avoid these params being subject to shrinkage in adam.
|
# we want to avoid these params being subject to shrinkage in adam.
|
||||||
loss, _ = compute_loss(
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
params=params,
|
loss, _ = compute_loss(
|
||||||
model=model,
|
params=params,
|
||||||
sp=sp,
|
model=model,
|
||||||
batch=batch,
|
sp=sp,
|
||||||
is_training=True,
|
batch=batch,
|
||||||
warmup = 0.0
|
is_training=True,
|
||||||
)
|
warmup = 0.0
|
||||||
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user