python使用训练好的模型进行特征提取
因为最近可用的开源代码库里找到了一个网络,因为要提高网络训练速度,网络的输入就改成了一组组特征,所以就需要提前利用特征提取模型进行特征提取,并将特征存为数组。
·
前言
因为最近可用的开源代码库里找到了一个网络,因为要提高网络训练速度,网络的输入就改成了一组组特征,所以就需要提前利用特征提取模型进行特征提取,并将特征存为数组。
代码
import os
import numpy as np
from tqdm import tqdm
from PIL import Image
import torch
import torch.utils.data
def main(device, in_jit_model, out_feat_dir):
model = torch.jit.load(in_jit_model, 'cpu')
model.eval()
model = model.to(device)
f = open('.../xxx.txt', 'r')
files = f.readlines() # 读取整个文件所有行,保存在 list 列表中
feats = []
tmp = '0'
n = 0
for filename in files:
filename = filename[:-1]
image = Image.open('.../' + filename)
image = image.convert("RGB")
image = np.array(image)
# image = np.expand_dims(image, axis=0)
image = torch.tensor(image)
image = image.float()
image /= 255
image = image.permute(2,1,0)
# image = data_transform(image)
image = torch.tensor(np.expand_dims(image, axis=0))
B, N = [1,1]
image = image.contiguous().to(device)
t_y = model(image)
assert t_y.ndim == 2
t_y = t_y.reshape(B, N, t_y.shape[-1])
if filename[:12]==tmp :
feats.extend(t_y.cpu().detach().numpy())
n = n+1
print(filename+' get {0} feat'.format(n))
elif tmp =='0':
feats.extend(t_y.cpu().detach().numpy())
n = n + 1
print(filename+' get {0} feat'.format(n))
else:
files2 = open('.../xxx.txt',
'r')
files2 = files2.readlines() # 读取整个文件所有行,保存在 list 列表中
for k in files2:
k = k[:-1]
if tmp == k[:12]:
feats = np.stack(feats, 0)
out_feat_file = f'{out_feat_dir}/{k}.svs.npy'
os.makedirs(os.path.dirname(out_feat_file), exist_ok=True)
np.save(out_feat_file, feats, allow_pickle=False)
n = 0
print(k+"----------save 1 feats")
feats = []
feats.extend(t_y.cpu().detach().numpy())
break
tmp = filename[:12]
if __name__ == '__main__':
device = 'cuda:0'
in_jit_model = '.../xxx.pt'
out_feat_dir = '..../...'
main(
device=device,
in_jit_model=in_jit_model,
out_feat_dir=out_feat_dir
)

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。
更多推荐
所有评论(0)