【KAN】KAN神经网络学习训练营(11)——Symbolic_KANLayer.py
该代码实现了一个符号激活神经网络层,其核心特点在于:灵活的激活函数:每条输入-输出连接可以选择不同的激活函数(例如 sin、cos 等),并支持符号表达式版本,便于后续分析和解释。仿射参数拟合:每个连接不仅有激活函数,还通过四个参数 [a, b, c, d] 对输入进行仿射变换,使得激活函数能够更好地适应数据分布。奇异性处理:提供了funs_avoid_singularity 版本来确保在输入接近
一、引言
KAN神经网络(Kolmogorov–Arnold Networks)是一种基于Kolmogorov-Arnold表示定理的新型神经网络架构。该定理指出,任何多元连续函数都可以表示为有限个单变量函数的组合。与传统多层感知机(MLP)不同,KAN通过可学习的激活函数和结构化网络设计,在函数逼近效率和可解释性上展现出潜力。
二、技术与原理简介
1.Kolmogorov-Arnold 表示定理
Kolmogorov-Arnold 表示定理指出,如果 是有界域上的多元连续函数,那么它可以写为单个变量的连续函数的有限组合,以及加法的二进制运算。更具体地说,对于 光滑
其中 和 。从某种意义上说,他们表明唯一真正的多元函数是加法,因为所有其他函数都可以使用单变量函数和 sum 来编写。然而,这个 2 层宽度 - Kolmogorov-Arnold 表示可能不是平滑的由于其表达能力有限。我们通过以下方式增强它的表达能力将其推广到任意深度和宽度。,
2.Kolmogorov-Arnold 网络 (KAN)
Kolmogorov-Arnold 表示可以写成矩阵形式
其中
我们注意到 和 都是以下函数矩阵(包含输入和输出)的特例,我们称之为 Kolmogorov-Arnold 层:
其中。
定义层后,我们可以构造一个 Kolmogorov-Arnold 网络只需堆叠层!假设我们有层,层的形状为 。那么整个网络是
相反,多层感知器由线性层和非线错:
KAN 可以很容易地可视化。(1) KAN 只是 KAN 层的堆栈。(2) 每个 KAN 层都可以可视化为一个全连接层,每个边缘上都有一个1D 函数。
三、代码详解
该代码实现了一个符号激活神经网络层,其核心特点在于:
-
灵活的激活函数:每条输入-输出连接可以选择不同的激活函数(例如 sin、cos 等),并支持符号表达式版本,便于后续分析和解释。
-
仿射参数拟合:每个连接不仅有激活函数,还通过四个参数
[a, b, c, d]
对输入进行仿射变换,使得激活函数能够更好地适应数据分布。 -
奇异性处理:提供了
funs_avoid_singularity
版本来确保在输入接近奇异值时不会出现数值不稳定。 -
模块化设计:通过
get_subset
与swap
方法,可以方便地对网络结构进行剪枝或重新排列,便于构建和优化模型。
通过这些设计,该类不仅能实现复杂的非线性映射,还能借助符号计算的优势提高模型的可解释性和灵活性。
A. 代码详解
1. 模块导入
import torch
import torch.nn as nn
import numpy as np
import sympy
from .utils import *
-
torch & torch.nn:用于张量运算和构建神经网络模块。
-
numpy:进行数值计算。
-
sympy:用于符号数学计算,可进行符号表达式的操作。
-
from .utils import *:导入同目录下的工具函数,可能包含例如
fit_params
等用于参数拟合的函数。
2. Symbolic_KANLayer 类定义
该类继承自 nn.Module
,目的是构建一个支持符号激活函数的神经网络层。主要特点是:
-
每个神经元的激活不仅仅是简单的非线性函数,而是通过一个仿射变换后传入符号函数,并且可以进行参数拟合。
-
支持奇异性避免策略,通过不同的函数版本(
funs
与funs_avoid_singularity
)来确保数值稳定性。
3. 初始化方法 init
def __init__(self, in_dim=3, out_dim=2, device='cpu'):
参数及属性:
-
in_dim, out_dim:分别表示输入和输出神经元的数量。
-
mask:大小为 (out_dim, in_dim) 的参数张量,用于对各条路径(输入到输出的连接)进行加权。这里初始化为全 0 且不需要梯度更新。
-
funs:二维列表(形状 [out_dim][in_dim]),每个元素是一个 lambda 函数,初始时设定为返回 0(即
lambda x: x*0.
)。这些函数将来会被固定为特定的符号函数。 -
funs_avoid_singularity:类似于
funs
,但其函数接收额外参数y_th
,用于处理输入在接近奇异值时的数值稳定性问题。 -
funs_name:记录每个连接对应的激活函数名称,初始为字符串 "0"。
-
funs_sympy:二维列表,存储对应的 sympy 符号函数版本,可用于符号计算与表达式分析。
-
affine:大小为 (out_dim, in_dim, 4) 的参数张量,存储每条连接的仿射变换参数
output=c⋅f(a⋅x+b)+d\text{output} = c \cdot f(a \cdot x + b) + doutput=c⋅f(a⋅x+b)+d[a, b, c, d]
,用于表达激活形式: -
device:设备(如 'cpu' 或 'cuda'),并调用
self.to(device)
将模型移动到指定设备上。
4. to 方法
def to(self, device):
-
重写了 PyTorch 的
to()
方法,不仅将模块移动到指定设备,同时更新内部的self.device
属性,确保所有操作(如张量创建)都在正确的设备上执行。
5. forward 方法
def forward(self, x, singularity_avoiding=False, y_th=10.):
功能:
执行前向传播计算,将输入 x
经过每个连接的仿射变换及激活函数得到输出。
过程:
-
输入说明:
-
x
:形状为 (batch, in_dim) 的输入张量。 -
singularity_avoiding
:布尔标志,决定是否采用避免奇异性的版本函数。 -
y_th
:阈值,用于在避免奇异性函数中控制数值稳定性。
-
-
计算流程:
-
对每个输入维度
i
和每个输出神经元j
进行循环:-
对输入
z=a⋅x[:,i]+b(其中 a,b来自 affine[j,i,0] 和 affine[j,i,1])z = a \cdot x[:, i] + b \quad (\text{其中 } a, b \text{来自 } \texttt{affine[j,i,0]} \text{ 和 } \texttt{affine[j,i,1]})z=a⋅x[:,i]+b(其中 a,b来自 affine[j,i,0] 和 affine[j,i,1])x[:, i]
进行仿射变换: -
根据
singularity_avoiding
的标志,选择使用funs[j][i]
或funs_avoid_singularity[j][i]
对变换后的z
计算激活值。-
对于
funs_avoid_singularity
,函数接收额外的y_th
参数,返回一个元组,其第二项为激活值。
-
-
将激活值乘以缩放参数
affine[j,i,2]
并加上偏置affine[j,i,3]
,再乘以对应的mask[j][i]
(用于调节连接的贡献)。
-
-
所有计算结果被收集并整理为:
-
postacts
:每个节点各输入通道的激活值(用于调试或后续操作)。 -
y
:对所有输入通道激活值求和后,得到最终的输出,形状为 (batch, out_dim)。
-
-
6. get_subset 方法
def get_subset(self, in_id, out_id):
功能:
-
用于从一个较大的 Symbolic_KANLayer 中提取一个子层(子网络),常用于网络剪枝或结构调整。
过程:
-
根据提供的输入索引列表
in_id
和输出索引列表out_id
:-
创建一个新的 Symbolic_KANLayer 实例。
-
调整新的层的
in_dim
与out_dim
为选定的神经元数目。 -
从原层中提取相应的
mask
、funs
、funs_avoid_singularity
、funs_sympy
、funs_name
和affine
参数,只保留所选索引对应的数据。
-
7. fix_symbolic 方法
def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-10,10), b_range=(-10,10), verbose=True):
功能:
-
固定指定输入神经元
i
到输出神经元j
的激活函数,使其为特定的符号函数,并可根据提供的样本数据拟合出最佳的仿射变换参数。
参数说明:
-
i, j:指定的输入和输出神经元索引。
-
fun_name:激活函数的名称(如 "sin"),或者直接传入函数。如果是字符串,则通过预定义的
SYMBOLIC_LIB
获取该激活函数的多个版本(常规、符号、避免奇异性)。 -
x, y:若提供了样本数据,则调用
y≈c⋅f(a⋅x+b)+dy \approx c \cdot f(a \cdot x + b) + dy≈c⋅f(a⋅x+b)+dfit_params
函数来拟合仿射参数[a, b, c, d]
,使得同时返回拟合优度(决定系数 r²)。
-
random:是否随机初始化仿射参数;默认 False 时初始化为 [1, 0, 1, 0]。
-
a_range, b_range:在参数拟合时,a 和 b 的搜索范围。
-
verbose:是否打印更多调试信息。
过程:
-
若
fun_name
为字符串,则:-
从
SYMBOLIC_LIB
中取出对应激活函数fun
、符号版本fun_sympy
和避免奇异性的版本fun_avoid_singularity
。 -
更新
funs_sympy
、funs_name
。
-
-
如果没有提供 x 和 y 数据,则直接设置对应的激活函数,并初始化仿射参数(随机或默认)。
-
若提供了数据,则调用
fit_params
函数来拟合参数,更新 affine 参数,并返回拟合的 r² 值作为指标。
8. swap 方法
def swap(self, i1, i2, mode='in'):
功能:
-
实现对层内神经元的交换。可以交换输入神经元或输出神经元的顺序,这在网络剪枝、模型重排等操作中非常有用。
过程:
-
定义了内部辅助函数:
-
swap_list_:用于交换二维列表中的指定元素,针对
funs_name
、funs_sympy
和funs_avoid_singularity
。 -
swap_:用于交换 tensor 数据,针对
affine
和mask
参数。
-
-
根据
mode
参数决定是交换输入神经元('in')还是输出神经元('out'),并对各个相关属性进行交换。
B. 完整代码
import torch
import torch.nn as nn
import numpy as np
import sympy
from .utils import *
class Symbolic_KANLayer(nn.Module):
'''
KANLayer class
Attributes:
-----------
in_dim : int
input dimension
out_dim : int
output dimension
funs : 2D array of torch functions (or lambda functions)
symbolic functions (torch)
funs_avoid_singularity : 2D array of torch functions (or lambda functions) with singularity avoiding
funs_name : 2D arry of str
names of symbolic functions
funs_sympy : 2D array of sympy functions (or lambda functions)
symbolic functions (sympy)
affine : 3D array of floats
affine transformations of inputs and outputs
'''
def __init__(self, in_dim=3, out_dim=2, device='cpu'):
'''
initialize a Symbolic_KANLayer (activation functions are initialized to be identity functions)
Args:
-----
in_dim : int
input dimension
out_dim : int
output dimension
device : str
device
Returns:
--------
self
Example
-------
>>> sb = Symbolic_KANLayer(in_dim=3, out_dim=3)
>>> len(sb.funs), len(sb.funs[0])
'''
super(Symbolic_KANLayer, self).__init__()
self.out_dim = out_dim
self.in_dim = in_dim
self.mask = torch.nn.Parameter(torch.zeros(out_dim, in_dim, device=device)).requires_grad_(False)
# torch
self.funs = [[lambda x: x*0. for i in range(self.in_dim)] for j in range(self.out_dim)]
self.funs_avoid_singularity = [[lambda x, y_th: ((), x*0.) for i in range(self.in_dim)] for j in range(self.out_dim)]
# name
self.funs_name = [['0' for i in range(self.in_dim)] for j in range(self.out_dim)]
# sympy
self.funs_sympy = [[lambda x: x*0. for i in range(self.in_dim)] for j in range(self.out_dim)]
### make funs_name the only parameter, and make others as the properties of funs_name?
self.affine = torch.nn.Parameter(torch.zeros(out_dim, in_dim, 4, device=device))
# c*f(a*x+b)+d
self.device = device
self.to(device)
def to(self, device):
'''
move to device
'''
super(Symbolic_KANLayer, self).to(device)
self.device = device
return self
def forward(self, x, singularity_avoiding=False, y_th=10.):
'''
forward
Args:
-----
x : 2D array
inputs, shape (batch, input dimension)
singularity_avoiding : bool
if True, funs_avoid_singularity is used; if False, funs is used.
y_th : float
the singularity threshold
Returns:
--------
y : 2D array
outputs, shape (batch, output dimension)
postacts : 3D array
activations after activation functions but before being summed on nodes
Example
-------
>>> sb = Symbolic_KANLayer(in_dim=3, out_dim=5)
>>> x = torch.normal(0,1,size=(100,3))
>>> y, postacts = sb(x)
>>> y.shape, postacts.shape
(torch.Size([100, 5]), torch.Size([100, 5, 3]))
'''
batch = x.shape[0]
postacts = []
for i in range(self.in_dim):
postacts_ = []
for j in range(self.out_dim):
if singularity_avoiding:
xij = self.affine[j,i,2]*self.funs_avoid_singularity[j][i](self.affine[j,i,0]*x[:,[i]]+self.affine[j,i,1], torch.tensor(y_th))[1]+self.affine[j,i,3]
else:
xij = self.affine[j,i,2]*self.funs[j][i](self.affine[j,i,0]*x[:,[i]]+self.affine[j,i,1])+self.affine[j,i,3]
postacts_.append(self.mask[j][i]*xij)
postacts.append(torch.stack(postacts_))
postacts = torch.stack(postacts)
postacts = postacts.permute(2,1,0,3)[:,:,:,0]
y = torch.sum(postacts, dim=2)
return y, postacts
def get_subset(self, in_id, out_id):
'''
get a smaller Symbolic_KANLayer from a larger Symbolic_KANLayer (used for pruning)
Args:
-----
in_id : list
id of selected input neurons
out_id : list
id of selected output neurons
Returns:
--------
spb : Symbolic_KANLayer
Example
-------
>>> sb_large = Symbolic_KANLayer(in_dim=10, out_dim=10)
>>> sb_small = sb_large.get_subset([0,9],[1,2,3])
>>> sb_small.in_dim, sb_small.out_dim
'''
sbb = Symbolic_KANLayer(self.in_dim, self.out_dim, device=self.device)
sbb.in_dim = len(in_id)
sbb.out_dim = len(out_id)
sbb.mask.data = self.mask.data[out_id][:,in_id]
sbb.funs = [[self.funs[j][i] for i in in_id] for j in out_id]
sbb.funs_avoid_singularity = [[self.funs_avoid_singularity[j][i] for i in in_id] for j in out_id]
sbb.funs_sympy = [[self.funs_sympy[j][i] for i in in_id] for j in out_id]
sbb.funs_name = [[self.funs_name[j][i] for i in in_id] for j in out_id]
sbb.affine.data = self.affine.data[out_id][:,in_id]
return sbb
def fix_symbolic(self, i, j, fun_name, x=None, y=None, random=False, a_range=(-10,10), b_range=(-10,10), verbose=True):
'''
fix an activation function to be symbolic
Args:
-----
i : int
the id of input neuron
j : int
the id of output neuron
fun_name : str
the name of the symbolic functions
x : 1D array
preactivations
y : 1D array
postactivations
a_range : tuple
sweeping range of a
b_range : tuple
sweeping range of a
verbose : bool
print more information if True
Returns:
--------
r2 (coefficient of determination)
Example 1
---------
>>> # when x & y are not provided. Affine parameters are set to a = 1, b = 0, c = 1, d = 0
>>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2)
>>> sb.fix_symbolic(2,1,'sin')
>>> print(sb.funs_name)
>>> print(sb.affine)
Example 2
---------
>>> # when x & y are provided, fit_params() is called to find the best fit coefficients
>>> sb = Symbolic_KANLayer(in_dim=3, out_dim=2)
>>> batch = 100
>>> x = torch.linspace(-1,1,steps=batch)
>>> noises = torch.normal(0,1,(batch,)) * 0.02
>>> y = 5.0*torch.sin(3.0*x + 2.0) + 0.7 + noises
>>> sb.fix_symbolic(2,1,'sin',x,y)
>>> print(sb.funs_name)
>>> print(sb.affine[1,2,:].data)
'''
if isinstance(fun_name,str):
fun = SYMBOLIC_LIB[fun_name][0]
fun_sympy = SYMBOLIC_LIB[fun_name][1]
fun_avoid_singularity = SYMBOLIC_LIB[fun_name][3]
self.funs_sympy[j][i] = fun_sympy
self.funs_name[j][i] = fun_name
if x == None or y == None:
#initialzie from just fun
self.funs[j][i] = fun
self.funs_avoid_singularity[j][i] = fun_avoid_singularity
if random == False:
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
else:
self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
return None
else:
#initialize from x & y and fun
params, r2 = fit_params(x,y,fun, a_range=a_range, b_range=b_range, verbose=verbose, device=self.device)
self.funs[j][i] = fun
self.funs_avoid_singularity[j][i] = fun_avoid_singularity
self.affine.data[j][i] = params
return r2
else:
# if fun_name itself is a function
fun = fun_name
fun_sympy = fun_name
self.funs_sympy[j][i] = fun_sympy
self.funs_name[j][i] = "anonymous"
self.funs[j][i] = fun
self.funs_avoid_singularity[j][i] = fun
if random == False:
self.affine.data[j][i] = torch.tensor([1.,0.,1.,0.], device=self.device)
else:
self.affine.data[j][i] = torch.rand(4, device=self.device) * 2 - 1
return None
def swap(self, i1, i2, mode='in'):
'''
swap the i1 neuron with the i2 neuron in input (if mode == 'in') or output (if mode == 'out')
'''
with torch.no_grad():
def swap_list_(data, i1, i2, mode='in'):
if mode == 'in':
for j in range(self.out_dim):
data[j][i1], data[j][i2] = data[j][i2], data[j][i1]
elif mode == 'out':
data[i1], data[i2] = data[i2], data[i1]
def swap_(data, i1, i2, mode='in'):
if mode == 'in':
data[:,i1], data[:,i2] = data[:,i2].clone(), data[:,i1].clone()
elif mode == 'out':
data[i1], data[i2] = data[i2].clone(), data[i1].clone()
swap_list_(self.funs_name,i1,i2,mode)
swap_list_(self.funs_sympy,i1,i2,mode)
swap_list_(self.funs_avoid_singularity,i1,i2,mode)
swap_(self.affine.data,i1,i2,mode)
swap_(self.mask.data,i1,i2,mode)
四、总结与思考
KAN神经网络通过融合数学定理与深度学习,为科学计算和可解释AI提供了新思路。尽管在高维应用中仍需突破,但其在低维复杂函数建模上的潜力值得关注。未来可能通过改进计算效率、扩展理论边界,成为MLP的重要补充。
1. KAN网络架构
-
关键设计:可学习的激活函数:每个网络连接的“权重”被替换为单变量函数(如样条、多项式),而非固定激活函数(如ReLU)。分层结构:输入层和隐藏层之间、隐藏层与输出层之间均通过单变量函数连接,形成多层叠加。参数效率:由于理论保证,KAN可能用更少的参数达到与MLP相当或更好的逼近效果。
-
示例结构:输入层 → 隐藏层:每个输入节点通过单变量函数
连接到隐藏节点。隐藏层 → 输出层:隐藏节点通过另一组单变量函数
组合得到输出。
2. 优势与特点
-
高逼近效率:基于数学定理,理论上能以更少参数逼近复杂函数;在低维科学计算任务(如微分方程求解)中表现优异。
-
可解释性:单变量函数可可视化,便于分析输入变量与输出的关系;网络结构直接对应函数分解过程,逻辑清晰。
-
灵活的函数学习:激活函数可自适应调整(如学习平滑或非平滑函数);支持符号公式提取(例如从数据中恢复物理定律)。
3. 挑战与局限
-
计算复杂度:单变量函数的学习(如样条参数化)可能增加训练时间和内存消耗。需要优化高阶连续函数,对硬件和算法提出更高要求。
-
泛化能力:在高维数据(如图像、文本)中的表现尚未充分验证,可能逊色于传统MLP。
-
训练难度:需设计新的优化策略,避免单变量函数的过拟合或欠拟合。
4. 应用场景
-
科学计算:求解微分方程、物理建模、化学模拟等需要高精度函数逼近的任务。
-
可解释性需求领域:医疗诊断、金融风控等需明确输入输出关系的场景。
-
符号回归:从数据中自动发现数学公式(如物理定律)。
5. 与传统MLP的对比
6. 研究进展
-
近期论文:2024年,MIT等团队提出KAN架构(如论文《KAN: Kolmogorov-Arnold Networks》),在低维任务中验证了其高效性和可解释性。
-
开源实现:已有PyTorch等框架的初步实现。
【作者声明】
本文分享的论文内容及观点均来源于《KAN: Kolmogorov-Arnold Networks》原文,旨在介绍和探讨该研究的创新成果和应用价值。作者尊重并遵循学术规范,确保内容的准确性和客观性。如有任何疑问或需要进一步的信息,请参考论文原文或联系相关作者。
【关注我们】
如果您对神经网络、群智能算法及人工智能技术感兴趣,请关注【灵犀拾荒者】,获取更多前沿技术文章、实战案例及技术分享!

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。
更多推荐
所有评论(0)