2021-12-23 18:33:39 +08:00

79 lines
2.2 KiB
Python

import torch
from torch import nn
from torch import Tensor
from typing import Tuple, Callable
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, function: Callable, *args):
# `function` must return either a Tensor or a tuple of Tensors
ctx.function = function
ctx.args = [x.detach() if isinstance(x, Tensor) else x for x in args]
for i in range(len(ctx.args)):
if isinstance(args[i], Tensor) and args[i].requires_grad:
ctx.args[i].requires_grad = True
with torch.no_grad():
ans = function(*args)
return ans
@staticmethod
def backward(ctx, *ans_grads):
if not any([a is not None for a in ans_grads]):
return [None] * len(ctx.args)
with torch.enable_grad():
ans = ctx.function(*ctx.args)
if isinstance(ans, Tensor):
assert len(ans_grads) == 1
loss = (ans * ans_grads[0]).sum()
else:
assert len(ans_grads) == len(ans)
loss = torch.stack(
[
(a * g).sum()
for a, g in zip(ans, ans_grads)
if g is not None
]
).sum()
loss.backward()
return tuple(
[None]
+ [a.grad if isinstance(a, Tensor) else None for a in ctx.args]
)
def checkpoint(function, *args):
return CheckpointFunction.apply(function, *args)
def _test1():
x = torch.Tensor([0])
y = torch.Tensor([1])
y.requires_grad = True
l = lambda x, y, trash: torch.stack((x, y))
ans = checkpoint(l, x, y, None)
# ans = l(x, y, None)
print("ans = ", ans)
ans.sum().backward()
print("y grad = ", y.grad)
def _test2():
x = torch.Tensor([0])
y = torch.Tensor([1])
x.requires_grad = True
l = lambda x, y, trash: torch.stack((x, y))
ans = checkpoint(l, x, y, None)
ans = checkpoint(torch.sum, ans)
# ans = l(x, y, None)
print("ans = ", ans)
ans.backward()
print("x grad = ", x.grad)
if __name__ == "__main__":
_test1()
_test2()