2020年4月13日 星期一

PyTorch 如何追蹤 backward 路徑


next_functions

當我們在建立一個 model 並進行訓練時,有時候可能 forward 結果是對的,但是訓練結果卻不如預期,這時候可能會想要知道 back propagation 的運算是否正確。這時候應該怎麼做呢?我們舉以下這段程式為例:


a = torch.tensor(1.1, requires_grad=True)
b = torch.tensor(1.2, requires_grad=True)
c = torch.tensor(1.3, requires_grad=True)
d = a * b
e = d * c
e.backward()

我們可以使用 e.grad_fn.next_functions 來追蹤 backward 的路徑。e.grad_fn.next_functions 是一個 list,其中包含了 e 往前走一層的所有 gradient function。上面這個例子中,e.grad_fn.next_functions 中包含兩個 element,一個是往 d 那邊走的 gradient function,另一個是往 c 那邊走的。

往 d 走的是 e.grad_fn.next_functions[0][0]
往 c 走的是 e.grad_fn.next_functions[1][0]

若我們想要知道 c 這條路徑的 forward 運算結果和 gradient,分別是:
Forward 結果:e.grad_fn.next_functions[1][0].variable   (結果為 1.3)
Gradient 大小:e.grad_fn.next_functions[1][0].variable.grad   (結果為 1.32)

而由於 d 這條路徑還有上游,我們必須再往上一層才能抵達 a 和 b

其中往 a 這條路徑的 forward 運算結果和 gradient,分別是:
Forward 結果:e.grad_fn.next_functions[0][0].next_functions[0][0].variable   (結果為 1.1)
Gradient 大小:e.grad_fn.next_functions[0][0].next_functions[0][0].variable.grad   (結果為 1.56 = 1.3 * 1.2)

而往 b 這條路徑的 forward 運算結果和 gradient,分別是:
Forward 結果:e.grad_fn.next_functions[0][0].next_functions[1][0].variable   (結果為 1.2)
Gradient 大小:e.grad_fn.next_functions[0][0].next_functions[1][0].variable.grad   (結果為 1.43 = 1.3 * 1.1)

只有位於 leaf node 的 grad_fn 有 variable 這個 attribute,因此必須使用 next_function 不斷的往前追蹤到所有的 leaf node 位置,過程中我們可以用 hasattr(object, name) 來確認 grad_fn 是否包含 variable 這個 attribute。

register_hook

如上所述,只有 graph 的 leaf node 存在 variable 能夠用於追蹤 gradient back-propagation。假如我們想要知道 internal node d 的 gradient 時,可以使用 register_hook。

承接上面這個例子,我們可以使用如下程式碼來取得 d 的 gradient。

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook
d.register_hook(save_grad('d'))

在做完 e.backward() 以後,就能夠用 grads['d'] 來存取 d 路徑上的 gradient,其結果為 1.3。

Reference
[1] https://stackoverflow.com/questions/52988876/how-can-i-visualize-what-happens-during-loss-backward
[2] https://codertw.com/%E7%A8%8B%E5%BC%8F%E8%AA%9E%E8%A8%80/368709/
[3] https://discuss.pytorch.org/t/why-cant-i-see-grad-of-an-intermediate-variable/94/6


沒有留言:

張貼留言