From c162bd97282e464d7c6157377ea0e0478f8330c4 Mon Sep 17 00:00:00 2001 From: wolves Date: Sun, 16 Mar 2025 20:29:33 +0800 Subject: [PATCH] 250316 --- lab/8_FC-MNIST.ipynb | 4 +- lab/9_CNN-MNIST.ipynb | 279 +++++++++++++++++++++++++++++++++++++++++- 2 files changed, 279 insertions(+), 4 deletions(-) diff --git a/lab/8_FC-MNIST.ipynb b/lab/8_FC-MNIST.ipynb index e9277a7..1e3ac03 100644 --- a/lab/8_FC-MNIST.ipynb +++ b/lab/8_FC-MNIST.ipynb @@ -500,9 +500,7 @@ "# 导入数据\n", "trainset = datasets.MNIST(root='./data',train=True,download=True,transform=transform)\n", "trainloader = torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True)\n", - "\n", - "trainset = datasets.MNIST(root='./data',train=True,download=True,transform=transform)\n", - "testloader = torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True)\n", + "testloader = torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=False)\n", "\n", "# 定义模型\n", "class SimpleNet(nn.Module):\n", diff --git a/lab/9_CNN-MNIST.ipynb b/lab/9_CNN-MNIST.ipynb index a291974..d0c3262 100644 --- a/lab/9_CNN-MNIST.ipynb +++ b/lab/9_CNN-MNIST.ipynb @@ -4,7 +4,16 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# CNN实现MNIST手写数字识别" + "# CNN实现MNIST手写数字识别\n", + "\n", + "## 模型结构\n", + "- 32个卷积核,大小为3x3\n", + "- 池化核大小为2x2\n", + "- 64个卷积核,大小为3x3\n", + "- 池化核大小为2x2\n", + "- 展平\n", + "- 全连接层,64个神经元\n", + "- 全连接层,10个神经元" ] }, { @@ -299,6 +308,274 @@ "predicted_class = tf.argmax(predictions, axis=1).numpy()[0]\n", "print(f\"预测结果:{predicted_class}\")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CNN实现MNIST手写数字识别 - torch" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Epoch 1/10: 100%|██████████| 938/938 [00:26<00:00, 35.81batch/s, Loss=0.1387, Accuracy=95.90]\n", + "Epoch 2/10: 100%|██████████| 938/938 [00:26<00:00, 34.93batch/s, Loss=0.0491, Accuracy=98.49]\n", + "Epoch 3/10: 100%|██████████| 938/938 [00:26<00:00, 35.87batch/s, Loss=0.0362, Accuracy=98.86]\n", + "Epoch 4/10: 100%|██████████| 938/938 [00:26<00:00, 35.98batch/s, Loss=0.0297, Accuracy=99.08]\n", + "Epoch 5/10: 100%|██████████| 938/938 [00:26<00:00, 35.63batch/s, Loss=0.0218, Accuracy=99.30]\n", + "Epoch 6/10: 100%|██████████| 938/938 [00:26<00:00, 35.56batch/s, Loss=0.0198, Accuracy=99.36]\n", + "Epoch 7/10: 100%|██████████| 938/938 [00:26<00:00, 35.84batch/s, Loss=0.0159, Accuracy=99.46]\n", + "Epoch 8/10: 100%|██████████| 938/938 [00:27<00:00, 34.09batch/s, Loss=0.0125, Accuracy=99.58]\n", + "Epoch 9/10: 100%|██████████| 938/938 [00:26<00:00, 35.21batch/s, Loss=0.0125, Accuracy=99.57]\n", + "Epoch 10/10: 100%|██████████| 938/938 [00:27<00:00, 34.48batch/s, Loss=0.0125, Accuracy=99.58]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test Loss: 0.0438, Test Accuracy: 98.95%\n", + "模型已保存 ./models/mnist_model_cnn_torch.pth\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from modulefinder import test\n", + "import torch\n", + "from torch.utils.data import DataLoader # 显式导入DataLoader\n", + "import torch.nn as nn\n", + "from torchvision import datasets, transforms\n", + "from tqdm import tqdm\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# 处理数据\n", + "transform = transforms.Compose([\n", + " transforms.ToTensor(), # 添加 ToTensor\n", + " transforms.Normalize(mean=[0.5],std=[0.5])\n", + "])\n", + "\n", + "# 导入数据\n", + "trainset = datasets.MNIST(root='./data',train=True,download=True,transform=transform)\n", + "testset = datasets.MNIST(root='./data',train=False,download=True,transform=transform) # 修正测试集\n", + "\n", + "trainloader = DataLoader(trainset,batch_size=64,shuffle=True)\n", + "testloader = DataLoader(testset,batch_size=64,shuffle=False) # 使用正确的测试集\n", + "\n", + "class SimpleNet(nn.Module):\n", + " def __init__(self):\n", + " super(SimpleNet,self).__init__()\n", + " self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1)\n", + " self.relu1 = nn.ReLU()\n", + " self.pool1 = nn.MaxPool2d(kernel_size=2)\n", + " \n", + " self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1)\n", + " self.relu2 = nn.ReLU()\n", + " self.pool2 = nn.MaxPool2d(kernel_size=2)\n", + " \n", + " self.flat = nn.Flatten()\n", + " self.fc1 = nn.Linear(64 * 5 * 5, 128)\n", + " self.fc2 = nn.Linear(128, 10)\n", + "\n", + " def forward(self, x):\n", + " x = self.pool1(self.relu1(self.conv1(x)))\n", + " x = self.pool2(self.relu2(self.conv2(x)))\n", + " x = self.flat(x)\n", + " x = self.fc1(x)\n", + " x = self.fc2(x)\n", + " return x\n", + "\n", + "model = SimpleNet()\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "model.to(device)\n", + "\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = torch.optim.Adam(model.parameters(),lr=0.001)\n", + "\n", + "epochs = 10\n", + "train_losses = []\n", + "train_accuracies = []\n", + "\n", + "for epoch in range(epochs):\n", + " running_loss = 0.0\n", + " correct = 0\n", + " total = 0\n", + "\n", + " train_iter = tqdm(trainloader, desc=f'Epoch {epoch+1}/{epochs}', unit='batch')\n", + "\n", + " for images,labels in train_iter:\n", + " images, labels = images.to(device), labels.to(device)\n", + " optimizer.zero_grad()\n", + "\n", + " output = model.forward(images)\n", + " loss = criterion(output,labels)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " _,predicted = torch.max(output.data,1)\n", + " total += labels.size(0)\n", + " correct += (predicted == labels).sum().item()\n", + "\n", + " running_loss += loss.item()\n", + " train_iter.set_postfix({\n", + " 'Loss': f\"{running_loss / len(train_iter):.4f}\",\n", + " 'Accuracy': f\"{100 * correct / total:.2f}\"\n", + " })\n", + "\n", + " # 将统计信息移到循环外\n", + " epoch_loss = running_loss / len(trainloader)\n", + " epoch_acc = 100 * correct / total\n", + " train_losses.append(epoch_loss)\n", + " train_accuracies.append(epoch_acc) \n", + "\n", + "model.eval()\n", + "test_loss = 0.0\n", + "test_correct = 0\n", + "test_total = 0\n", + "\n", + "with torch.no_grad():\n", + " for images,labels in testloader:\n", + " images, labels = images.to(device), labels.to(device)\n", + " output = model.forward(images)\n", + " loss = criterion(output,labels)\n", + "\n", + " _,predicted = torch.max(output.data,1)\n", + " test_total += labels.size(0)\n", + " test_correct += (predicted == labels).sum().item()\n", + " test_loss += loss.item()\n", + "\n", + "test_loss /= len(testloader)\n", + "test_acc = 100 * test_correct / test_total\n", + "print(f\"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%\")\n", + "\n", + "torch.save({\n", + " 'model_state_dict':model.state_dict(),\n", + " 'optimizer_state_dict':optimizer.state_dict(),\n", + " 'model_definition':SimpleNet\n", + "},'./models/mnist_model_cnn_torch.pth')\n", + "print(\"模型已保存 ./models/mnist_model_cnn_torch.pth\")\n", + "\n", + "# 训练结束后绘制曲线\n", + "plt.figure(figsize=(12, 5))\n", + "\n", + "# 绘制损失曲线\n", + "plt.subplot(1, 2, 1)\n", + "plt.plot(train_losses, label='Train Loss')\n", + "plt.xlabel('Epoch')\n", + "plt.ylabel('Loss')\n", + "plt.title('Training Loss')\n", + "plt.legend()\n", + "\n", + "\n", + "# 绘制准确率曲线\n", + "plt.subplot(1, 2, 2)\n", + "plt.plot(train_accuracies, label='Train Accuracy')\n", + "plt.xlabel('Epoch')\n", + "plt.ylabel('Accuracy')\n", + "plt.title('Training Accuracy')\n", + "plt.legend()\n", + "\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "图片 ./test/0.png 的预测结果:0\n", + "概率分布:[9.9999154e-01 5.6558246e-08 3.9694714e-07 5.7686158e-09 4.9810051e-08\n", + " 1.6635587e-09 4.9098890e-06 1.9493548e-08 1.0571516e-06 2.1142114e-06]\n", + "----------------------------------------\n", + "图片 ./test/2.png 的预测结果:2\n", + "概率分布:[2.3606985e-07 1.3911462e-07 9.9999619e-01 1.7604706e-06 1.4577779e-07\n", + " 3.0409397e-09 4.2655217e-09 8.3528255e-07 5.8725487e-07 9.2491959e-08]\n", + "----------------------------------------\n", + "图片 ./test/3.png 的预测结果:3\n", + "概率分布:[1.3083167e-14 2.3400937e-10 5.6043814e-10 9.9999726e-01 8.9110906e-12\n", + " 1.3752571e-06 2.9259326e-15 1.3170462e-06 1.0465531e-09 3.7561829e-08]\n", + "----------------------------------------\n" + ] + } + ], + "source": [ + "# 多图片同时预测\n", + "import torch\n", + "from PIL import Image\n", + "import torchvision.transforms as transforms\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# 加载模型\n", + "checkpoint = torch.load('./models/mnist_model_cnn_torch.pth')\n", + "model = checkpoint['model_definition']()\n", + "model.load_state_dict(checkpoint['model_state_dict'])\n", + "model.eval() # 设置为评估模式\n", + "\n", + "# 图片预处理\n", + "transform = transforms.Compose([\n", + " transforms.Resize((28, 28)),\n", + " transforms.Grayscale(),\n", + " transforms.ToTensor(),\n", + " transforms.Normalize((0.5,), (0.5,))\n", + "])\n", + "\n", + "# 加载并预处理多个图片\n", + "image_paths = ['./test/0.png', './test/2.png', './test/3.png'] # 添加更多图片路径\n", + "images = [Image.open(path) for path in image_paths]\n", + "processed_images = torch.stack([transform(img) for img in images]) # 将多个图片堆叠成一个批次\n", + "\n", + "# 可视化预处理后的图片\n", + "fig, axes = plt.subplots(1, len(images), figsize=(12, 4))\n", + "for i, img in enumerate(processed_images):\n", + " axes[i].imshow(img.squeeze(), cmap='gray')\n", + " axes[i].set_title(f'Image {i+1}')\n", + "plt.show()\n", + "\n", + "# 进行预测\n", + "with torch.no_grad():\n", + " outputs = model(processed_images) # 直接传入批次数据\n", + " probabilities = torch.nn.functional.softmax(outputs, dim=1)\n", + " predicted_classes = torch.argmax(probabilities, dim=1).numpy()\n", + "\n", + "# 打印预测结果\n", + "for i, (path, pred, prob) in enumerate(zip(image_paths, predicted_classes, probabilities)):\n", + " print(f\"图片 {path} 的预测结果:{pred}\")\n", + " print(f\"概率分布:{prob.numpy()}\")\n", + " print(\"-\" * 40)" + ] } ], "metadata": {