2020年3月30日 星期一

PyTorch 特殊函數之求導

torch.max() / torch.min()

對 torch.max() 進行 backward() 時,對輸入的 tensor 中數值最大的元素的 gradient 會貢獻 1,其餘則為 0。

對 torch.min() 進行 backward() 時,對數值最小的元素的 gradient 會貢獻 1,其餘則為 0。

程式
a = torch.tensor([1., 2., 3.], requires_grad=True)
b = torch.max(a)
b.backward()
a.grad

輸出
tensor([0., 0., 1.])

torch.abs()

對 torch.abs() 進行求導時,對輸入的 tensor 中正的元素的 gradient 貢獻為 1,對負的元素的 gradient 貢獻為 -1,對元素值為 0 的 gradient 貢獻為 0。

程式
a = torch.tensor([-1., -2., 3., 0.], requires_grad=True)
b = torch.abs(a)
b.backward(torch.ones_like(b))
a.grad

輸出
tensor([-1., -1., 1., 0.])

torch.where()

d = torch.where(a, b, c)
a 通常是一個 bool type 的張量,在 a 張量裡面是 True 的位置,b 該位置的元素值會傳到 d 的相應位置,在 a 張量裡面是 False 的位置,c 該位置的元素值會傳到 d 的相應位置。當我們對 where 進行 backward 時,在 b 和 c 當中有傳到 d 的那些位置的 grad 是 1,其餘位置的 grad 是 0。

程式
a = torch.tensor([True, False, False, True, True])
b = torch.tensor([1., 2., 3., 4., 5.], requires_grad=True)
c = torch.tensor([2., 4., 6., 8., 10.], requires_grad=True)
d = torch.where(a, b, c)
d
d.backward(torch.ones_like(d))
b.grad
c.grad

輸出
tensor([1., 4., 6., 4., 5.], grad_fn=<SWhereBackward>)   # d
tensor([1., 0., 0., 1., 1.])   # b.grad
tensor([0., 1., 1., 0., 0.])   # c.grad

求導結果為 0 的函數

sign, ceil

不可求導的函數

argmax, argmin, lt, le, eq, ne, ge, gt

當一條運算路徑中有經過不可求導函數時,在進行 backward() 時,這條路徑就無法進行求導運算了。

沒有留言:

張貼留言