2020年3月25日 星期三

PyTorch 的 backward function


在呼叫 t.backward() 時,若 t 是一個純量,就不需要傳入任何的引數。但是當 t 是一個張量的時候,就需要傳入一個和 t 同樣 size 的張量,這個張量是做什麼用的?

假設我們的 t = [t1, t2],這個張量的運算圖裡面有一個 leaf node 是 x = [x1, x2]。如果在呼叫 backward 時,我們傳入 [1, 1],PyTorch 會將 t 當中的每一個元素 (t1 和 t2) 都對 x1 進行偏微分,並且將所有的結果加起來得到 x.grad = [g1, g2] 中的 g1。

此時如果我們只想要知道 t1 對 x1 的偏微分結果,就要在呼叫 backward 時,傳入引數 [1, 0]。

這個傳入的引數,可以看成一個權重,可以控制 t 裡面每一個元素對偏微分的影響程度。在針對一個運算圖進行 backward 時,每一個 operation 的 backward function 會接收它的下游所傳來的 gradient,因為在進行 back propagation 時,上游的 gradient 是透過微分的鏈鎖律 (chain rule) 算出來的,也就是 backward function 的 input 會和該 operation 的 gradient 相乘,去得到上游的 gradient。


Reference
Sherlock, "PyTorch中的backward," https://zhuanlan.zhihu.com/p/27808095.

沒有留言:

張貼留言