mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
make modifications to support full bf16 training
This commit is contained in:
parent
0c29c45c32
commit
5a05da8fcc
@ -296,10 +296,6 @@ class SoftmaxFunction(torch.autograd.Function):
|
|||||||
# if x dtype is float16, x.softmax() returns a float32 because
|
# if x dtype is float16, x.softmax() returns a float32 because
|
||||||
# (presumably) that op does not support float16, and autocast
|
# (presumably) that op does not support float16, and autocast
|
||||||
# is enabled.
|
# is enabled.
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
if torch.is_autocast_enabled():
|
|
||||||
# ans = ans.to(torch.float16)
|
|
||||||
ans = ans.to(ans.dtype)
|
|
||||||
ctx.save_for_backward(ans)
|
ctx.save_for_backward(ans)
|
||||||
ctx.x_dtype = x.dtype
|
ctx.x_dtype = x.dtype
|
||||||
ctx.dim = dim
|
ctx.dim = dim
|
||||||
@ -309,10 +305,6 @@ class SoftmaxFunction(torch.autograd.Function):
|
|||||||
def backward(ctx, ans_grad: Tensor):
|
def backward(ctx, ans_grad: Tensor):
|
||||||
(ans,) = ctx.saved_tensors
|
(ans,) = ctx.saved_tensors
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
if ctx.x_dtype == torch.float16:
|
|
||||||
ans_grad = ans_grad.to(torch.float32)
|
|
||||||
ans = ans.to(torch.float32)
|
|
||||||
x_grad = ans_grad * ans
|
x_grad = ans_grad * ans
|
||||||
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||||
return x_grad, None
|
return x_grad, None
|
||||||
@ -764,9 +756,6 @@ class BalancerFunction(torch.autograd.Function):
|
|||||||
try:
|
try:
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
if x.dtype == torch.float16:
|
|
||||||
x = x.to(torch.float32)
|
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
mean_dims = [i for i in range(x.ndim) if i != channel_dim]
|
mean_dims = [i for i in range(x.ndim) if i != channel_dim]
|
||||||
@ -797,15 +786,17 @@ class BalancerFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
|
loss_grad = loss_grad * (grad_scale / loss_grad_rms)
|
||||||
|
|
||||||
if x_grad.dtype == torch.float16:
|
# if x_grad.dtype == torch.float16:
|
||||||
x_grad_float = x_grad.to(torch.float32)
|
# x_grad_float = x_grad.to(torch.float32)
|
||||||
else:
|
# else:
|
||||||
x_grad_float = x_grad
|
# x_grad_float = x_grad
|
||||||
|
|
||||||
# scale each element of loss_grad by the absolute value of the corresponding
|
# scale each element of loss_grad by the absolute value of the corresponding
|
||||||
# element of x_grad, which we view as a noisy estimate of its magnitude for that
|
# element of x_grad, which we view as a noisy estimate of its magnitude for that
|
||||||
# (frame and dimension). later we can consider factored versions.
|
# (frame and dimension). later we can consider factored versions.
|
||||||
x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
|
# x_grad_mod = x_grad_float + (x_grad_float.abs() * loss_grad)
|
||||||
x_grad = x_grad_mod.to(x_grad.dtype)
|
x_grad = x_grad + (x_grad.abs() * loss_grad)
|
||||||
|
# x_grad = x_grad_mod.to(x_grad.dtype)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue."
|
f"Caught exception in Balancer backward: {e}, size={list(x_grad.shape)}, will continue."
|
||||||
@ -1025,11 +1016,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
|||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
dtype = x_orig.dtype
|
dtype = x_orig.dtype
|
||||||
# import pdb; pdb.set_trace()
|
x_detached = x_orig.detach()
|
||||||
if x_orig.dtype == torch.float16:
|
|
||||||
x_detached = x_orig.to(torch.float32).detach()
|
|
||||||
else:
|
|
||||||
x_detached = x_orig.detach()
|
|
||||||
x_detached.requires_grad = True
|
x_detached.requires_grad = True
|
||||||
|
|
||||||
metric = _whitening_metric(x_detached, w.num_groups)
|
metric = _whitening_metric(x_detached, w.num_groups)
|
||||||
@ -1248,8 +1235,6 @@ class DoubleSwishFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor) -> Tensor:
|
def forward(ctx, x: Tensor) -> Tensor:
|
||||||
requires_grad = x.requires_grad
|
requires_grad = x.requires_grad
|
||||||
if x.dtype == torch.float16:
|
|
||||||
x = x.to(torch.float32)
|
|
||||||
|
|
||||||
s = torch.sigmoid(x - 1.0)
|
s = torch.sigmoid(x - 1.0)
|
||||||
y = x * s
|
y = x * s
|
||||||
@ -1360,8 +1345,6 @@ class SwooshLFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor) -> Tensor:
|
def forward(ctx, x: Tensor) -> Tensor:
|
||||||
requires_grad = x.requires_grad
|
requires_grad = x.requires_grad
|
||||||
if x.dtype == torch.float16:
|
|
||||||
x = x.to(torch.float32)
|
|
||||||
|
|
||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
@ -1415,10 +1398,11 @@ class SwooshL(torch.nn.Module):
|
|||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
return logaddexp(zero, x - 4.0) - 0.08 * x - 0.035
|
||||||
if not x.requires_grad:
|
if not x.requires_grad:
|
||||||
return k2.swoosh_l_forward(x)
|
# return k2.swoosh_l_forward(x)
|
||||||
|
return SwooshLForward(x)
|
||||||
else:
|
else:
|
||||||
return k2.swoosh_l(x)
|
# return k2.swoosh_l(x)
|
||||||
# return SwooshLFunction.apply(x)
|
return SwooshLFunction.apply(x) # this support bf16
|
||||||
|
|
||||||
|
|
||||||
class SwooshLOnnx(torch.nn.Module):
|
class SwooshLOnnx(torch.nn.Module):
|
||||||
@ -1489,10 +1473,11 @@ class SwooshR(torch.nn.Module):
|
|||||||
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||||
return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
|
return logaddexp(zero, x - 1.0) - 0.08 * x - 0.313261687
|
||||||
if not x.requires_grad:
|
if not x.requires_grad:
|
||||||
return k2.swoosh_r_forward(x)
|
# return k2.swoosh_r_forward(x)
|
||||||
|
return SwooshRForward(x)
|
||||||
else:
|
else:
|
||||||
return k2.swoosh_r(x)
|
# return k2.swoosh_r(x)
|
||||||
# return SwooshRFunction.apply(x)
|
return SwooshRFunction.apply(x)
|
||||||
|
|
||||||
|
|
||||||
class SwooshROnnx(torch.nn.Module):
|
class SwooshROnnx(torch.nn.Module):
|
||||||
@ -1647,6 +1632,7 @@ class ActivationDropoutAndLinear(torch.nn.Module):
|
|||||||
self.activation = activation
|
self.activation = activation
|
||||||
self.dropout_p = dropout_p
|
self.dropout_p = dropout_p
|
||||||
self.dropout_shared_dim = dropout_shared_dim
|
self.dropout_shared_dim = dropout_shared_dim
|
||||||
|
self.dropout = Dropout3(dropout_p, shared_dim=dropout_shared_dim)
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
def forward(self, x: Tensor):
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
@ -1658,14 +1644,23 @@ class ActivationDropoutAndLinear(torch.nn.Module):
|
|||||||
assert False, self.activation
|
assert False, self.activation
|
||||||
return torch.nn.functional.linear(x, self.weight, self.bias)
|
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||||
|
|
||||||
return ActivationDropoutAndLinearFunction.apply(
|
if self.activation == "SwooshL":
|
||||||
x,
|
x = SwooshL()(x)
|
||||||
self.weight,
|
elif self.activation == "SwooshR":
|
||||||
self.bias,
|
x = SwooshR()(x)
|
||||||
self.activation,
|
|
||||||
float(self.dropout_p),
|
x = self.dropout(x)
|
||||||
self.dropout_shared_dim,
|
return torch.nn.functional.linear(x, self.weight, self.bias)
|
||||||
)
|
|
||||||
|
|
||||||
|
# return ActivationDropoutAndLinearFunction.apply(
|
||||||
|
# x,
|
||||||
|
# self.weight,
|
||||||
|
# self.bias,
|
||||||
|
# self.activation,
|
||||||
|
# float(self.dropout_p),
|
||||||
|
# self.dropout_shared_dim,
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
|
def convert_num_channels(x: Tensor, num_channels: int) -> Tensor:
|
||||||
|
406
egs/librispeech/ASR/zipformer/subsampling_bf16.py
Normal file
406
egs/librispeech/ASR/zipformer/subsampling_bf16.py
Normal file
@ -0,0 +1,406 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Daniel Povey,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from scaling_bf16 import (
|
||||||
|
Balancer,
|
||||||
|
BiasNorm,
|
||||||
|
Dropout3,
|
||||||
|
FloatLike,
|
||||||
|
Optional,
|
||||||
|
ScaledConv2d,
|
||||||
|
ScaleGrad,
|
||||||
|
ScheduledFloat,
|
||||||
|
SwooshL,
|
||||||
|
SwooshR,
|
||||||
|
Whiten,
|
||||||
|
)
|
||||||
|
from torch import Tensor, nn
|
||||||
|
|
||||||
|
|
||||||
|
class ConvNeXt(nn.Module):
|
||||||
|
"""
|
||||||
|
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channels: int,
|
||||||
|
hidden_ratio: int = 3,
|
||||||
|
kernel_size: Tuple[int, int] = (7, 7),
|
||||||
|
layerdrop_rate: FloatLike = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2)
|
||||||
|
hidden_channels = channels * hidden_ratio
|
||||||
|
if layerdrop_rate is None:
|
||||||
|
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015))
|
||||||
|
self.layerdrop_rate = layerdrop_rate
|
||||||
|
|
||||||
|
self.depthwise_conv = nn.Conv2d(
|
||||||
|
in_channels=channels,
|
||||||
|
out_channels=channels,
|
||||||
|
groups=channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
padding=self.padding,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pointwise_conv1 = nn.Conv2d(
|
||||||
|
in_channels=channels, out_channels=hidden_channels, kernel_size=1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hidden_balancer = Balancer(
|
||||||
|
hidden_channels,
|
||||||
|
channel_dim=1,
|
||||||
|
min_positive=0.3,
|
||||||
|
max_positive=1.0,
|
||||||
|
min_abs=0.75,
|
||||||
|
max_abs=5.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.activation = SwooshL()
|
||||||
|
self.pointwise_conv2 = ScaledConv2d(
|
||||||
|
in_channels=hidden_channels,
|
||||||
|
out_channels=channels,
|
||||||
|
kernel_size=1,
|
||||||
|
initial_scale=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_balancer = Balancer(
|
||||||
|
channels,
|
||||||
|
channel_dim=1,
|
||||||
|
min_positive=0.4,
|
||||||
|
max_positive=0.6,
|
||||||
|
min_abs=1.0,
|
||||||
|
max_abs=6.0,
|
||||||
|
)
|
||||||
|
self.out_whiten = Whiten(
|
||||||
|
num_groups=1,
|
||||||
|
whitening_limit=5.0,
|
||||||
|
prob=(0.025, 0.25),
|
||||||
|
grad_scale=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training:
|
||||||
|
return self.forward_internal(x)
|
||||||
|
layerdrop_rate = float(self.layerdrop_rate)
|
||||||
|
|
||||||
|
if layerdrop_rate != 0.0:
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
mask = (
|
||||||
|
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||||
|
> layerdrop_rate
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
# turns out this caching idea does not work with --world-size > 1
|
||||||
|
# return caching_eval(self.forward_internal, x, mask)
|
||||||
|
return self.forward_internal(x, mask)
|
||||||
|
|
||||||
|
def forward_internal(
|
||||||
|
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
||||||
|
|
||||||
|
The returned value has the same shape as x.
|
||||||
|
"""
|
||||||
|
bypass = x
|
||||||
|
x = self.depthwise_conv(x)
|
||||||
|
x = self.pointwise_conv1(x)
|
||||||
|
x = self.hidden_balancer(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.pointwise_conv2(x)
|
||||||
|
|
||||||
|
if layer_skip_mask is not None:
|
||||||
|
x = x * layer_skip_mask
|
||||||
|
|
||||||
|
x = bypass + x
|
||||||
|
x = self.out_balancer(x)
|
||||||
|
|
||||||
|
if x.requires_grad:
|
||||||
|
x = x.transpose(1, 3) # (N, W, H, C); need channel dim to be last
|
||||||
|
x = self.out_whiten(x)
|
||||||
|
x = x.transpose(1, 3) # (N, C, H, W)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def streaming_forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
cached_left_pad: Tensor,
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs)
|
||||||
|
cached_left_pad: (batch_size, num_channels, left_pad, num_freqs)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- The returned value has the same shape as x.
|
||||||
|
- Updated cached_left_pad.
|
||||||
|
"""
|
||||||
|
padding = self.padding
|
||||||
|
|
||||||
|
# The length without right padding for depth-wise conv
|
||||||
|
T = x.size(2) - padding[0]
|
||||||
|
|
||||||
|
bypass = x[:, :, :T, :]
|
||||||
|
|
||||||
|
# Pad left side
|
||||||
|
assert cached_left_pad.size(2) == padding[0], (
|
||||||
|
cached_left_pad.size(2),
|
||||||
|
padding[0],
|
||||||
|
)
|
||||||
|
x = torch.cat([cached_left_pad, x], dim=2)
|
||||||
|
# Update cached left padding
|
||||||
|
cached_left_pad = x[:, :, T : padding[0] + T, :]
|
||||||
|
|
||||||
|
# depthwise_conv
|
||||||
|
x = torch.nn.functional.conv2d(
|
||||||
|
x,
|
||||||
|
weight=self.depthwise_conv.weight,
|
||||||
|
bias=self.depthwise_conv.bias,
|
||||||
|
padding=(0, padding[1]),
|
||||||
|
groups=self.depthwise_conv.groups,
|
||||||
|
)
|
||||||
|
x = self.pointwise_conv1(x)
|
||||||
|
x = self.hidden_balancer(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.pointwise_conv2(x)
|
||||||
|
|
||||||
|
x = bypass + x
|
||||||
|
return x, cached_left_pad
|
||||||
|
|
||||||
|
|
||||||
|
class Conv2dSubsampling(nn.Module):
|
||||||
|
"""Convolutional 2D subsampling (to 1/2 length).
|
||||||
|
|
||||||
|
Convert an input of shape (N, T, idim) to an output
|
||||||
|
with shape (N, T', odim), where
|
||||||
|
T' = (T-3)//2 - 2 == (T-7)//2
|
||||||
|
|
||||||
|
It is based on
|
||||||
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
layer1_channels: int = 8,
|
||||||
|
layer2_channels: int = 32,
|
||||||
|
layer3_channels: int = 128,
|
||||||
|
dropout: FloatLike = 0.1,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
in_channels:
|
||||||
|
Number of channels in. The input shape is (N, T, in_channels).
|
||||||
|
Caution: It requires: T >=7, in_channels >=7
|
||||||
|
out_channels
|
||||||
|
Output dim. The output shape is (N, (T-3)//2, out_channels)
|
||||||
|
layer1_channels:
|
||||||
|
Number of channels in layer1
|
||||||
|
layer1_channels:
|
||||||
|
Number of channels in layer2
|
||||||
|
bottleneck:
|
||||||
|
bottleneck dimension for 1d squeeze-excite
|
||||||
|
"""
|
||||||
|
assert in_channels >= 7
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# The ScaleGrad module is there to prevent the gradients
|
||||||
|
# w.r.t. the weight or bias of the first Conv2d module in self.conv from
|
||||||
|
# exceeding the range of fp16 when using automatic mixed precision (amp)
|
||||||
|
# training. (The second one is necessary to stop its bias from getting
|
||||||
|
# a too-large gradient).
|
||||||
|
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=1,
|
||||||
|
out_channels=layer1_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=(0, 1), # (time, freq)
|
||||||
|
),
|
||||||
|
ScaleGrad(0.2),
|
||||||
|
Balancer(layer1_channels, channel_dim=1, max_abs=1.0),
|
||||||
|
SwooshR(),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=layer1_channels,
|
||||||
|
out_channels=layer2_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=0,
|
||||||
|
),
|
||||||
|
Balancer(layer2_channels, channel_dim=1, max_abs=4.0),
|
||||||
|
SwooshR(),
|
||||||
|
nn.Conv2d(
|
||||||
|
in_channels=layer2_channels,
|
||||||
|
out_channels=layer3_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=(1, 2), # (time, freq)
|
||||||
|
),
|
||||||
|
Balancer(layer3_channels, channel_dim=1, max_abs=4.0),
|
||||||
|
SwooshR(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# just one convnext layer
|
||||||
|
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7))
|
||||||
|
|
||||||
|
# (in_channels-3)//4
|
||||||
|
self.out_width = (((in_channels - 1) // 2) - 1) // 2
|
||||||
|
self.layer3_channels = layer3_channels
|
||||||
|
|
||||||
|
self.out = nn.Linear(self.out_width * layer3_channels, out_channels)
|
||||||
|
# use a larger than normal grad_scale on this whitening module; there is
|
||||||
|
# only one such module, so there is not a concern about adding together
|
||||||
|
# many copies of this extra gradient term.
|
||||||
|
self.out_whiten = Whiten(
|
||||||
|
num_groups=1,
|
||||||
|
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
|
||||||
|
prob=(0.025, 0.25),
|
||||||
|
grad_scale=0.02,
|
||||||
|
)
|
||||||
|
|
||||||
|
# max_log_eps=0.0 is to prevent both eps and the output of self.out from
|
||||||
|
# getting large, there is an unnecessary degree of freedom.
|
||||||
|
self.out_norm = BiasNorm(out_channels)
|
||||||
|
self.dropout = Dropout3(dropout, shared_dim=1)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Subsample x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
Its shape is (N, T, idim).
|
||||||
|
x_lens:
|
||||||
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- a tensor of shape (N, (T-7)//2, odim)
|
||||||
|
- output lengths, of shape (batch_size,)
|
||||||
|
"""
|
||||||
|
# On entry, x is (N, T, idim)
|
||||||
|
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||||
|
# scaling x by 0.1 allows us to use a larger grad-scale in fp16 "amp" (automatic mixed precision)
|
||||||
|
# training, since the weights in the first convolution are otherwise the limiting factor for getting infinite
|
||||||
|
# gradients.
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.convnext(x)
|
||||||
|
|
||||||
|
# Now x is of shape (N, odim, (T-7)//2, (idim-3)//4)
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
|
||||||
|
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||||
|
# now x: (N, (T-7)//2, out_width * layer3_channels))
|
||||||
|
|
||||||
|
x = self.out(x)
|
||||||
|
# Now x is of shape (N, (T-7)//2, odim)
|
||||||
|
x = self.out_whiten(x)
|
||||||
|
x = self.out_norm(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
|
||||||
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
|
x_lens = (x_lens - 7) // 2
|
||||||
|
else:
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
x_lens = (x_lens - 7) // 2
|
||||||
|
assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max())
|
||||||
|
|
||||||
|
return x, x_lens
|
||||||
|
|
||||||
|
def streaming_forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
cached_left_pad: Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Subsample x.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
Its shape is (N, T, idim).
|
||||||
|
x_lens:
|
||||||
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- a tensor of shape (N, (T-7)//2, odim)
|
||||||
|
- output lengths, of shape (batch_size,)
|
||||||
|
- updated cache
|
||||||
|
"""
|
||||||
|
# On entry, x is (N, T, idim)
|
||||||
|
x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W)
|
||||||
|
|
||||||
|
# T' = (T-7)//2
|
||||||
|
x = self.conv(x)
|
||||||
|
|
||||||
|
# T' = (T-7)//2-3
|
||||||
|
x, cached_left_pad = self.convnext.streaming_forward(
|
||||||
|
x, cached_left_pad=cached_left_pad
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now x is of shape (N, odim, T', ((idim-1)//2 - 1)//2)
|
||||||
|
b, c, t, f = x.size()
|
||||||
|
|
||||||
|
x = x.transpose(1, 2).reshape(b, t, c * f)
|
||||||
|
# now x: (N, T', out_width * layer3_channels))
|
||||||
|
|
||||||
|
x = self.out(x)
|
||||||
|
# Now x is of shape (N, T', odim)
|
||||||
|
x = self.out_norm(x)
|
||||||
|
|
||||||
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
|
assert self.convnext.padding[0] == 3
|
||||||
|
# The ConvNeXt module needs 3 frames of right padding after subsampling
|
||||||
|
x_lens = (x_lens - 7) // 2 - 3
|
||||||
|
else:
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
# The ConvNeXt module needs 3 frames of right padding after subsampling
|
||||||
|
assert self.convnext.padding[0] == 3
|
||||||
|
x_lens = (x_lens - 7) // 2 - 3
|
||||||
|
|
||||||
|
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max())
|
||||||
|
|
||||||
|
return x, x_lens, cached_left_pad
|
||||||
|
|
||||||
|
@torch.jit.export
|
||||||
|
def get_init_states(
|
||||||
|
self,
|
||||||
|
batch_size: int = 1,
|
||||||
|
device: torch.device = torch.device("cpu"),
|
||||||
|
) -> Tensor:
|
||||||
|
"""Get initial states for Conv2dSubsampling module.
|
||||||
|
It is the cached left padding for ConvNeXt module,
|
||||||
|
of shape (batch_size, num_channels, left_pad, num_freqs)
|
||||||
|
"""
|
||||||
|
left_pad = self.convnext.padding[0]
|
||||||
|
freq = self.out_width
|
||||||
|
channels = self.layer3_channels
|
||||||
|
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
|
||||||
|
device
|
||||||
|
)
|
||||||
|
|
||||||
|
return cached_embed_left_pad
|
@ -76,13 +76,13 @@ from lhotse.dataset.sampling.base import CutSampler
|
|||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import AsrModel
|
from model import AsrModel
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from scaling import ScheduledFloat
|
from scaling_bf16 import ScheduledFloat
|
||||||
from subsampling import Conv2dSubsampling
|
from subsampling_bf16 import Conv2dSubsampling
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from zipformer_bf16 import Zipformer2
|
from zipformer_full_bf16 import Zipformer2
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
@ -870,6 +870,8 @@ def compute_loss(
|
|||||||
# at entry, feature is (N, T, C)
|
# at entry, feature is (N, T, C)
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
|
if params.full_bf16:
|
||||||
|
feature = feature.to(torch.bfloat16)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
@ -1041,7 +1043,7 @@ def train_one_epoch(
|
|||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16, dtype=params.dtype):
|
with torch.cuda.amp.autocast(enabled=params.use_autocast, dtype=params.dtype):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
@ -1054,11 +1056,16 @@ def train_one_epoch(
|
|||||||
|
|
||||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||||
# in the batch and there is no normalization to it so far.
|
# in the batch and there is no normalization to it so far.
|
||||||
scaler.scale(loss).backward()
|
if params.use_autocast:
|
||||||
scheduler.step_batch(params.batch_idx_train)
|
scaler.scale(loss).backward()
|
||||||
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
|
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
scheduler.step_batch(params.batch_idx_train)
|
||||||
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(
|
logging.info(
|
||||||
@ -1104,7 +1111,7 @@ def train_one_epoch(
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
if batch_idx % 100 == 0 and params.use_fp16:
|
if batch_idx % 100 == 0 and params.use_autocast:
|
||||||
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
# If the grad scale was less than 1, try increasing it. The _growth_interval
|
||||||
# of the grad scaler is configurable, but we can't configure it to have different
|
# of the grad scaler is configurable, but we can't configure it to have different
|
||||||
# behavior depending on the current grad scale.
|
# behavior depending on the current grad scale.
|
||||||
@ -1123,14 +1130,14 @@ def train_one_epoch(
|
|||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
cur_lr = max(scheduler.get_last_lr())
|
cur_lr = max(scheduler.get_last_lr())
|
||||||
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
|
cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {batch_idx}, loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
f"lr: {cur_lr:.2e}, "
|
f"lr: {cur_lr:.2e}, "
|
||||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -1242,11 +1249,16 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.use_fp16:
|
if params.use_fp16:
|
||||||
params.dtype = torch.float16 if not params.use_bf16 else torch.bfloat16
|
params.dtype = torch.float16 if not params.use_bf16 else torch.bfloat16
|
||||||
|
params.use_autocast = True
|
||||||
else:
|
else:
|
||||||
params.dtype = torch.float32
|
params.dtype = torch.float32
|
||||||
|
params.use_autocast = False
|
||||||
logging.info(f"Training using: {params.dtype}")
|
logging.info(f"Training using: {params.dtype}")
|
||||||
|
model.to(params.dtype)
|
||||||
|
|
||||||
if params.full_bf16
|
if params.full_bf16:
|
||||||
|
assert params.use_bf16
|
||||||
|
params.use_autocast = False # use full bf16 training, no autocast and grad scaling
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
@ -1352,16 +1364,16 @@ def run(rank, world_size, args):
|
|||||||
valid_cuts += librispeech.dev_other_cuts()
|
valid_cuts += librispeech.dev_other_cuts()
|
||||||
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
valid_dl = librispeech.valid_dataloaders(valid_cuts)
|
||||||
|
|
||||||
# if not params.print_diagnostics:
|
if not params.print_diagnostics:
|
||||||
# scan_pessimistic_batches_for_oom(
|
scan_pessimistic_batches_for_oom(
|
||||||
# model=model,
|
model=model,
|
||||||
# train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
# optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
# sp=sp,
|
sp=sp,
|
||||||
# params=params,
|
params=params,
|
||||||
# )
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
|
||||||
if checkpoints and "grad_scaler" in checkpoints:
|
if checkpoints and "grad_scaler" in checkpoints:
|
||||||
logging.info("Loading grad scaler state dict")
|
logging.info("Loading grad scaler state dict")
|
||||||
scaler.load_state_dict(checkpoints["grad_scaler"])
|
scaler.load_state_dict(checkpoints["grad_scaler"])
|
||||||
@ -1461,7 +1473,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
for criterion, cuts in batches.items():
|
for criterion, cuts in batches.items():
|
||||||
batch = train_dl.dataset[cuts]
|
batch = train_dl.dataset[cuts]
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_autocast, dtype=params.dtype):
|
||||||
loss, _ = compute_loss(
|
loss, _ = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
|
2437
egs/librispeech/ASR/zipformer/zipformer_full_bf16.py
Normal file
2437
egs/librispeech/ASR/zipformer/zipformer_full_bf16.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user