2020年3月30日 星期一

PyTorch add function


在 PyTorch 中的 add function,可以實現將兩個張量相加,可以寫成 c = torch.add(a, b),也可以直接寫成 c = a + b。

其中 a 和 b 分別可以是一個張量加上一個純量,結果 c 就會是把純量加到張量中的每一個元素上。

a 和 b 也可以是兩個張量,但是若要實現兩個張量的相加,a 和 b 的 size 就有特定的限制 (broadcastable)。

以下幾種 a 和 b 的 size 是允許相加的:

Size of a:  (5)
Size of b:  (3, 2, 1, 5)

Size of a:  (1, 5)
Size of b:  (3, 2, 1, 5)

Size of a:  (2, 1, 5)
Size of b:  (3, 2, 1, 5)

Size of a:  (3, 2, 1, 5)
Size of b:  (3, 2, 1, 5)

由以上的例子可以看出,其中一個張量 (a) 的維度必須小於或等於另一個張量 (b),而且 a 的 size 必須與另一個張量 b 的最後幾個維度的 size 相同。

以下幾種 a 和 b 也是允許相加的:

Size of a:  (3, 2, 2, 5)
Size of b:  (1, 5)

Size of a:  (3, 2, 2, 5)
Size of b:  (1, 1, 5)

Size of a:  (3, 2, 2, 5)
Size of b:  (2, 1, 5)

Size of a:  (3, 2, 2, 5)
Size of b:  (3, 1, 1, 5)

Size of a:  (1, 5)
Size of b:  (2, 1)

Size of a:  (1, 1, 1, 5)
Size of b:  (1, 2, 1, 1)

由以上的例子可以看出,在其中一個張量的維度中,可以允許一個或多個維度的大小為 1。在大小為 1 的那個維度中,b (a) 的所有元素會被廣播到 a (b) 裡面同一個維度中所有的 index 上並執行加法運算。


沒有留言:

張貼留言