pytorch-day09
张量的结构操作
张量创建
- 与numpy中创建array的方法相似
1
2
3
4
5
6
7
8
9
10
11a = torch.tensor([1,2,3],dtype=torch.float)
b = torch.arange(1,10,step = 2)
c = torch.linspace(0.0,2**3,10)
d = torch.zeros((3,3))
e = torch.ones((3,3),dtype=torch.int)
f = torch.zeros_like(a,dtype = torch.float)1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17# 创建服从某种分布的张量
## 均匀分布
min, max = 0,10
a = min + (max - min)*torch.rand([5])
## 正态分布
b = torch.normal(mean = torch.zeros(3,3),std = torch.ones(3,3))
c = std*torch.randn((3,3)) + mean
## 整数随机排列
d = torch.randperm(20)
## 特殊矩阵
I = torch.eye(3,3) # 单位矩阵
t = torch.diag(torch.tensor([1,2,3])) # 对角矩阵
索引切片
- 方式也与numpy几乎一致,切片支持缺省参数和省略号
- 对于不规则切片提取,可以使用
torch.index_select,torch.masked_select,torch.take - 如果通过修改张量某些元素得到新的张量,可以使用
torch.where,torch.masked_fill,torch.index_fill - 省略号可以表示多个冒号,
a[...,1]
维度变换
- 相关的函数有
torch.reshape,torch.squeeze,torch.unsqueeze,torch.transpose
合并分割
- 可以用
torch.cat和torch.stack方法讲多个张量合并,可以用torch.split方法讲一个张量分割 torch.cat和torch.stack有略微区别:torch.cat是连接,不会增加维度;而torch.stack是堆叠,会增加维度torch.split是torch.cat的逆运算,可以指定分割份数平均分割,也可以指定记录数量进行分割1
2
3
4print(abc_cat)
a,b,c = torch.split(abc_cat,split_size_or_sections = 2,dim = 0) #每份2个进行分割
p,q,r = torch.split(abc_cat,split_size_or_sections =[4,1,1],dim = 0) #每份分别为[4,1,1]
张量的数学运算
标量运算
- 标量运算符对张量实施逐元素运算,有些运算符对常用的运算符进行了重载,并且支持类似numpy 广播的特性
- 还有一些特殊的运算符
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15x = torch.tensor([2.6,-2.7])
print(torch.round(x)) #保留整数部分,四舍五入
print(torch.floor(x)) #保留整数部分,向下归整
print(torch.ceil(x)) #保留整数部分,向上归整
print(torch.trunc(x)) #保留整数部分,向0归整
x = torch.tensor([2.6,-2.7])
print(torch.fmod(x,2)) #作除法取余数
print(torch.remainder(x,2)) #作除法取剩余的部分,结果恒正
# 幅值裁剪
x = torch.tensor([0.9,-0.8,100.0,-20.0,0.7])
y = torch.clamp(x,min=-1,max = 1)
z = torch.clamp(x,max = 1)
print(y)
print(z)
向量运算
- 向量运算符在特定轴上运算,将一个向量映射到一个标量或另一个向量
1
2
3
4
5
6
7
8
9
10a = torch.arange(1,10).float().view(3,3)
print(torch.sum(a))
print(torch.mean(a))
print(torch.max(a))
print(torch.min(a))
print(torch.prod(a)) #累乘
print(torch.std(a)) #标准差
print(torch.var(a)) #方差
print(torch.median(a)) #中位数1
2
3
4
5
6
7
8#cum扫描
a = torch.arange(1,10)
print(torch.cumsum(a,0))
print(torch.cumprod(a,0))
print(torch.cummax(a,0).values)
print(torch.cummax(a,0).indices)
print(torch.cummin(a,0))
All articles on this blog are licensed under CC BY-NC-SA 3.0 CN unless otherwise stated.