KAN 学习 Day4 —— MultKAN 正向传播代码解读及测试

在KAN学习Day1——模型框架解析及HelloKAN中,我对KAN模型的基本原理进行了简单说明,并将作者团队给出的入门教程hellokan跑了一遍;

在KAN 学习 Day2 —— utils.py及spline.py 代码解读及测试中,我对项目的基本模块代码进行了解释,并以单元测试的形式深入理解模块功能,其中还发现了一个细小的错误。

在KAN 学习 Day3 —— KANLayer.py 与 Symbolic_KANLayer.py 代码解读及测试中,我对两种KAN层的实现进行了解读,它们分别是 “基于B样条曲线的KAN层” 和 “基于 eq?c*f%28a*x+b%29+d 的KAN层” 。(在下文中就称 B样条KAN层 和 符号KAN层)

今天我们开始对完整的KAN网络进行剖析,根据之前的经验,MultKAN类应该包括网络初始化、层之间网格参数传递、反向传播参数更新、网络剪枝、画图等等操作。

目录

一、kan目录

二、MultKAN.py 

2.1 类注释

​​​​2.2 构造函数 __init__

2.3 节点数计算

2.4 前向传播 forward

 0. 方法定义及注释

1. 初始化阶段

2. 前向传播循环

2.5 训练方法 fit

三、总结


一、kan目录

kan目录结构如下,包括了模型源码、检查点、实验以及assets等

e12295be65d94b3381e647242dc51eba.pngcc0f7d4a5a5148c995fcd44ad9bbbab6.png

 先了解一下这些文件/文件夹的大致信息:

  • kan\__init__.py:用于初始化Python包,方便使用时导入模块
  • kan\compiler.py:用于编译模型
  • kan\experiment.py:实验代码
  • kan\feynman.py:费曼函数,根据传入“name”的值确定函数,暂时没找到这个在哪里用到
  • kan\hypothesis.py:将函数进行线性分离,还包含一些画图函数
  •  kan\KANLayer.py:KAN层的实现,使用B样条曲线作为激活函数 
  • kan\LBFGS.py:这个文件名似乎昨天见过,训练时的opt参数。L-BFGS是一种用于无约束优化问题的算法,它是一种拟牛顿方法,特别适用于大型稀疏问题。
  • kan\MLP.py:作者自己实现了一个MLP,应该使来与KAN做对比的
  • kan\MultKAN.py:在KANLayer的基础上实现的KAN类的定义,提供了关于构建和配置这种网络的详细信息。
  • kan\spline.py:样条函数的实现
  •  kan\Symbolic_KANLayer.py:符号化的KAN层,使用四参线性函数作为激活函数 
  • kan\utils.py:通用模块
  • kan\.ipynb_checkpoints:看目录名,这个文件夹下存放的应该是检查点文件,但是似乎和模型的实现代码区别不大,没遇到过,还不知道有什么用。
  • kan\assets:这个目录下存放了两张图片,一张加号一张乘号,应该是对函数进行线性分离后,可视化时用的
  • kan\experiments:这个目录下是experiment1.ipynb,和昨天跑的hellokan差不多,今天再跑一下

二、MultKAN.py 

import torch
import torch.nn as nn
import numpy as np
from .KANLayer import KANLayer
#from .Symbolic_MultKANLayer import *
from .Symbolic_KANLayer import Symbolic_KANLayer
from .LBFGS import *
import os
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import copy
#from .MultKANLayer import MultKANLayer
import pandas as pd
from sympy.printing import latex
from sympy import *
import sympy
import yaml
from .spline import curve2coef
from .utils import SYMBOLIC_LIB
from .hypothesis import plot_tree

导入的这些依赖中,只有 LBFGS 和 plot_tree 我们还没介绍,这两部分内容我也没打算深入研究

  • LBFGS(Limited-memory BFGS)是一种优化算法,它主要用于求解无约束优化问题。
  • plot_tree则是画出网络的树状图

2.1 类注释

class MultKAN(nn.Module):
    '''
    KAN class
    
    Attributes:
    -----------
        grid : int
            the number of grid intervals
        k : int
            spline order
        act_fun : a list of KANLayers
        symbolic_fun: a list of Symbolic_KANLayer
        depth : int
            depth of KAN
        width : list
            number of neurons in each layer.
            Without multiplication nodes, [2,5,5,3] means 2D inputs, 3D outputs, with 2 layers of 5 hidden neurons.
            With multiplication nodes, [2,[5,3],[5,1],3] means besides the [2,5,53] KAN, there are 3 (1) mul nodes in layer 1 (2). 
        mult_arity : int, or list of int lists
            multiplication arity for each multiplication node (the number of numbers to be multiplied)
        grid : int
            the number of grid intervals
        k : int
            the order of piecewise polynomial
        base_fun : fun
            residual function b(x). an activation function phi(x) = sb_scale * b(x) + sp_scale * spline(x)
        symbolic_fun : a list of Symbolic_KANLayer
            Symbolic_KANLayers
        symbolic_enabled : bool
            If False, the symbolic front is not computed (to save time). Default: True.
        width_in : list
            The number of input neurons for each layer
        width_out : list
            The number of output neurons for each layer
        base_fun_name : str
            The base function b(x)
        grip_eps : float
            The parameter that interpolates between uniform grid and adaptive grid (based on sample quantile)
        node_bias : a list of 1D torch.float
        node_scale : a list of 1D torch.float
        subnode_bias : a list of 1D torch.float
        subnode_scale : a list of 1D torch.float
        symbolic_enabled : bool
            when symbolic_enabled = False, the symbolic branch (symbolic_fun) will be ignored in computation (set to zero)
        affine_trainable : bool
            indicate whether affine parameters are trainable (node_bias, node_scale, subnode_bias, subnode_scale)
        sp_trainable : bool
            indicate whether the overall magnitude of splines is trainable
        sb_trainable : bool
            indicate whether the overall magnitude of base function is trainable
        save_act : bool
            indicate whether intermediate activations are saved in forward pass
        node_scores : None or list of 1D torch.float
            node attribution score
        edge_scores : None or list of 2D torch.float
            edge attribution score
        subnode_scores : None or list of 1D torch.float
            subnode attribution score
        cache_data : None or 2D torch.float
            cached input data
        acts : None or a list of 2D torch.float
            activations on nodes
        auto_save : bool
            indicate whether to automatically save a checkpoint once the model is modified
        state_id : int
            the state of the model (used to save checkpoint)
        ckpt_path : str
            the folder to store checkpoints
        round : int
            the number of times rewind() has been called
        device : str
    '''

这段代码定义了一个名为 MultKAN 的类,它是基于 nn.Module 构建的,这个类具有众多的属性,用于描述和控制其行为和特征:

  • grid:网格的间隔数(使用网格进行参数优化)
  • k:分段多项式的阶数,或者说B样条的控制点数
  • act_fun:B样条KAN层列表
  • symbolic_fun:符号KAN层列表。
  • depth:表示模型的深度。
  • width:描述了各层神经元的数量。
  • mult_arity:与乘法节点的乘法运算的元数有关。
  • base_fun:公式中的eq?b%28x%29
  •  symbolic_enabled:布尔值,是否使用符号KAN层 
  • width_in 和 width_out:分别表示各层的输入和输出神经元数量。
  • base_fun_name:基础函数的名称。
  • grip_eps:可能用于在均匀网格和自适应网格之间进行插值。
  • 各种与偏差、缩放、训练相关的属性,如 node_biasnode_scale 等,用于控制模型的训练和参数调整。
  • 各种与分数、缓存、自动保存、设备等相关的属性,用于模型的评估、数据存储、模型保存和硬件设置等方面。 

嘛,就是说这里好多注释又重复了......

​​​​2.2 构造函数 __init__

    def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, scale_base_mu=0.0, scale_base_sigma=1.0, base_fun='silu', symbolic_enabled=True, affine_trainable=False, grid_eps=0.02, grid_range=[-1, 1], sp_trainable=True, sb_trainable=True, seed=1, save_act=True, sparse_init=False, auto_save=True, first_init=True, ckpt_path='./model', state_id=0, round=0, device='cpu'):
        '''
        initalize a KAN model
        
        Args:
        -----
            width : list of int
                Without multiplication nodes: :math:`[n_0, n_1, .., n_{L-1}]` specify the number of neurons in each layer (including inputs/outputs)
                With multiplication nodes: :math:`[[n_0,m_0=0], [n_1,m_1], .., [n_{L-1},m_{L-1}]]` specify the number of addition/multiplication nodes in each layer (including inputs/outputs)
            grid : int
                number of grid intervals. Default: 3.
            k : int
                order of piecewise polynomial. Default: 3.
            mult_arity : int, or list of int lists
                multiplication arity for each multiplication node (the number of numbers to be multiplied)
            noise_scale : float
                initial injected noise to spline.
            base_fun : str
                the residual function b(x). Default: 'silu'
            symbolic_enabled : bool
                compute (True) or skip (False) symbolic computations (for efficiency). By default: True. 
            affine_trainable : bool
                affine parameters are updated or not. Affine parameters include node_scale, node_bias, subnode_scale, subnode_bias
            grid_eps : float
                When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
            grid_range : list/np.array of shape (2,))
                setting the range of grids. Default: [-1,1]. This argument is not important if fit(update_grid=True) (by default updata_grid=True)
            sp_trainable : bool
                If true, scale_sp is trainable. Default: True.
            sb_trainable : bool
                If true, scale_base is trainable. Default: True.
            device : str
                device
            seed : int
                random seed
            save_act : bool
                indicate whether intermediate activations are saved in forward pass
            sparse_init : bool
                sparse initialization (True) or normal dense initialization. Default: False.
            auto_save : bool
                indicate whether to automatically save a checkpoint once the model is modified
            state_id : int
                the state of the model (used to save checkpoint)
            ckpt_path : str
                the folder to store checkpoints. Default: './model'
            round : int
                the number of times rewind() has been called
            device : str
            
        Returns:
        --------
            self
            
        Example
        -------
        >>> from kan import *
        >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
        checkpoint directory created: ./model
        saving model version 0.0
        '''

这段代码是 MultKAN 类的构造函数 __init__ 的定义。构造函数用于初始化一个 MultKAN 模型实例,并为其设置各种参数。

参数说明:

  • width:一个整数列表,指定了每一层的神经元数量。如果没有乘法节点,列表中的每个元素代表相应层的神经元数量;如果有乘法节点,列表中的元素是一个包含神经元数量和乘法节点数量的元组。
  • grid:网格间隔的数量,默认为3。
  • k:分段多项式的阶数,默认为3。
  • mult_arity:每个乘法节点的乘法运算的元数,可以是单个整数或整数列表。
  • noise_scale:注入到样条函数中的初始噪声的缩放比例。
  • scale_base_mu 和 scale_base_sigma:基础函数的缩放参数的均值和标准差。
  • base_fun:残差函数 b(x) 的类型,默认为 'silu'。
  • symbolic_enabled:是否启用符号计算,默认为True。
  • affine_trainable:是否更新仿射参数,包括节点缩放、节点偏差、子节点缩放和子节点偏差。
  • grid_eps:用于在均匀网格和自适应网格之间进行插值的参数。
  • grid_range:设置网格范围的列表或NumPy数组。
  • sp_trainable:如果为真,则spline的缩放是可训练的。
  • sb_trainable:如果为真,则基础函数的缩放是可训练的。
  • device:指定设备,如 'cpu' 或 'cuda'。
  • seed:随机种子,用于初始化权重。
  • save_act:指示是否在正向传递中保存中间激活。
  • sparse_init:是否进行稀疏初始化。
  • auto_save:指示是否在修改模型后自动保存检查点。
  • state_id:模型的当前状态,用于保存检查点。
  • ckpt_path:存储检查点的文件夹路径。
  • roundrewind() 被调用次数。
  • device:设备类型。

代码说明:

        super(MultKAN, self).__init__()
  • 调用父类 MultKAN 的初始化方法,用于设置一些基本的属性或执行一些初始化操作。
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
  • 这三行代码设置了随机数种子,确保每次运行代码时生成的随机数序列相同,这对于测试和调试非常有用。据说将seed设置为3407会将模型的性能提升1%

        ### initializeing the numerical front ###

        self.act_fun = []
        self.depth = len(width) - 1
  • 这里初始化了激活函数列表 self.act_fun 和模型的深度 self.depth,深度是通过宽度列表的长度减一得到的。
        for i in range(len(width)):
            if type(width[i]) == int:
                width[i] = [width[i],0]
            
        self.width = width
  • 遍历宽度列表,如果宽度为整数,则将其转换为列表形式,形式为 [宽度, 0]。
  • 将宽度列表赋值给 self.width 属性。
  • 注意到,注释中的width属性是有两种形式的,这几行代码使其都转化为了第二种形式,即如果有乘法节点,列表中的元素是一个包含神经元数量和乘法节点数量的元组。
        # if mult_arity is just a scalar, we extend it to a list of lists
        # e.g, mult_arity = [[2,3],[4]] means that in the first hidden layer, 2 mult ops have arity 2 and 3, respectively;
        # in the second hidden layer, 1 mult op has arity 4.
        if isinstance(mult_arity, int):
            self.mult_homo = True # when homo is True, parallelization is possible
        else:
            self.mult_homo = False # when home if False, for loop is required. 
        self.mult_arity = mult_arity
  • 如果 mult_arity 是一个标量(即单个数字),代码将把它扩展为一个列表的列表。这样做通常是为了将单一的参数应用到多个乘法操作上。
  • 例如,如果 mult_arity = [[2,3],[4]],这意味着在第一个隐藏层中有两个乘法操作,它们的参数分别是 2 和 3;在第二个隐藏层中有一个乘法操作,其参数是 4。
  • 这里检查 mult_arity 是否是一个整数。如果是,那么所有乘法操作的参数都是相同的,这意味着它们是同质的。在这种情况下,可以将这些操作并行化,以提高计算效率。因此,将 self.mult_homo 设置为 True
  • 如果 mult_arity 不是一个整数,那么它可能是一个列表的列表,其中包含不同层级的不同参数。在这种情况下,不能并行化乘法操作,因为每个操作的参数可能不同。因此,将 self.mult_homo 设置为 False,这意味着可能需要使用循环来处理每个操作。
  • 最后,将处理后的 mult_arity 参数赋值给 self.mult_arity,这样模型就可以使用这个参数来定义其乘法操作了。
        width_in = self.width_in
        width_out = self.width_out

调用了两个方法,获得了KAN层真正的输入输出节点数。

        self.base_fun_name = base_fun
        if base_fun == 'silu':
            base_fun = torch.nn.SiLU()
        elif base_fun == 'identity':
            base_fun = torch.nn.Identity()
        elif base_fun == 'zero':
            base_fun = lambda x: x*0.
  • 将传入的 base_fun 参数赋值给实例变量 self.base_fun_name。这意味着 base_fun 是一个字符串,它表示想要使用的基础函数的名称。
  • 如果 base_fun_name 是字符串 'silu',那么代码将创建一个 torch.nn.SiLU() 对象。SiLU(Sigmoid-weighted Linear Unit)是一个激活函数,通常用于神经网络中。
  • 如果 base_fun_name 是字符串 'identity',那么代码将创建一个 torch.nn.Identity() 对象。Identity 函数是一个恒等函数,它直接返回其输入值,通常用作默认激活函数或不改变输入的层。
  • 如果 base_fun_name 是字符串 'zero',那么代码将创建一个匿名函数(lambda 函数),这个函数将任何输入 x 乘以 0,从而输出 0。这可能表示一个“关闭”激活状态的函数,不激活任何神经元。
        self.grid_eps = grid_eps
        self.grid_range = grid_range
  • 将网格相关的参数赋值给 self.grid_eps 和 self.grid_range
  • grid_eps:控制网格细化策略的浮点数,默认为0.02。当 grid_eps = 1 时,网格是均匀的;当 grid_eps = 0 时,它使用样本的百分位数进行分区。0 < grid_eps < 1 插值在两种极端之间。
        for l in range(self.depth):
            # splines
            sp_batch = KANLayer(in_dim=width_in[l], out_dim=width_out[l+1], num=grid, k=k, noise_scale=noise_scale, scale_base_mu=scale_base_mu, scale_base_sigma=scale_base_sigma, scale_sp=1., base_fun=base_fun, grid_eps=grid_eps, grid_range=grid_range, sp_trainable=sp_trainable, sb_trainable=sb_trainable, sparse_init=sparse_init)
            self.act_fun.append(sp_batch)
  • 这段代码在循环中为每一层创建一个 KANLayer 实例,并把这些实例添加到一个列表中,以便后续可以用于KAN网络模型。
        self.node_bias = []
        self.node_scale = []
        self.subnode_bias = []
        self.subnode_scale = []
  •  初始化用于节点和子节点的偏差和缩放参数的列表。
        globals()['self.node_bias_0'] = torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False)
        exec('self.node_bias_0' + " = torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False)")
  • globals() 返回当前全局命名空间中的所有全局变量。
  • self.node_bias_0 是类的一个属性,这个属性在类的定义中尚未明确定义(即,它不是类的内部成员,而是通过全局命名空间访问的)。
  • torch.nn.Parameter(torch.zeros(3,1)) 创建了一个PyTorch的参数(Parameter对象),该对象是张量,用于在神经网络中存储权重,并且支持梯度计算。
  • .requires_grad_(False) 设置了该参数对象不进行梯度计算,即不会追踪其在计算图中的操作,这对于不需要计算梯度的参数(如偏置项)来说是合理的。
  • exec() 是Python的内置函数,用于执行字符串形式的Python代码。在这里,它被用来动态地创建或更新类属性。
  • 'self.node_bias_0' 是一个字符串,表示类中要创建或更新的属性名。
  • torch.nn.Parameter(torch.zeros(3,1)).requires_grad_(False) 是一个字符串表达式,创建了一个新的PyTorch参数并设置了其梯度计算为False。

这种做法在某些情况下非常有用,比如在定义神经网络模型时,需要动态地为特定的参数创建属性,或者在模型中为某些不需要梯度计算的参数(如偏置项)创建独立的属性。

但是!我没找到这两行代码有啥用处。

        for l in range(self.depth):
            exec(f'self.node_bias_{l} = torch.nn.Parameter(torch.zeros(width_in[l+1])).requires_grad_(affine_trainable)')
            exec(f'self.node_scale_{l} = torch.nn.Parameter(torch.ones(width_in[l+1])).requires_grad_(affine_trainable)')
            exec(f'self.subnode_bias_{l} = torch.nn.Parameter(torch.zeros(width_out[l+1])).requires_grad_(affine_trainable)')
            exec(f'self.subnode_scale_{l} = torch.nn.Parameter(torch.ones(width_out[l+1])).requires_grad_(affine_trainable)')
            exec(f'self.node_bias.append(self.node_bias_{l})')
            exec(f'self.node_scale.append(self.node_scale_{l})')
            exec(f'self.subnode_bias.append(self.subnode_bias_{l})')
            exec(f'self.subnode_scale.append(self.subnode_scale_{l})')
  • 通过循环,它为模型的每一层创建了节点偏置、节点缩放、子节点偏置和子节点缩放参数,并将这些参数存储在类的属性中,以便后续使用。
  • 通过 affine_trainable 参数来控制哪些参数是可训练的。
        self.act_fun = nn.ModuleList(self.act_fun)

        self.grid = grid
        self.k = k
        self.base_fun = base_fun

这几个基础的设置就不解释了。

        ### initializing the symbolic front ###
        self.symbolic_fun = []
        for l in range(self.depth):
            sb_batch = Symbolic_KANLayer(in_dim=width_in[l], out_dim=width_out[l+1])
            self.symbolic_fun.append(sb_batch)

刚刚创建了B样条KAN层,现在创建符号KAN层。

        self.symbolic_fun = nn.ModuleList(self.symbolic_fun)
        self.symbolic_enabled = symbolic_enabled
        self.affine_trainable = affine_trainable
        self.sp_trainable = sp_trainable
        self.sb_trainable = sb_trainable
  • 将符号层加入列表
  • 设置符号层是否可用
  • 设置符号层线性函数的四个参数是否可训练
  • 设置激活函数中的参数 eq?w_%7Bs%7D 是否可训练,分为了sp和sb两种,sp为B样条KAN层的,sb为符号KAN层的
        self.save_act = save_act
            
        self.node_scores = None
        self.edge_scores = None
        self.subnode_scores = None
        
        self.cache_data = None
        self.acts = None
        
        self.auto_save = auto_save
        self.state_id = 0
        self.ckpt_path = ckpt_path
        self.round = round

一些中间结果的存储变量和保存操作设置,保存的具体操作如下:

        if auto_save:
            if first_init:
                if not os.path.exists(ckpt_path):
                    # Create the directory
                    os.makedirs(ckpt_path)
                print(f"checkpoint directory created: {ckpt_path}")
                print('saving model version 0.0')

                history_path = self.ckpt_path+'/history.txt'
                with open(history_path, 'w') as file:
                    file.write(f'### Round {self.round} ###' + '\n')
                    file.write('init => 0.0' + '\n')
                self.saveckpt(path=self.ckpt_path+'/'+'0.0')
            else:
                self.state_id = state_id

我们在hellokan中就见识过,模型在训练过程中会保存中间数据、状态和历史信息等内容

        self.input_id = torch.arange(self.width_in[0],)
  • 给输入节点编号 ,从0开始
        self.device = device
        self.to(device)
    def to(self, device):
        '''
        move the model to device
        
        Args:
        -----
            device : str or device

        Returns:
        --------
            self
            
        Example
        -------
        >>> from kan import *
        >>> device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
        >>> model.to(device)
        '''
        super(MultKAN, self).to(device)
        self.device = device
        
        for kanlayer in self.act_fun:
            kanlayer.to(device)
            
        for symbolic_kanlayer in self.symbolic_fun:
            symbolic_kanlayer.to(device)
            
        return self
  •  选择计算设备

测试:

from kan import *
torch.set_default_dtype(torch.float64)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

model = KAN(width=[2,[5,3],[5,1],3], mult_arity=[0,[2,3,4],[2],0],grid=3, k=3, seed=42, device=device)
model.input_id

cuda
checkpoint directory created: ./model
saving model version 0.0

tensor([0, 1])

2.3 节点数计算

2.3.1 width_in 

    @property
    def width_in(self):
        '''
        The number of input nodes for each layer
        '''
        width = self.width
        width_in = [width[l][0]+width[l][1] for l in range(len(width))]
        return width_in

这段代码定义了一个属性 width_in ,它的作用是计算并返回模型每一层的输入节点数量。

  •  首先,获取了模型的宽度信息 width 。
  • 然后,通过列表推导式计算每一层输入节点的数量,计算方式是将每一层的总和维度 width[l][0] 和乘法操作维度 width[l][1] 相加。
  • 最后,返回计算得到的输入节点数量列表。

所以每层节点数=设置的节点数+乘法操作次数

测试:

width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]

print(model.width)
model.width_in

[[2, 0], [5, 3], [5, 1], [3, 0]]

 [2, 8, 6, 3]

2.3.2 width_out

    @property
    def width_out(self):
        '''
        The number of output subnodes for each layer
        '''
        width = self.width
        if self.mult_homo == True:
            width_out = [width[l][0]+self.mult_arity*width[l][1] for l in range(len(width))]
        else:
            width_out = [width[l][0]+int(np.sum(self.mult_arity[l])) for l in range(len(width))]
        return width_out

这段代码定义了一个属性 width_out,其目的是计算并返回模型每一层的输出子节点数量。

  • 首先,获取了模型的宽度信息 width。然后根据 self.mult_homo 的值来决定计算输出节点数量的方式。
  • 如果 self.mult_homo 为 True,则使用列表推导式计算每一层的输出节点数量。计算方式是将每一层的总和维度 width[l][0] 与乘法操作维度 width[l][1] 的结果乘以 mult_arity 的值相加。mult_arity 是一个数组,表示每一层的乘法操作的幅度。
  • 如果 self.mult_homo 为 False,则使用列表推导式计算每一层的输出节点数量。计算方式是将每一层的总和维度 width[l][0] 与 mult_arity[l] 的元素之和相加。mult_arity[l] 是一个数组,表示每一层的乘法操作的幅度。
  • 最后,返回计算得到的输出节点数量列表。

测试:

width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]

print(model.width)
model.width_out

[[2, 0], [5, 3], [5, 1], [3, 0]]

[2, 14, 7, 3]

2.3.3 n_sum

    @property
    def n_sum(self):
        '''
        The number of addition nodes for each layer
        '''
        width = self.width
        n_sum = [width[l][0] for l in range(1,len(width)-1)]
        return n_sum

这段代码定义了一个属性 n_sum ,用于计算并返回除了第一层和最后一层之外,每一层的总和维度 width[l][0] 所组成的列表。

首先,获取了模型的宽度信息 width 。然后通过列表推导式,从第二层到倒数第二层,提取出每一层的 width[l][0] ,并将这些值组成一个新的列表 n_sum ,最后返回这个列表。

测试:

width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]

print(model.width)
model.n_sum

[[2, 0], [5, 3], [5, 1], [3, 0]]
[5, 5]

2.3.4 n_mult

    @property
    def n_mult(self):
        '''
        The number of multiplication nodes for each layer
        '''
        width = self.width
        n_mult = [width[l][1] for l in range(1,len(width)-1)]
        return n_mult

这段代码定义了一个属性 n_mult ,用于计算并返回除了第一层和最后一层之外,每一层的乘法节点数量。这里 width 是一个包含多层宽度信息的数据结构,每一层的信息以列表的形式存储,其中 width[l][1] 表示第 l 层的乘法节点数量。

通过列表推导式,代码遍历从第二层到倒数第二层的所有层,提取每一层的乘法节点数量,并将这些数量组成一个新的列表 n_mult 。最后,这个列表被返回给调用者。

测试:

width=[2,[5,3],[5,1],3],mult_arity = [[0],[2,3,4],[2],[0]]

print(model.width)
model.n_mult

[[2, 0], [5, 3], [5, 1], [3, 0]]

[3, 1]

2.3.5 feature_score

    @property
    def feature_score(self):
        '''
        attribution scores for inputs
        '''
        self.attribute()
        if self.node_scores == None:
            return None
        else:
            return self.node_scores[0]

这段代码定义了一个名为 feature_score 的属性。其功能是计算输入的归因分数。

首先调用了 self.attribute() 方法。然后判断 self.node_scores 是否为 None ,如果是,则直接返回 None ;如果不是,则返回 self.node_scores 中的第一个元素。

这意味着只有在 self.node_scores 不为空的情况下,才会返回其第一个元素作为特征分数。

2.4 前向传播 forward

这个前向传播有点诡异,总感觉跟论文中的对不上,这次我们一边解释一边测试!

先来个简单的:width=[2,5,5,3],mult_arity = 2

from kan import *
torch.set_default_dtype(torch.float64)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

model = KAN(width=[2,5,5,3], mult_arity=2,grid=3, k=3, seed=42, device=device)
model.input_id

cuda
checkpoint directory created: ./model
saving model version 0.0
tensor([0, 1])

 测试数据:

x = torch.tensor([[1,2],
                  [3,4],
                  [5,6],
                  [7,8],
                  [9,10]]).float()
x = x.to(device)

 0. 方法定义及注释

    def forward(self, x, singularity_avoiding=False, y_th=10.):
        '''
        forward pass
        
        Args:
        -----
            x : 2D torch.tensor
                inputs
            singularity_avoiding : bool
                whether to avoid singularity for the symbolic branch
            y_th : float
                the threshold for singularity

        Returns:
        --------
            None
            
        Example1
        --------
        >>> from kan import *
        >>> model = KAN(width=[2,5,1], grid=5, k=3, seed=0)
        >>> x = torch.rand(100,2)
        >>> model(x).shape
        
        Example2
        --------
        >>> from kan import *
        >>> model = KAN(width=[1,1], grid=5, k=3, seed=0)
        >>> x = torch.tensor([[1],[-0.01]])
        >>> model.fix_symbolic(0,0,0,'log',fit_params_bool=False)
        >>> print(model(x))
        >>> print(model(x, singularity_avoiding=True))
        >>> print(model(x, singularity_avoiding=True, y_th=1.))
        '''

参数说明:

  • x: 2D torch.tensor,输入数据。
  • singularity_avoiding: bool,默认为 False。如果为 True,则在符号分支中避免奇异点。
  • y_th: float,默认为 10.。用于判断是否避免奇异点的阈值。

返回值:

  • None:方法执行后不返回任何值

1. 初始化阶段

        x = x[:,self.input_id.long()]
        assert x.shape[1] == self.width_in[0]

        # cache data
        self.cache_data = x
        
        self.acts = []  # shape ([batch, n0], [batch, n1], ..., [batch, n_L])
        self.acts_premult = []
        self.spline_preacts = []
        self.spline_postsplines = []
        self.spline_postacts = []
        self.acts_scale = []
        self.acts_scale_spline = []
        self.subnode_actscale = []
        self.edge_actscale = []
        # self.neurons_scale = []

        self.acts.append(x)  # acts shape: (batch, width[l])
  • 数据选择与验证

    • 选择输入数据 x 的特定列,并验证其形状是否符合模型的输入宽度要求。
    • 缓存输入数据 x
  • 初始化变量

    • 初始化用于存储不同层激活、尺度因子等的列表。

2. 前向传播循环

        for l in range(self.depth):
  • 循环遍历模型中的每一层,其中 self.depth 是模型的层数。
            x_numerical, preacts, postacts_numerical, postspline = self.act_fun[l](x)
            #print(preacts, postacts_numerical, postspline)
  • 使用第 l 层的激活函数 act_fun[l] 对输入 x 进行处理。
  • 这里的激活函数是B样条KAN层的激活函数,详情见KANLayer
  • 处理结果包括数值分支的输出 x_numerical、预激活输出 preacts、后激活输出 postacts_numerical 和样条函数的输出 postspline。(对应的是y, preacts, postacts, postspline)
            if self.symbolic_enabled == True:
                x_symbolic, postacts_symbolic = self.symbolic_fun[l](x, singularity_avoiding=singularity_avoiding, y_th=y_th)
            else:
                x_symbolic = 0.
                postacts_symbolic = 0.
  • 可使用符号KAN层时,同样进行计算
            x = x_numerical + x_symbolic

这里要注意了,作者将两种层的计算结果相加了!也就是把B样条和线性函数同时叠加使用!

            # subnode affine transform
            x = self.subnode_scale[l][None,:] * x + self.subnode_bias[l][None,:]
  •  对激活函数的计算结果进行缩放,并增加偏置常数b

对以上这一部分内容做测试:

x = x[:,model.input_id.long()]
assert x.shape[1] == model.width_in[0]

for l in range(model.depth):
            
    x_numerical, preacts, postacts_numerical, postspline = model.act_fun[l](x)
    #print(preacts, postacts_numerical, postspline)
            
    if model.symbolic_enabled == True:
        x_symbolic, postacts_symbolic = model.symbolic_fun[l](x, singularity_avoiding=False, y_th=10)
    else:
        x_symbolic = 0.
        postacts_symbolic = 0.

    x = x_numerical + x_symbolic
    x = model.subnode_scale[l][None,:] * x + model.subnode_bias[l][None,:]

    print(x)
    print(x.shape)

 tensor([[ 1.2935, -0.7047, -1.1071,  0.1673,  0.7162],
        [ 3.6752, -2.1692, -2.5181,  0.0475,  2.4007],
        [ 5.9759, -3.6323, -3.8994, -0.1110,  4.0748],
        [ 8.2045, -5.0443, -5.2469, -0.2554,  5.6880],
        [10.4148, -6.4431, -6.5866, -0.3956,  7.2852]], device='cuda:0',
       grad_fn=<AddBackward0>)
torch.Size([5, 5])
tensor([[-0.1832,  0.2447,  0.2546,  0.0981, -0.0997],
        [-0.8389,  1.2821,  1.2113,  0.4335, -0.0337],
        [-1.3981,  2.3204,  2.1626,  0.7829, -0.0287],
        [-1.9293,  3.2605,  3.0434,  1.1072, -0.0466],
        [-2.4603,  4.1674,  3.9076,  1.4242, -0.0795]], device='cuda:0',
       grad_fn=<AddBackward0>)
torch.Size([5, 5])
tensor([[ 0.0064,  0.0481, -0.1441],
        [ 0.4862,  0.2570, -0.7088],
        [ 1.0812,  0.6035, -1.4443],
        [ 1.6532,  0.8613, -2.0778],
        [ 2.1898,  1.0638, -2.6628]], device='cuda:0', grad_fn=<AddBackward0>)
torch.Size([5, 3])

 所有中间结果的形状都没有问题。

            if self.save_act:
                # save subnode_scale
                self.subnode_actscale.append(torch.std(x, dim=0).detach())

            if self.save_act:
                postacts = postacts_numerical + postacts_symbolic

                # self.neurons_scale.append(torch.mean(torch.abs(x), dim=0))
                #grid_reshape = self.act_fun[l].grid.reshape(self.width_out[l + 1], self.width_in[l], -1)
                input_range = torch.std(preacts, dim=0) + 0.1
                output_range_spline = torch.std(postacts_numerical, dim=0) # for training, only penalize the spline part
                output_range = torch.std(postacts, dim=0) # for visualization, include the contribution from both spline + symbolic
                # save edge_scale
                self.edge_actscale.append(output_range)
                
                self.acts_scale.append((output_range / input_range).detach())
                self.acts_scale_spline.append(output_range_spline / input_range)
                self.spline_preacts.append(preacts.detach())
                self.spline_postacts.append(postacts.detach())
                self.spline_postsplines.append(postspline.detach())

                self.acts_premult.append(x.detach())

 如果启用了保存激活函数尺度因子的选项,则计算并保存以下内容:

  • 子节点尺度因子(标准差)。
  • 边尺度因子(输出范围)。
  • 激活函数输出的尺度因子(输出范围与输入范围的比例)。
  • 样条部分的尺度因子。
  • 预激活输出、后激活输出和样条函数输出的副本。

但是我很好奇,这不是self.spline_postacts嘛,但是存的是postacts = postacts_numerical + postacts_symbolic,保存样条的激活输出为什么不只保存postacts_numerical。

还有就是都是判断 save_act,为啥用两个if。有时候就挺不能理解的

接下来介绍的这个东西,非常重要!它在基础节点的基础上引入了乘法操作。并且分为同质和非同质两种。

            # multiplication
            dim_sum = self.width[l+1][0]
            dim_mult = self.width[l+1][1]
  • 获取下一次节点数以及乘法操作次数
  • self.width[l+1][0]是下一层的节点数
  • self.width[l+1][1]是乘法操作次数

对于上面的例子,有

x.shape: torch.Size([5, 5])
dim_sum: 5
dim_mult: 0
x.shape: torch.Size([5, 5])
dim_sum: 5
dim_mult: 0
x.shape: torch.Size([5, 3])
dim_sum: 3
dim_mult: 0

            if self.mult_homo == True:
                for i in range(self.mult_arity-1):
                    if i == 0:
                        x_mult = x[:,dim_sum::self.mult_arity] * x[:,dim_sum+1::self.mult_arity]
                    else:
                        x_mult = x_mult * x[:,dim_sum+i+1::self.mult_arity]
  • 当本层乘法参数都相同,则进行矩阵运算,即处理同质(homogeneous)乘法操作:
    • 在第一次循环(i == 0)中:
      • x_mult = x[:,dim_sum::self.mult_arity] * x[:,dim_sum+1::self.mult_arity]:对 x 的特定部分进行逐元素乘法。这里 x[:,dim_sum::self.mult_arity] 表示从 dim_sum 开始,每隔 self.mult_arity 个元素取一个元素,形成一个新的张量。同理 x[:,dim_sum+1::self.mult_arity] 表示从 dim_sum+1 开始取元素。这两个张量逐元素相乘得到 x_mult
    • 在后续的循环中(i != 0):
      • x_mult = x_mult * x[:,dim_sum+i+1::self.mult_arity]:将上一次乘法的结果与 x 的另一部分相乘。

对于我们width=[2,5,5,3],mult_arity = 2这个例子,有model.mult_homo == True,但结果如下:

tensor([], device='cuda:0', size=(5, 0), grad_fn=<MulBackward0>)
x_mult.shape: torch.Size([5, 0])
tensor([], device='cuda:0', size=(5, 0), grad_fn=<MulBackward0>)
x_mult.shape: torch.Size([5, 0])
tensor([], device='cuda:0', size=(5, 0), grad_fn=<MulBackward0>)
x_mult.shape: torch.Size([5, 0])

 由于dim_mult = 0,所以不进行乘法运算,代码中表现为dim_sum超出index,所以dim_sum::model.mult_arity都为0,自然乘积也为0。

测试升级:

设置width=[2,[5,2],[5,3],3], mult_arity=3,这是一个同质运算,由第一层向第二层传递时,会做乘法运算,次数为mult_arity-1=2,而乘法运算结果维度为dim_mult,然后与原始的dim_sum维度拼接,参数设置 width=[2,[5,1],[5,3],3], mult_arity=3,拼接操作:

            if self.width[l+1][1] > 0:
                x = torch.cat([x[:,:dim_sum], x_mult], dim=1)
  •  将x中未参与乘法计算的部分与乘法计算结果进行拼接,恢复原始张量形状 

x.shape: torch.Size([5, 8])
x_mult.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 1])
x.shape: torch.Size([5, 6])

 

x.shape: torch.Size([5, 14])
x_mult.shape: torch.Size([5, 3])
x_mult.shape: torch.Size([5, 3])
x.shape: torch.Size([5, 8])

 

x.shape: torch.Size([5, 3])
x_mult.shape: torch.Size([5, 0])
x_mult.shape: torch.Size([5, 0])
x.shape: torch.Size([5, 3])

 测试再次升级:

我用数据展示第二层的计算:

x = torch.tensor([[1,2,3,4,5,6,7,8,9,10,11,12,13,14]])
dim_sum = 5
dim_mult = 3
mult_arity = 3
if model.mult_homo == True:
    for i in range(2):
        print(f"第{i+1}次乘法:")
        if i == 0:
            print(x[:,dim_sum::mult_arity])
            print(x[:,dim_sum+1::mult_arity])
            x_mult = x[:,dim_sum::mult_arity] * x[:,dim_sum+1::mult_arity]
            
        else:
            print(x_mult)
            print(x[:,dim_sum+i+1::mult_arity])
            x_mult = x_mult * x[:,dim_sum+i+1::mult_arity]

        print(x_mult)

if dim_mult > 0:
    x = torch.cat([x[:,:dim_sum], x_mult], dim=1)

print(x)
print("x.shape:",x.shape)
print()

 第1次乘法:
tensor([[ 6,  9, 12]])
tensor([[ 7, 10, 13]])
tensor([[ 42,  90, 156]])


第2次乘法:
tensor([[ 42,  90, 156]])
tensor([[ 8, 11, 14]])
tensor([[ 336,  990, 2184]])


tensor([[   1,    2,    3,    4,    5,  336,  990, 2184]])
x.shape: torch.Size([1, 8])

 这下就完全理解它的乘法是如何运算的了。同质运算使用了矩阵运算以加快运算速度,这建立在mult_arity为常数的情况下,而当mult_arity的元素为列表时,只能进行遍历运算,如下:

            else:
                for j in range(dim_mult):
                    acml_id = dim_sum + np.sum(self.mult_arity[l+1][:j])
                    for i in range(self.mult_arity[l+1][j]-1):
                        if i == 0:
                            x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]
                        else:
                            x_mult_j = x_mult_j * x[:,[acml_id+i+1]]
                            
                    if j == 0:
                        x_mult = x_mult_j
                    else:
                        x_mult = torch.cat([x_mult, x_mult_j], dim=1)
  • 当本层乘法参数不相同,则进行遍历运算,即处理非同质(non-homogeneous)乘法操作:
    • for j in range(dim_mult):循环遍历 dim_mult 次,dim_mult 表示乘法操作的次数。
    • 在每次循环中,计算 acml_id
      • acml_id = dim_sum + np.sum(self.mult_arity[l+1][:j]):计算当前乘法操作的起始索引。
    • 然后对每个乘法操作:
      • for i in range(self.mult_arity[l+1][j]-1)::循环遍历当前维度的乘法操作次数。
      • 在第一次循环(i == 0)中:
        • x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]:对 x 的特定部分进行逐元素乘法。
      • 在后续的循环中(i != 0):
        • x_mult_j = x_mult_j * x[:,[acml_id+i+1]]:将上一次乘法的结果与 x 的另一部分相乘。
    • 如果是第一个乘法操作(j == 0):
      • x_mult = x_mult_j:将第一个乘法操作的结果赋值给 x_mult
    • 如果不是第一个乘法操作:
      • x_mult = torch.cat([x_mult, x_mult_j], dim=1):将当前乘法操作的结果与之前的结果在最后一个维度上连接。

测试:

参数设置 width=[2,[5,1],[5,3],3], mult_arity=[[0],[2],[2,3,4],[0]]

x = torch.tensor([[1,2],
                  [3,4],
                  [5,6],
                  [7,8],
                  [9,10]]).float()
x = x.to(device)

x = x[:,model.input_id.long()]
assert x.shape[1] == model.width_in[0]

for l in range(model.depth):
            
    x_numerical, preacts, postacts_numerical, postspline = model.act_fun[l](x)
    #print(preacts, postacts_numerical, postspline)
            
    if model.symbolic_enabled == True:
        x_symbolic, postacts_symbolic = model.symbolic_fun[l](x, singularity_avoiding=False, y_th=10)
    else:
        x_symbolic = 0.
        postacts_symbolic = 0.

    x = x_numerical + x_symbolic
    x = model.subnode_scale[l][None,:] * x + model.subnode_bias[l][None,:]

    #print(x)
    print("x.shape:",x.shape)

    # multiplication
    dim_sum = model.width[l+1][0]
    dim_mult = model.width[l+1][1]
    #print("dim_sum:",dim_sum)
    #print("dim_mult:",dim_mult)
    if model.mult_homo == False:
        for j in range(dim_mult):
            acml_id = dim_sum + np.sum(model.mult_arity[l+1][:j])
            for i in range(model.mult_arity[l+1][j]-1):
                if i == 0:
                    x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]
                else:
                    x_mult_j = x_mult_j * x[:,[acml_id+i+1]]
                print("x_mult_j.shape:",x_mult_j.shape )
                            
            if j == 0:
                x_mult = x_mult_j
            else:
                x_mult = torch.cat([x_mult, x_mult_j], dim=1)
            print("x_mult.shape:",x_mult.shape)

    if model.width[l+1][1] > 0:
        x = torch.cat([x[:,:dim_sum], x_mult], dim=1)

x.shape: torch.Size([5, 7])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 1])
x.shape: torch.Size([5, 6])

 

x.shape: torch.Size([5, 14])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 2])
x_mult_j.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult_j.shape: torch.Size([5, 1])
x_mult.shape: torch.Size([5, 3])
x.shape: torch.Size([5, 8])

 

x.shape: torch.Size([5, 3])
x.shape: torch.Size([5, 3])

 来逐一分析:

  • 第一层到第二层计算,经过B样条KAN层和符号KAN层,x形状为[batch_size,dim_sum+sum(mult_arity[l+1])],其中sum(mult_arity[l+1])=dim_mult*mult_arity[l+1],因为mult_arity[l+1]只有一个元素。然后进行了一次乘法运算,并将结果拼接在x[:,:dim_sum]后面
  • 第二层到第三次计算:经过B样条KAN层和符号KAN层,x形状为[batch_size,dim_sum+sum(mult_arity[l+1])],其中sum(mult_arity[l+1])=np.sum(model.mult_arity[l+1][:j]),对于mult_arity[l+1]列表中的每一个元素,都执行其数值减一的乘法运算,运算结果x_mult_j的形状为[batch_size,1],最终获得的x_mult都是由x_mult_j拼接来的,最后将x_mult拼接在x[:,:dim_sum]后面。
  • 第三层到第四层同理。

使用数据展示第二层的计算:

x = torch.tensor([[1,2,3,4,5,6,7,8,9,10,11,12,13,14]])
dim_sum = 5
dim_mult = 3
mult_arity = [2,3,4]

print(x)

if model.mult_homo == False:
    for j in range(dim_mult):
        print(f"第{j+1}次运算:")
        acml_id = dim_sum + np.sum(mult_arity[:j])
        for i in range(mult_arity[j]-1):
            if i == 0:
                print(x[:,[acml_id]])
                print(x[:,[acml_id+1]])
                x_mult_j = x[:,[acml_id]] * x[:,[acml_id+1]]
            else:
                print(x_mult_j)
                print(x[:,[acml_id+i+1]])
                x_mult_j = x_mult_j * x[:,[acml_id+i+1]]
            print("x_mult_j:",x_mult_j)
        if j == 0:
                x_mult = x_mult_j
        else:
            x_mult = torch.cat([x_mult, x_mult_j], dim=1)
        print("x_mult:",x_mult)

if dim_mult > 0:
    x = torch.cat([x[:,:dim_sum], x_mult], dim=1)

print(x)
print("x.shape:",x.shape)
print()

tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14]])
第1次运算:
tensor([[6]])
tensor([[7]])
x_mult_j: tensor([[42]])
x_mult: tensor([[42]])


第2次运算:
tensor([[8]])
tensor([[9]])
x_mult_j: tensor([[72]])
tensor([[72]])
tensor([[10]])
x_mult_j: tensor([[720]])
x_mult: tensor([[ 42, 720]])


第3次运算:
tensor([[11]])
tensor([[12]])
x_mult_j: tensor([[132]])
tensor([[132]])
tensor([[13]])
x_mult_j: tensor([[1716]])
tensor([[1716]])
tensor([[14]])
x_mult_j: tensor([[24024]])
x_mult: tensor([[   42,   720, 24024]])


tensor([[    1,     2,     3,     4,     5,    42,   720, 24024]])
x.shape: torch.Size([1, 8])

            # x = x + self.biases[l].weight
            # node affine transform
            x = self.node_scale[l][None,:] * x + self.node_bias[l][None,:]
            
            self.acts.append(x.detach())
        
        return x
  • 对拼接后的x进行缩放并且加上偏置常数 
  • 返回计算结果

我们理一下整个计算思路:

  1. 传入x后,首先检查x的形状,遍历KAN层进行计算:
    1. 再分别使用B样条KAN层和符号KAN层计算出x = x_numerical + x_symbolic
    2. 对x进行缩放处理,并加入偏置常数
    3. 乘法运算
      1. 同质乘法运算:对于dim_sum之外的维度,使用矩阵运算计算出x_mult
      2. 非同质乘法运算:根据mult_arity[l+1]列表一次计算出x_mult_j,拼接成x_mult
    4. 如进行了乘法运算,则将x_mult与x[:,:dim_sum]拼接
    5. 对x进行缩放处理,并加入偏置常数
    6. 返回x

2.5 训练方法 fit

通过对前向传播进行剖析,KAN网络并不像论文中展示的那么简单

  • KAN层包含了B样条层和符号层两种,我们可以设置是否使用符号层,如使用的话,中间x计算结果为两者之和。
  • KAN层节点过渡时引入了乘法操作,包括同质乘法和非同质乘法,在定义的基础维度上进行了扩展,进一步加强了网络的学习能力。

现在我们对MultKAN的fit方法的使用进行详解。

    def fit(self, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1.,start_grid_update_step=-1, stop_grid_update_step=50, batch=-1,
              metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n', display_metrics=None):
        '''
        training

        Args:
        -----
            dataset : dic
                contains dataset['train_input'], dataset['train_label'], dataset['test_input'], dataset['test_label']
            opt : str
                "LBFGS" or "Adam"
            steps : int
                training steps
            log : int
                logging frequency
            lamb : float
                overall penalty strength
            lamb_l1 : float
                l1 penalty strength
            lamb_entropy : float
                entropy penalty strength
            lamb_coef : float
                coefficient magnitude penalty strength
            lamb_coefdiff : float
                difference of nearby coefficits (smoothness) penalty strength
            update_grid : bool
                If True, update grid regularly before stop_grid_update_step
            grid_update_num : int
                the number of grid updates before stop_grid_update_step
            start_grid_update_step : int
                no grid updates before this training step
            stop_grid_update_step : int
                no grid updates after this training step
            loss_fn : function
                loss function
            lr : float
                learning rate
            batch : int
                batch size, if -1 then full.
            save_fig_freq : int
                save figure every (save_fig_freq) steps
            singularity_avoiding : bool
                indicate whether to avoid singularity for the symbolic part
            y_th : float
                singularity threshold (anything above the threshold is considered singular and is softened in some ways)
            reg_metric : str
                regularization metric. Choose from {'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'}
            metrics : a list of metrics (as functions)
                the metrics to be computed in training
            display_metrics : a list of functions
                the metric to be displayed in tqdm progress bar

        Returns:
        --------
            results : dic
                results['train_loss'], 1D array of training losses (RMSE)
                results['test_loss'], 1D array of test losses (RMSE)
                results['reg'], 1D array of regularization
                other metrics specified in metrics

        Example
        -------
        >>> from kan import *
        >>> model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
        >>> f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
        >>> dataset = create_dataset(f, n_var=2)
        >>> model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
        >>> model.plot()
        # Most examples in toturals involve the fit() method. Please check them for useness.
        '''

参数说明:

  1. dataset (dic): 包含训练集和测试集的数据字典,通常包括输入数据(train_inputtest_input)和标签数据(train_labeltest_label)。

  2. opt (str): 选择的优化器,可以是 "LBFGS"(L-BFGS)或 "Adam"。

  3. steps (int): 训练的总步骤数。

  4. log (int): 日志输出的频率,即每多少步骤输出一次日志。

  5. lamb (float): 总体正则化强度,用于控制模型复杂度。

  6. lamb_l1 (float): L1 正则化强度,用于惩罚模型参数的绝对值。

  7. lamb_entropy (float): 用于惩罚模型熵的强度,有助于防止过拟合。

  8. lamb_coef (float): 模型系数的大小惩罚强度。

  9. lamb_coefdiff (float): 邻近系数之间的差异惩罚强度,用于增加模型的平滑性。

  10. update_grid (bool): 如果为 True,则在训练步骤达到 stop_grid_update_step 之前定期更新网格。

  11. grid_update_num (int): 在 stop_grid_update_step 之前更新网格的次数。

  12. start_grid_update_step (int): 在这个步骤之前不进行网格更新。

  13. stop_grid_update_step (int): 这个步骤之后不进行网格更新。

  14. loss_fn (function): 自定义损失函数,用于计算模型的损失。

  15. lr (float): 学习率,决定每次更新参数时的步长。

  16. batch (int): 批处理大小,如果为 -1,则使用完整数据集。

  17. save_fig_freq (int): 每多少步骤保存一次训练结果的图形。

  18. singularity_avoiding (bool): 如果为 True,则在符号部分避免奇异点。

  19. y_th (float): 奇异点阈值,高于此值的任何值都将被视为奇异点。

  20. reg_metric (str): 用于计算正则化的度量标准,可以选择不同的选项如 edge_forward_spline_n 等。

  21. metrics (list of functions): 计算并返回的自定义度量列表。

  22. display_metrics (list of functions): 在训练进度条中显示的度量列表。

返回值:

  • results (dic): 包含训练过程中的关键信息的字典,包括:
    • train_loss: 训练集上的损失(通常为 RMSE)。
    • test_loss: 测试集上的损失(通常为 RMSE)。
    • reg: 正则化项的值。
    • 其他用户指定的度量。

测试1:

from kan import *
model = KAN(width=[2,5,1], grid=5, k=3, noise_scale=0.3, seed=2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
model.plot()

checkpoint directory created: ./model
saving model version 0.0
| train_loss: 1.91e-02 | test_loss: 1.97e-02 | reg: 1.38e+01 | : 100%|█| 20/20 [00:07<00:00,  2.66it
saving model version 0.1

7bdaad0cace146f7b795d0de96bbae94.png

 测试2:

from kan import *
model = KAN(width=[2,[5,3],3], mult_arity=3, grid=5, k=3, noise_scale=0.3, seed=2)
f = lambda x: torch.exp(torch.sin(torch.pi*x[:,[0]]) + x[:,[1]]**2)
dataset = create_dataset(f, n_var=2)
model.fit(dataset, opt='LBFGS', steps=20, lamb=0.001);
model.plot()

checkpoint directory created: ./model
saving model version 0.0
| train_loss: 2.44e-02 | test_loss: 2.81e-02 | reg: 2.67e+01 | : 100%|█| 20/20 [00:09<00:00,  2.15it
saving model version 0.1

b03c0e119ad04e4eb44dc123370a436c.png

这个图包含了3个乘法节点。

三、总结

今天内容主要包括MultKAN网络的初始化、正向传播方法实现、训练方法参数说明。MultKAN网络正向传播有两个特点:

  1. 传播时可以同时使用KANLayer和Symbolic_KANLayer,以叠加的形式计算中间结果
  2. KAN节点的连接既有加法连接也有乘法连接,我们可以自定义乘法运算的方式(同质或非同质)

在上文中,我用数据直观展示了mult的计算过程,实际上只是连续的列相乘,因此在我看来,MultKAN的mult节点运算还有一定的优化空间,除了改善单一控制变量self.mult_homo,将其扩展为列表,还可以用numpy库实现连续列相乘的算法,这些尝试我打算放在实际应用中进行。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/873776.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

顶级出图效果!免费在线使用FLux.1 模型,5s出图无限制!

最近发现一个可以在线免费使用 FLux.1 模型 生成图片的AI工具。 先看效果图&#xff1a; 工具不需要登录即可使用&#xff0c;目前还是完全免费的&#xff0c;国内可以直接使用。 在提示词输入框直接输入提示词即可&#xff0c;选择图片比例之后&#xff0c;直接生图。 出图的…

24年9月通信基础知识补充1

看文献过程中不断发现有太多不懂的基础知识&#xff0c;故长期更新这类blog不断补充在这过程中学到的知识。由于这些内容与我的研究方向并不一定强相关&#xff0c;故记录不会很深入请见谅。 【通信基础知识补充2】9月通信基础知识补充1 一、Zadoff-Chu 序列1.1 Zadoff-Chu 序列…

3GPP协议入门——物理层基础(一)

1. 频段/带宽 NR指定了两个频率范围&#xff0c;FR1&#xff1a;通常称Sub 6GHz&#xff0c;也称低频5G&#xff1b;FR2&#xff1a;通常称毫米波&#xff08;Millimeter Wave&#xff09;&#xff0c;也称高频5G。 2. 子载波间隔 NR中有15kHz&#xff0c;30kHz&#xff0c;6…

C++——入门基础(下)

目录 一、引用 &#xff08;1&#xff09;引用的概念和定义 &#xff08;2&#xff09;引用的特性 &#xff08;3&#xff09;引用的使用 &#xff08;4&#xff09;const引用 &#xff08;5&#xff09;指针和引用的关系 二、inline 三、nullptr 四、写在最后 一、引用…

带相对位置表示的自注意力(201803)

Self-Attention with Relative Position Representations 带相对位置表示的自注意力 https://arxiv.org/pdf/1803.02155v1 Abstract Relying entirely on an attention mechanism, the Transformer introduced by Vaswani et al. (2017) achieves state-of-the-art results …

【加密社】比特币海量数据问题解决方案

加密社 比特币是无敌的存在&#xff0c;刚翻了一遍中本聪的论文&#xff08;其实以前看过一次&#xff0c;那时不明觉厉&#xff09;&#xff0c;发现咱们一直在考虑的问题&#xff0c;基本都能在其论文上找到解决方案了。。 现在出现的这些问题&#xff0c;完全是因为bitcoin…

4千6历年高考英语试题大全ACCESS\EXCEL数据库

《历年高#考英语试题大全ACCESS数据库》搜集了大量的全#国各#地高#考英语模拟试题&#xff0c;每道题目均有相应的答案和解析&#xff1b;这种数据虽然没有《一站到底》类的数据结构&#xff08;一个选项一个字段&#xff09;那么好&#xff0c;但是通过技术人员还是可以很简单…

自适应中值滤波器:图像去噪的高效解决方案

在数字图像处理中&#xff0c;椒盐噪声是常见的干扰之一&#xff0c;它会导致图像出现随机的黑点和白点&#xff0c;严重影响图像质量。传统的中值滤波器虽然在一定程度上能够去除这种噪声&#xff0c;但可能无法完全恢复图像的细节。为此&#xff0c;本文将介绍一种自适应中值…

k8s上搭建devops环境

一、gitlab 1.安装gitlab # 下载安装包 wget https://mirrors.tuna.tsinghua.edu.cn/gitlab-ce/yum/el7/gitlab-ce-15.9.1-ce.0.el7.x86_64.rpm # 安装 rpm -i gitlab-ce-15.9.1-ce.0.el7.x86_64.rpm # 编辑 vi /etc/gitlab/gitlab.rb 文件 # 修改 external_url 访问路径 htt…

【网络安全】分析JS文件实现账户接管

未经许可,不得转载。 文章目录 正文正文 网站使用的是简单的OTP(一次性密码)验证机制,通过用户注册时提供的电子邮件发送邮箱验证码。在功能有限的情况下,我选择去分析网站加载的JavaScript文件。 我发现了一个名为 saveJobseekerPasswordInCache 的函数: 这个函数虽然…

vscode侧边工具栏不见了找回方法

有时候因为误操作&#xff0c;vscode编辑器里面的侧边工具栏不见了找回方法&#xff0c;请按照以下步骤操作。 例:1&#xff1a;这个工具栏不见了 方法&#xff1a;菜单栏点击文件》点击首选项》点击设置》点击工作台》点击外观》勾选如下图选项 例如2&#xff1a;蓝控制台底…

无人机之穿越机的飞行模式

穿越机的飞行模式主要分为两种基本类型&#xff1a;自稳模式&#xff08;ANGLE MODE&#xff09;和手动模式&#xff08;ACRO MODE&#xff09;&#xff0c;以及一些衍生的飞行模式&#xff0c;如半自稳模式&#xff08;Horizon Mode&#xff09;等。下面将详细介绍这两种基本模…

vulhub think PHP 2-rce远程命令执行漏洞

1.开启环境 2。访问对应网站端口 3.这里我们直接构造payload&#xff0c;访问phpinfo() http://192.168.159.149:8080/?s/Index/index/L/${phpinfo()} 4.可以访问到我们的phpinfo&#xff0c; 所以写入一句话木马&#xff0c;也可使用蚁剑进行连接&#xff0c;获得其shell进…

云计算之大数据(下)

目录 一、Hologres 1.1 产品定义 1.2 产品架构 1.3 Hologres基本概念 1.4 最佳实践 - Hologres分区表 1.5 最佳实践 - 分区字段设置 1.6 最佳实践 - 设置字段类型 1.7 最佳实践 - 存储属性设置 1.8 最佳实践 - 分布键设置 1.9 最佳实践 - 聚簇键设置 1.10 最佳实践 -…

AT3340-6T杭州中科微BDS定位授时板卡性能指标

AT3340-6T是一款高性能多系统卫星定位安全授时板卡&#xff0c;可通过配置支持各个单系统的定位授时。 外观尺寸&#xff1a; 电气参数 应用领域&#xff1a; 通信基站授时 电力授时 广播电视授时 轨道系统授时 金融系统授时 其他授时应用 注意事项&#xff1a; 为了充分发挥…

Linux入门攻坚——31、rpc概念及nfs和samba

NFS&#xff1a;Network File System 传统意义上&#xff0c;文件系统在内核中实现 RPC&#xff1a;函数调用&#xff08;远程主机上的函数&#xff09;&#xff0c;Remote Procedure Call protocol 一部分功能由本地程序完成 另一部分功能由远程主机上的 NFS本质…

软件部署-Docker容器化技术

开始前的环境说明 VMware 17 Pro Centos release 7.9.2009(防火墙已关闭) Docker 26.1.4 Docker镜像加速器配置:"https://do.nark.eu.org", "https://dc.j8.work", "https://docker.m.daocloud.io", "https://dockerproxy.com", &…

2. c#从不同cs的文件调用函数

1.文件目录如下&#xff1a; 2. Program.cs文件的主函数如下 using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using System.Windows.Forms;namespace datasAnalysis {internal static class Program{/// <summary>…

HUAWEI华为MateBook B5-420 i5 集显(KLCZ-WXX9,KLCZ-WDH9)原装出厂Windows10系统文件下载

适用型号&#xff1a;KLCZ-WXX9、KLCZ-WDH9 链接&#xff1a;https://pan.baidu.com/s/12xnaLtcPjZoyfCcJUHynVQ?pwdelul 提取码&#xff1a;elul 华为原装系统自带所有驱动、出厂主题壁纸、系统属性联机支持标志、系统属性专属LOGO标志、华为浏览器、Office办公软件、华为…

网络传输的基本流程

目录 0.前言 1.TCP/IP四层协议模型的认识 2.数据传输的大致流程 3.局域网通信的原理 4.同一网段下两台主机之间的通信 5.不同网段下两台主机之间的通信 0.前言 不知道你有没有这样的疑问&#xff0c;为什么不同的设备之间能够进行数据的发送和接收&#xff1f;不同的通信…