2020年4月25日 星期六

PyTorch no_grad


在進行 back-propagation 時,有些路徑是我們不想要去計算 gradient 的,這時我們可以用 no_grad() 這個 function。但有一個需要注意的地方是,只有在 no_grad() 底下產生出來的 tensor 才會被 disable gradient。以下舉兩個例子:

a = torch.tensor(1.0, requires_grad=True)
with torch.no_grad():
     b = a
b.backward()

這時候我們去看 a.grad,會發現還是有 gradient 產生。原因就在於,這個 b = a 其實只是讓 b 成為一個 a 的 reference,也就是 b 其實就是 a,只是名字換了而已。在這種情況下,b 不是一個在 no_grad() 底下產生的 tensor,因此還是會有 gradient。

a = torch.tensor(1.0, requires_grad=True)
with torch.no_grad():
     b = a + 0
b.backward()

假如改成 b = a + 0,那所產生的 b 就是一個新的 tensor,並且此 tensor 是不會被計算 gradient 的。

沒有留言:

張貼留言