聊城市网站建设_网站建设公司_Oracle_seo优化
2025/12/26 11:45:47 网站建设 项目流程

import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont

解决Matplotlib中文显示问题

plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans'] # 优先中文字体
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题

设备配置

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

目标字符集

CHAR_SET = ['一', '二', '三', '四', '五', '六', '七', '八', '九', '十', '年', '月', '日']
CHAR_TO_ID = {char: idx for idx, char in enumerate(CHAR_SET)}
ID_TO_CHAR = {idx: char for idx, char in enumerate(CHAR_SET)}
NUM_CLASSES = len(CHAR_SET)

图像/训练参数

IMG_SIZE = 64 # 增大图像尺寸,适配汉字复杂结构
BATCH_SIZE = 64
NUM_EPOCHS = 15
LR = 0.0005 # 降低学习率,避免震荡
NUM_SAMPLES_TRAIN = 2000 # 每个字符生成更多样本
NUM_SAMPLES_TEST = 500

class ChineseCharDataset(Dataset):
def init(self, char_set, img_size, num_samples_per_char, transform=None, is_train=True):
self.char_set = char_set
self.img_size = img_size
self.num_samples_per_char = num_samples_per_char
self.transform = transform
self.is_train = is_train
self.data, self.labels = self._generate_chinese_chars()

def _get_font(self):"""获取支持中文的字体文件(Windows默认路径)"""font_paths = ['C:/Windows/Fonts/simhei.ttf',  # 黑体'C:/Windows/Fonts/msyh.ttc',  # 微软雅黑'C:/Windows/Fonts/simsun.ttc'  # 宋体]for path in font_paths:if os.path.exists(path):return ImageFont.truetype(path, size=int(self.img_size * 0.6))  # 字体大小适配图像raise Exception("未找到中文字体文件,请检查路径!")def _generate_chinese_chars(self):"""用PIL生成高质量手写风格汉字图像"""data = []labels = []font = self._get_font()for char in self.char_set:char_id = CHAR_TO_ID[char]for _ in range(self.num_samples_per_char):# 创建空白灰度图像img = Image.new('L', (self.img_size, self.img_size), color=0)  # 0=黑色背景draw = ImageDraw.Draw(img)# 随机化参数(模拟手写差异)# 1. 随机字体大小(±10%)font_size = int(self.img_size * 0.6 * np.random.uniform(0.9, 1.1))font = ImageFont.truetype(font.path, size=font_size)# 2. 随机位置(避免字符超出图像)text_bbox = draw.textbbox((0, 0), char, font=font)text_width = text_bbox[2] - text_bbox[0]text_height = text_bbox[3] - text_bbox[1]x = np.random.randint(0, self.img_size - text_width)y = np.random.randint(0, self.img_size - text_height)# 3. 随机笔画粗细(模拟手写力度)fill_color = np.random.randint(200, 255)  # 白色字符(200-255)# 绘制字符draw.text((x, y), char, font=font, fill=fill_color)# 训练集添加数据增强(旋转、平移、噪声)if self.is_train:# 随机旋转(-15° ~ 15°)img = img.rotate(np.random.uniform(-15, 15), expand=False, fillcolor=0)# 添加高斯噪声img_np = np.array(img)noise = np.random.normal(0, 15, img_np.shape).astype(np.int16)img_np = np.clip(img_np + noise, 0, 255).astype(np.uint8)img = Image.fromarray(img_np)# 转换为张量并归一化if self.transform:img = self.transform(img)data.append(img)labels.append(char_id)return data, labelsdef __len__(self):return len(self.data)def __getitem__(self, idx):return self.data[idx], self.labels[idx]

transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5]) # 归一化到[-1, 1]
])

生成训练集和测试集

train_dataset = ChineseCharDataset(
CHAR_SET, IMG_SIZE, NUM_SAMPLES_TRAIN, transform=transform, is_train=True
)
test_dataset = ChineseCharDataset(
CHAR_SET, IMG_SIZE, NUM_SAMPLES_TEST, transform=transform, is_train=False
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

class ChineseCharCNN(nn.Module):
def init(self, num_classes):
super(ChineseCharCNN, self).init()
# 特征提取器(加深网络,添加BatchNorm防止过拟合)
self.features = nn.Sequential(
# 第一层:64x64 → 32x32
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2, 2),
# 第二层:32x32 → 16x16
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(2, 2),
# 第三层:16x16 → 8x8
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(2, 2),
# 第四层:8x8 → 4x4
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
# 分类头(适配64x64输入的展平维度)
self.classifier = nn.Sequential(
nn.Linear(512 * 4 * 4, 1024),
nn.ReLU(),
nn.Dropout(0.5), # 防止过拟合
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)

def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)  # 展平x = self.classifier(x)return x

def train_model(model, criterion, optimizer, train_loader, num_epochs):
model.train()
train_losses = []
train_accs = []
start_time = time.time()

for epoch in range(num_epochs):total_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)# 前向传播outputs = model(data)loss = criterion(outputs, target)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 统计损失和准确率total_loss += loss.item() * data.size(0)_, predicted = torch.max(outputs, 1)total += target.size(0)correct += (predicted == target).sum().item()# 计算本轮指标avg_loss = total_loss / len(train_loader.dataset)train_acc = 100. * correct / totaltrain_losses.append(avg_loss)train_accs.append(train_acc)print(f'Epoch [{epoch + 1}/{num_epochs}] | Loss: {avg_loss:.4f} | Train Acc: {train_acc:.2f}%')train_time = time.time() - start_time
return model, train_time, train_losses, train_accs

def evaluate_model(model, test_loader):
model.eval()
all_preds = []
all_targets = []

with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)outputs = model(data)_, predicted = torch.max(outputs, 1)all_preds.extend(predicted.cpu().numpy())all_targets.extend(target.cpu().numpy())# 手动计算宏观平均指标
all_preds = np.array(all_preds)
all_targets = np.array(all_targets)TP = np.zeros(NUM_CLASSES)
FP = np.zeros(NUM_CLASSES)
FN = np.zeros(NUM_CLASSES)for cls in range(NUM_CLASSES):TP[cls] = np.sum((all_targets == cls) & (all_preds == cls))FP[cls] = np.sum((all_targets != cls) & (all_preds == cls))FN[cls] = np.sum((all_targets == cls) & (all_preds != cls))# 避免除以0
precision = np.mean(TP / (TP + FP + 1e-8))
recall = np.mean(TP / (TP + FN + 1e-8))
f1 = np.mean(2 * (precision * recall) / (precision + recall + 1e-8))
accuracy = np.sum(TP) / len(all_targets)return accuracy, precision, recall, f1

初始化模型、损失函数、优化器

model = ChineseCharCNN(NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4) # 添加权重衰减

开始训练

print("=" * 60 + " 开始训练汉字识别模型 " + "=" * 60)
model, train_time, train_losses, train_accs = train_model(model, criterion, optimizer, train_loader, NUM_EPOCHS)

评估模型

print("\n" + "=" * 60 + " 模型评估结果 " + "=" * 60)
accuracy, precision, recall, f1 = evaluate_model(model, test_loader)
print(f"测试集准确率: {accuracy * 100:.2f}%")
print(f"测试集精确率: {precision * 100:.2f}%")
print(f"测试集召回率: {recall * 100:.2f}%")
print(f"测试集F1值: {f1 * 100:.2f}%")
print(f"总训练时间: {train_time:.2f}秒")

def plot_training_curves(losses, accs):
"""绘制训练损失和准确率曲线"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# 损失曲线
ax1.plot(range(1, len(losses) + 1), losses, 'b-', label='训练损失')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('训练损失变化')
ax1.legend()
ax1.grid(True)# 准确率曲线
ax2.plot(range(1, len(accs) + 1), accs, 'r-', label='训练准确率')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('训练准确率变化')
ax2.legend()
ax2.grid(True)plt.tight_layout()
plt.savefig('chinese_char_training_curves.png', dpi=300)
plt.show()

def plot_predictions(model, test_loader, id_to_char, num_samples=12):
model.eval()
data_iter = iter(test_loader)
images, labels = next(data_iter)
images, labels = images.to(device), labels.to(device)

with torch.no_grad():outputs = model(images)_, predicted = torch.max(outputs, 1)
fig, axes = plt.subplots(3, 4, figsize=(15, 10))
axes = axes.flatten()for i in range(num_samples):img = images[i].cpu().numpy().squeeze()img = (img * 0.5) + 0.5  # 从[-1,1]转回[0,1]axes[i].imshow(img, cmap='gray')axes[i].axis('off')true_char = id_to_char[labels[i].item()]pred_char = id_to_char[predicted[i].item()]color = 'green' if true_char == pred_char else 'red'axes[i].set_title(f'真实: {true_char}\n预测: {pred_char}', color=color)plt.tight_layout()
plt.savefig('chinese_char_predictions.png', dpi=300)
plt.show()

print("3020")
plot_training_curves(train_losses, train_accs)
plot_predictions(model, test_loader, ID_TO_CHAR)

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询