手把手教你部署CV-UNet,5分钟实现智能去背
2026/1/22 7:35:46
@浙大疏锦行
作业:对信贷数据集进行训练后保持权重,后继续训练50次,采取早停策略
import torch import torch.nn as nn import torch.optim as optim from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.preprocessing import MinMaxScaler import time import matplotlib.pyplot as plt from tqdm import tqdm import warnings warnings.filterwarnings("ignore") # 检查GPU是否可用,优先使用GPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 若有多个GPU,可指定具体GPU,例如cuda:1 # 验证GPU是否真的在使用(可选) if torch.cuda.is_available(): print(f"GPU名称: {torch.cuda.get_device_name(0)}") torch.cuda.empty_cache() # 清空GPU缓存 # 加载信贷数据集 iris = load_iris() X = iris.data # 特征数据 y = iris.target # 标签数据 # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 归一化数据 scaler = MinMaxScaler() X_train = scaler.fit_transform(X_train) X_test = scaler.transform(X_test) # 转换为PyTorch张量并强制移至指定设备(GPU/CPU) X_train = torch.FloatTensor(X_train).to(device, non_blocking=True) y_train = torch.LongTensor(y_train).to(device, non_blocking=True) X_test = torch.FloatTensor(X_test).to(device, non_blocking=True) y_test = torch.LongTensor(y_test).to(device, non_blocking=True) class MLP(nn.Module): def __init__(self): super(MLP, self).__init__() self.fc1 = nn.Linear(4, 10) # 输入层(信贷数据集需修改输入维度) self.relu = nn.ReLU() self.fc2 = nn.Linear(10, 3) # 输出层(信贷数据集需修改输出维度) def forward(self, x): out = self.fc1(x) out = self.relu(out) out = self.fc2(out) return out # 实例化模型并移至GPU model = MLP().to(device) criterion = nn.CrossEntropyLoss() # 分类损失函数 optimizer = optim.SGD(model.parameters(), lr=0.01) # 优化器 # 首次训练参数 first_train_epochs = 20000 train_losses = [] # 首次训练损失 test_losses = [] epochs = [] # 早停参数(首次训练和继续训练共用相同策略) best_test_loss = float('inf') best_epoch = 0 patience = 50 counter = 0 early_stopped = False print("\n===== 开始首次训练 =====") start_time = time.time() with tqdm(total=first_train_epochs, desc="首次训练进度", unit="epoch") as pbar: for epoch in range(first_train_epochs): model.train() # 前向传播 outputs = model(X_train) train_loss = criterion(outputs, y_train) # 反向传播和优化 optimizer.zero_grad() train_loss.backward() optimizer.step() # 每200轮记录损失并检查早停 if (epoch + 1) % 200 == 0: model.eval() with torch.no_grad(): test_outputs = model(X_test) test_loss = criterion(test_outputs, y_test) train_losses.append(train_loss.item()) test_losses.append(test_loss.item()) epochs.append(epoch + 1) # 更新进度条 pbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'}) # 早停逻辑 if test_loss.item() < best_test_loss: best_test_loss = test_loss.item() best_epoch = epoch + 1 counter = 0 # 保存最佳模型 torch.save(model.state_dict(), 'best_model.pth') else: counter += 1 if counter >= patience: print(f"\n首次训练早停触发!在第{epoch+1}轮,测试集损失已有{patience}轮未改善。") print(f"最佳测试集损失出现在第{best_epoch}轮,损失值为{best_test_loss:.4f}") early_stopped = True break # 更新进度条 if (epoch + 1) % 1000 == 0: pbar.update(1000) # 补全进度条 if pbar.n < first_train_epochs: pbar.update(first_train_epochs - pbar.n) # 保存首次训练结束后的模型权重(核心修改点1) torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch + 1, 'best_loss': best_test_loss }, 'trained_model.pth') print(f"\n首次训练完成,权重已保存至 trained_model.pth") print(f"首次训练总耗时: {time.time() - start_time:.2f} 秒") print("\n===== 加载权重并开始继续训练 =====") # 加载保存的权重(核心修改点2) checkpoint = torch.load('trained_model.pth', map_location=device) model.load_state_dict(checkpoint['model_state_dict']) print(f"加载了首次训练至第{checkpoint['epoch']}轮的权重,最佳损失: {checkpoint['best_loss']:.4f}") # 重新初始化优化器(核心修改点3:继续训练必须重置优化器) optimizer = optim.SGD(model.parameters(), lr=0.01) # 若需要延续优化器状态,可取消下面注释(视场景选择) # optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # 继续训练的参数 continue_train_epochs = 50 # 目标继续训练50轮 continue_train_losses = [] # 继续训练损失 continue_test_losses = [] continue_epochs = [] # 重置早停参数(针对继续训练) continue_best_loss = checkpoint['best_loss'] continue_counter = 0 continue_early_stop = False start_continue_time = time.time() with tqdm(total=continue_train_epochs, desc="继续训练进度", unit="epoch") as pbar: for epoch in range(continue_train_epochs): model.train() # 前向传播 outputs = model(X_train) train_loss = criterion(outputs, y_train) # 反向传播和优化 optimizer.zero_grad() train_loss.backward() optimizer.step() # 每1轮就检查损失和早停(继续训练轮数少,无需间隔) model.eval() with torch.no_grad(): test_outputs = model(X_test) test_loss = criterion(test_outputs, y_test) continue_train_losses.append(train_loss.item()) continue_test_losses.append(test_loss.item()) continue_epochs.append(epoch + 1) # 更新进度条 pbar.set_postfix({'Train Loss': f'{train_loss.item():.4f}', 'Test Loss': f'{test_loss.item():.4f}'}) pbar.update(1) # 继续训练的早停逻辑 if test_loss.item() < continue_best_loss: continue_best_loss = test_loss.item() continue_counter = 0 # 保存继续训练后的最佳模型 torch.save(model.state_dict(), 'continue_best_model.pth') else: continue_counter += 1 if continue_counter >= patience: print(f"\n继续训练早停触发!在第{epoch+1}轮,测试集损失已有{patience}轮未改善。") print(f"继续训练最佳损失: {continue_best_loss:.4f}") continue_early_stop = True break print(f"继续训练完成,总耗时: {time.time() - start_continue_time:.2f} 秒") print(f"继续训练实际轮数: {len(continue_epochs)} 轮(早停触发则少于50轮)") print("\n===== 最终模型评估 =====") model.load_state_dict(torch.load('continue_best_model.pth', map_location=device)) model.eval() with torch.no_grad(): outputs = model(X_test) _, predicted = torch.max(outputs, 1) correct = (predicted == y_test).sum().item() accuracy = correct / y_test.size(0) print(f'测试集最终准确率: {accuracy * 100:.2f}%') # ====================== 8. 可视化 ====================== plt.figure(figsize=(12, 6)) # 绘制首次训练损失 plt.subplot(1, 2, 1) plt.plot(epochs, train_losses, label='Train Loss') plt.plot(epochs, test_losses, label='Test Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('首次训练损失曲线') plt.legend() plt.grid(True) # 绘制继续训练损失 plt.subplot(1, 2, 2) plt.plot(continue_epochs, continue_train_losses, label='Train Loss') plt.plot(continue_epochs, continue_test_losses, label='Test Loss') plt.xlabel('Continue Epoch') plt.ylabel('Loss') plt.title('继续训练50轮损失曲线') plt.legend() plt.grid(True) plt.tight_layout() plt.show()