YoloV8改进策略:卷积篇|Kan行天下之GRAM,KAN遇见Gram多项式V2版本

GRAM(GRAM可能是一个新提出的模型或方法的缩写,这里我们根据上下文进行解释)受到诸如TorchKAN和ChebyKAN等Kolmogorov-Arnold网络(KAN)替代方案的启发。GRAM引入了一种简化的KAN模型,但同时利用了Gram多项式变换的简单性。它与其他替代方案的不同之处在于其独特的离散性特征。与其他在连续区间上定义的多项式不同,Gram多项式是在一组离散点上定义的。GRAM的这种离散性为处理像图像和文本数据这样的离散数据集提供了一种新颖的方法。
在这里插入图片描述

Kolmogorov-Arnold Networks (KAN): KAN是一种基于Kolmogorov-Arnold表示定理的神经网络模型,该定理表明任何多元连续函数都可以表示为一系列一元函数的复合。KAN通常用于函数逼近和复杂系统的建模。

Gram Polynomials: Gram多项式是一种在离散点上定义的特殊多项式,与在连续区间上定义的传统多项式不同。这种离散性使得Gram多项式在处理离散数据集时具有独特的优势。

离散数据集处理: 在机器学习和数据科学中,处理离散数据集(如图像和文本)是一个重要任务。图像可以视为像素的离散集合,而文本则可以视为字符或单词的离散序列。GRAM通过其离散性特征,为这类数据的处理提供了一种新的视角和方法。

模型简化与效率: GRAM通过简化KAN模型并引入Gram多项式变换,旨在提高模型的效率和实用性。这种简化可能使得GRAM在保持一定性能的同时,具有更低的计算复杂度和更快的训练速度。

应用前景: GRAM的离散性特征使其在图像处理、自然语言处理等领域具有潜在的应用前景。例如,在图像分类任务中,GRAM可能能够更有效地提取图像中的特征;在自然语言处理任务中,GRAM可能能够更准确地捕捉文本中的语义信息。

GRAM通过结合KAN和Gram多项式的思想,为离散数据集的处理提供了一种新颖而有效的解决方案。随着进一步的研究和应用,GRAM有望在多个领域展现出其独特的优势和价值。

代码

代码在原作者的基础上做了修改,代码如下:

# Based on this: https://github.com/Khochawongwat/GRAMKAN/blob/main/model.py
from functools import lru_cache
import torch
import torch.nn as nn
from torch.nn.functional import conv3d, conv2d, conv1d


class KAGNConvNDLayer(nn.Module):
    def __init__(self, conv_class, norm_class, conv_w_fun, input_dim, output_dim, degree, kernel_size,
                 groups=1, padding=0, stride=1, dilation=1, dropout: float = 0.0, ndim: int = 2.,
                 **norm_kwargs):
        super(KAGNConvNDLayer, self).__init__()
        self.inputdim = input_dim
        self.outdim = output_dim
        self.degree = degree
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.dilation = dilation
        self.groups = groups
        self.base_activation = nn.SiLU()
        self.conv_w_fun = conv_w_fun
        self.ndim = ndim
        self.dropout = None
        self.norm_kwargs = norm_kwargs
        self.p_dropout = dropout
        if dropout > 0:
            if ndim == 1:
                self.dropout = nn.Dropout1d(p=dropout)
            if ndim == 2:
                self.dropout = nn.Dropout2d(p=dropout)
            if ndim == 3:
                self.dropout = nn.Dropout3d(p=dropout)

        if groups <= 0:
            raise ValueError('groups must be a positive integer')
        if input_dim % groups != 0:
            raise ValueError('input_dim must be divisible by groups')
        if output_dim % groups != 0:
            raise ValueError('output_dim must be divisible by groups')

        self.base_conv = nn.ModuleList([conv_class(input_dim // groups,
                                                   output_dim // groups,
                                                   kernel_size,
                                                   stride,
                                                   padding,
                                                   dilation,
                                                   groups=1,
                                                   bias=False) for _ in range(groups)])

        self.layer_norm = nn.ModuleList([norm_class(output_dim // groups, **norm_kwargs) for _ in range(groups)])

        poly_shape = (groups, output_dim // groups, (input_dim // groups) * (degree + 1)) + tuple(
            kernel_size for _ in range(ndim))

        self.poly_weights = nn.Parameter(torch.randn(*poly_shape))
        self.beta_weights = nn.Parameter(torch.zeros(degree + 1, dtype=torch.float32))

        # Initialize weights using Kaiming uniform distribution for better training start
        for conv_layer in self.base_conv:
            nn.init.kaiming_uniform_(conv_layer.weight, nonlinearity='linear')

        nn.init.kaiming_uniform_(self.poly_weights, nonlinearity='linear')
        nn.init.normal_(
            self.beta_weights,
            mean=0.0,
            std=1.0 / ((kernel_size ** ndim) * self.inputdim * (self.degree + 1.0)),
        )

    def beta(self, n, m):
        return (
                       ((m + n) * (m - n) * n ** 2) / (m ** 2 / (4.0 * n ** 2 - 1.0))
               ) * self.beta_weights[n]

    @lru_cache(maxsize=128)  # Cache to avoid recomputation of Gram polynomials
    def gram_poly(self, x, degree):
        p0 = x.new_ones(x.size())

        if degree == 0:
            return p0.unsqueeze(-1)

        p1 = x
        grams_basis = [p0, p1]

        for i in range(2, degree + 1):
            p2 = x * p1 - self.beta(i - 1, i) * p0
            grams_basis.append(p2)
            p0, p1 = p1, p2

        return torch.concatenate(grams_basis, dim=1)

    def forward_kag(self, x, group_index):
        # Apply base activation to input and then linear transform with base weights
        basis = self.base_conv[group_index](self.base_activation(x))

        # Normalize x to the range [-1, 1] for stable Legendre polynomial computation
        x = torch.tanh(x).contiguous()

        if self.dropout is not None:
            x = self.dropout(x)

        grams_basis = self.base_activation(self.gram_poly(x, self.degree))

        y = self.conv_w_fun(grams_basis, self.poly_weights[group_index],
                            stride=self.stride, dilation=self.dilation,
                            padding=self.padding, groups=1)

        y = self.base_activation(self.layer_norm[group_index](y + basis))

        return y

    def forward(self, x):

        split_x = torch.split(x, self.inputdim // self.groups, dim=1)
        output = []
        for group_ind, _x in enumerate(split_x):
            y = self.forward_kag(_x.clone(), group_ind)
            output.append(y.clone())
        y = torch.cat(output, dim=1)
        return y


class KAGNConv3DLayer(KAGNConvNDLayer):
    def __init__(self, input_dim, output_dim, kernel_size, degree=3, groups=1, padding=0, stride=1, dilation=1,
                 dropout: float = 0.0, norm_layer=nn.InstanceNorm3d, **norm_kwargs):
        super(KAGNConv3DLayer, self).__init__(nn.Conv3d, norm_layer, conv3d,
                                              input_dim, output_dim,
                                              degree, kernel_size,
                                              groups=groups, padding=padding, stride=stride, dilation=dilation,
                                              ndim=3, dropout=dropout, **norm_kwargs)


class KAGNConv2DLayer(KAGNConvNDLayer):
    def __init__(self, input_dim, output_dim, kernel_size, degree=3, groups=1, padding=0, stride=1, dilation=1,
                 dropout: float = 0.0, norm_layer=nn.InstanceNorm2d, **norm_kwargs):
        super(KAGNConv2DLayer, self).__init__(nn.Conv2d, norm_layer, conv2d,
                                              input_dim, output_dim,
                                              degree, kernel_size,
                                              groups=groups, padding=padding, stride=stride, dilation=dilation,
                                              ndim=2, dropout=dropout, **norm_kwargs)


class KAGNConv1DLayer(KAGNConvNDLayer):
    def __init__(self, input_dim, output_dim, kernel_size, degree=3, groups=1, padding=0, stride=1, dilation=1,
                 dropout: float = 0.0, norm_layer=nn.InstanceNorm1d, **norm_kwargs):
        super(KAGNConv1DLayer, self).__init__(nn.Conv1d, norm_layer, conv1d,
                                              input_dim, output_dim,
                                              degree, kernel_size,
                                              groups=groups, padding=padding, stride=stride, dilation=dilation,
                                              ndim=1, dropout=dropout, **norm_kwargs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/792283.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

paddla模型转gguf

在使用ollama配置本地模型时&#xff0c;只支持gguf格式的模型&#xff0c;所以我们首先需要把自己的模型转化为bin格式&#xff0c;本文为paddle&#xff0c;onnx&#xff0c;pytorch格式的模型提供说明&#xff0c;safetensors格式比较简单请参考官方文档&#xff0c;或其它教…

Docker存储目录问题,如何修改Docker默认存储位置?(Docker存储路径、Docker存储空间)etc/docker/daemon.json

文章目录 如何更改docker默认存储路径&#xff1f;版本1&#xff08;没测试&#xff09;版本2&#xff08;可行&#xff09;1. 停止 Docker 服务&#xff1a;2. 创建新的存储目录&#xff1a;3. 修改 Docker 配置文件&#xff1a;4. 移动现有的 Docker 数据&#xff1a;5. 重新…

【Pytorch】Conda环境下载慢换源/删源/恢复默认源

文章目录 背景临时换源永久换源打开conda配置condarc换源执行配置 命令行修改源添加源查看源 删源恢复默认源使用示范 背景 随着实验增多&#xff0c;需要分割创建环境的情况时有出现&#xff0c;在此情况下使用conda create --name xx python3.10 pytorch torchvision pytorc…

香港紧缺什么类型人才?如何通过香港优才计划去香港就业?

香港目前紧缺多种类型的人才&#xff0c;这些需求反映在不同行业和专业领域。以下是根据最新信息整理的紧缺人才概览&#xff1a; 资讯科技&#xff08;IT&#xff09;人才&#xff1a;香港在IT领域&#xff0c;尤其是人工智能、云计算、软件开发、数据分析、用户体验设计&…

基于4G、5G和卫星宽带的应急通信车载聚合路由器组网方案

应急指挥车、现场应急指挥系统作为整个应急指挥平台的主要组成部分&#xff0c;被广泛用于救灾抢险,安全保障等特殊场景&#xff0c;可通过应急指挥车或现场应急指挥系统与后方指挥中心间传输音视频信息&#xff0c;实现现场与指挥中心的实时通信&#xff0c;进行视频会议和远程…

通用代码生成器模板体系,域对象,枚举和动词算子

通用代码生成器模板体系&#xff0c;域对象&#xff0c;枚举和动词算子 通用代码生成器或者叫动词算子式通用目的代码生成器是一组使用Java编写的通用代码生成器。它们的原理基于动词算子和域对象的笛卡尔积。它们没有使用FreeMarker和或者Velocity等现成的文件式模板引擎。而…

win11下部署Jenkins,build c#项目

一个c#的项目&#xff0c;由于项目经理总要新版本测试&#xff0c;以前每次都是手动出包&#xff0c;现在改成jenkins自动生成&#xff0c;节省时间。 一、下载Jenkins&#xff0c; 可以通过清华镜像下载Index of /jenkins/windows-stable/ | 清华大学开源软件镜像站 | Tsingh…

Java面试八股之Redis有哪些数据类型?底层实现分别是什么

Redis有哪些数据类型&#xff1f;底层实现分别是什么 Redis数据类型概述 Redis作为一款键值存储系统&#xff0c;提供了丰富多样的数据类型以满足不同场景的需求。以下是Redis支持的主要数据类型及其基本用途&#xff1a; String&#xff08;字符串&#xff09; 存储单个键…

嵌入式ARM控制器在AGV里的应用

随着ARM技术以及芯片加工工艺的迅猛发展&#xff0c; ARM工业计算机得到了越来越广泛的应用&#xff0c;尤其在工业智慧城市、智能设备以及工业自动化控制等领域。本文将为大家详细介绍ARM控制器在AGV控制系统中的应用&#xff0c;来供大家学习和参考&#xff0c;欢迎大家一起来…

【VUE进阶】安装使用Element Plus组件

Element Plus组件 安装引入组件使用Layout 布局button按钮行内表单菜单 安装 包管理安装 # 选择一个你喜欢的包管理器# NPM $ npm install element-plus --save# Yarn $ yarn add element-plus# pnpm $ pnpm install element-plus浏览器直接引入 例如 <head><!-- I…

银河麒麟(Kylin)KYSEC使用

1.推荐使用方法 *.临时禁用指令: setstatus disable--禁用 注&#xff1a;执行reboot后系统会自动启动 2.选用指令&#xff1a; *.永久禁用指令&#xff1a; setstatus disable -p *.重启后,KYSEC还是处理关闭关状态。 *.使用如下指令启用&#xff1a;setstatus enable …

第十九章 Nest multer 文件上传

上章我们了解了Express multer 文件上传的相关操作 本章将了解Nest中的文件上传。用 multer 包处理 multipart/form-data 类型的请求中的 file 新建个 nest 项目: nest new nest-multer-upload 安装 multer 的 ts 类型的包&#xff1a; npm install -D types/multer1、单文件…

tableau数据分层,数据组,与数据集 - 11

tableau数据分层&#xff0c;数据组&#xff0c;与数据集 1. 数据分层1.1 下钻1.2 上钻1.3 创建层级结构1.4 层级排序 2. 数据组2.1 创建分组2.2 编辑组2.3 分组2.4 相关结果2.5 相关例子 3. 静态数据集3.1 数据集相关概念3.2 静态数据集创建方法一3.3 静态数据集创建方法二3.4…

SpringBoot集成Sentinel 实现QPS限流

Spring Cloud Alibaba 的 Sentinel 组件提供了丰富的“流量控制“规则&#xff0c; 单体SpringBoot应用中也可以集成 Sentinel 来实现流量控制&#xff0c;本文主要讲 QPS流量控制。 SpringBoot集成Sentinel有两种方式&#xff1a; 一种是 dashboard控制面板的方式&#xff0…

Java接口案例

一案例要求&#xff1a; 二代码&#xff1a;(换方案只需要将操作类第二行的new新对象修改就能更改项目) Ⅰ&#xff1a;&#xff08;主函数&#xff09; package d1;public class test {public static void main(String[] args) {operator anew operator();a.show();a.averag…

Vue在一个页面调用另一个同级页面的方法

1、建个中转站 2、然后在两个页面都引入它&#xff0c;注意引入路径。 import Utils from src/utils/way 3、调用方的写法 //eg :Utils.$emit(demo, msg) 4、被调用方的写法 //eg :Utils.$on(demo, val>{})

想要制作自己的歌曲伴奏?提取伴奏人声分离软件我用这几款

在音乐制作、翻唱创作或是学术研究等领域&#xff0c;将人声与伴奏分离是一项常见的需求。随着这两年 AI 技术的发展&#xff0c;现在有许多软件和工具可以帮助用户实现这一目标&#xff0c;无需专业的音频编辑知识。本文将重点介绍简鹿人声分离工具以及其他几款知名的人声伴奏…

数据库第四次练习

数据准备 创建两张表&#xff1a;部门&#xff08;dept&#xff09;和员工&#xff08;emp&#xff09;&#xff0c;并插入数据&#xff0c;代码如下 create table dept( dept_id int primary key auto_increment comment 部门编号, dept_name char(20) comment 部门名称 ); in…

全现金!6.65亿美刀!AMD大手一挥收购欧洲最大私人AI实验室

当地时间2024年7月10日&#xff0c;AMD宣布&#xff0c;他们刚刚签署了一项最终协议&#xff0c;豪掷6.65亿美刀收购欧洲最大的私人人工智能实验室Silo AI&#xff0c;而且还是全现金收购&#xff01; 据AMD官网报道&#xff0c;该协议代表着该公司基于开放标准并与全球人工智…

IAR全面支持芯驰科技E3系列车规MCU产品E3119/E3118

中国上海&#xff0c;2024年7月11日 — 全球领先的嵌入式系统开发软件解决方案供应商IAR与全场景智能车芯引领者芯驰科技宣布进一步扩大合作&#xff0c;最新版IAR Embedded Workbench for Arm已全面支持芯驰科技的E3119/E3118车规级MCU产品。IAR与芯驰科技有着悠久的合作历史&…