一、引言

        KAN神经网络(Kolmogorov–Arnold Networks)是一种基于Kolmogorov-Arnold表示定理的新型神经网络架构。该定理指出,任何多元连续函数都可以表示为有限个单变量函数的组合。与传统多层感知机(MLP)不同,KAN通过可学习的激活函数和结构化网络设计,在函数逼近效率和可解释性上展现出潜力。


二、技术与原理简介

        1.Kolmogorov-Arnold 表示定理

         Kolmogorov-Arnold 表示定理指出,如果 是有界域上的多元连续函数,那么它可以写为单个变量的连续函数的有限组合,以及加法的二进制运算。更具体地说,对于 光滑ff:[0,1]^{^{n}}\rightarrow \mathbb{R}

f \left( x \right)=f \left( x_{1}, \cdots,x_{n} \right)= \sum_{q=1}^{2n+1} \Phi_{q} \left( \sum_{p=1}^{n} \phi_{q,p} \left( x_{p} \right) \right)

        其中 和 。从某种意义上说,他们表明唯一真正的多元函数是加法,因为所有其他函数都可以使用单变量函数和 sum 来编写。然而,这个 2 层宽度 - Kolmogorov-Arnold 表示可能不是平滑的由于其表达能力有限。我们通过以下方式增强它的表达能力将其推广到任意深度和宽度。\boldsymbol{\phi_{q,p}:[0,1]\to\mathbb{R}}\boldsymbol{\Phi_{q}:\mathbb{R}\to\mathbb{R}(2n+1)}

        2.Kolmogorov-Arnold 网络 (KAN)

        Kolmogorov-Arnold 表示可以写成矩阵形式

f(x)=\mathbf{\Phi_{out}}\mathsf{o}\mathbf{\Phi_{in}}\mathsf{o}{}x

其中

\mathbf{\Phi}_{\mathrm{in}}=\begin{pmatrix}\phi_{1,1}(\cdot)&\cdots&\phi_{1,n }(\cdot)\\ \vdots&&\vdots\\ \phi_{2n+1,1}(\cdot)&\cdots&\phi_{2n+1,n}(\cdot)\end{pmatrix}

\quad\mathbf{ \Phi}_{\mathrm{out}}=\left(\Phi_{1}(\cdot)\quad\cdots\quad\Phi_{2n+1}(\cdot)\right)

        我们注意到 和 都是以下函数矩阵(包含输入和输出)的特例,我们称之为 Kolmogorov-Arnold 层:\mathbf{\Phi_{in}} \mathbf{\Phi_{out}} \mathbf{\Phi_{n_{in}n_{out}}}

其中\boldsymbol{n_{\text{in}}=n,n_{\text{out}}=2n+1\Phi_{\text{out}}n_{\text{in}}=2n+1,n_{\text{out}}=1}

        定义层后,我们可以构造一个 Kolmogorov-Arnold 网络只需堆叠层!假设我们有层,层的形状为 。那么整个网络是Ll^{th} \Phi_{l} \left( n_{l+1},n_{l} \right)

\mathbf{KAN(x)}=\mathbf{\Phi_{L-1}}\circ\cdots\circ\mathbf{\Phi_{1}}\circ \mathbf{\Phi_{0}}\circ\mathbf{x}

        相反,多层感知器由线性层和非线错:\mathbf{W}_{l^{\sigma}}

\text{MLP}(\mathbf{x})=\mathbf{W}_{\textit{L-1}}\circ\sigma\circ\cdots\circ \mathbf{W}_{1}\circ\sigma\circ\mathbf{W}_{0}\circ\mathbf{x}

        KAN 可以很容易地可视化。(1) KAN 只是 KAN 层的堆栈。(2) 每个 KAN 层都可以可视化为一个全连接层,每个边缘上都有一个1D 函数。


三、代码详解

        该代码实现了一个符号激活神经网络层,其核心特点在于:

  • 灵活的激活函数:每条输入-输出连接可以选择不同的激活函数(例如 sin、cos 等),并支持符号表达式版本,便于后续分析和解释。

  • 仿射参数拟合:每个连接不仅有激活函数,还通过四个参数 [a, b, c, d] 对输入进行仿射变换,使得激活函数能够更好地适应数据分布。

  • 奇异性处理:提供了 funs_avoid_singularity 版本来确保在输入接近奇异值时不会出现数值不稳定。

  • 模块化设计:通过 get_subsetswap 方法,可以方便地对网络结构进行剪枝或重新排列,便于构建和优化模型。

        通过这些设计,该类不仅能实现复杂的非线性映射,还能借助符号计算的优势提高模型的可解释性和灵活性。

        A. 代码详解

        1. 模块导入

import torch
import torch.nn as nn
import numpy as np
import sympy
from .utils import *
  • torch & torch.nn:用于张量运算和构建神经网络模块。

  • numpy:进行数值计算。

  • sympy:用于符号数学计算,可进行符号表达式的操作。

  • from .utils import *:导入同目录下的工具函数,可能包含例如 fit_params 等用于参数拟合的函数。


        2. Symbolic_KANLayer 类定义

该类继承自 nn.Module,目的是构建一个支持符号激活函数的神经网络层。主要特点是:

  • 每个神经元的激活不仅仅是简单的非线性函数,而是通过一个仿射变换后传入符号函数,并且可以进行参数拟合。

  • 支持奇异性避免策略,通过不同的函数版本(funsfuns_avoid_singularity)来确保数值稳定性。


        3. 初始化方法 init

def __init__(self, in_dim=3, out_dim=2, device='cpu'):

参数及属性:

  • in_dim, out_dim:分别表示输入和输出神经元的数量。

  • mask:大小为 (out_dim, in_dim) 的参数张量,用于对各条路径(输入到输出的连接)进行加权。这里初始化为全 0 且不需要梯度更新。

  • funs:二维列表(形状 [out_dim][in_dim]),每个元素是一个 lambda 函数,初始时设定为返回 0(即 lambda x: x*0.)。这些函数将来会被固定为特定的符号函数。

  • funs_avoid_singularity:类似于 funs,但其函数接收额外参数 y_th,用于处理输入在接近奇异值时的数值稳定性问题。

  • funs_name:记录每个连接对应的激活函数名称,初始为字符串 "0"。

  • funs_sympy:二维列表,存储对应的 sympy 符号函数版本,可用于符号计算与表达式分析。

  • affine:大小为 (out_dim, in_dim, 4) 的参数张量,存储每条连接的仿射变换参数 [a, b, c, d],用于表达激活形式:

    output=c⋅f(a⋅x+b)+d\text{output} = c \cdot f(a \cdot x + b) + doutput=c⋅f(a⋅x+b)+d
  • device:设备(如 'cpu' 或 'cuda'),并调用 self.to(device) 将模型移动到指定设备上。


        4. to 方法

def to(self, device):
  • 重写了 PyTorch 的 to() 方法,不仅将模块移动到指定设备,同时更新内部的 self.device 属性,确保所有操作(如张量创建)都在正确的设备上执行。


        5. forward 方法

def forward(self, x, singularity_avoiding=False, y_th=10.):

功能:

执行前向传播计算,将输入 x 经过每个连接的仿射变换及激活函数得到输出。

过程:

  1. 输入说明

    • x:形状为 (batch, in_dim) 的输入张量。

    • singularity_avoiding:布尔标志,决定是否采用避免奇异性的版本函数。

    • y_th:阈值,用于在避免奇异性函数中控制数值稳定性。

  2. 计算流程

    • 对每个输入维度 i 和每个输出神经元 j 进行循环:

      • 对输入 x[:, i] 进行仿射变换:

        z=a⋅x[:,i]+b(其中 a,b来自 affine[j,i,0] 和 affine[j,i,1])z = a \cdot x[:, i] + b \quad (\text{其中 } a, b \text{来自 } \texttt{affine[j,i,0]} \text{ 和 } \texttt{affine[j,i,1]})z=a⋅x[:,i]+b(其中 a,b来自 affine[j,i,0] 和 affine[j,i,1])
      • 根据 singularity_avoiding 的标志,选择使用 funs[j][i]funs_avoid_singularity[j][i] 对变换后的 z 计算激活值。

        • 对于 funs_avoid_singularity,函数接收额外的 y_th 参数,返回一个元组,其第二项为激活值。

      • 将激活值乘以缩放参数 affine[j,i,2] 并加上偏置 affine[j,i,3],再乘以对应的 mask[j][i](用于调节连接的贡献)。

    • 所有计算结果被收集并整理为:

      • postacts:每个节点各输入通道的激活值(用于调试或后续操作)。

      • y:对所有输入通道激活值求和后,得到最终的输出,形状为 (batch, out_dim)。


        6. get_subset 方法

def get_subset(self, in_id, out_id):

功能:

  • 用于从一个较大的 Symbolic_KANLayer 中提取一个子层(子网络),常用于网络剪枝或结构调整。

过程:

  • 根据提供的输入索引列表 in_id 和输出索引列表 out_id

    • 创建一个新的 Symbolic_KANLayer 实例。

    • 调整新的层的 in_dimout_dim 为选定的神经元数目。

    • 从原层中提取相应的 maskfunsfuns_avoid_singularityfuns_sympyfuns_nameaffine 参数,只保留所选索引对应的数据。


        7. fix_symbolic 方法

def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-10,10), b_range=(-10,10), verbose=True):

功能:

  • 固定指定输入神经元 i 到输出神经元 j 的激活函数,使其为特定的符号函数,并可根据提供的样本数据拟合出最佳的仿射变换参数。

参数说明:

  • i, j:指定的输入和输出神经元索引。

  • fun_name:激活函数的名称(如 "sin"),或者直接传入函数。如果是字符串,则通过预定义的 SYMBOLIC_LIB 获取该激活函数的多个版本(常规、符号、避免奇异性)。

  • x, y:若提供了样本数据,则调用 fit_params 函数来拟合仿射参数 [a, b, c, d],使得

    y≈c⋅f(a⋅x+b)+dy \approx c \cdot f(a \cdot x + b) + dy≈c⋅f(a⋅x+b)+d

    同时返回拟合优度(决定系数 r²)。

  • random:是否随机初始化仿射参数;默认 False 时初始化为 [1, 0, 1, 0]。

  • a_range, b_range:在参数拟合时,a 和 b 的搜索范围。

  • verbose:是否打印更多调试信息。

过程:

  • fun_name 为字符串,则:

    • SYMBOLIC_LIB 中取出对应激活函数 fun、符号版本 fun_sympy 和避免奇异性的版本 fun_avoid_singularity

    • 更新 funs_sympyfuns_name

  • 如果没有提供 x 和 y 数据,则直接设置对应的激活函数,并初始化仿射参数(随机或默认)。

  • 若提供了数据,则调用 fit_params 函数来拟合参数,更新 affine 参数,并返回拟合的 r² 值作为指标。


        8. swap 方法

def swap(self, i1, i2, mode='in'):

功能:

  • 实现对层内神经元的交换。可以交换输入神经元或输出神经元的顺序,这在网络剪枝、模型重排等操作中非常有用。

过程:

  • 定义了内部辅助函数:

    • swap_list_:用于交换二维列表中的指定元素,针对 funs_namefuns_sympyfuns_avoid_singularity

    • swap_:用于交换 tensor 数据,针对 affinemask 参数。

  • 根据 mode 参数决定是交换输入神经元('in')还是输出神经元('out'),并对各个相关属性进行交换。

        B. 完整代码

import torch
import torch.nn as nn
import numpy as np
import sympy
from .utils import *



class Symbolic_KANLayer(nn.Module):
    '''
    KANLayer class

    Attributes:
    -----------
        in_dim : int
            input dimension
        out_dim : int
            output dimension
        funs : 2D array of torch functions (or lambda functions)
            symbolic functions (torch)
        funs_avoid_singularity : 2D array of torch functions (or lambda functions) with singularity avoiding
        funs_name : 2D arry of str
            names of symbolic functions
        funs_sympy : 2D array of sympy functions (or lambda functions)
            symbolic functions (sympy)
        affine : 3D array of floats
            affine transformations of inputs and outputs
    '''
    def __init__(self, in_dim=3, out_dim=2, device='cpu'):
        '''
        initialize a Symbolic_KANLayer (activation functions are initialized to be identity functions)
        
        Args:
        -----
            in_dim : int
                input dimension
            out_dim : int
                output dimension
            device : str
                device
            
        Returns:
        --------
            self
            
        Example
        -------
        >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=3)
        >>> len(sb.funs), len(sb.funs[0])
        '''
        super(Symbolic_KANLayer, self).__init__()
        self.out_dim = out_dim
        self.in_dim = in_dim
        self.mask = torch.nn.Parameter(torch.zeros(out_dim, in_dim, device=device)).requires_grad_(False)
        # torch
        self.funs = [[lambda x: x*0. for i in range(self.in_dim)] for j in range(self.out_dim)]
        self.funs_avoid_singularity = [[lambda x, y_th: ((), x*0.) for i in range(self.in_dim)] for j in range(self.out_dim)]
        # name
        self.funs_name = [['0' for i in range(self.in_dim)] for j in range(self.out_dim)]
        # sympy
        self.funs_sympy = [[lambda x: x*0. for i in range(self.in_dim)] for j in range(self.out_dim)]
        ### make funs_name the only parameter, and make others as the properties of funs_name?
        
        self.affine = torch.nn.Parameter(torch.zeros(out_dim, in_dim, 4, device=device))
        # c*f(a*x+b)+d
        
        self.device = device
        self.to(device)
        
    def to(self, device):
        '''
        move to device
        '''
        super(Symbolic_KANLayer, self).to(device)
        self.device = device    
        return self
    
    def forward(self, x, singularity_avoiding=False, y_th=10.):
        '''
        forward
        
        Args:
        -----
            x : 2D array
                inputs, shape (batch, input dimension)
            singularity_avoiding : bool
                if True, funs_avoid_singularity is used; if False, funs is used. 
            y_th : float
                the singularity threshold
            
        Returns:
        --------
            y : 2D array
                outputs, shape (batch, output dimension)
            postacts : 3D array
                activations after activation functions but before being summed on nodes
        
        Example
        -------
        >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=5)
        >>> x = torch.normal(0,1,size=(100,3))
        >>> y, postacts = sb(x)
        >>> y.shape, postacts.shape
        (torch.Size([100, 5]), torch.Size([100, 5, 3]))
        '''
        
        batch = x.shape[0]
        postacts = []

        for i in range(self.in_dim):
            postacts_ = []
            for j in range(self.out_dim):
                if singularity_avoiding:
                    xij = self.affine[j,i,2]*self.funs_avoid_singularity[j][i](self.affine[j,i,0]*x[:,[i]]+self.affine[j,i,1], torch.tensor(y_th))[1]+self.affine[j,i,3]
                else:
                    xij = self.affine[j,i,2]*self.funs[j][i](self.affine[j,i,0]*x[:,[i]]+self.affine[j,i,1])+self.affine[j,i,3]
                postacts_.append(self.mask[j][i]*xij)
            postacts.append(torch.stack(postacts_))

        postacts = torch.stack(postacts)
        postacts = postacts.permute(2,1,0,3)[:,:,:,0]
        y = torch.sum(postacts, dim=2)
        
        return y, postacts
        
        
    def get_subset(self, in_id, out_id):
        '''
        get a smaller Symbolic_KANLayer from a larger Symbolic_KANLayer (used for pruning)
        
        Args:
        -----
            in_id : list
                id of selected input neurons
            out_id : list
                id of selected output neurons
            
        Returns:
        --------
            spb : Symbolic_KANLayer
         
        Example
        -------
        >>> sb_large = Symbolic_KANLayer(in_dim=10, out_dim=10)
        >>> sb_small = sb_large.get_subset([0,9],[1,2,3])
        >>> sb_small.in_dim, sb_small.out_dim
        '''
        sbb = Symbolic_KANLayer(self.in_dim, self.out_dim, device=self.device)
        sbb.in_dim = len(in_id)
        sbb.out_dim = len(out_id)
        sbb.mask.data = self.mask.data[out_id][:,in_id]
        sbb.funs = [[self.funs[j][i] for i in in_id] for j in out_id]
        sbb.funs_avoid_singularity = [[self.funs_avoid_singularity[j][i] for i in in_id] for j in out_id]
        sbb.funs_sympy = [[self.funs_sympy[j][i] for i in in_id] for j in out_id]
        sbb.funs_name = [[self.funs_name[j][i] for i in in_id] for j in out_id]
        sbb.affine.data = self.affine.data[out_id][:,in_id]
        return sbb
    
    
    def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-10,10), b_range=(-10,10), verbose=True):
        '''
        fix an activation function to be symbolic
        
        Args:
        -----
            i : int
                the id of input neuron
            j : int 
                the id of output neuron
            fun_name : str
                the name of the symbolic functions
            x : 1D array
                preactivations
            y : 1D array
                postactivations
            a_range : tuple
                sweeping range of a
            b_range : tuple
                sweeping range of a
            verbose : bool
                print more information if True
            
        Returns:
        --------
            r2 (coefficient of determination)
            
        Example 1
        ---------
        >>> # when x & y are not provided. Affine parameters are set to a = 1, b = 0, c = 1, d = 0
        >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2)
        >>> sb.fix_symbolic(2,1,'sin')
        >>> print(sb.funs_name)
        >>> print(sb.affine)
        
        Example 2
        ---------
        >>> # when x & y are provided, fit_params() is called to find the best fit coefficients
        >>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2)
        >>> batch = 100
        >>> x = torch.linspace(-1,1,steps=batch)
        >>> noises = torch.normal(0,1,(batch,)) * 0.02
        >>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises
        >>> sb.fix_symbolic(2,1,'sin',x,y)
        >>> print(sb.funs_name)
        >>> print(sb.affine[1,2,:].data)
        '''
        if isinstance(fun_name,str):
            fun = SYMBOLIC_LIB[fun_name][0]
            fun_sympy = SYMBOLIC_LIB[fun_name][1]
            fun_avoid_singularity = SYMBOLIC_LIB[fun_name][3]
            self.funs_sympy[j][i] = fun_sympy
            self.funs_name[j][i] = fun_name
            
            if x == None or y == None:
                #initialzie from just fun
                self.funs[j][i] = fun
                self.funs_avoid_singularity[j][i] = fun_avoid_singularity
                if random == False:
                    self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
                else:
                    self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
                return None
            else:
                #initialize from x & y and fun
                params, r2 = fit_params(x,y,fun, a_range=a_range, b_range=b_range, verbose=verbose, device=self.device)
                self.funs[j][i] = fun
                self.funs_avoid_singularity[j][i] = fun_avoid_singularity
                self.affine.data[j][i] = params
                return r2
        else:
            # if fun_name itself is a function
            fun = fun_name
            fun_sympy = fun_name
            self.funs_sympy[j][i] = fun_sympy
            self.funs_name[j][i] = "anonymous"

            self.funs[j][i] = fun
            self.funs_avoid_singularity[j][i] = fun
            if random == False:
                self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
            else:
                self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
            return None
        
    def swap(self, i1, i2, mode='in'):
        '''
        swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out') 
        '''
        with torch.no_grad():
            def swap_list_(data, i1, i2, mode='in'):

                if mode == 'in':
                    for j in range(self.out_dim):
                        data[j][i1], data[j][i2] = data[j][i2], data[j][i1]

                elif mode == 'out':
                    data[i1], data[i2] = data[i2], data[i1] 

            def swap_(data, i1, i2, mode='in'):
                if mode == 'in':
                    data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()

                elif mode == 'out':
                    data[i1], data[i2] = data[i2].clone(), data[i1].clone()

            swap_list_(self.funs_name,i1,i2,mode)
            swap_list_(self.funs_sympy,i1,i2,mode)
            swap_list_(self.funs_avoid_singularity,i1,i2,mode)
            swap_(self.affine.data,i1,i2,mode)
            swap_(self.mask.data,i1,i2,mode)

四、总结与思考

        KAN神经网络通过融合数学定理与深度学习,为科学计算和可解释AI提供了新思路。尽管在高维应用中仍需突破,但其在低维复杂函数建模上的潜力值得关注。未来可能通过改进计算效率、扩展理论边界,成为MLP的重要补充。

        1. KAN网络架构

  • 关键设计可学习的激活函数:每个网络连接的“权重”被替换为单变量函数(如样条、多项式),而非固定激活函数(如ReLU)。分层结构:输入层和隐藏层之间、隐藏层与输出层之间均通过单变量函数连接,形成多层叠加。参数效率:由于理论保证,KAN可能用更少的参数达到与MLP相当或更好的逼近效果。

  • 示例结构输入层 → 隐藏层:每个输入节点通过单变量函数\phi_{q,i} \left( x_{i} \right) 连接到隐藏节点。隐藏层 → 输出层:隐藏节点通过另一组单变量函数\psi_{q}组合得到输出。

        2. 优势与特点

  • 高逼近效率:基于数学定理,理论上能以更少参数逼近复杂函数;在低维科学计算任务(如微分方程求解)中表现优异。

  • 可解释性:单变量函数可可视化,便于分析输入变量与输出的关系;网络结构直接对应函数分解过程,逻辑清晰。

  • 灵活的函数学习:激活函数可自适应调整(如学习平滑或非平滑函数);支持符号公式提取(例如从数据中恢复物理定律)。

        3. 挑战与局限

  • 计算复杂度:单变量函数的学习(如样条参数化)可能增加训练时间和内存消耗。需要优化高阶连续函数,对硬件和算法提出更高要求。

  • 泛化能力:在高维数据(如图像、文本)中的表现尚未充分验证,可能逊色于传统MLP。

  • 训练难度:需设计新的优化策略,避免单变量函数的过拟合或欠拟合。

        4. 应用场景

  • 科学计算:求解微分方程、物理建模、化学模拟等需要高精度函数逼近的任务。

  • 可解释性需求领域:医疗诊断、金融风控等需明确输入输出关系的场景。

  • 符号回归:从数据中自动发现数学公式(如物理定律)。

        5. 与传统MLP的对比

        6. 研究进展

  • 近期论文:2024年,MIT等团队提出KAN架构(如论文《KAN: Kolmogorov-Arnold Networks》),在低维任务中验证了其高效性和可解释性。

  • 开源实现:已有PyTorch等框架的初步实现。


【作者声明】

        本文分享的论文内容及观点均来源于《KAN: Kolmogorov-Arnold Networks》原文,旨在介绍和探讨该研究的创新成果和应用价值。作者尊重并遵循学术规范,确保内容的准确性和客观性。如有任何疑问或需要进一步的信息,请参考论文原文或联系相关作者。


 【关注我们】

        如果您对神经网络、群智能算法及人工智能技术感兴趣,请关注【灵犀拾荒者】,获取更多前沿技术文章、实战案例及技术分享!

Logo

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。

更多推荐