前言
在学习 PyTorch 的过程中,遇到计算的时候经常会遇到广播机制(Broadcast),而Pytorch官方文档中关于广播机制的介绍比较简单,而且说明支持的是NumPy
的广播机制。
广播的意义
在机器学习过程中,张量的计算往往是有维度要求的,需要满足一定的要求才能进行计算。在实际操作的时候就需要为了满足这种要求进行一些操作,如维度扩展。为了计算方便,就引入了广播机制(Broadcast),简化计算。
适用情况
- 每个张量至少有一个维度。
- 在维度大小上迭代时,从尾部维度开始,维度大小必须相等,其中一个为1,或者其中一个不存在。
首先广播的时候他不会给张量胡乱赋值,而是扩展(复制)已有的数据,所有这里必须有一个维度的数据作为复制的基础。
第二,在扩展的时候他不会扩展一半或者部分,也不会填充部分,而是把0或者1变n的过程。如:一个shape(3,3)可以变成shape(1,3,3)或者shape(3,3,3),但不能变成shape(4,3)。
实例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
A (2d array): 5 x 4
B (1d array): 1
Result (2d array): 5 x 4
A (2d array): 5 x 4
B (1d array): 4
Result (2d array): 5 x 4
A (3d array): 15 x 3 x 5
B (3d array): 15 x 1 x 5
Result (3d array): 15 x 3 x 5
A (3d array): 15 x 3 x 5
B (2d array): 3 x 5
Result (3d array): 15 x 3 x 5
A (3d array): 15 x 3 x 5
B (2d array): 3 x 1
Result (3d array): 15 x 3 x 5
通过上面的例子可以更好的理解,在维度大小上迭代时,从尾部维度开始,维度大小必须相等,其中一个为1,或者其中一个不存在。要么相等,要么为1,要么不存在。