pytorch-day07
动态计算图
- Pytorch 中的计算图是动态图
- 计算图的正向传播立即执行,无需等待完整的图创建完毕
- 计算图在反向传播后立即销毁,下次调用需要重新构建计算图。如果使用backward方法或者torch.autograd.grad 方法计算了梯度,创建的梯度会被立即销毁,释放储存空间。
1
2
3
4
5
6
7
8
9
10
11
12
13
14#计算图在反向传播之后立即销毁
import torch
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.randn(10,2)
Y = torch.randn(10,1)
Y_hat = X@w.t() + b # Y_hat定义后其正向传播被立即执行,与其后面的loss创建语句无关
loss = torch.mean(torch.pow(Y_hat-Y,2))
#计算图在反向传播后立即销毁,如果需要保留计算图, 需要设置retain_graph = True
loss.backward() #loss.backward(retain_graph = True)
#loss.backward() #如果再次执行反向传播将报错
计算图中的function
- 与python中的函数不同,同时包含正向计算逻辑和反向计算逻辑
- 继承
torch.autograd.Function来船舰支持反向传播的function1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17# 创建relu函数的示例
class MyReLU(torch.autograd.Function):
#正向传播逻辑,可以用ctx存储一些值,供反向传播使用。
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
#反向传播逻辑
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
- 创建的函数可以使用
apply属性进行调用1
2
3
4
5
6
7
8
9
10
11
12
13
14import torch
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.tensor([[-1.0,-1.0],[1.0,1.0]])
Y = torch.tensor([[2.0,3.0]])
relu = MyReLU.apply # relu现在也可以具有正向传播和反向传播功能
Y_hat = relu(X@w.t() + b)
loss = torch.mean(torch.pow(Y_hat-Y,2))
loss.backward()
print(w.grad)
print(b.grad)
计算图与反向传播
loss.backward()函数被调用后发生如下过程:- loss 自身的grad 被为赋值为1
- loss根据自身梯度以及管理那的backward 方法,计算出对应自变量的梯度并存入
y1.grad和y2.grad - y2和y1根据其自身梯度以及关联的backward方法, 分别计算出其对应的自变量x的梯度,x.grad将其收到的多个梯度值累加。
- 因为求导链式法则衍生的梯度累加规则,张量的grad不会自动清零,在需要的时候要手动清零
叶子节点和非叶子节点
- 在反向传播过程中,只有
is_leaf=True的叶子节点,需要求导的张量的导数结果才会被最后保留。- 叶子节点张量是由用户直接创建的张量,而非某个Function通过计算得到的张量
- 叶子节点张量的
requires_grad属性必须为True
- Pytorch这种设计是为了节约内存或显存空间,因为几乎所有时候用户都只关心自己创建张量的梯度
- 依赖于叶子节点张量的张量,其
require_grad属性必定为True,但梯度值只在计算中用到,不会被存储到grad属性中 - 如果需要保存可以使用
retain-grad方法,如果只是为了调试代码,可以利用register_hook打印日志.1
2
3
4
5
6
7
8
9
10
11
12import torch
x = torch.tensor(3.0,requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2)**2
loss.backward()
print("loss.grad:", loss.grad)
print("y1.grad:", y1.grad)
print("y2.grad:", y2.grad)
print(x.grad)1
2
3
4print(x.is_leaf)
print(y1.is_leaf)
print(y2.is_leaf)
print(loss.is_leaf)
All articles on this blog are licensed under CC BY-NC-SA 3.0 CN unless otherwise stated.