pytorch-day08
使用线性回归和DNN作为示例进行演示
- 首先手动创建训练数据集
1
2
3
4
5n = 400
X = 10*torch.rand([n,2])-5.0
w0 = torch.tensor([[2.0,3.0]])
b0 = torch.tensor([[10.0]]) # 这里为什么要创建一个二维的张量?
Y = X@w0 + b0 + torch.normal(0.0,2.0,size=[n,1]) - 创建数据管道
1
2
3
4
5
6
7
8def data_iter(features, labels, batch_size=8):
num_examples = len(features)
indices = list(range(num_examples))
np.random.shuffle(indices)
for i in range(0,num_examples,batch_size):
indexs = torch.LongTensor(indices[i:min(i + batch_size,num_examples)])
yield features.index_select(0, indexs), labels.index_select(0,indexs)
低阶API
- Pytorch 的低阶API主要包括张量操作,计算图和自动微分
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22# define the model
class LinearRegression:
def __init__(self):
self.w = torch.randn_like(w0,requires_grad = True)
self.b = torch.zeros_like(b0,requires = True)
def forward(self,x)
return x@self.w0 + self.b
def loss_func(self,y_pred,y_true):
return torch.mean((y_pred - y_true)**2/2)
# train the model
def train_step(model, features, labels):
predicitons = model.forward(features)
loss = model.loss_func(predictions, labels)
loss.backward()
with torch.no_grad():
model,w -= 0.001*model.w.grad
model.b -= 0.001*model.b.grad
model.w.grad.zero_()
model.b.grad.zero()
return loss
中阶API
- 包括各类模型层,损失函数,优化器,数据管道等等
- 一般来说,都会使用这里的API,有的模型需要自定义一些新的模型层,才有可能会用到低阶API的内容
1
2
3
4
5
6
7# load the data
ds = TensorDataset(X,Y)
dl = DataLoader(ds,batch_size = 10,shuffle = True,num_workers = 2)
# define the model
model = nn.Linear(2,1)
model.loss_fn = nn.MSELoss()
model.optimizer = torch.optim.SGD(model.parameters(),lr =0.01)
高阶API
- 对于每个模型,可以自己定义自己的API封装,这一点就因人而异了。
- 一般来说,可以把一些常用操作做成API用来调用,比如summary,eval等常用操作。
Something else
什么是yields
yields将一个普通函数变成”生成器函数”,调用它不会立刻执行完,而是返回一个生成器对象- 每次遇到
yield会产出一个值并挂起函数状态;下次迭代会从挂起处据悉执行 - return 结束函数而 yield 可以多次产出多个值
什么是 num_worker
- 指dataloader用多少个子进程并行加载/预取批次数据
- 0:在主机成加载,最稳定,最省内存,最少坑
- 1 开启多进程并行调用
getitem,在后台异步读取,常见于有耗时 I/O 或 CPU 预处理时(图像解码,数据增强)提速 - 更多worker会更快,但会占用更多内存
All articles on this blog are licensed under CC BY-NC-SA 3.0 CN unless otherwise stated.