张量的数据类型

张量的数据类型与 numpy.array 基本一一对应,除了不支持str类型

  • 一般的神经网络用的是torch.float32类型
  • 如果要显示指定数据类型,可以使用torch.tensor(data,dtype = torch.type)
  • 也可以使用特定的构造函数
    1
    2
    3
    i = torch.Inttensor() #构造数据类型为 int 的张量
    x = torch.Tensor() # 构造数据类型为 float 的张量
    b = torch.BoolTensor() #构造数据类型为 bool 的张量
  • 此外,还可以对不同类型的张量进行转化
    1
    2
    3
    4
    i = torch.tensor(1)  # 构建类型为int64的张量
    x = i.float() # 调用float方法转换为float类型
    y = i.type(torch.float) # 使用type函数转换为浮点类型
    z = i.type_as(x) # 使用type_as 方法转化为与某个Tensor相同类型的张量

张量的维度

张量的尺寸

  • 可以使用shape属性或者size() 方法查看张量在每一维的长度
  • 可以使用view方法改变张量的尺寸
  • view失败的情况下,可以使用reshape方法
    • viewreshape 的区别:
      1. view方法要求原张量在内存中是连续的,如果不连续则会失败;reshape则会自动处理布局
      2. view方法总是与原张量共享内存,返回的是原向量的”视图”; reshape则可能返回视图或者副本,取决于内存的布局
    • 为什么不只使用 reshape
      1. 性能考虑 view 更快,因为只是改变张量的元数据不涉及数据复制;reshape可能涉及到复制数据,会有额外的开销
      2. 内存效率 view保证内存共享,修改一个会影响另一个;reshape可能创建副本导致占用更多内存
      3. 语义的明确性 view 明确表示期望的是内存共享的试图操作;当view失败时,提醒开发者注意内存布局问题
    • 以下是失败情况的例子
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      17
      18
      19
      20
      21
      22
      23
      24
      25
      26
      27
      28
      29
      import torch

      # 情况1: 转置后的张量不连续
      x = torch.randn(3, 4)
      y = x.transpose(0, 1) # 转置后内存不连续
      print(y.is_contiguous()) # False

      # 这会失败
      try:
      z = y.view(2, 6)
      except RuntimeError as e:
      print(f"view失败: {e}")

      # reshape可以成功
      z = y.reshape(2, 6) # 成功
      print(z.shape) # torch.Size([2, 6])

      # 情况2: 切片操作后的张量
      x = torch.randn(4, 4)
      y = x[:, ::2] # 非连续切片
      print(y.is_contiguous()) # False

      # view失败,reshape成功
      try:
      z = y.view(-1)
      except RuntimeError as e:
      print(f"view失败: {e}")

      z = y.reshape(-1) # 成功
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 有些操作会让张量存储结构扭曲,直接使用view会失败,可以用reshape方法

matrix26 = torch.arange(0,12).view(2,6)
print(matrix26)
print(matrix26.shape)

# 转置操作让张量存储结构扭曲
matrix62 = matrix26.t()
print(matrix62.is_contiguous())


# 直接使用view方法会失败,可以使用reshape方法
#matrix34 = matrix62.view(3,4) #error!
matrix34 = matrix62.reshape(3,4) #等价于matrix34 = matrix62.contiguous().view(3,4)
print(matrix34)

张量与numpy数组

  • 可以使用numpy方法从tensor得到numpy数组,也可以用torch.from_numpy从numpy数组得到tensor.
    • 两种方法共享数据内存,改变一个另一个也会随之改变
    • 可以用张量的clone 方法拷贝张量,中断这种关联
  • 可以使用item方法从标量张量得到对应的python数值
  • 使用tolist方法从张量得到对应的python数值列表
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    #torch.from_numpy函数从numpy数组得到Tensor

    arr = np.zeros(3)
    tensor = torch.from_numpy(arr)
    print("before add 1:")
    print(arr)
    print(tensor)

    print("\nafter add 1:")
    np.add(arr,1, out = arr) #给 arr增加1,tensor也随之改变
    print(arr)
    print(tensor)


    # numpy方法从Tensor得到numpy数组

    tensor = torch.zeros(3)
    arr = tensor.numpy()
    print("before add 1:")
    print(tensor)
    print(arr)

    print("\nafter add 1:")

    #使用带下划线的方法表示计算结果会返回给调用 张量
    tensor.add_(1) #给 tensor增加1,arr也随之改变
    #或: torch.add(tensor,1,out = tensor)
    print(tensor)
    print(arr)


    # 可以用clone() 方法拷贝张量,中断这种关联

    tensor = torch.zeros(3)

    #使用clone方法拷贝张量, 拷贝后的张量和原始张量内存独立
    arr = tensor.clone().numpy() # 也可以使用tensor.data.numpy()
    print("before add 1:")
    print(tensor)
    print(arr)

    print("\nafter add 1:")

    #使用 带下划线的方法表示计算结果会返回给调用 张量
    tensor.add_(1) #给 tensor增加1,arr不再随之改变
    print(tensor)
    print(arr)



    # item方法和tolist方法可以将张量转换成Python数值和数值列表
    scalar = torch.tensor(1.0)
    s = scalar.item()
    print(s)
    print(type(s))

    tensor = torch.rand(2,2)
    t = tensor.tolist()
    print(t)
    print(type(t))

Something else

什么是元数据

  • 元数据(metadata)描述数据张什么样而不是数据本身的信息,张量结构性的描述信息,不包括实际的数值
    在 PyTorch 中,张量由“数据存储(storage)”与“元数据”两部分构成:
  • 数据存储:真正的数值内存区域
  • 元数据:描述如何解释这些数值的结构信息
    典型元数据:shape、stride、dtype、device、storage_offset、requires_grad、layout、(可选) names
    view() 这类操作只是改元数据(不复制数据);而当现有 stride 组合无法支持新形状时,reshape() 会退化为复制,得到新的连续存储。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
x = torch.arange(12)
y = x.view(3, 4) # 仅改元数据
print(y.shape, y.stride())

t = y.t() # 转置改变stride
print(t.shape, t.stride(), t.is_contiguous())

# view失败
try:
t.view(12)
except RuntimeError as e:
print("view失败:", e)

# reshape自动处理(必要时复制)
z = t.reshape(12)
print(z.is_contiguous())