mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 15:14:18 +00:00
support half precision training
This commit is contained in:
parent
395a3f952b
commit
2ff81b2838
@ -51,8 +51,8 @@ from lhotse.dataset.sampling.base import CutSampler
|
|||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from transformer import Noam
|
from transformer import Noam
|
||||||
|
|
||||||
@ -221,6 +221,13 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-fp16",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to use half precision training.",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -411,6 +418,7 @@ def save_checkpoint(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||||
sampler: Optional[CutSampler] = None,
|
sampler: Optional[CutSampler] = None,
|
||||||
|
scaler: Optional[GradScaler] = None,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save model, optimizer, scheduler and training stats to file.
|
"""Save model, optimizer, scheduler and training stats to file.
|
||||||
@ -424,6 +432,8 @@ def save_checkpoint(
|
|||||||
The optimizer used in the training.
|
The optimizer used in the training.
|
||||||
sampler:
|
sampler:
|
||||||
The sampler for the training dataset.
|
The sampler for the training dataset.
|
||||||
|
scaler:
|
||||||
|
The scaler used for mix precision training.
|
||||||
"""
|
"""
|
||||||
if rank != 0:
|
if rank != 0:
|
||||||
return
|
return
|
||||||
@ -434,6 +444,7 @@ def save_checkpoint(
|
|||||||
params=params,
|
params=params,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
|
scaler=scaler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -473,6 +484,7 @@ def compute_loss(
|
|||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
|
|
||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
@ -547,6 +559,7 @@ def train_one_epoch(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
|
scaler: GradScaler,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
rank: int = 0,
|
rank: int = 0,
|
||||||
@ -568,6 +581,8 @@ def train_one_epoch(
|
|||||||
Dataloader for the training dataset.
|
Dataloader for the training dataset.
|
||||||
valid_dl:
|
valid_dl:
|
||||||
Dataloader for the validation dataset.
|
Dataloader for the validation dataset.
|
||||||
|
scaler:
|
||||||
|
The scaler used for mix precision training.
|
||||||
tb_writer:
|
tb_writer:
|
||||||
Writer to write log messages to tensorboard.
|
Writer to write log messages to tensorboard.
|
||||||
world_size:
|
world_size:
|
||||||
@ -610,14 +625,17 @@ def train_one_epoch(
|
|||||||
and tb_writer is not None
|
and tb_writer is not None
|
||||||
and params.batch_idx_train % (params.log_interval * 5) == 0
|
and params.batch_idx_train % (params.log_interval * 5) == 0
|
||||||
):
|
):
|
||||||
deltas = optim_step_and_measure_param_change(model, optimizer)
|
deltas = optim_step_and_measure_param_change(
|
||||||
|
model, optimizer, scaler=scaler
|
||||||
|
)
|
||||||
tb_writer.add_scalars(
|
tb_writer.add_scalars(
|
||||||
"train/relative_param_change_per_minibatch",
|
"train/relative_param_change_per_minibatch",
|
||||||
deltas,
|
deltas,
|
||||||
global_step=params.batch_idx_train,
|
global_step=params.batch_idx_train,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
optimizer.step()
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
|
||||||
cur_batch_idx = params.get("cur_batch_idx", 0)
|
cur_batch_idx = params.get("cur_batch_idx", 0)
|
||||||
|
|
||||||
@ -629,20 +647,23 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
loss, loss_info = compute_loss(
|
with torch.autocast(
|
||||||
params=params,
|
device_type=model.device.type, enabled=params.use_fp16
|
||||||
model=model,
|
):
|
||||||
sp=sp,
|
loss, loss_info = compute_loss(
|
||||||
batch=batch,
|
params=params,
|
||||||
is_training=True,
|
model=model,
|
||||||
)
|
sp=sp,
|
||||||
|
batch=batch,
|
||||||
|
is_training=True,
|
||||||
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
|
|
||||||
loss.backward()
|
scaler.scale(loss).backward()
|
||||||
|
|
||||||
maybe_log_weights("train/param_norms")
|
maybe_log_weights("train/param_norms")
|
||||||
maybe_log_gradients("train/grad_norms")
|
maybe_log_gradients("train/grad_norms")
|
||||||
@ -662,6 +683,7 @@ def train_one_epoch(
|
|||||||
params=params,
|
params=params,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sampler=train_dl.sampler,
|
sampler=train_dl.sampler,
|
||||||
|
scaler=scaler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
del params.cur_batch_idx
|
del params.cur_batch_idx
|
||||||
@ -831,6 +853,11 @@ def run(rank, world_size, args):
|
|||||||
params=params,
|
params=params,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
scaler = GradScaler(enabled=params.use_fp16)
|
||||||
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
|
logging.info("Loading grad scaler state dict")
|
||||||
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
|
|
||||||
for epoch in range(params.start_epoch, params.num_epochs):
|
for epoch in range(params.start_epoch, params.num_epochs):
|
||||||
fix_random_seed(params.seed + epoch)
|
fix_random_seed(params.seed + epoch)
|
||||||
train_dl.sampler.set_epoch(epoch)
|
train_dl.sampler.set_epoch(epoch)
|
||||||
@ -857,6 +884,7 @@ def run(rank, world_size, args):
|
|||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
|
scaler=scaler,
|
||||||
)
|
)
|
||||||
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
@ -864,6 +892,7 @@ def run(rank, world_size, args):
|
|||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
sampler=train_dl.sampler,
|
sampler=train_dl.sampler,
|
||||||
|
scaler=scaler,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -891,15 +920,17 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss, _ = compute_loss(
|
with torch.autocast(
|
||||||
params=params,
|
device_type=model.device.type, enabled=params.use_fp16
|
||||||
model=model,
|
):
|
||||||
sp=sp,
|
loss, _ = compute_loss(
|
||||||
batch=batch,
|
params=params,
|
||||||
is_training=True,
|
model=model,
|
||||||
)
|
sp=sp,
|
||||||
|
batch=batch,
|
||||||
|
is_training=True,
|
||||||
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if "CUDA out of memory" in str(e):
|
if "CUDA out of memory" in str(e):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user