pytorch-day-05
张量的数据类型
张量的数据类型与 numpy.array 基本一一对应,除了不支持str类型
- 一般的神经网络用的是
torch.float32类型 - 如果要显示指定数据类型,可以使用
torch.tensor(data,dtype = torch.type) - 也可以使用特定的构造函数
1
2
3i = torch.Inttensor() #构造数据类型为 int 的张量
x = torch.Tensor() # 构造数据类型为 float 的张量
b = torch.BoolTensor() #构造数据类型为 bool 的张量 - 此外,还可以对不同类型的张量进行转化
1
2
3
4i = torch.tensor(1) # 构建类型为int64的张量
x = i.float() # 调用float方法转换为float类型
y = i.type(torch.float) # 使用type函数转换为浮点类型
z = i.type_as(x) # 使用type_as 方法转化为与某个Tensor相同类型的张量
张量的维度
张量的尺寸
- 可以使用shape属性或者size() 方法查看张量在每一维的长度
- 可以使用view方法改变张量的尺寸
- view失败的情况下,可以使用reshape方法
view和reshape的区别:- view方法要求原张量在内存中是连续的,如果不连续则会失败;reshape则会自动处理布局
- view方法总是与原张量共享内存,返回的是原向量的”视图”; reshape则可能返回视图或者副本,取决于内存的布局
- 为什么不只使用
reshape- 性能考虑 view 更快,因为只是改变张量的元数据不涉及数据复制;reshape可能涉及到复制数据,会有额外的开销
- 内存效率 view保证内存共享,修改一个会影响另一个;reshape可能创建副本导致占用更多内存
- 语义的明确性 view 明确表示期望的是内存共享的试图操作;当view失败时,提醒开发者注意内存布局问题
- 以下是失败情况的例子
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29import torch
# 情况1: 转置后的张量不连续
x = torch.randn(3, 4)
y = x.transpose(0, 1) # 转置后内存不连续
print(y.is_contiguous()) # False
# 这会失败
try:
z = y.view(2, 6)
except RuntimeError as e:
print(f"view失败: {e}")
# reshape可以成功
z = y.reshape(2, 6) # 成功
print(z.shape) # torch.Size([2, 6])
# 情况2: 切片操作后的张量
x = torch.randn(4, 4)
y = x[:, ::2] # 非连续切片
print(y.is_contiguous()) # False
# view失败,reshape成功
try:
z = y.view(-1)
except RuntimeError as e:
print(f"view失败: {e}")
z = y.reshape(-1) # 成功
1 | # 有些操作会让张量存储结构扭曲,直接使用view会失败,可以用reshape方法 |
张量与numpy数组
- 可以使用numpy方法从tensor得到numpy数组,也可以用torch.from_numpy从numpy数组得到tensor.
- 两种方法共享数据内存,改变一个另一个也会随之改变
- 可以用张量的clone 方法拷贝张量,中断这种关联
- 可以使用item方法从标量张量得到对应的python数值
- 使用tolist方法从张量得到对应的python数值列表
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60#torch.from_numpy函数从numpy数组得到Tensor
arr = np.zeros(3)
tensor = torch.from_numpy(arr)
print("before add 1:")
print(arr)
print(tensor)
print("\nafter add 1:")
np.add(arr,1, out = arr) #给 arr增加1,tensor也随之改变
print(arr)
print(tensor)
# numpy方法从Tensor得到numpy数组
tensor = torch.zeros(3)
arr = tensor.numpy()
print("before add 1:")
print(tensor)
print(arr)
print("\nafter add 1:")
#使用带下划线的方法表示计算结果会返回给调用 张量
tensor.add_(1) #给 tensor增加1,arr也随之改变
#或: torch.add(tensor,1,out = tensor)
print(tensor)
print(arr)
# 可以用clone() 方法拷贝张量,中断这种关联
tensor = torch.zeros(3)
#使用clone方法拷贝张量, 拷贝后的张量和原始张量内存独立
arr = tensor.clone().numpy() # 也可以使用tensor.data.numpy()
print("before add 1:")
print(tensor)
print(arr)
print("\nafter add 1:")
#使用 带下划线的方法表示计算结果会返回给调用 张量
tensor.add_(1) #给 tensor增加1,arr不再随之改变
print(tensor)
print(arr)
# item方法和tolist方法可以将张量转换成Python数值和数值列表
scalar = torch.tensor(1.0)
s = scalar.item()
print(s)
print(type(s))
tensor = torch.rand(2,2)
t = tensor.tolist()
print(t)
print(type(t))
Something else
什么是元数据
- 元数据(metadata)描述数据张什么样而不是数据本身的信息,张量结构性的描述信息,不包括实际的数值
在 PyTorch 中,张量由“数据存储(storage)”与“元数据”两部分构成: - 数据存储:真正的数值内存区域
- 元数据:描述如何解释这些数值的结构信息
典型元数据:shape、stride、dtype、device、storage_offset、requires_grad、layout、(可选) names
像view()这类操作只是改元数据(不复制数据);而当现有 stride 组合无法支持新形状时,reshape()会退化为复制,得到新的连续存储。
1 | import torch |
All articles on this blog are licensed under CC BY-NC-SA 3.0 CN unless otherwise stated.