文章目录


前言

因为最近可用的开源代码库里找到了一个网络,因为要提高网络训练速度,网络的输入就改成了一组组特征,所以就需要提前利用特征提取模型进行特征提取,并将特征存为数组。

代码

import os
import numpy as np
from tqdm import tqdm
from PIL import Image
import torch
import torch.utils.data
def main(device, in_jit_model, out_feat_dir):



    model = torch.jit.load(in_jit_model, 'cpu')
    model.eval()
    model = model.to(device)
    f = open('.../xxx.txt', 'r')

    files = f.readlines()  # 读取整个文件所有行,保存在 list 列表中

    feats = []
    tmp = '0'
    n = 0
    for filename in files:
            filename = filename[:-1]
            image = Image.open('.../' + filename)
            image = image.convert("RGB")
            image = np.array(image)
            # image = np.expand_dims(image, axis=0)



            image = torch.tensor(image)
            image = image.float()
            image /= 255
            image = image.permute(2,1,0)
            # image = data_transform(image)
            image = torch.tensor(np.expand_dims(image, axis=0))
            B, N = [1,1]
            image = image.contiguous().to(device)
            t_y = model(image)

            assert t_y.ndim == 2

            t_y = t_y.reshape(B, N, t_y.shape[-1])

            if filename[:12]==tmp :
                feats.extend(t_y.cpu().detach().numpy())
                n = n+1
                print(filename+' get {0} feat'.format(n))

            elif tmp =='0':
                feats.extend(t_y.cpu().detach().numpy())
                n = n + 1
                print(filename+' get {0} feat'.format(n))
            else:
                files2 = open('.../xxx.txt',
                    'r')
                files2 = files2.readlines()  # 读取整个文件所有行,保存在 list 列表中
                for k in files2:
                    k = k[:-1]
                    if tmp == k[:12]:
                        feats = np.stack(feats, 0)
                        out_feat_file = f'{out_feat_dir}/{k}.svs.npy'
                        os.makedirs(os.path.dirname(out_feat_file), exist_ok=True)
                        np.save(out_feat_file, feats, allow_pickle=False)
                        n = 0
                        print(k+"----------save 1 feats")
                        feats = []
                        feats.extend(t_y.cpu().detach().numpy())
                        break



            tmp = filename[:12]



if __name__ == '__main__':

    device = 'cuda:0'
    in_jit_model = '.../xxx.pt'
    out_feat_dir = '..../...'
    main(
        device=device,
        in_jit_model=in_jit_model,
        out_feat_dir=out_feat_dir
    )

Logo

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。

更多推荐