mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Change for memory efficiency
This commit is contained in:
parent
903955f5d9
commit
d31e2e12c6
@ -439,23 +439,29 @@ class BasicNormFunction(torch.autograd.Function):
|
|||||||
# some other reason, related to the next operation, so we can save memory).
|
# some other reason, related to the next operation, so we can save memory).
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd
|
@custom_fwd
|
||||||
def forward(ctx, x: Tensor, bias: Tensor, eps: Tensor, channel_dim: int) -> Tensor:
|
def forward(ctx, x: Tensor, bias: Tensor, eps: Tensor, channel_dim: int,
|
||||||
|
store_output_for_backprop: bool) -> Tensor:
|
||||||
assert bias.ndim == 1
|
assert bias.ndim == 1
|
||||||
if channel_dim < 0:
|
if channel_dim < 0:
|
||||||
channel_dim = channel_dim + x.ndim
|
channel_dim = channel_dim + x.ndim
|
||||||
|
ctx.store_output_for_backprop = store_output_for_backprop
|
||||||
ctx.channel_dim = channel_dim
|
ctx.channel_dim = channel_dim
|
||||||
for _ in range(channel_dim + 1, x.ndim):
|
for _ in range(channel_dim + 1, x.ndim):
|
||||||
bias = bias.unsqueeze(-1)
|
bias = bias.unsqueeze(-1)
|
||||||
scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
scales = (torch.mean((x + bias) ** 2, dim=channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
||||||
ans = x * scales
|
ans = x * scales
|
||||||
ctx.save_for_backward(ans, scales, bias, eps)
|
ctx.save_for_backward(ans if store_output_for_backprop else x,
|
||||||
|
scales, bias, eps)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, ans_grad: Tensor) -> Tensor:
|
def backward(ctx, ans_grad: Tensor) -> Tensor:
|
||||||
ans, scales, bias, eps = ctx.saved_tensors
|
ans_or_x, scales, bias, eps = ctx.saved_tensors
|
||||||
x = ans / scales
|
if ctx.store_output_for_backprop:
|
||||||
|
x = ans_or_x / scales
|
||||||
|
else:
|
||||||
|
x = ans_or_x
|
||||||
x = x.detach()
|
x = x.detach()
|
||||||
bias = bias.detach()
|
bias = bias.detach()
|
||||||
eps = eps.detach()
|
eps = eps.detach()
|
||||||
@ -467,7 +473,7 @@ class BasicNormFunction(torch.autograd.Function):
|
|||||||
scales = (torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
scales = (torch.mean((x + bias) ** 2, dim=ctx.channel_dim, keepdim=True) + eps.exp()) ** -0.5
|
||||||
ans = x * scales
|
ans = x * scales
|
||||||
ans.backward(gradient=ans_grad)
|
ans.backward(gradient=ans_grad)
|
||||||
return x.grad, bias.grad.flatten(), eps.grad, None
|
return x.grad, bias.grad.flatten(), eps.grad, None, None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -497,10 +503,13 @@ class BasicNorm(torch.nn.Module):
|
|||||||
to indicate the connection with conventional LayerNorm.
|
to indicate the connection with conventional LayerNorm.
|
||||||
learn_eps: if true, we learn epsilon; if false, we keep it
|
learn_eps: if true, we learn epsilon; if false, we keep it
|
||||||
at the initial value.
|
at the initial value.
|
||||||
|
store_output_for_backprop: this option makes no difference
|
||||||
|
to the output, but may affect memory usage; determines
|
||||||
|
whether, for backprop purposes, we store the input or the output
|
||||||
|
of this module.
|
||||||
eps_min: float
|
eps_min: float
|
||||||
eps_max: float
|
eps_max: float
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_channels: int,
|
num_channels: int,
|
||||||
@ -509,6 +518,7 @@ class BasicNorm(torch.nn.Module):
|
|||||||
learn_eps: bool = True,
|
learn_eps: bool = True,
|
||||||
eps_min: float = -3.0,
|
eps_min: float = -3.0,
|
||||||
eps_max: float = 3.0,
|
eps_max: float = 3.0,
|
||||||
|
store_output_for_backprop: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
super(BasicNorm, self).__init__()
|
super(BasicNorm, self).__init__()
|
||||||
self.num_channels = num_channels
|
self.num_channels = num_channels
|
||||||
@ -520,6 +530,7 @@ class BasicNorm(torch.nn.Module):
|
|||||||
self.bias = nn.Parameter(torch.zeros(num_channels))
|
self.bias = nn.Parameter(torch.zeros(num_channels))
|
||||||
self.eps_min = eps_min
|
self.eps_min = eps_min
|
||||||
self.eps_max = eps_max
|
self.eps_max = eps_max
|
||||||
|
self.store_output_for_backprop = store_output_for_backprop
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
assert x.shape[self.channel_dim] == self.num_channels
|
assert x.shape[self.channel_dim] == self.num_channels
|
||||||
@ -544,7 +555,8 @@ class BasicNorm(torch.nn.Module):
|
|||||||
# region if it happens to exit it.
|
# region if it happens to exit it.
|
||||||
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
eps = eps.clamp(min=self.eps_min, max=self.eps_max)
|
||||||
|
|
||||||
return BasicNormFunction.apply(x, self.bias, eps, self.channel_dim)
|
return BasicNormFunction.apply(x, self.bias, eps, self.channel_dim,
|
||||||
|
self.store_output_for_backprop)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -451,7 +451,7 @@ class ZipformerEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2)
|
||||||
|
|
||||||
self.norm_final = BasicNorm(embed_dim)
|
self.norm_final = BasicNorm(embed_dim, store_output_for_backprop=False)
|
||||||
|
|
||||||
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user