@浙大疏锦行
📘 Day 35 实战作业:最后一公里 —— 可视化、保存与推理
1. 作业综述
核心目标:
完成深度学习项目的闭环。
训练不是终点,应用才是。我们需要学会查看模型结构,将训练好的模型保存为文件(.pth),并重新加载它来进行预测(推理)。
涉及知识点:
- 结构检视:
print(model)与model.named_parameters()。 - 模型持久化:
torch.save与torch.load。 - 状态字典: 理解
state_dict的核心作用。 - 推理模式:
model.eval()与torch.no_grad()的重要性。
场景类比:
- 训练: 像是读书上课,把知识(权重)装进脑子。
- 保存: 像是把脑子里的知识写成一本“秘籍”(.pth文件)。
- 加载: 别人拿到秘籍,修炼一下,也拥有了同样的功力。
- 推理: 用这身功力去解决实际问题(考试/打架)。
步骤 1:模型解剖
场景描述:
我们在代码里写了nn.Linear,但模型内部到底有多少参数?
比如一个4 -> 10的全连接层,参数量是10 × 4 10 \times 410×4(权重) +10 1010(偏置) =50 5050个。
我们需要学会查看这些细节。
任务:
- 定义并实例化之前的 MLP 模型。
- 直接打印模型对象(查看层结构)。
- 遍历
named_parameters(),打印每一层的参数形状。
importtorchimporttorch.nnasnn# --- 1. 快速复现模型 (复习) ---classMLP(nn.Module):def__init__(self):super(MLP,self).__init__()self.fc1=nn.Linear(4,10)self.relu=nn.ReLU()self.fc2=nn.Linear(10,3)defforward(self,x):out=self.fc1(x)out=self.relu(out)out=self.fc2(out)returnout model=MLP()# --- 2. 宏观:看结构 ---print("=== 模型结构图 ===")print(model)# --- 3. 微观:看参数 ---print("\n=== 参数细节 ===")total_params=0forname,paraminmodel.named_parameters():print(f"层:{name}| 形状:{param.shape}")# 累加参数数量 (numel = number of elements)total_params+=param.numel()print(f"\n🔥 模型总参数量:{total_params}")# 算一下:fc1 (4*10 + 10) + fc2 (10*3 + 3) = 50 + 33 = 83。对上了吗?=== 模型结构图 === MLP( (fc1): Linear(in_features=4, out_features=10, bias=True) (relu): ReLU() (fc2): Linear(in_features=10, out_features=3, bias=True) ) === 参数细节 === 层: fc1.weight | 形状: torch.Size([10, 4]) 层: fc1.bias | 形状: torch.Size([10]) 层: fc2.weight | 形状: torch.Size([3, 10]) 层: fc2.bias | 形状: torch.Size([3]) 🔥 模型总参数量: 83步骤 2:模型的保存与加载
核心概念:
PyTorch 推荐只保存参数(权重和偏置),而不是保存整个模型对象。
这些参数存储在一个字典里,叫state_dict。
- 保存:
torch.save(model.state_dict(), "best_model.pth") - 加载:先实例化一个空模型,然后
model.load_state_dict(...)
任务:
- 模拟训练(这里直接保存初始模型即可)。
- 将模型参数保存到
iris_model.pth。 - 删除原模型,创建一个新模型,加载参数,验证是否复活。
importos# --- 1. 保存 (Save) ---save_path="iris_model.pth"print(f"💾 正在保存模型参数到:{save_path}...")# 这里的 state_dict() 就是那本“秘籍”torch.save(model.state_dict(),save_path)print("✅ 保存成功!")# 检查文件是否存在print(f"文件存在性检查:{os.path.exists(save_path)}")# --- 2. 加载 (Load) ---print("\n🔄 正在模拟加载过程...")# 假设我们在另一台电脑上,首先需要定义同样的模型结构(空壳)new_model=MLP()# 此时 new_model 的参数是随机初始化的# 我们把保存的参数加载进去# weights_only=True 是为了安全(防止pickle注入),新版本推荐加上state_dict=torch.load(save_path,weights_only=True)new_model.load_state_dict(state_dict)print("✅ 模型加载完毕!")print("新模型 fc1 偏置的前5个值:",new_model.fc1.bias[:5].detach().numpy())💾 正在保存模型参数到: iris_model.pth ... ✅ 保存成功! 文件存在性检查: True 🔄 正在模拟加载过程... ✅ 模型加载完毕! 新模型 fc1 偏置的前5个值: [ 0.01631749 -0.33050632 -0.18811053 -0.42050672 -0.4696836 ]步骤 3:推理模式 (Inference)
场景描述:
模型训练好并加载后,就可以上线使用了。
在推理(预测)阶段,有两个关键动作:
model.eval(): 告诉模型“我要考试了”,关闭 Dropout 和 BatchNorm 等训练专用的层。torch.no_grad(): 告诉 PyTorch “不需要算梯度”,这样能省大量内存并加速。
任务:
- 准备一条新的测试数据。
- 切换到推理模式。
- 预测这条数据属于哪一类鸢尾花。
# --- 1. 准备一条假数据 ---# 假设有4个特征:花萼长、宽,花瓣长、宽# 注意:输入必须是 Tensor,且通常需要加一个 batch 维度 (1, 4)sample_data=torch.tensor([[5.1,3.5,1.4,0.2]])print(f"输入数据形状:{sample_data.shape}")# --- 2. 推理流程 (标准范式) ---# A. 切换评估模式new_model.eval()# B. 关闭梯度计算上下文withtorch.no_grad():# 前向传播outputs=new_model(sample_data)# 获取预测结果# outputs 是 (1, 3) 的概率分布(Logits)print(f"模型原始输出 (Logits):{outputs}")# 转化为概率 (Softmax)probs=torch.softmax(outputs,dim=1)print(f"预测概率:{probs}")# 取概率最大的类别索引predicted_class=torch.argmax(probs,dim=1).item()# --- 3. 结果解读 ---class_names=['Setosa','Versicolor','Virginica']print(f"\n🔮 最终预测类别:{predicted_class}->{class_names[predicted_class]}")输入数据形状: torch.Size([1, 4]) 模型原始输出 (Logits): tensor([[ 0.0609, -0.5886, -0.3129]]) 预测概率: tensor([[0.4524, 0.2363, 0.3113]]) 🔮 最终预测类别: 0 -> Setosa🎓 Day 35 总结:深度学习基础通关!
恭喜你!完成了从 Numpy 手搓感知机,到 PyTorch 搭建、训练、保存、推理的全过程。
回顾今天的重点:
- 参数量: 以后看到论文里的 “10B parameters” (100亿参数),你就知道那是
numel()累加出来的。 - State Dict: 模型文件本质上就是一个 Python 字典,存着
{'fc1.weight': tensor(...), ...}。 - Eval Mode: 预测时不加
model.eval()和no_grad()是新手最容易犯的错误,可能导致结果不准或显存爆炸。
Next Level (预告):
从明天开始,我们将不再处理简单的表格数据。
我们将进入计算机视觉 (Computer Vision)的世界,去处理真正的图像数据。卷积神经网络 (CNN)、ResNet、迁移学习……激动人心的旅程才刚刚开始!
准备好你的 GPU,我们明天见!🚀