深度学习 GNN图神经网络(四)线性回归之ESOL数据集水溶性预测

线性回归之ESOL数据集水溶性预测

  • 一、前言
  • 二、ESOL数据集
  • 三、加载数据集
  • 四、数据拆分
  • 五、构造模型
  • 六、训练模型
  • 七、测试结果
  • 八、分类问题
  • 参考文献

一、前言

本文旨在使用化合物分子的SMILES字符串进行数据模型训练,对其水溶性的值进行预测。

之前的文章《深度学习 GNN图神经网络(三)模型思想及文献分类案例实战》引用的Cora数据集只有一张图,属于图神经网络的节点分类问题。本文介绍的是多图批量训练的线性回归问题,在文章最后也讨论了图分类问题。

二、ESOL数据集

本文使用的是ESOL数据集,在文章《如何将化学分子SMILES字符串转化为Pytorch图数据结构——ESOL分子水溶性数据集解析》中有详细介绍,在此不作详述。

三、加载数据集

from torch_geometric.datasets import MoleculeNet

dataset = MoleculeNet(root="data", name="ESOL")

print('num_features:',dataset.num_features)
print('num_classes:',dataset.num_classes)
print('num_node_features',dataset.num_node_features)
print("size:", len(dataset))

d=dataset[10]
print("Sample:", d)
print("Sample y:", d.y)
print("Sample num_nodes:",d.num_nodes)
print("Sample num_edges:",d.num_edges)

这里可以得到数据集的一些基本信息:

num_features: 9
num_classes: 734
num_node_features 9
size: 1128
Sample: Data(x=[6, 9], edge_index=[2, 12], edge_attr=[12, 3], smiles='O=C1CCCN1', y=[1, 1])
Sample y: tensor([[1.0700]])
Sample num_nodes: 6
Sample num_edges: 12

四、数据拆分

将数据集拆分为训练数据和测试数据:

from torch_geometric.loader import DataLoader
data_size = len(dataset)
batch_size = 128
train_data=dataset[:int(data_size*0.8)]
test_data=dataset[int(data_size*0.8):]

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=len(test_data))

五、构造模型

import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
import matplotlib.pyplot as plt
from torch_geometric.nn import global_mean_pool

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

hidden_channels = 64

class GNN(nn.Module):
    
    def __init__(self):
        # 初始化Pytorch父类
        super().__init__()
        
        self.conv1=GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2=GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.conv4 = GCNConv(hidden_channels, hidden_channels)
        self.out = nn.Linear(hidden_channels, 1)
        
        # 创建损失函数,使用均方误差
        self.loss_function = nn.MSELoss()

        # 创建优化器,使用Adam梯度下降
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.005,weight_decay=5e-4)

        # 训练次数计数器
        self.counter = 0
        # 训练过程中损失值记录
        self.progress = []
    
    # 前向传播函数
    def forward(self, x, edge_index,batch):
        
        x=x.to(device)
        edge_index=edge_index.to(device)
        batch=batch.to(device)

        x=self.conv1(x, edge_index)
        x=x.relu()
        x=self.conv2(x, edge_index)
        x=x.relu()
        x=self.conv3(x, edge_index)
        x=x.relu()
        x=self.conv4(x, edge_index)
        x=x.relu()

        # 全局池化
        x = global_mean_pool(x, batch)  # [x, batch]

        out=self.out(x)
        return out
    
    # 训练函数
    def train(self, data):

        # 前向传播计算,获得网络输出
        outputs = self.forward(data.x.float(),data.edge_index,data.batch)
        
        # 计算损失值
        y=data.y.to(device)
        loss = self.loss_function(outputs, y)

        # 累加训练次数
        self.counter += 1

        # 每10次训练记录损失值
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())

        # 每1000次输出训练次数   
        if (self.counter % 1000 == 0):
            print(f"counter={self.counter}, loss={loss.item()}")
            
        # 梯度清零, 反向传播, 更新权重
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
    
    # 测试函数
    def test(self, data):
        # 前向传播计算,获得网络输出
        outputs = self.forward(data.x.float(),data.edge_index,data.batch)

        # 把绝对值误差小于1的视为正确,计算准确度
        y=data.y.to(device)
        acc=sum(torch.abs(y-outputs)<1)/len(data.y)
        return acc

    # 绘制损失变化图
    def plot_progress(self):
        plt.plot(range(len(self.progress)),self.progress)
      

六、训练模型

model = GNN()
model.to(device)

for i in range(1001):
    for data in train_loader:
        # print(data,'num_graphs:',data.num_graphs)
        model.train(data)
counter=1000, loss=1.4304862022399902
counter=2000, loss=0.9842458963394165
counter=3000, loss=0.27240827679634094
counter=4000, loss=0.23295772075653076
counter=5000, loss=0.38499030470848083
counter=6000, loss=1.470423698425293
counter=7000, loss=0.845589816570282
counter=8000, loss=0.15707021951675415

绘制损失值变化图::

model.plot_progress()

在这里插入图片描述

七、测试结果

#torch.set_printoptions(precision=4,sci_mode=False) #pytorch不使用科学计数法显示

for data in test_loader:
    acc=model.test(data)
    print(acc)
tensor([0.8186], device='cuda:0')

可以看到,预测值误差小于1的占了81.86%,效果还行。

八、分类问题

对于图分类问题,其实也差不多。只需要修改下Linear网络层:

self.out = Linear(hidden_channels, dataset.num_classes)

这样预测结果就会有num_classes个,取最大值的下标索引即可。
伪代码为:

pred=outputs.argmax(dim=1)
correct += int((pred == data.y).sum())

参考文献

[1] https://pytorch-geometric.readthedocs.io/en/latest/get_started/colabs.html
[2] https://zhuanlan.zhihu.com/p/504978470

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/15717.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

vue - pc端实现对div的拖动功能

实现对div的拖动功能&#xff0c;需要先要知道以下的一些原生事件和方法&#xff1b; 1&#xff0c;事件: 方法描述onmousedown鼠标按钮被按下onmousemove鼠标被移动onmouseup鼠标按键被松开 2&#xff0c;方法: 方法描述event.clientX返回当事件被触发时鼠标指针相对于浏览…

02 【Sass语法介绍-变量】

sass有两种语法格式Sass(早期的缩进格式&#xff1a;Indented Sass)和SCSS(Sassy CSS) 目前最常用的是SCSS&#xff0c;任何css文件将后缀改为scss&#xff0c;都可以直接使用Sassy CSS语法编写。 所有有效的 CSS 也同样都是有效的 SCSS。 Sass语法介绍-变量 1.前言 Sass …

【VM服务管家】VM4.0平台SDK_2.5 全局工具类

目录 2.5.1 全局相机&#xff1a;全局相机设置参数的方法2.5.2 全局相机&#xff1a;获取全局相机列表的方法2.5.3 全局通信&#xff1a;通信管理中设备开启状态管理2.5.4 全局通信&#xff1a;接收和发送数据的方法2.5.5 全局变量获取和设置全局变量的方法 2.5.1 全局相机&…

2023-4-27-深入理解C++指针类型间强制转换

&#x1f37f;*★,*:.☆(&#xffe3;▽&#xffe3;)/$:*.★* &#x1f37f; &#x1f4a5;&#x1f4a5;&#x1f4a5;欢迎来到&#x1f91e;汤姆&#x1f91e;的csdn博文&#x1f4a5;&#x1f4a5;&#x1f4a5; &#x1f49f;&#x1f49f;喜欢的朋友可以关注一下&#xf…

记一次峰回路转的注入

0X01 背景 自己之前写过一篇记录&#xff0c;当时是由于之前是一位校友刚做开发&#xff0c;叫我友情帮忙测试一波&#xff0c;由于是开发的新手&#xff0c;漏洞比较多&#xff0c;所以直接从注入开始讲起&#xff0c;但是到getshell的过程也算是一场峰回路转再跌跌撞撞的路程…

【Java笔试强训 17】

&#x1f389;&#x1f389;&#x1f389;点进来你就是我的人了博主主页&#xff1a;&#x1f648;&#x1f648;&#x1f648;戳一戳,欢迎大佬指点! 欢迎志同道合的朋友一起加油喔&#x1f93a;&#x1f93a;&#x1f93a; 目录 一、选择题 二、编程题 &#x1f525;杨辉三角…

洞见数字时代的创新原力,数云原力大会暨2023TECH第五届数字中国技术年会开幕

4月25日&#xff0c;神州控股、神州信息、神州数码集团共同主办的数云原力大会暨2023TECH第五届数字中国技术年会开幕。开幕式上&#xff0c;数百位投身并关注数字技术、数字产业发展的学者、技术专家、从业者、行业用户齐聚一堂&#xff0c;围绕云原生、数字原生、大数据、金融…

这一次,让Kotlin Flow 操作符真正好用起来

前言 Kotlin Flow 如此受欢迎大部分归功于其丰富、简洁的操作符&#xff0c;巧妙使用Flow操作符可以大大简化我们的程序结构&#xff0c;提升可读性与可维护性。 然而&#xff0c;虽然好用&#xff0c;但有些操作符不太好理解&#xff0c;可惜的是网上大部分文章只是简单介绍其…

吴恩达 Chatgpt prompt 工程--1.Guidelines

课程链接 Setup #安装 !pip install openai#设置key !export OPENAI_API_KEYsk-... # or #import openai #openai.api_key "sk-..."import openai import osfrom dotenv import load_dotenv, find_dotenv _ load_dotenv(find_dotenv())openai.api_key os.geten…

记一次SSRF漏洞的学习和利用

导语&#xff1a;本文主要记录一次我们在复盘嘶吼网站渗透报告时遇到的一个SSRF漏洞。 1.前言 本文主要记录一次我们在复盘嘶吼网站渗透报告时遇到的一个SSRF漏洞。此漏洞并结合腾讯云的API接口&#xff0c;可以获取大量嘶吼服务器的敏感信息。利用这些敏感信息&#xff0c;又…

android10 关闭默认输入法的“更正建议”

1. 场景 使用系统默认的输入法&#xff0c;在进行输入时&#xff0c;在输入法上方&#xff0c;会显示更正建议列表&#xff0c;同时会干扰我们的输入内容&#xff1a;会自动补全到输入框&#xff0c;而且删除不掉&#xff0c;甚至越删越多&#xff0c;非常讨厌。 如下&#x…

一段凄惨Android 面试经历分享,败在了项目架构原理上……

大家应该看过很多分享面试成功的经验&#xff0c;但根据幸存者偏差的理论&#xff0c;也许多看看别人面试失败在哪里&#xff0c;对自己才更有帮助。 这是一位网友分享的面试经历&#xff0c;他准备了3个月&#xff0c;刚刚参加完字节跳动的第三面&#xff0c;视频面&#xff…

都23年了你还记得渐进式框架是什么意思吗

vue 文章目录 vue前言一、眼见为实举个栗子二、渐进式的优势 前言 渐进式框架是一种能够逐步增强应用功能的框架&#xff0c;它允许开发者在不影响应用性能的情况下&#xff0c;逐步添加新的功能和特性。Vue.js提供了一些基础功能&#xff0c;如数据绑定和组件化&#xff0c;然…

这8个摸鱼神器,千万别让你老板知道!

工欲善其事&#xff0c;必先利其器&#xff0c;对于程序员来说也是如此&#xff0c;想早点下班就不能死脑筋&#xff0c;必须借助于一些开发工具来提高自己的工作效率&#xff0c;小编选取了8款任务/项目管理工具&#xff0c;能助你[打通任督二脉]&#xff0c;工作效率大大提升…

React之动态路由创建以及解决刷新白屏问题

动态路由的创建和动态菜单的创建几乎类似&#xff0c;只不过的是&#xff0c;动态路由需要导入组件。这样才能完成跳转。 动态路由与动态菜单一样都需要封装一个转化函数&#xff0c;将后端传来的数据进行转换&#xff0c;转换成我们需要的格式。 需要导入的依赖 导入路由use…

【ChatGPT】阿里版 ChatGPT 突然官宣意味着什么?

Yan-英杰的主页 悟已往之不谏 知来者之可追 C程序员&#xff0c;2024届电子信息研究生 目录 阿里版 ChatGPT 突然官宣 ​ ChatGPT 技术在 AI 领域的重要性 自然语言生成 上下文连续性 多语言支持 ChatGPT 未来可能的应用场景 社交领域 商业领域 ​编辑 医疗领域…

网络安全常用术语

肉鸡 肉鸡指的就是被黑客成功入侵并取得控制权限的电脑。黑客们可以随意的控制肉鸡&#xff0c;就像在使用自己的电脑一样&#xff0c;很形象的比喻&#xff0c;就像是养的肉鸡&#xff0c;任黑客宰杀和利用。关键的是&#xff0c;在成为肉鸡后&#xff0c;只要黑客不对电脑进…

linux|进程间通信如何加锁

进程间通信有一种[共享内存]方式&#xff0c;大家有没有想过&#xff0c;这种通信方式中如何解决数据竞争问题&#xff1f;我们可能自然而然的就会想到用锁。但我们平时使用的锁都是用于解决线程间数据竞争问题&#xff0c;貌似没有看到过它用在进程中&#xff0c;那怎么办&…

Java——把数组排成最小的数

题目链接 牛客网在线oj题——把数组排成最小的数 题目描述 输入一个非负整数数组numbers&#xff0c;把数组里所有数字拼接起来排成一个数&#xff0c;打印能拼接出的所有数字中最小的一个。 例如输入数组[3&#xff0c;32&#xff0c;321]&#xff0c;则打印出这三个数字能…

如何提高三维模型OSGB格式转换3DTILES的转换速度和数据质量

如何提高三维模型OSGB格式转换3DTILES的转换速度和数据质量 提高三维模型从OSGB格式转换为3DTILES格式的转换速度和数据质量&#xff0c;可以从以下几个方面进行优化&#xff1a; 1、选用高效的转换工具&#xff1a;选择高效的转换工具是提高转换速度和数据质量的关键。目前市…