PyTorch 自動(dòng)求導(dǎo)機(jī)制

2020-09-16 14:37 更新

原文: PyTorch 自動(dòng)求導(dǎo)機(jī)制

本說(shuō)明將概述 autograd 的工作方式并記錄操作。 不一定要完全了解所有這些內(nèi)容,但我們建議您熟悉它,因?yàn)樗梢詭椭帉懜咝?,更?jiǎn)潔的程序,并可以幫助您進(jìn)行調(diào)試。

從向后排除子圖

每個(gè)張量都有一個(gè)標(biāo)志:requires_grad,允許從梯度計(jì)算中細(xì)粒度地排除子圖,并可以提高效率。

requires_grad

如果某個(gè)操作的單個(gè)輸入需要進(jìn)行漸變,則其輸出也將需要進(jìn)行漸變。 相反,僅當(dāng)所有輸入都不需要漸變時(shí),輸出才不需要。 在所有張量都不要求漸變的子圖中,永遠(yuǎn)不會(huì)執(zhí)行向后計(jì)算。

>>> x = torch.randn(5, 5)  # requires_grad=False by default
>>> y = torch.randn(5, 5)  # requires_grad=False by default
>>> z = torch.randn((5, 5), requires_grad=True)
>>> a = x + y
>>> a.requires_grad
False
>>> b = a + z
>>> b.requires_grad
True

當(dāng)您要凍結(jié)部分模型,或者事先知道您將不使用漸變色時(shí),此功能特別有用。 一些參數(shù)。 例如,如果您想微調(diào)預(yù)訓(xùn)練的 CNN,只需在凍結(jié)的基數(shù)中切換requires_grad標(biāo)志,就不會(huì)保存任何中間緩沖區(qū),直到計(jì)算到達(dá)最后一層,仿射變換將使用權(quán)重為 需要梯度,網(wǎng)絡(luò)的輸出也將需要它們。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
## Replace the last fully-connected layer
## Parameters of newly constructed modules have requires_grad=True by default
model.fc = nn.Linear(512, 100)


## Optimize only the classifier
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

autograd 如何編碼歷史

Autograd 是反向自動(dòng)分化系統(tǒng)。 從概念上講,autograd 會(huì)記錄一個(gè)圖形,記錄執(zhí)行操作時(shí)創(chuàng)建數(shù)據(jù)的所有操作,從而為您提供一個(gè)有向無(wú)環(huán)圖,其葉子為輸入張量,根為輸出張量。 通過(guò)從根到葉跟蹤該圖,您可以使用鏈?zhǔn)揭?guī)則自動(dòng)計(jì)算梯度。

在內(nèi)部,autograd 將該圖表示為Function對(duì)象(真正的表達(dá)式)的圖,可以將其apply()編輯以計(jì)算評(píng)估圖的結(jié)果。 在計(jì)算前向通過(guò)時(shí),autograd 同時(shí)執(zhí)行請(qǐng)求的計(jì)算,并建立一個(gè)表示表示計(jì)算梯度的函數(shù)的圖形(每個(gè) torch.Tensor.grad_fn屬性是該圖形的入口)。 完成前向遍歷后,我們?cè)诤笙虮闅v中評(píng)估此圖以計(jì)算梯度。

需要注意的重要一點(diǎn)是,每次迭代都會(huì)從頭開始重新創(chuàng)建圖形,這正是允許使用任意 Python 控制流語(yǔ)句的原因,它可以在每次迭代時(shí)更改圖形的整體形狀和大小。 在開始訓(xùn)練之前,您不必編碼所有可能的路徑-跑步就是您的與眾不同。

使用 autograd 進(jìn)行就地操作

在 autograd 中支持就地操作很困難,并且在大多數(shù)情況下,我們不鼓勵(lì)使用它們。 Autograd 積極的緩沖區(qū)釋放和重用使其非常高效,就地操作實(shí)際上很少顯著降低內(nèi)存使用量的情況很少。 除非您在高內(nèi)存壓力下進(jìn)行操作,否則可能永遠(yuǎn)不需要使用它們。

限制就地操作的適用性的主要原因有兩個(gè):

  1. 就地操作可能會(huì)覆蓋計(jì)算梯度所需的值。
  2. 實(shí)際上,每個(gè)就地操作都需要實(shí)現(xiàn)來(lái)重寫計(jì)算圖。 異地版本僅分配新對(duì)象并保留對(duì)舊圖形的引用,而就地操作則需要更改表示此操作的Function的所有輸入的創(chuàng)建者。 這可能很棘手,特別是如果有許多張量引用相同的存儲(chǔ)(例如通過(guò)索引或轉(zhuǎn)置創(chuàng)建的),并且如果修改后的輸入的存儲(chǔ)被任何其他Tensor引用,則就地函數(shù)實(shí)際上會(huì)引發(fā)錯(cuò)誤。

就地正確性檢查

每個(gè)張量都有一個(gè)版本計(jì)數(shù)器,每次在任何操作中被標(biāo)記為臟時(shí),該計(jì)數(shù)器都會(huì)增加。 當(dāng)函數(shù)保存任何張量以供向后時(shí),也會(huì)保存其包含 Tensor 的版本計(jì)數(shù)器。 訪問(wèn)self.saved_tensors后,將對(duì)其進(jìn)行檢查,如果該值大于保存的值,則會(huì)引發(fā)錯(cuò)誤。 這樣可以確保,如果您使用的是就地函數(shù)并且沒(méi)有看到任何錯(cuò)誤,則可以確保計(jì)算出的梯度是正確的。

以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)