mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement memory-saving measure in randomized modules
This commit is contained in:
parent
86bb0623e9
commit
cc739b193a
@ -20,7 +20,7 @@ from itertools import repeat
|
||||
from typing import Optional, Tuple, Union
|
||||
from functools import reduce
|
||||
import logging
|
||||
|
||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -230,6 +230,96 @@ def random_cast_to_half(x: Tensor,
|
||||
return torch.where(is_too_small, random_val, x).to(torch.float16)
|
||||
|
||||
|
||||
|
||||
class CutoffEstimator:
|
||||
"""
|
||||
Estimates cutoffs of an arbitrary numerical quantity such that a specified
|
||||
proportion of items will be above the cutoff on average.
|
||||
p is the proportion of items that should be above the cutoff.
|
||||
|
||||
"""
|
||||
def __init__(self, p: float):
|
||||
self.p = p
|
||||
# total count of items
|
||||
self.count = 0
|
||||
# total count of items that were above the cutoff
|
||||
self.count_above = 0
|
||||
# initial cutoff value
|
||||
self.cutoff = 0
|
||||
|
||||
|
||||
def __call__(self, x: float) -> bool:
|
||||
"""
|
||||
Returns true if x is above the cutoff.
|
||||
"""
|
||||
ans = (x > self.cutoff)
|
||||
self.count += 1
|
||||
if ans:
|
||||
self.count_above += 1
|
||||
cur_p = self.count_above / self.count
|
||||
delta_p = cur_p - self.p
|
||||
if (delta_p > 0) == ans:
|
||||
q = abs(delta_p)
|
||||
self.cutoff = x * q + self.cutoff * (1-q)
|
||||
return ans
|
||||
|
||||
|
||||
class CachingEvalFunction(torch.autograd.Function):
|
||||
# @custom_fwd and @custom_bwd related to automatic mixed precision (amp) an ensure
|
||||
# that the backward path runs with the same autocast context as the forward pass.
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, x: Tensor, m) -> Tensor:
|
||||
"""
|
||||
m might be an nn.Module
|
||||
"""
|
||||
ctx.x_requires_grad = x.requires_grad
|
||||
ctx.m = m
|
||||
# we need any random numbers used in this evaluation and the next evaluation to be identical.
|
||||
# Caution: this assumes you are not going to use any random numbers from torch (for any purpose
|
||||
# that matters in the forward pass), e.g. there should be no dropout.
|
||||
ctx.random_state = random.getstate()
|
||||
# we are inside torch.no_grad() here, so the following won't create the computation graph.
|
||||
with torch.no_grad():
|
||||
y = m(x)
|
||||
ctx.save_for_backward(x, y)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, y_grad: Tensor):
|
||||
x, y = ctx.saved_tensors
|
||||
x = x.detach()
|
||||
x.requires_grad = ctx.x_requires_grad
|
||||
m = ctx.m # e.g. a nn.Module
|
||||
|
||||
random_state = random.getstate()
|
||||
# set the state to what we used in the 1st forward pass.
|
||||
random.setstate(ctx.random_state)
|
||||
with torch.enable_grad():
|
||||
y2 = m(x)
|
||||
assert torch.allclose(y, y2, atol=1.0e-02)
|
||||
# this call to backward() should create grads in the module's parameters
|
||||
y2.backward(gradient=y_grad)
|
||||
|
||||
# restore the state from before we entered this function
|
||||
random.setstate(random_state)
|
||||
|
||||
return x.grad, None # x.grad will be None if x.requires_grad is False.
|
||||
|
||||
|
||||
def caching_eval(x: Tensor, m: nn.Module) -> Tensor:
|
||||
if m.training:
|
||||
# The purpose of this code is to make all parameters of m reachable in
|
||||
# the computation graph, so that if we give find_unused_parameters=True
|
||||
# to PyTorch's autograd code it does not assign them zero gradient.
|
||||
tot = 0.0
|
||||
for p in m.parameters():
|
||||
tot = tot + 0.0 * p.flatten()[0]
|
||||
x = x + tot # tot will be 0, this does nothing.
|
||||
return CachingEvalFunction.apply(x, m)
|
||||
|
||||
|
||||
class RandomGradFunction(torch.autograd.Function):
|
||||
"""
|
||||
Does nothing in forward pass; in backward pass, gets rid of very small grads using
|
||||
@ -549,6 +639,9 @@ class ActivationBalancer(torch.nn.Module):
|
||||
if prob is None:
|
||||
prob = ScheduledFloat((0.0, 0.4), (8000.0, 0.1), default=0.4)
|
||||
self.prob = prob
|
||||
# 10% of the time we will return and do nothing because memory usage
|
||||
# is too high.
|
||||
self.mem_cutoff = CutoffEstimator(0.1)
|
||||
|
||||
# actually self.num_channels is no longer needed except for an assertion.
|
||||
self.num_channels = num_channels
|
||||
@ -562,11 +655,9 @@ class ActivationBalancer(torch.nn.Module):
|
||||
self.scale_gain_factor = scale_gain_factor
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if torch.jit.is_scripting() or not x.requires_grad:
|
||||
if (torch.jit.is_scripting() or not x.requires_grad or
|
||||
(x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))):
|
||||
return _no_op(x)
|
||||
|
||||
prob = float(self.prob)
|
||||
@ -720,7 +811,7 @@ class Whiten(nn.Module):
|
||||
num_groups: int,
|
||||
whitening_limit: FloatLike,
|
||||
prob: Union[float, Tuple[float,float]],
|
||||
grad_scale: float):
|
||||
grad_scale: FloatLike):
|
||||
"""
|
||||
Args:
|
||||
num_groups: the number of groups to divide the channel dim into before
|
||||
@ -745,6 +836,12 @@ class Whiten(nn.Module):
|
||||
assert grad_scale >= 0
|
||||
self.num_groups = num_groups
|
||||
self.whitening_limit = whitening_limit
|
||||
self.grad_scale = grad_scale
|
||||
|
||||
# 10% of the time we will return and do nothing because memory usage
|
||||
# is too high.
|
||||
self.mem_cutoff = CutoffEstimator(0.1)
|
||||
|
||||
if isinstance(prob, float):
|
||||
assert 0 < prob <= 1
|
||||
self.prob = prob
|
||||
@ -753,7 +850,6 @@ class Whiten(nn.Module):
|
||||
assert 0 < self.min_prob < self.max_prob <= 1
|
||||
self.prob = self.max_prob
|
||||
self.name = None # will be set in training loop
|
||||
self.grad_scale = grad_scale
|
||||
|
||||
def forward(self,
|
||||
x: Tensor) -> Tensor:
|
||||
@ -773,7 +869,9 @@ class Whiten(nn.Module):
|
||||
you use the returned value, or the graph will be freed
|
||||
and nothing will happen in backprop.
|
||||
"""
|
||||
if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0:
|
||||
grad_scale = float(self.grad_scale)
|
||||
if (not x.requires_grad or random.random() > self.prob or grad_scale == 0
|
||||
or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated()))):
|
||||
return _no_op(x)
|
||||
else:
|
||||
whitening_limit = float(self.whitening_limit)
|
||||
@ -789,7 +887,7 @@ class Whiten(nn.Module):
|
||||
return WhiteningPenaltyFunction.apply(x,
|
||||
self.num_groups,
|
||||
whitening_limit,
|
||||
self.grad_scale,
|
||||
grad_scale,
|
||||
self.name)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user