Pytorch Hub 是經(jīng)過預先訓練的模型資料庫,旨在促進研究的可重復性。
Pytorch Hub 支持通過添加簡單的hubconf.py
文件將預訓練的模型(模型定義和預訓練的權重)發(fā)布到 github 存儲庫;
hubconf.py
可以有多個入口點。 每個入口點都定義為 python 函數(shù)(例如:您要發(fā)布的經(jīng)過預先訓練的模型)。
def entrypoint_name(*args, **kwargs):
# args & kwargs are optional, for models which take positional/keyword arguments.
...
如果我們擴展pytorch/vision/hubconf.py
中的實現(xiàn),則以下代碼段指定了resnet18
模型的入口點。 在大多數(shù)情況下,在hubconf.py
中導入正確的功能就足夠了。 在這里,我們僅以擴展版本為例來說明其工作原理。
dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18
## resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
Resnet18 model
pretrained (bool): kwargs, load pretrained weights into the model
"""
# Call the model, load pretrained weights
model = _resnet18(pretrained=pretrained, **kwargs)
return model
dependencies
變量是加載模型所需的軟件包名稱的列表。 請注意,這可能與訓練模型所需的依賴項稍有不同。args
和kwargs
傳遞給實際的可調用函數(shù)。torch.hub.list()
中顯示。torch.hub.load_state_dict_from_url()
加載。 如果少于 2GB,建議將其附加到項目版本,并使用該版本中的網(wǎng)址。 在上面的示例中,torchvision.models.resnet.resnet18
處理pretrained
,或者,您可以在入口點定義中添加以下邏輯。if pretrained:
# For checkpoint saved in local github repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
dirname = os.path.dirname(__file__)
checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
state_dict = torch.load(checkpoint)
model.load_state_dict(state_dict)
# For checkpoint saved elsewhere
checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
Pytorch Hub 提供了便捷的 API,可通過torch.hub.list()
瀏覽集線器中的所有可用模型,通過torch.hub.help()
顯示文檔字符串和示例,并使用torch.hub.load()
加載經(jīng)過預先訓練的模型
torch.hub.list(github, force_reload=False)?
列出 <cite>github</cite> hubconf 中可用的所有入口點。
參數(shù)
退貨
可用入口點名稱的列表
返回類型
入口點
例
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
torch.hub.help(github, model, force_reload=False)?
顯示入口點<cite>模型</cite>的文檔字符串。
Parameters
Example
>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
torch.hub.load(github, model, *args, **kwargs)?
使用預訓練的權重從 github 存儲庫加載模型。
Parameters
Returns
具有相應預訓練權重的單個模型。
Example
>>> model = torch.hub.load('pytorch/vision', 'resnet50', pretrained=True)
torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)?
將給定 URL 上的對象下載到本地路徑。
Parameters
Example
>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False)?
將 Torch 序列化對象加載到給定的 URL。
如果下載的文件是 zip 文件,它將被自動解壓縮。
如果 <cite>model_dir</cite> 中已經(jīng)存在該對象,則將其反序列化并返回。 <cite>model_dir</cite> 的默認值為$TORCH_HOME/checkpoints
,其中環(huán)境變量$TORCH_HOME
的默認值為$XDG_CACHE_HOME/torch
。 $XDG_CACHE_HOME
遵循 Linux 文件系統(tǒng)布局的 X 設計組規(guī)范,如果未設置,則默認值為~/.cache
。
Parameters
filename-<sha256>.ext
,其中[ <sha256>
是文件內(nèi)容的 SHA256 哈希值的前 8 位或更多位。 哈希用于確保唯一的名稱并驗證文件的內(nèi)容。 默認值:FalseExample
>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
注意,torch.load()
中的*args, **kwargs
用于實例化模型。 加載模型后,如何找到可以使用該模型的功能? 建議的工作流程是
dir(model)
查看模型的所有可用方法。help(model.foo)
檢查model.foo
需要執(zhí)行哪些參數(shù)為了幫助用戶探索而又不來回參考文檔,我們強烈建議回購所有者使功能幫助消息清晰明了。 包含一個最小的工作示例也很有幫助。
這些位置按以下順序使用
hub.set_dir(<PATH_TO_HUB_DIR>)
$TORCH_HOME/hub
,如果設置了環(huán)境變量TORCH_HOME
。$XDG_CACHE_HOME/torch/hub
,如果設置了環(huán)境變量XDG_CACHE_HOME
。~/.cache/torch/hub
torch.hub.set_dir(d)?
(可選)將 hub_dir 設置為本地目錄,以保存下載的模型&權重。
如果未調用set_dir
,則默認路徑為$TORCH_HOME/hub
,其中環(huán)境變量$TORCH_HOME
默認為$XDG_CACHE_HOME/torch
。 $XDG_CACHE_HOME
遵循 Linux 文件系統(tǒng)布局的 X 設計組規(guī)范,如果未設置環(huán)境變量,則默認值為~/.cache
。
Parameters
d (字符串)–本地文件夾的路徑,用于保存下載的模型&權重。
默認情況下,加載文件后我們不會清理文件。 如果hub_dir
中已經(jīng)存在,則集線器默認使用緩存。
用戶可以通過調用hub.load(..., force_reload=True)
來強制重新加載。 這將刪除現(xiàn)有的 github 文件夾和下載的權重,重新初始化新的下載。 當更新發(fā)布到同一分支時,此功能很有用,用戶可以跟上最新版本。
Torch 集線器通過導入軟件包來進行工作,就像安裝軟件包一樣。 在 Python 中導入會帶來一些副作用。 例如,您可以在 Python 緩存sys.modules
和sys.path_importer_cache
中看到新項目,這是正常的 Python 行為。
在這里值得一提的已知限制是用戶無法在相同的 python 進程中加載同一存儲庫的兩個不同分支。 就像在 Python 中安裝兩個具有相同名稱的軟件包一樣,這是不好的。 快取可能會加入聚會,如果您實際嘗試的話會給您帶來驚喜。 當然,將它們分別加載是完全可以的。
更多建議: