PostgreSQL数据库
2025/12/17 23:01:56
import torch from torch.utils.data import Dataset, DataLoader from torchvision import datasets, transforms import matplotlib.pyplot as plt # ========================================== # 1. 深入理解 Dataset (自定义演示) # ========================================== class SimpleNumberDataset(Dataset): def __init__(self, start, end): # 模拟数据:生成一个范围内的数字 self.data = list(range(start, end)) def __len__(self): # 返回数据集大小 return len(self.data) def __getitem__(self, index): # 返回一个样本及其标签(这里假设标签就是数字本身) sample = self.data[index] label = sample return torch.tensor(sample), torch.tensor(label) # ========================================== # 2. MNIST 数据集与 DataLoader 实战 # ========================================== # 定义预处理步骤:转为 Tensor 并标准化 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # 下载并加载 MNIST 训练集 train_dataset = datasets.MNIST( root='./data', train=True, download=True, transform=transform ) # 创建 DataLoader train_loader = DataLoader( dataset=train_dataset, batch_size=64, shuffle=True ) # ========================================== # 3. 验证与回顾 # ========================================== def review(): # 回顾 Dataset 的 len 和 getitem print(f"MNIST 数据集总长度: {len(train_dataset)}") # 取出一个样本 image, label = train_dataset[0] print(f"单个样本形状: {image.shape}, 标签: {label}") # 回顾 DataLoader 的迭代 # 取出一个 batch data_iter = iter(train_loader) images, labels = next(data_iter) print(f"一个 Batch 的图片形状: {images.shape}") # [64, 1, 28, 28] print(f"一个 Batch 的标签形状: {labels.shape}") # [64] # 可视化一个样本 plt.imshow(images[0].numpy().squeeze(), cmap='gray') plt.title(f"Label: {labels[0]}") plt.show() if __name__ == "__main__": review()