pytorch训练五子棋ai

有3个文件

game.py  五子棋游戏

mod.py  神经网络模型

xl.py   训练的代码

aigame.py   玩家与对战的五子棋

game.py

 
class Game:
    def __init__(self, h, w):
        # 行数
        self.h = h
        # 列数
        self.w = w
        # 棋盘
        self.L = [['-' for _ in range(w)] for _ in range(h)]
        # 当前玩家 - 表示空 X先下 然后是O
        self.cur = 'X'
        # 游戏胜利者
        self.win_user = None
 
    # 检查下完这步后有没有赢 y是行 x是列 返回True表示赢
    def check_win(self, y, x):
        directions = [
            # 水平、垂直、两个对角线方向
            (1, 0), (0, 1), (1, 1), (1, -1)
        ]
        player = self.L[y][x]
        for dy, dx in directions:
            count = 0
            # 检查四个方向上的连续相同棋子
            for i in range(-4, 5):  # 检查-4到4的范围,因为五子连珠需要5个棋子
                ny, nx = y + i * dy, x + i * dx
                if 0 <= ny < self.h and 0 <= nx < self.w and self.L[ny][nx] == player:
                    count += 1
                    if count == 5:
                        return True
                else:
                    count = 0
        return False
 
    # 检查能不能下这里 y行 x列 返回True表示能下
    def check(self, y, x):
        return self.L[y][x] == '-' and self.win_user is None
 
    # 打印棋盘 可视化用得到
    def __str__(self):
        # 确定行号和列号的宽度
        row_width = len(str(self.h - 1))
        col_width = len(str(self.w - 1))
        
        # 生成带有行号和列号的棋盘字符串表示
        result = []
        # 添加列号标题
        result.append(' ' * (row_width + 1) + ' '.join(f'{i:>{col_width}}' for i in range(self.w)))
        # 添加分隔线(可选)
        result.append(' ' * (row_width + 1) + '-' * (col_width * self.w))
        # 添加棋盘行
        for y, row in enumerate(self.L):
            # 添加行号
            result.append(f'{y:>{row_width}} ' + ' '.join(f'{cell:>{col_width}}' for cell in row))
        return '\n'.join(result)
 
    # 一步棋
    def set(self, y, x):
        if self.win_user or not self.check(y, x):
            return False
        self.L[y][x] = self.cur
        if self.check_win(y, x):
            self.win_user = self.cur
            return True
        self.cur = 'X' if self.cur == 'O' else 'O'
        return True
    #和棋
    def heqi(self):
        for y in range(self.h):
            for x in range(self.w):
                if self.L[y][x]=='-':
                    return False
        return True
    

#玩家自己下
def run_game01():
    g = Game(15, 15)
    while not g.win_user:
        # 打印当前棋盘状态
        while 1:
            print(g)
            try:
                y,x=input(g.cur+':').split(',')
                x=int(x)
                y=int(y)
                if g.set(y,x):
                    break
            except Exception as e:
                print(e)
    print(g)
    print('胜利者',g.win_user)
 
 
 

mod.py

import torch
import torch.nn as nn
import torch.optim as optim
from game import Game

class MyMod(nn.Module):
    def __init__(self, input_channels=1, output_size=15*15):
        super(MyMod, self).__init__()
        
        # 定义卷积层,用于提取特征
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)  # 输出 32 x 15 x 15
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # 输出 64 x 15 x 15
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)  # 输出 128 x 15 x 15
        
        # 定义全连接层,用于最后的得分预测
        self.fc1 = nn.Linear(128 * 15 * 15, 1024)  # 展平后传入全连接层
        self.fc2 = nn.Linear(1024, output_size)  # 输出 15*15 的得分预测

    def forward(self, x):
        # 卷积层 -> 激活函数 -> 最大池化
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        
        # 将卷积层输出展平为一维
        x = x.view(x.size(0), -1)
        
        # 全连接层
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    # 保存模型权重
    def save(self, path):
        torch.save(self.state_dict(), path)

    # 加载模型权重
    def load(self, path):
        self.load_state_dict(torch.load(path))

#改进一下  output 把有棋子的地方的概率=0避免下这些地方
# 输入Game对象和MyMod对象,用于得到概率最大的落棋点 (行y, 列x)
def input_qi(g: Game, m: MyMod):
    # 获取当前棋盘状态
    board_state = g.L  # 使用 game.L 获取当前棋盘的状态 (15x15的二维列表)
    
    # 将棋盘状态转换为PyTorch的Tensor并增加一个维度(batch_size = 1)
    board_tensor = torch.tensor([[1 if cell == 'X' else -1 if cell == 'O' else 0 for cell in row] for row in board_state], 
                                 dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # 形状变为 (1, 1, 15, 15)
    
    # 传入模型获取每个位置的得分
    output = m(board_tensor)
    
    # 将输出转为概率值(可以使用softmax来归一化)
    probabilities = torch.softmax(output, dim=-1).view(g.h, g.w).detach().numpy()  # 变为 (15, 15) 大小
    
    # 将已有棋子的位置的概率设置为 -inf,避免选择这些位置
    for y in range(g.h):
        for x in range(g.w):
            if board_state[y][x] != '-':
                probabilities[y, x] = -float('inf')  # 设置已经有棋子的地方的概率为 -inf
    
    # 找到概率最大的落子点
    max_prob_pos = divmod(probabilities.argmax(), g.w)  # 得到最大概率的行列坐标
    
    # 确保返回的是合法的位置
    y, x = max_prob_pos
    
    return (y, x), output  # 返回坐标和模型输出



xl.py

import os
import torch
import torch.optim as optim
import torch.nn.functional as F
from mod import MyMod, input_qi, Game

# 两个权重文件,分别代表 X 棋和 O 棋
MX = 'MX'
MO = 'MO'

# 加载模型,若文件不存在则初始化
def load_model(model, path):
    if os.path.exists(path):
        model.load(path)
        print(f"Loaded model from {path}")
    else:
        print(f"{path} not found, initializing new model.")
        # 这里可以加一些初始化模型的代码,例如:
        # model.apply(init_weights) 如果需要初始化权重

# 初始化模型
modx = MyMod()
load_model(modx, MX)

modo = MyMod()
load_model(modo, MO)

# 定义优化器
lr=0.001
optimizer_x = optim.Adam(modx.parameters(), lr=lr)
optimizer_o = optim.Adam(modo.parameters(), lr=lr)

# 损失函数:根据游戏结果调整损失
def compute_loss(winner: int, player: str, model_output):
    # 将目标值转换为相应的张量
    if player == "X":
        if winner == 1:  # X 胜
            target = torch.tensor(1.0, dtype=torch.float32)
        elif winner == 0:  # 平局
            target = torch.tensor(0.5, dtype=torch.float32)
        else:  # X 输
            target = torch.tensor(0.0, dtype=torch.float32)
    else:
        if winner == -1:  # O 胜
            target = torch.tensor(1.0, dtype=torch.float32)
        elif winner == 0:  # 平局
            target = torch.tensor(0.5, dtype=torch.float32)
        else:  # O 输
            target = torch.tensor(0.0, dtype=torch.float32)

    # 确保目标值的形状和 model_output 一致,假设 model_output 是单一的值
    target = target.unsqueeze(0).unsqueeze(0)  # 形状变为 (1, 1)

    # 使用均方误差损失计算
    return F.mse_loss(model_output, target)



# 训练模型的过程
def train_game():
    modx.train()
    modo.train()

    # 创建新的游戏实例
    game = Game(15, 15)  # 默认是 15x15 棋盘
    # 反向传播和优化
    optimizer_x.zero_grad()
    optimizer_o.zero_grad()

    while not game.win_user:  # 游戏未结束
        # X 方落子
        x_move, x_output = input_qi(game, modx)  # 获取落子位置和模型输出(x_output 是模型的输出)
        game.set(x_move[0], x_move[1])  # X 下棋
        
        if game.win_user:
            break
        
        # O 方落子
        
        o_move, o_output = input_qi(game, modo)  # 获取落子位置和模型输出(o_output 是模型的输出)
        #print(o_move,game)
        game.set(o_move[0], o_move[1])  # O 下棋
       

    # 获取比赛结果
    winner = 0 if game.heqi() else (1 if game.win_user == 'X' else -1)  # 1为X胜,-1为O胜,0为平局

    # 计算损失
    loss_x = compute_loss(winner, "X", x_output)  # 传递模型输出给计算损失函数
    loss_o = compute_loss(winner, "O", o_output)  # 传递模型输出给计算损失函数

    # 计算损失并进行反向传播
    loss_x.backward()
    loss_o.backward()

    # 更新权重
    optimizer_x.step()
    optimizer_o.step()
    print(game)
    return loss_x.item(), loss_o.item()


# 训练多个回合
def train(num_epochs,n):
    k=0
    for epoch in range(num_epochs):
        loss_x, loss_o = train_game()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss X: {loss_x}, Loss O: {loss_o}")
        k+=1
        if k==n:
            modo.save('MO')
            modx.save('MX')
            print('saved')
            k=0

# 开始训练
train(50000,1000)

aigame.py

from game import Game
from mod import MyMod,input_qi


#玩家下X ai下O
def playX():
    m=MyMod()
    m.load('MO')
    g=Game(15,15)
    while 1:
        print(g)
        if g.heqi() or g.win_user:
            break
        while 1:
            try:
                r=input('X:')
                y,x=r.split(',')
                y=int(y)
                x=int(x)
                if g.set(y,x):
                    break
            except Exception as e:
                print(e)
        if g.heqi() or g.win_user:
            break
        while 1:
            (y,x),_=input_qi(g,m)
            if g.set(y,x):
                break
    print(g)
    print('winner',g.win_user)


#玩家下O ai下X
def playO():
    m=MyMod()
    m.load('MX')
    g=Game(15,15)
    while 1:
        
        if g.heqi() or g.win_user:
            break
        while 1:
            (y,x),_=input_qi(g,m)
            if g.set(y,x):
                break
        if g.heqi() or g.win_user:
            break
        print(g)
        while 1:
            try:
                r=input('O:')
                y,x=r.split(',')
                y=int(y)
                x=int(x)
                if g.set(y,x):
                    break
            except Exception as e:
                print(e)
    print(g)
    print('winner',g.win_user)

playX()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/969780.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

制造业物联网的十大用例

预计到 2026 年&#xff0c;物联网制造市场价值将达到 4000 亿美元。实时收集和分析来自联网物联网设备与传感器的数据&#xff0c;这一能力为制造商提供了对生产流程前所未有的深入洞察。物联网&#xff08;IoT&#xff09;有潜力彻底改变制造业&#xff0c;使工厂能够更高效地…

无法读取配置节“system.web.extensions”,因为它缺少节声明

无法读取配置节“system.web.extensions”&#xff0c;因为它缺少节声明 在IIS配置.net接口时&#xff0c;报错&#xff1a; 无法读取配置节“system.web.extensions”&#xff0c;因为它缺少节声明 解决办法&#xff1a;打开IIS&#xff0c;右键>>管理网站>>高级…

Android Studio:键值对存储sharedPreferences

一、了解 SharedPreferences SharedPreferences是Android的一个轻量级存储工具&#xff0c;它采用的存储结构是Key-Value的键值对方式&#xff0c;类似于Java的Properties&#xff0c;二者都是把Key-Value的键值对保存在配置文件中。不同的是&#xff0c;Properties的文件内容形…

Redis——优惠券秒杀问题(分布式id、一人多单超卖、乐悲锁、CAS、分布式锁、Redisson)

#想cry 好想cry 目录 1 全局唯一id 1.1 自增ID存在的问题 1.2 分布式ID的需求 1.3 分布式ID的实现方式 1.4 自定义分布式ID生成器&#xff08;示例&#xff09; 1.5 总结 2 优惠券秒杀接口实现 3 单体系统下一人多单超卖问题及解决方案 3.1 问题背景 3.2 超卖问题的…

easyexcel快速使用

1.easyexcel EasyExcel是一个基于ava的简单、省内存的读写Excel的开源项目。在尽可能节约内存的情况下支持读写百M的Excel 即通过java完成对excel的读写操作&#xff0c; 上传下载 2.easyexcel写操作 把java类中的对象写入到excel表格中 步骤 1.引入依赖 <depen…

数据结构 04

4. 栈 4.2. 链式栈 4.2.1. 特性 逻辑结构&#xff1a;线性结构 存储结构&#xff1a;链式存储结构 操作&#xff1a;创建&#xff0c;入栈&#xff0c;出栈&#xff0c;清空&#xff0c;获取 4.2.2. 代码实现 头文件 LinkStack.h #ifndef __LINKSTACK_H__ #define __LINKST…

LeetCode刷题第7题【整数反转】---解题思路及源码注释

LeetCode刷题第7题【整数反转】—解题思路及源码注释 结果预览 目录 LeetCode刷题第7题【整数反转】---解题思路及源码注释结果预览一、题目描述二、解题思路1、问题理解2、解题思路 三、代码实现及注释1、源码实现2、代码解释 四、执行效果1、时间和空间复杂度分析 一、题目描…

相机闪光灯拍照流程分析

和你一起终身学习&#xff0c;这里是程序员Android 经典好文推荐&#xff0c;通过阅读本文&#xff0c;您将收获以下知识点: 一、Flash 基础知识二、MTK 闪光灯拍照log分析 一、Flash 基础知识 1.1 Flash HAL 场景枚举值 Flash HAL 场景枚举值 1.2 AE AF mode State 枚举值 AE …

给本地模型“投喂“数据

如何训练本地Deepseek-r1:7b模型 在前面两篇文章中&#xff0c;我在自己的电脑的本地部署了Deepseek的7b的模型&#xff0c;并接入到我Chrome浏览器的插件中&#xff0c;使用起来更方便了。在使用的过程中发现7b的推理能力确实没有671满血版本的能力强&#xff0c;很多问题回答…

大脑网络与智力:基于图神经网络的静息态fMRI数据分析方法|文献速递-医学影像人工智能进展

Title 题目 Brain networks and intelligence: A graph neural network based approach toresting state fMRI data 大脑网络与智力&#xff1a;基于图神经网络的静息态fMRI数据分析方法 01 文献速递介绍 智力是一个复杂的构念&#xff0c;包含了多种认知过程。研究人员通…

原生Three.js 和 Cesium.js 案例 。 智慧城市 数字孪生常用功能列表

对于大多数的开发者来言&#xff0c;看了很多文档可能遇见不到什么有用的&#xff0c;就算有用从文档上看&#xff0c;把代码复制到自己的本地大多数也是不能用的&#xff0c;非常浪费时间和学习成本&#xff0c; 尤其是three.js &#xff0c; cesium.js 这种难度较高&#xff…

学习总结三十二

map #include<iostream> #include<map> using namespace std;int main() {//首先创建一个map对象map<int, char>oneMap;//插入数据oneMap.insert(pair<int, char>(1, A));oneMap.insert(make_pair(2,B));oneMap.insert(map<int,char>::value_ty…

AI如何与DevOps集成,提升软件质量效能

随着技术的不断演进&#xff0c;DevOps和AI的融合成为推动软件开发质量提升的重要力量。传统的DevOps已经为软件交付速度和可靠性打下了坚实的基础&#xff0c;而随着AI技术的加入&#xff0c;DevOps流程不仅能提升效率&#xff0c;还能在质量保障、缺陷预测、自动化测试等方面…

ESP学习-1(MicroPython VSCode开发环境搭建)

下载ESP8266固件&#xff1a;https://micropython.org/download/ESP8266_GENERIC/win电脑&#xff1a;pip install esptools python.exe -m pip install --upgrade pip esptooo.py --port COM5 erase_flash //清除之前的固件 esptool --port COM5 --baud 115200 write_fla…

解决DeepSeek服务器繁忙问题

目录 解决DeepSeek服务器繁忙问题 一、用户端即时优化方案 二、高级技术方案 三、替代方案与平替工具&#xff08;最推荐简单好用&#xff09; 四、系统层建议与官方动态 用加速器本地部署DeepSeek 使用加速器本地部署DeepSeek的完整指南 一、核心原理与工具选择 二、…

在WPS中通过JavaScript宏(JSA)调用本地DeepSeek API优化文档教程

既然我们已经在本地部署了DeepSeek,肯定希望能够利用本地的模型对自己软件开发、办公文档进行优化使用,接下来就先在WPS中通过JavaScript宏(JSA)调用本地DeepSeek API优化文档的教程奉上。 前提: (1)已经部署好了DeepSeek,可以看我的文章:个人windows电脑上安装DeepSe…

CentOS-Stream 9安装

文章目录 1 CentOS9安装引导界面2 CentOS9安装过程2.1 语言选择2.2 安装项选择2.2.1 安装目标位置2.2.2 软件选择2.2.3 网络和主机名2.2.4 root密码2.2.5 创建用户 2.3 开始安装2.4 等待安装成功 3 安装成功 1 CentOS9安装引导界面 选择Install CentOS Stream 9后按Enter键&…

【神经网络框架】非局部神经网络

一、非局部操作的数学定义与理论框架 1.1 非局部操作的通用公式 非局部操作(Non-local Operation)是该研究的核心创新点,其数学定义源自经典计算机视觉中的非局部均值算法(Non-local Means)。在深度神经网络中,非局部操作被形式化为: 其中: 1.2 与传统操作的对比分析…

RAG科普文!检索增强生成的技术全景解析

RAG 相关技术的八个主题&#xff1a;https://pub.towardsai.net/a-taxonomy-of-retrieval-augmented-generation-a39eb2c4e2ab 增强生成 (RAG) 是塑造应用生成式 AI 格局的关键技术。Lewis 等人在其开创性论文中提出了一个新概念面向知识密集型 NLP 任务的检索增强生成之后&…

【做一个微信小程序】校园地图页面实现

前言 上一个教程我们实现了小程序的一些的功能&#xff0c;有背景渐变色&#xff0c;发布功能有的呢&#xff0c;已支持图片上传功能&#xff0c;表情和投票功能开发中&#xff08;请期待&#xff09;。下面是一个更高级的微信小程序实现&#xff0c;包含以下功能&#xff1a;…