[Pytorch] timm.create_model()通过指定pretrained_cfg从本地加载pretrained模型


问题描述

timm.models.create_model在选择pretrained=True时会默认在本地路径查找是否有相应的pretrained模型参数文件,如果没有则下载到本地指定目录:

windows:C:\Users\用户名\.cache\torch\hub\checkpoints
Linux:/home/用户名/.cache/torch/hub/checkpoints

model = timm.models.create_model('swinv2_tiny_window8_256', pretrained=True)

通过设置pretrained_cfg,从file路径去加载本地pretrained模型

方式一

print(timm.models.create_model('swinv2_tiny_window8_256').default_cfg)
'''
{'url': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth', 
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': None, 'crop_pct': 0.9, 'interpolation': 'bicubic', 'fixed_input_size': True, 
'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'architecture': 'swinv2_tiny_window8_256'}
'''

pretrained_cfg = timm.models.create_model('swinv2_tiny_window8_256').default_cfg
pretrained_cfg['file'] = r'E:\proj\AI\dataset\build_dataset\pretrained\swinv2_tiny_patch4_window8_256.pth'
print(pretrained_cfg)
'''
{'url': 'https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth', 
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': None, 'crop_pct': 0.9, 'interpolation': 'bicubic', 'fixed_input_size': True, 
'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'architecture': 'swinv2_tiny_window8_256', 
'file': 'E:\\proj\\AI\\dataset\\build_dataset\\pretrained\\swinv2_tiny_patch4_window8_256.pth'}
'''

model = timm.models.create_model('swinv2_tiny_window8_256', pretrained=True, pretrained_cfg=pretrained_cfg)
print(model)

方式二

pretrained_cfg = timm.models.create_model('swinv2_tiny_window8_256').default_cfg
pretrained_cfg['file'] = r'E:\proj\AI\dataset\build_dataset\pretrained\swinv2_tiny_patch4_window8_256.pth'
model = timm.models.swinv2_tiny_window8_256(pretrained=True, pretrained_cfg=pretrained_cfg)

Debug记录

在这里插入图片描述
进入_create_swin_transformer_v2调用build_model_with_cfg
在这里插入图片描述
进入build_model_with_cfg调用load_pretrained
在这里插入图片描述
进入load_pretrained调用_resolve_pretrained_source(pretrained_cfg)
在这里插入图片描述
首先检查pretrained_cfg文件中是否有file,如果有file则从file的值中加载,如果没有则从url进行下载
在这里插入图片描述
在这里插入图片描述

Logo

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。

更多推荐