👋 你好!这里有实用干货与深度分享✨✨ 若有帮助,欢迎:​
👍 点赞 | ⭐ 收藏 | 💬 评论 | ➕ 关注 ,解锁更多精彩!​
📁 收藏专栏即可第一时间获取最新推送🔔。​
📖后续我将持续带来更多优质内容,期待与你一同探索知识,携手前行,共同进步🚀。​



人工智能

数据集读取

本文使用PyTorch框架,介绍PyTorch中数据读取的相关知识。

本文目标:

  1. 了解PyTorch中数据读取的基本概念
  2. 了解PyTorch中集成的开源数据集的读取方法
  3. 了解PyTorch中自定义数据集的读取方法
  4. 了解PyTorch中数据读取的流程

一、数据的准备

使用开源数据集或者自己采集数据后进行数据标注。

PyTorch中数据读取的基本概念

PyTorch中数据读取的基本概念是DatasetDataLoader

Dataset是一个抽象类,用于表示数据集。它包含了数据集的长度、索引、数据获取等方法。

DataLoader是一个类,用于将数据集按批次加载到模型中。它包含了数据读取、数据转换、数据打乱等方法。

实现数据集读取的步骤:

  1. 继承Dataset类,实现__len____getitem__方法
  2. 使用DataLoader类,将数据集按批次加载到模型中

示例代码:

import torch
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

data = torch.randn(100, 3, 224, 224)
labels = torch.randint(0, 10, (100,))

dataset = MyDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch_data, batch_labels in dataloader:
    print(batch_data.shape, batch_labels.shape)

PyTorch中集成的开源数据集的读取方法

使用开源数据MNIST作为示范。

数据集链接:MNIST数据集

PyTorch中以及集成了很多开源数据集,我们可以直接使用。MNIST也包括在其中。

只需要使用PyTorch中的torchvision.datasets模块即可。

示例代码:

  1. 引入必要的库:
import torch
from torchvision import datasets
import matplotlib.pyplot as plt
  1. 加载数据集:
train_dataset = datasets.MNIST(root='./data', train=True, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, download=True)

参数说明:

  • root:数据集保存的路径
  • train:是否为训练集
  • download:是否下载数据集
  1. 查看数据集信息:
print(len(train_dataset), len(test_dataset))
print(train_dataset[0][0].size, train_dataset[0][1])
  1. 可视化数据集:
plt.imshow(train_dataset[0][0], cmap='gray')
plt.show()
  1. 数据加载:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for batch_data, batch_labels in train_dataloader:
    print(batch_data.shape, batch_labels.shape)
    break

参数说明:

  • batch_size:批次大小
  • shuffle:是否打乱数据,训练集一般需要打乱数据,测试集一般不需要打乱数据

其实,真实的训练过程只需要步骤1、2、5即可,3、4步骤是为了验证数据集是否正确。

二、PyTorch中自定义数据集的读取方法

自定义数据集的读取方法是指,我们自己定义一个数据集,然后使用PyTorch中的DatasetDataLoader类来读取数据集。因为不是所有的数据集都在PyTorch中集成了,当我们有拥有(自己标注或下载)一个新的数据集时,就需要自己定义数据集的读取方法。

这时候需要将数据集以一定的规则保存起来,然后使用PyTorch中的DatasetDataLoader类来读取数据集。

示例代码:

  1. 引入必要的库:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import numpy as np
import matplotlib.pyplot as plt
  1. 定义数据集类:
class MyDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.data_list = os.listdir(data_dir)

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        data_path = os.path.join(self.data_dir, self.data_list[index])
        data = np.load(data_path)
        label = data['label']

        if self.transform is not None:
            data = self.transform(data)

        return data, label

参数说明:

  • data_dir:数据集保存的路径
  • transform:数据转换函数,可选。1. 用于数据增强,一般的数据增强方法有:随机裁剪、随机旋转、随机翻转、随机缩放等。2. 也可以用于数据预处理,如归一化、标准化等。
  1. 定义数据转换函数:
def transform(data):
    data = data['data']
    data = data.astype(np.float32)
    data = data / 255.0
    data = torch.from_numpy(data)
    return data
  1. 加载数据集:
train_dataset = MyDataset(data_dir='./data/train', transform=transform)
test_dataset = MyDataset(data_dir='./data/test', transform=transform)
  1. 查看数据集信息:
print(len(train_dataset), len(test_dataset))
print(train_dataset[0][0].size, train_dataset[0][1])
  1. 可视化数据集:
plt.imshow(train_dataset[0][0], cmap='gray')
plt.show()
  1. 数据加载:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for batch_data, batch_labels in train_dataloader:
    print(batch_data.shape, batch_labels.shape)
    break
  1. 数据增强:
from torchvision import transforms

transform = transforms.Compose([
    transforms.RandomCrop(28),  # 随机裁剪,裁剪大小为28x28
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.RandomVerticalFlip(),  # 随机垂直翻转
    transforms.RandomRotation(10),  # 随机旋转
    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),  # 随机仿射变换
    transforms.ToTensor()  # 转换为张量
])
train_dataset = MyDataset(data_dir='./data/train', transform=transform)
test_dataset = MyDataset(data_dir='./data/test', transform=transform)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
for batch_data, batch_labels in train_dataloader:
    print(batch_data.shape, batch_labels.shape)
    break

DataLoader核心参数详解

DataLoader(
    dataset, 
    batch_size=1, 
    shuffle=False, 
    sampler=None,
    batch_sampler=None,
    num_workers=0, 
    collate_fn=None,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None,
    multiprocessing_context=None,
)

关键参数解析

  • num_workers:数据预加载进程数(建议设为CPU核心数的70-80%)
  • pin_memory:启用CUDA锁页内存加速GPU传输
  • prefetch_factor:每个worker预加载的batch数(PyTorch 1.7+)

数据加载性能优化公式

理论最大吞吐量
T h r o u g h p u t = min ⁡ ( B a t c h S i z e × n u m _ w o r k e r s D a t a L o a d T i m e , G P U C o m p u t e T i m e − 1 ) Throughput = \min\left(\frac{BatchSize \times num\_workers}{DataLoadTime}, GPUComputeTime^{-1}\right) Throughput=min(DataLoadTimeBatchSize×num_workers,GPUComputeTime1)

三、拓展:多模态数据加载示例

class MultiModalDataset(Dataset):
    def __init__(self, img_dir, text_path):
        self.img_dir = img_dir
        self.text_data = pd.read_csv(text_path)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        
    def __getitem__(self, idx):
        # 图像处理
        img_path = os.path.join(self.img_dir, self.text_data.iloc[idx]['image_id'])
        image = Image.open(img_path).convert('RGB')
        image = transforms.ToTensor()(image)
        
        # 文本处理
        text = self.text_data.iloc[idx]['description']
        inputs = self.tokenizer(text, padding='max_length', truncation=True, max_length=128)
        
        return {
            'image': image,
            'input_ids': torch.tensor(inputs['input_ids']),
            'attention_mask': torch.tensor(inputs['attention_mask'])
        }

四、总结

本文介绍了PyTorch中数据读取的基本概念、集成的开源数据集的读取方法、自定义数据集的读取方法和数据读取的流程。

数据读取是深度学习训练的重要环节,数据读取的流程是:

  1. 定义数据集类
  2. 定义数据转换函数、数据增强函数
  3. 加载数据集



📌 感谢阅读!若文章对你有用,别吝啬互动~​
👍 点个赞 | ⭐ 收藏备用 | 💬 留下你的想法 ,关注我,更多干货持续更新!

Logo

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。

更多推荐