make modifications to support full bf16 training

This commit is contained in:
marcoyang 2024-07-23 23:14:32 +08:00
parent 0c29c45c32
commit 5a05da8fcc
4 changed files with 2913 additions and 63 deletions

View File

@ -296,10 +296,6 @@ class SoftmaxFunction(torch.autograd.Function):
# if x dtype is float16, x.softmax() returns a float32 because # if x dtype is float16, x.softmax() returns a float32 because
# (presumably) that op does not support float16, and autocast # (presumably) that op does not support float16, and autocast
# is enabled. # is enabled.
# import pdb; pdb.set_trace()
if torch.is_autocast_enabled():
# ans = ans.to(torch.float16)
ans = ans.to(ans.dtype)
ctx.save_for_backward(ans) ctx.save_for_backward(ans)
ctx.x_dtype = x.dtype ctx.x_dtype = x.dtype
ctx.dim = dim ctx.dim = dim
@ -309,10 +305,6 @@ class SoftmaxFunction(torch.autograd.Function):
def backward(ctx, ans_grad: Tensor): def backward(ctx, ans_grad: Tensor):
(ans,) = ctx.saved_tensors (ans,) = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
# import pdb; pdb.set_trace()
if ctx.x_dtype == torch.float16:
ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32)
x_grad = ans_grad * ans x_grad = ans_grad * ans
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
return x_grad, None return x_grad, None
@ -764,9 +756,6 @@ class BalancerFunction(torch.autograd.Function):
try: try:
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
# import pdb; pdb.set_trace()
if x.dtype == torch.float16:
x = x.to(torch.float32)
x = x.detach() x = x.detach()
x.requires_grad = True x.requires_grad = True
mean_dims = [i for i in range(x.ndim) if i != channel_dim] mean_dims = [i for i in range(x.ndim) if i != channel_dim]
@ -797,15 +786,17 @@ class BalancerFunction(torch.autograd.Function):
loss_grad = loss_grad * (grad_scale / loss_grad_rms) loss_grad = loss_grad * (grad_scale / loss_grad_rms)
if x_grad.dtype == torch.float16: # if x_grad.dtype == torch.float16:
x_grad_float = x_grad.to(torch.float32) # x_grad_float = x_grad.to(torch.float32)
else: # else:
x_grad_float = x_grad # x_grad_float = x_grad
# scale each element of loss_grad by the absolute value of the corresponding # scale each element of loss_grad by the absolute value of the corresponding
# element of x_grad, which we view as a noisy estimate of its magnitude for that # element of x_grad, which we view as a noisy estimate of its magnitude for that
# (frame and dimension). later we can consider factored versions. # (frame and dimension). later we can consider factored versions.
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad) # x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
x_grad = x_grad_mod.to(x_grad.dtype) x_grad = x_grad + (x_grad.abs() * loss_grad)
# x_grad = x_grad_mod.to(x_grad.dtype)
except Exception as e: except Exception as e:
logging.info( logging.info(
f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue." f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue."
@ -1025,11 +1016,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
dtype = x_orig.dtype dtype = x_orig.dtype
# import pdb; pdb.set_trace() x_detached = x_orig.detach()
if x_orig.dtype == torch.float16:
x_detached = x_orig.to(torch.float32).detach()
else:
x_detached = x_orig.detach()
x_detached.requires_grad = True x_detached.requires_grad = True
metric = _whitening_metric(x_detached, w.num_groups) metric = _whitening_metric(x_detached, w.num_groups)
@ -1248,8 +1235,6 @@ class DoubleSwishFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor) -> Tensor: def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad requires_grad = x.requires_grad
if x.dtype == torch.float16:
x = x.to(torch.float32)
s = torch.sigmoid(x - 1.0) s = torch.sigmoid(x - 1.0)
y = x * s y = x * s
@ -1360,8 +1345,6 @@ class SwooshLFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor) -> Tensor: def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad requires_grad = x.requires_grad
if x.dtype == torch.float16:
x = x.to(torch.float32)
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
@ -1415,10 +1398,11 @@ class SwooshL(torch.nn.Module):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035 return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
if not x.requires_grad: if not x.requires_grad:
return k2.swoosh_l_forward(x) # return k2.swoosh_l_forward(x)
return SwooshLForward(x)
else: else:
return k2.swoosh_l(x) # return k2.swoosh_l(x)
# return SwooshLFunction.apply(x) return SwooshLFunction.apply(x) # this support bf16
class SwooshLOnnx(torch.nn.Module): class SwooshLOnnx(torch.nn.Module):
@ -1489,10 +1473,11 @@ class SwooshR(torch.nn.Module):
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device) zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687 return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
if not x.requires_grad: if not x.requires_grad:
return k2.swoosh_r_forward(x) # return k2.swoosh_r_forward(x)
return SwooshRForward(x)
else: else:
return k2.swoosh_r(x) # return k2.swoosh_r(x)
# return SwooshRFunction.apply(x) return SwooshRFunction.apply(x)
class SwooshROnnx(torch.nn.Module): class SwooshROnnx(torch.nn.Module):
@ -1647,6 +1632,7 @@ class ActivationDropoutAndLinear(torch.nn.Module):
self.activation = activation self.activation = activation
self.dropout_p = dropout_p self.dropout_p = dropout_p
self.dropout_shared_dim = dropout_shared_dim self.dropout_shared_dim = dropout_shared_dim
self.dropout = Dropout3(dropout_p, shared_dim=dropout_shared_dim)
def forward(self, x: Tensor): def forward(self, x: Tensor):
if torch.jit.is_scripting() or torch.jit.is_tracing(): if torch.jit.is_scripting() or torch.jit.is_tracing():
@ -1658,14 +1644,23 @@ class ActivationDropoutAndLinear(torch.nn.Module):
assert False, self.activation assert False, self.activation
return torch.nn.functional.linear(x, self.weight, self.bias) return torch.nn.functional.linear(x, self.weight, self.bias)
return ActivationDropoutAndLinearFunction.apply( if self.activation == "SwooshL":
x, x = SwooshL()(x)
self.weight, elif self.activation == "SwooshR":
self.bias, x = SwooshR()(x)
self.activation,
float(self.dropout_p), x = self.dropout(x)
self.dropout_shared_dim, return torch.nn.functional.linear(x, self.weight, self.bias)
)
# return ActivationDropoutAndLinearFunction.apply(
# x,
# self.weight,
# self.bias,
# self.activation,
# float(self.dropout_p),
# self.dropout_shared_dim,
# )
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor: def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:

View File

@ -0,0 +1,406 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Tuple
import torch
from scaling_bf16 import (
Balancer,
BiasNorm,
Dropout3,
FloatLike,
Optional,
ScaledConv2d,
ScaleGrad,
ScheduledFloat,
SwooshL,
SwooshR,
Whiten,
)
from torch import Tensor, nn
class ConvNeXt(nn.Module):
"""
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
"""
def __init__(
self,
channels: int,
hidden_ratio: int = 3,
kernel_size: Tuple[int, int] = (7, 7),
layerdrop_rate: FloatLike = None,
):
super().__init__()
self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
hidden_channels = channels * hidden_ratio
if layerdrop_rate is None:
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
self.layerdrop_rate = layerdrop_rate
self.depthwise_conv = nn.Conv2d(
in_channels=channels,
out_channels=channels,
groups=channels,
kernel_size=kernel_size,
padding=self.padding,
)
self.pointwise_conv1 = nn.Conv2d(
in_channels=channels, out_channels=hidden_channels, kernel_size=1
)
self.hidden_balancer = Balancer(
hidden_channels,
channel_dim=1,
min_positive=0.3,
max_positive=1.0,
min_abs=0.75,
max_abs=5.0,
)
self.activation = SwooshL()
self.pointwise_conv2 = ScaledConv2d(
in_channels=hidden_channels,
out_channels=channels,
kernel_size=1,
initial_scale=0.01,
)
self.out_balancer = Balancer(
channels,
channel_dim=1,
min_positive=0.4,
max_positive=0.6,
min_abs=1.0,
max_abs=6.0,
)
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=5.0,
prob=(0.025, 0.25),
grad_scale=0.01,
)
def forward(self, x: Tensor) -> Tensor:
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
return self.forward_internal(x)
layerdrop_rate = float(self.layerdrop_rate)
if layerdrop_rate != 0.0:
batch_size = x.shape[0]
mask = (
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
> layerdrop_rate
)
else:
mask = None
# turns out this caching idea does not work with --world-size > 1
# return caching_eval(self.forward_internal, x, mask)
return self.forward_internal(x, mask)
def forward_internal(
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
) -> Tensor:
"""
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
The returned value has the same shape as x.
"""
bypass = x
x = self.depthwise_conv(x)
x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
if layer_skip_mask is not None:
x = x * layer_skip_mask
x = bypass + x
x = self.out_balancer(x)
if x.requires_grad:
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
x = self.out_whiten(x)
x = x.transpose(1, 3) # (N, C, H, W)
return x
def streaming_forward(
self,
x: Tensor,
cached_left_pad: Tensor,
) -> Tuple[Tensor, Tensor]:
"""
Args:
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
cached_left_pad: (batch_size, num_channels, left_pad, num_freqs)
Returns:
- The returned value has the same shape as x.
- Updated cached_left_pad.
"""
padding = self.padding
# The length without right padding for depth-wise conv
T = x.size(2) - padding[0]
bypass = x[:, :, :T, :]
# Pad left side
assert cached_left_pad.size(2) == padding[0], (
cached_left_pad.size(2),
padding[0],
)
x = torch.cat([cached_left_pad, x], dim=2)
# Update cached left padding
cached_left_pad = x[:, :, T : padding[0] + T, :]
# depthwise_conv
x = torch.nn.functional.conv2d(
x,
weight=self.depthwise_conv.weight,
bias=self.depthwise_conv.bias,
padding=(0, padding[1]),
groups=self.depthwise_conv.groups,
)
x = self.pointwise_conv1(x)
x = self.hidden_balancer(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
x = bypass + x
return x, cached_left_pad
class Conv2dSubsampling(nn.Module):
"""Convolutional 2D subsampling (to 1/2 length).
Convert an input of shape (N, T, idim) to an output
with shape (N, T', odim), where
T' = (T-3)//2 - 2 == (T-7)//2
It is based on
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
"""
def __init__(
self,
in_channels: int,
out_channels: int,
layer1_channels: int = 8,
layer2_channels: int = 32,
layer3_channels: int = 128,
dropout: FloatLike = 0.1,
) -> None:
"""
Args:
in_channels:
Number of channels in. The input shape is (N, T, in_channels).
Caution: It requires: T >=7, in_channels >=7
out_channels
Output dim. The output shape is (N, (T-3)//2, out_channels)
layer1_channels:
Number of channels in layer1
layer1_channels:
Number of channels in layer2
bottleneck:
bottleneck dimension for 1d squeeze-excite
"""
assert in_channels >= 7
super().__init__()
# The ScaleGrad module is there to prevent the gradients
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
# exceeding the range of fp16 when using automatic mixed precision (amp)
# training. (The second one is necessary to stop its bias from getting
# a too-large gradient).
self.conv = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=layer1_channels,
kernel_size=3,
padding=(0, 1), # (time, freq)
),
ScaleGrad(0.2),
Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
SwooshR(),
nn.Conv2d(
in_channels=layer1_channels,
out_channels=layer2_channels,
kernel_size=3,
stride=2,
padding=0,
),
Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
SwooshR(),
nn.Conv2d(
in_channels=layer2_channels,
out_channels=layer3_channels,
kernel_size=3,
stride=(1, 2), # (time, freq)
),
Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
SwooshR(),
)
# just one convnext layer
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
# (in_channels-3)//4
self.out_width = (((in_channels - 1) // 2) - 1) // 2
self.layer3_channels = layer3_channels
self.out = nn.Linear(self.out_width * layer3_channels, out_channels)
# use a larger than normal grad_scale on this whitening module; there is
# only one such module, so there is not a concern about adding together
# many copies of this extra gradient term.
self.out_whiten = Whiten(
num_groups=1,
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
prob=(0.025, 0.25),
grad_scale=0.02,
)
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
# getting large, there is an unnecessary degree of freedom.
self.out_norm = BiasNorm(out_channels)
self.dropout = Dropout3(dropout, shared_dim=1)
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, (T-7)//2, odim)
- output lengths, of shape (batch_size,)
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
# gradients.
x = self.conv(x)
x = self.convnext(x)
# Now x is of shape (N, odim, (T-7)//2, (idim-3)//4)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, (T-7)//2, out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, (T-7)//2, odim)
x = self.out_whiten(x)
x = self.out_norm(x)
x = self.dropout(x)
if torch.jit.is_scripting() or torch.jit.is_tracing():
x_lens = (x_lens - 7) // 2
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
x_lens = (x_lens - 7) // 2
assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())
return x, x_lens
def streaming_forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
cached_left_pad: Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Subsample x.
Args:
x:
Its shape is (N, T, idim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
Returns:
- a tensor of shape (N, (T-7)//2, odim)
- output lengths, of shape (batch_size,)
- updated cache
"""
# On entry, x is (N, T, idim)
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
# T' = (T-7)//2
x = self.conv(x)
# T' = (T-7)//2-3
x, cached_left_pad = self.convnext.streaming_forward(
x, cached_left_pad=cached_left_pad
)
# Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
b, c, t, f = x.size()
x = x.transpose(1, 2).reshape(b, t, c * f)
# now x: (N, T', out_width * layer3_channels))
x = self.out(x)
# Now x is of shape (N, T', odim)
x = self.out_norm(x)
if torch.jit.is_scripting() or torch.jit.is_tracing():
assert self.convnext.padding[0] == 3
# The ConvNeXt module needs 3 frames of right padding after subsampling
x_lens = (x_lens - 7) // 2 - 3
else:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# The ConvNeXt module needs 3 frames of right padding after subsampling
assert self.convnext.padding[0] == 3
x_lens = (x_lens - 7) // 2 - 3
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max())
return x, x_lens, cached_left_pad
@torch.jit.export
def get_init_states(
self,
batch_size: int = 1,
device: torch.device = torch.device("cpu"),
) -> Tensor:
"""Get initial states for Conv2dSubsampling module.
It is the cached left padding for ConvNeXt module,
of shape (batch_size, num_channels, left_pad, num_freqs)
"""
left_pad = self.convnext.padding[0]
freq = self.out_width
channels = self.layer3_channels
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
device
)
return cached_embed_left_pad

View File

@ -76,13 +76,13 @@ from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import AsrModel from model import AsrModel
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from scaling import ScheduledFloat from scaling_bf16 import ScheduledFloat
from subsampling import Conv2dSubsampling from subsampling_bf16 import Conv2dSubsampling
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from zipformer_bf16 import Zipformer2 from zipformer_full_bf16 import Zipformer2
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
@ -870,6 +870,8 @@ def compute_loss(
# at entry, feature is (N, T, C) # at entry, feature is (N, T, C)
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
if params.full_bf16:
feature = feature.to(torch.bfloat16)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
@ -1041,7 +1043,7 @@ def train_one_epoch(
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16, dtype=params.dtype): with torch.cuda.amp.autocast(enabled=params.use_autocast, dtype=params.dtype):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
model=model, model=model,
@ -1054,11 +1056,16 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward() if params.use_autocast:
scheduler.step_batch(params.batch_idx_train) scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
else:
loss.backward()
scheduler.step_batch(params.batch_idx_train)
optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
except Exception as e: except Exception as e:
logging.info( logging.info(
@ -1104,7 +1111,7 @@ def train_one_epoch(
rank=rank, rank=rank,
) )
if batch_idx % 100 == 0 and params.use_fp16: if batch_idx % 100 == 0 and params.use_autocast:
# If the grad scale was less than 1, try increasing it. The _growth_interval # If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different # of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale. # behavior depending on the current grad scale.
@ -1123,14 +1130,14 @@ def train_one_epoch(
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr()) cur_lr = max(scheduler.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0
logging.info( logging.info(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "")
) )
if tb_writer is not None: if tb_writer is not None:
@ -1242,11 +1249,16 @@ def run(rank, world_size, args):
if params.use_fp16: if params.use_fp16:
params.dtype = torch.float16 if not params.use_bf16 else torch.bfloat16 params.dtype = torch.float16 if not params.use_bf16 else torch.bfloat16
params.use_autocast = True
else: else:
params.dtype = torch.float32 params.dtype = torch.float32
params.use_autocast = False
logging.info(f"Training using: {params.dtype}") logging.info(f"Training using: {params.dtype}")
model.to(params.dtype)
if params.full_bf16
if params.full_bf16:
assert params.use_bf16
params.use_autocast = False # use full bf16 training, no autocast and grad scaling
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
@ -1352,16 +1364,16 @@ def run(rank, world_size, args):
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
# if not params.print_diagnostics: if not params.print_diagnostics:
# scan_pessimistic_batches_for_oom( scan_pessimistic_batches_for_oom(
# model=model, model=model,
# train_dl=train_dl, train_dl=train_dl,
# optimizer=optimizer, optimizer=optimizer,
# sp=sp, sp=sp,
# params=params, params=params,
# ) )
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints: if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict") logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"]) scaler.load_state_dict(checkpoints["grad_scaler"])
@ -1461,7 +1473,7 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_autocast, dtype=params.dtype):
loss, _ = compute_loss( loss, _ = compute_loss(
params=params, params=params,
model=model, model=model,

File diff suppressed because it is too large Load Diff