【Python】MacBook M系列芯片Anaconda下载Pytorch,并开发一个简单的数字识别代码(附带踩坑记录)

文章目录

  • 配置镜像源
  • 下载Pytorch
  • 验证
  • 使用Pytorch进行数字识别

配置镜像源

Anaconda下载完毕之后,有两种方式下载pytorch,一种是用页面可视化的方式去下载,另一种方式就是直接用命令行工具去下载。
在这里插入图片描述
但是由于默认的Anaconda走的是外网,所以下载很慢,我们得首先配置镜像源,这里推荐用清华的,之前用中科大的出问题了,换成清华马上就好了。。。

打开Termial或者iTerm2
输入如下命令

conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free
conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r
conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/pro
conda config --add channels http://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2

然后输入如下命令查看是否ok了

conda config --show channels

在输入如下命令

conda config --set show_channel_urls yes

这个时候你的配置基本就完成了,接下来你就可以开始下载了

下载Pytorch

pytorch官网
进入到官网,然后基于你的机器配置选择命令
在这里插入图片描述
然后将命令放入到命令行中进行运行。
特别注意!!!
这里一定要把梯子等工具都关掉,不然会出现HTTP相关的异常。
可以考虑使用如下命令处理一下

conda config --set ssl_verify false

如果踩坑了,从如下几个地方思考:

  1. 镜像源问题,换镜像源
  2. ssl验证关闭,使用上面的命令
  3. 别开梯子!!!!!!!

验证

使用如下命令就可以查看是否安装成功了

conda list | grep pytorch

在这里插入图片描述

使用Pytorch进行数字识别

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
from PIL import Image

# 定义神经网络模型
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(28*28, 64)  # 第一个全连接层,将输入从784维映射到64维
        self.fc2 = torch.nn.Linear(64, 64)     # 第二个全连接层,将输入从64维映射到64维
        self.fc3 = torch.nn.Linear(64, 64)     # 第三个全连接层,将输入从64维映射到64维
        self.fc4 = torch.nn.Linear(64, 10)     # 第四个全连接层,将输入从64维映射到10维(对应10个类别)
    
    def forward(self, x):
        x = torch.nn.functional.relu(self.fc1(x))  # 应用ReLU激活函数
        x = torch.nn.functional.relu(self.fc2(x))  # 应用ReLU激活函数
        x = torch.nn.functional.relu(self.fc3(x))  # 应用ReLU激活函数
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)  # 应用log_softmax激活函数
        return x

# 定义数据加载函数
def get_data_loader(is_train):
    to_tensor = transforms.Compose([transforms.ToTensor()])  # 定义数据转换
    data_set = MNIST("", is_train, transform=to_tensor, download=True)  # 加载MNIST数据集
    return DataLoader(data_set, batch_size=15, shuffle=True)  # 创建数据加载器

# 定义模型评估函数
def evaluate(test_data, net):
    n_correct = 0
    n_total = 0
    with torch.no_grad():  # 禁用梯度计算
        for (x, y) in test_data:
            outputs = net.forward(x.view(-1, 28*28))  # 前向传播计算输出
            for i, output in enumerate(outputs):
                if torch.argmax(output) == y[i]:  # 比较预测结果与真实标签
                    n_correct += 1
                n_total += 1
    return n_correct / n_total  # 返回准确率

# 定义模型保存函数
def save_model(net, path="mnist_model.pth"):
    torch.save(net.state_dict(), path)  # 保存模型权重到文件

# 定义模型加载函数
def load_model(net, path="mnist_model.pth"):
    net.load_state_dict(torch.load(path))  # 从文件加载模型权重

# 定义图像预测函数
def predict_image(image, net):
    net.eval()  # 设置为评估模式
    with torch.no_grad():  # 禁用梯度计算
        output = net(image.view(-1, 28*28))  # 前向传播计算输出
        predicted = torch.argmax(output, dim=1)  # 获取预测结果
    return predicted.item()  # 返回预测类别

# 定义图像加载函数
def load_image(image_path):
    image = Image.open(image_path).convert('L')  # 打开图像并转换为灰度图
    transform = transforms.Compose([transforms.Resize((28, 28)), transforms.ToTensor()])  # 定义图像转换
    image = transform(image)  # 应用转换
    return image  # 返回处理后的图像

def main():
    train_data = get_data_loader(is_train=True)  # 加载训练数据
    test_data = get_data_loader(is_train=False)  # 加载测试数据
    net = Net()  # 初始化神经网络模型

    # 训练模型
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)  # 定义Adam优化器
    for epoch in range(2):  # 训练2个epoch
        for (x, y) in train_data:
            net.zero_grad()  # 清零梯度
            output = net.forward(x.view(-1, 28*28))  # 前向传播计算输出
            loss = torch.nn.functional.nll_loss(output, y)  # 计算损失
            loss.backward()  # 反向传播计算梯度
            optimizer.step()  # 更新模型参数
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))  # 打印每个epoch后的准确率

    # 保存模型
    save_model(net)

    # 加载模型
    net = Net()  # 初始化新的神经网络模型
    load_model(net)  # 加载已保存的模型权重
    print("Loaded model accuracy:", evaluate(test_data, net))  # 打印加载模型后的准确率

    # 使用模型预测新图像
    image_path = "path_to_your_image.png"  # 替换为你要预测的图像路径
    image = load_image(image_path)  # 加载并预处理图像
    prediction = predict_image(image, net)  # 使用模型进行预测
    print(f"Predicted digit: {prediction}")  # 打印预测结果

if __name__ == "__main__":
    main()  # 运行main函数

第一次运行的时候,会加载数字识别模型到本地,第二次运行的时候,你就可以把训练过程的代码都注释掉了,直接使用这个最终的模型
在这里插入图片描述
第二次运行
你的模型就是这个pth文件
在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/765647.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

3D Gaussian Splatting代码中的forward和backward两个文件代码解读

3dgs代码前向传播部分 先来讨论一下glm,因为定义变量的时候用到了这个。 glm的解释 glm 是指 OpenGL Mathematics,这是一个针对图形编程的数学库。它的全称是 OpenGL Mathematics (GLM),主要用于 OpenGL 的开发。这个库是基于 C 的模板库&…

heic格式转化jpg,手把手教你将heic转换成jpg【办公必备】

一、什么是heic heic格式是一种高效的图片格式,它可以在较小的文件大小下提供高质量的图片。 二、如何打开heic 然而,这种图片因其格式的特殊性,在实际应用中仍存在一些问题:压缩效果可能不够理想,一些老旧的软件和设…

墨烯的C语言技术栈-C语言基础-003

三.数据类型 1.char // 字符数据型 2.short // 短整型 3.int // 整型 4.long // 长整型 5.long long // 更长的整型 6.float // 单精度浮点数 7.double // 双精度浮点数 为什么写代码? 为了解决生活中的问题 购物,点餐,看电影 为什么有这么多类型呢? 因为说的话都是字符型…

Ubuntu下反弹shell的思考

目录 Ubuntu的命令执行环境 bash (Bourne Again SHell): sh (Bourne SHell): dash (Debian Almquist SHell): 它们之间的关系: 可能遇到的问题 一、脚本权限问题 二、命令执行环境(shell解释器)问题 如何解决? 1.修改/bin/sh软连接的指向为bas…

什么美业门店管理系统好用?2024美业收银系统软件排名分享

美业SAAS系统在美容、美发、美甲等行业中十分重要,这种系统为美业提供了一种数字化解决方案,帮助企业更高效地管理业务和客户关系。 美业门店管理系统通常提供预约管理、客户管理、库存管理、报表生成等一系列功能,以满足美容院、美发沙龙等…

iptables防火墙详解、相关命令示例

目录 Linux包过滤防火墙 包过滤的工作层次 iptables的链结构 规则链 默认包括5中规则链(对数据包控制的时机) iptables的表结构 规则表 默认包括4个规则表 数据包过滤的匹配流程 规则表之间的顺序 规则链之间的顺序 规则链内的匹配顺序 匹配…

【Arduino】XIAOFEIYU实验ESP32使用TOUCH触摸模块(图文)

今天XIAOFEIYU继续来实验ESP32使用传感器模块,这次用到的模块为TOUCH触摸模块。 三个针脚分别为正负极,IO针脚。 #define pin 25void setup(){Serial.begin(9600); pinMode(pin, INPUT); }float value 0.0; void loop(){value digitalRead(pin); …

Vue3详解

vite和webpack区别 vite vite使用原生ES模块进行开发,无需在编译时将所有代码转换为JS打包,从而提供了更快的热更新和自动刷新功能; vite在开发模式下没有打包步骤,而是利用浏览器的ES Module Imports特性实现按需编译&#xff…

提高候选人的招聘感受:成功的策略

大约78%的候选人表示,他们的整体应聘体验表明企业对员工的关注。然而,超过一半的候选人透露,他们在招聘过程中有过负面的候选人经历,80%的候选人在经历了令人失望的招聘过程后会公开与他人分享他们的不良经历。 但也有一线希望&am…

友好前端vue脚手架

企业级后台集成方案vue-element-admin-CSDN博客在哔站学习,老师说可以有直接的脚手架(vue-element-admin)立马去搜索,找到了这博主这篇文章 介绍 | vue-element-admin​​​​​​ 官方默认英文版: git clone https:/…

【力扣 - 每日一题】3115. 质数的最大距离(一次遍历、头尾遍历、空间换时间、埃式筛、欧拉筛、打表)Golang实现

原题链接 题目描述 给你一个整数数组 nums。 返回两个(不一定不同的)质数在 nums 中 下标 的 最大距离。 示例 1: 输入: nums [4,2,9,5,3] 输出: 3 解释: nums[1]、nums[3] 和 nums[4] 是质数。因此答…

力扣每日一题 7/2 数学、数论、数组/双指针

博客主页:誓则盟约系列专栏:IT竞赛 专栏关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ 3115.质数的最大距离【中等】 题目: 给你一个整数数组 nums。…

uview文本框组件计数count报错u--textarea

报错内容: [Vue warn]: Error in render: “TypeError: Cannot read property ‘length’ of null” found in —> at uni_modules/uview-ui/components/u-textarea/u-textarea.vue at uni_modules/uview-ui/components/u–textarea/u–textarea.vue mp.runtime.…

C盘清理和管理

本篇是C盘一些常用的管理方法,以及定期清理C盘的方法,大部分情况下都能避免C盘爆红。 C盘清理和管理 C盘存储管理查看存储情况清理存储存储感知清理临时文件清理不需要的 迁移存储 磁盘清理桌面存储管理应用存储管理浏览器微信 工具清理 C盘存储管理 查…

ERROR: No matching distribution found for torch==2.0.1+cu117(比手动下载方便)

ERROR: No matching distribution found for torch2.0.1cu117 遇见这种报错可以把pip install -r requirements.txt修改为 pip install -r requirements.tx --extra-index-url https://download.pytorch.org/whl/cu117 -i https://pypi.tuna.tsinghua.edu.cn/simple或者直接…

大科技公司大量裁员背后的真相

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

网络配线架的隐藏功能

网络布线是确保现代信息社会高效运转的关键技术之一。在这一领域,网络配线架扮演着至关重要 的角色。它不仅仅是一个简单的物理连接点,更拥有许多隐藏功能,这些功能极大地提升了网络的 效率、稳定性和可管理性。 1、集中管理 网络配线架提…

VS2022+Qt+OpenCV Debug模式下,循环中格式转换引起的内存异常问题 debug_heap.cpp

文章目录 前言一、问题二、报错1.提示图片2.提示堆栈3.反汇编位置 三、解决办法总结 前言 最近在使用VS2022,C,OpenCV,Qt开发时,遇到了一个疑难杂症-在循环中执行字符串格式转换会触发内存异常,经过痛苦的排查过程&am…

24/07/02数据结构(1.1201)算法效率顺序表

数据结构基本内容:1.时间复杂度 空间复杂度2.顺序表链表3.栈 队列4.二叉树5.排序 数据结构是存储,组织数据的方式.指相互之间存在一种或多种特定关系的数据元素的集合 算法是定义良好的计算过程.取一个或一组值为输入并产生一个或一组值为输出. 需要知道虽然选择题有20-30个…

MyBatis-plus这么好用,不允许还有人不会

你好呀,我是 javapub. 做 Java 的同学都会用到的三件套,Spring、SpringMV、MyBatis。但是由于使用起来配置较多,依赖冲突频发。所有,各路大佬又在这上边做了包装,像我们常用的 SpringBoot、MyBatisPlus。 基于当前要…