技术 · 2024年3月31日

AI学习之路(1)失败的图像识别

尝试了一下pytorch数字识别

# 导入所需库
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt# 定义神经网络模型类
import torchvision.transforms as transforms
from PIL import Image

class Net(torch.nn.Module):
    def __init__(self):
        # 继承父类初始化方法
        super().__init__()
        # 定义全连接层(FC):输入维度为28*28,输出维度为64;中间层和最后一层保持64维输出,最终输出维度为10(对应MNIST数据集的10个类别)
        self.fc1 = torch.nn.Linear(28 * 28, 64)
        self.fc2 = torch.nn.Linear(64, 64)
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)
    def forward(self, x):
        # 前向传播过程:使用ReLU激活函数对每一层FC输出进行非线性变换,最后一层使用log_softmax函数计算每个类别的对数概率
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        return x

# 获取MNIST数据集加载器的函数,根据是否为训练数据设置参数
def get_data_loader(is_train):
    # 定义将图像数据转换为张量的转换器
    to_tensor = transforms.Compose([transforms.ToTensor()])
    # 加载MNIST数据集,指定是否为训练集、转换器、并下载数据(如果本地不存在)
    data_set = MNIST("", is_train, transform=to_tensor, download=True)
    # 创建DataLoader实例,设置批次大小为15,开启随机打乱
    return DataLoader(data_set, batch_size=15, shuffle=True)
# 评估模型准确率的函数,接收测试数据集和模型作为输入

def evaluate(test_data, net):
    n_correct = 0  # 正确预测数量计数器
    n_total = 0  # 总预测数量计数器
    # 使用无梯度模式进行评估
    with torch.no_grad():
        for (x, y) in test_data:
            # 计算模型对当前批次数据的输出
            outputs = net.forward(x.view(-1, 28 * 28))
            # 遍历每条样本,若模型预测类别与真实标签相符,则增加正确预测数量
            for i, output in enumerate(outputs):
                if torch.argmax(output) == y[i]:
                    n_correct += 1
                n_total += 1
    # 返回准确率(正确预测数量/总预测数量)
    return n_correct / n_total
# 主函数,执行训练及测试流程
def preprocess_image(image_path): # 读取图片
    img = Image.open(image_path).convert('L')  # 转为灰度图
    # 调整尺寸至28x28像素
    img = img.resize((28, 28))
    # 归一化至[0, 1]区间
    img = np.array(img) / 255.0
    # 转换为PyTorch张量
    img_tensor = torch.from_numpy(img).unsqueeze(0).float()
    return img_tensor

def predict_image(net, image_path):#识别图片
    img_tensor = preprocess_image(image_path)
    with torch.no_grad():
        output = net.forward(img_tensor.view(-1, 28 * 28))
        prediction = torch.argmax(output)
    return int(prediction)

def main():
    # 获取训练集和测试集数据加载器
    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)
    # 实例化神经网络模型
    net = Net()
    # 输出初始模型在测试集上的准确率
    print("initial accuracy:", evaluate(test_data, net))
    # 设置优化器(Adam算法,学习率为0.001),用于更新模型参数
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    # 进行2个训练周期(epochs)
    for epoch in range(3):
        # 遍历训练集中的每个批次数据
        for (x, y) in train_data:
            # 清零梯度
            net.zero_grad()
            # 前向传播计算模型输出
            output = net.forward(x.view(-1, 28 * 28))
            # 计算损失(使用负对数似然损失函数)
            loss = torch.nn.functional.nll_loss(output, y)
            # 反向传播计算梯度
            loss.backward()
            # 使用优化器更新模型参数
            optimizer.step()
        # 在每个周期结束后,输出当前模型在测试集上的准确率
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))
    # 展示前4个测试样本及其模型预测结果
    for (n, (x, _)) in enumerate(test_data):
        if n > 3:
            break
        # 获取模型对当前样本的预测类别
        predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
        # 绘制图像并显示预测结果
        plt.figure(n)
        plt.imshow(x[0].view(28, 28))
        plt.title("prediction: " + str(int(predict)))
    # 显示所有绘制的图像
    plt.show()
    image_path = r"D:\QQ\2481153962\FileRecv\MobileFile\IMG_20240327_155037_edit_230949122434450.jpg"
    predicted_digit = predict_image(net, image_path)
    print(f"Predicted digit: {predicted_digit}")

# 如果脚本直接运行,执行主函数
if __name__ == "__main__":
    main()

编写之后发现对于训练集没啥毛病,但是我的图片会寄,于是又写了个代码看了眼我图片的灰度图

唉我去合着没啥毛病啊,安详地逝了。(

AI学习之路道阻且长 待我重新归来。

苏ICP备2024067700号 | 苏公网安备32098202000238号