基于RNNs+Attention的红点位置检测(pytorch)
项目的目标是在图片中精确定位三条红线中心的位置。在之前的任务中,传统RNN模型精确度不佳。文章通过注意力机制来提升模型性能,包括加性注意力(Additive Attention)、点积注意力(Dot-Product Attention)、缩放点积注意力(Scaled Dot-Product Attention)、多头注意力(Multi-Head Attention)以及十字交叉注意力(Criss-
文章目录
1 项目背景
需要在图片精确识别三跟红线所在的位置,并输出这三个像素的位置。
其中,每跟红线占据不止一个像素,并且像素颜色也并不是饱和度和亮度极高的红黑配色,每个红线放大后可能是这样的。
而我们的目标是精确输出每个红点的位置,需要精确到像素。也就是说,对于每根红线,模型需要输出橙色箭头所指的像素而不是蓝色箭头所指的像素的位置。
在之前尝试过 LSTM 和 GRU 检测红点,但是准确率感人,在噪声极低的情况下并不能精准识别位置。但是可以看出它们还是能找到大致位置的,我打算尝试添加注意力机制从模型的角度来提升精度。
2 数据集
还是之前那个代码合成的数据集数据集,每个数据集规模在15000张图片左右,在没有加入噪音的情况下,每个样本预览如图所示:
加入噪音后,每个样本的预览如下图所示:
图中黑色部分包含比较弱的噪声,并非完全为黑色。
数据集包含两个文件,一个是文件夹,里面包含了jpg压缩的图像数据:
另一个是csv文件,里面包含了每个图像的名字以及3根红线所在的像素的位置。
3 思路
之前的模型流程是这样的:
3.1 一些小实验
直觉上我认为问题可能出在FC2的架构不够复杂,无法很好地将Bi-LSTM学习到的序列信息转化成位置,也怀疑过Position Info, FC1的设计问题等等,在一系列实验之后,我发现可能真的是因为前面的RNN模型不能很好地学习到提取序列信息。罗列下方案:
- 优化FC1, FC2:把原来只有1层的FC1, FC2复杂化。
- 加入Position embedding:之前我们做的是[0,1]之内的数字来标记位置,其实这并不合理,我现在选择transformer中的思路来做,也就是用 s i n ( t ) sin(t) sin(t), c o s ( t ) cos(t) cos(t)来做。
- 多通道输入:将原来的
1080 * 3
的RGB序列化数据变成1080 * (3+3+3+1+8)
的多通道(RGB+HSV+LAB+L+FFBB)的数据。其中,RGB、HSV、LAB、L分别是4种经典的色彩模式,FFBB是某个在R通道和H通道分别与前后两个像素的差值(图片首端和末端用0填补)。这种想法本质上是将一张图片的多种表达方式拼接到一起,发挥神经网络在高维 数据中的优势。 - 全明星方案:把上述方案结合到一起。
Ⅰ 多通道输入
RGB三通道的输入能表达的信息有限,因此尝试拼接多种色彩模式,依靠多通道表达像素颜色。考虑到红色区域包含多个像素且像素颜色相近,为了区别一段的像素,引入FFFBBB通道,即每个颜色在R通道和H通道分别与前后三个像素的差值(红色激光在这两个通道中相较背景颜色差异显著)。
在小样本数据集下测试:
实验 | train loss | val loss | test loss | test 完全准确样本 |
---|---|---|---|---|
3通道 | 0.2899 | 0.3751 | 0.5955 | 482.0/600(80%) |
10通道 | 0.3513 | 0.3093 | 0.5839 | 491.0/600(82%) |
22通道 | 0.1654 | 0.2206 | 0.4754 | 528.0/600(88%) |
多通道的有效性在后面的多头注意力机制的实验中也被验证,该实验是在高噪声、长序列的数据集中进行的:
实验 | train loss | val loss | test loss | test 完全准确样本 | 点1平均偏移量 | 点2平均偏移量 | 点3平均偏移量 |
---|---|---|---|---|---|---|---|
10通道 | 203.0760 | 13.9310 | 199.6459 | 510.0/4500 (11%) | 3.0831 | 2.9273 | 3.2598 |
22通道 | 165.6441 | 16.9843 | 162.2208 | 555.0/4500 (12%) | 2.7256 | 3.2001 | 2.7711 |
23通道 | 174.5019 | 18.1041 | 149.0297 | 645.0/4500 (14%) | 2.6598 | 3.2242 | 2.4309 |
Ⅱ 位置编码
之前的思路是通过RNNs对每一个step做评分,得到scores,再将位置信息拼接到scores上通过另一层神经网络预测出红点所在的位置。这种方法的问题很明显:
- 用0-1的数字直接表达位置,信息传递不明确;
- 拼接矩阵会导致后面的神经网络赋予位置信息太多权重从而忽略颜色(可以通过降低位置信息的权重解决,但是权重的设立又需要大量实验);
- 局限于绝对位置信息,不能学会像“第1个像素在第2个像素前面1个位置”这样的相对位置信息。
因此,我想引入transformer中的位置编码,直接把position info与22通道信息相加,这样能一定程度上解决上述的三个问题。我是直接参考论文中的编码方式:
Ⅲ 实验结果
汇总下上文提到的几个优化方案,对比实验结果如下:
实验 | loss | 完全准确的点 |
---|---|---|
GRU | 129.6641 | 1762.0/9000 (20%) |
LSTM | 249.2053 | 1267.0/9000 (14%) |
优化F2的GRU | 30.2363 | 3486.0/9000 (39%) |
优化F2的LSTM | 37.3669 | 3485.0/9000 (39%) |
优化F1的GRU | 1752.6486 | 459.0/9000 (5%) |
优化F1的LSTM | 4578.7138 | 326.0/9000 (4%) |
Position embedding + GRU | 16.3403 | 5025.0/9000 (56%) |
Position embedding + LSTM | 204.1551 | 1603.0/9000 (18%) |
多通道 + GRU | 144.3737 | 1953.0/9000 (22%) |
多通道 + LSTM | 958.9135 | 713.0/9000 (8%) |
其中,优化F2的效果不错,但是训练时间较长,GRU的训练时间是原来的9倍; FC1优化起反效果。改良Position Info效果一般,可能是我设计的问题。
以上实验都是在低噪声、长序列的实验中完成的。
3.2 Attention
从我个人的经验来说,Attention一直是一个很万能的东西,不过这样算是比较早的技术了,早就衍生出了好多不同的种类。我打算尝试我认为比较有用的几种。
因为使用了Attention,这里就不采取低噪声的背景了,直接上高噪声的,不过还是小数据集,只有15000张图。
目前的流程是这样的:
Ⅰ Additive Attention
在 GRU 输出后添加一个加性注意力 (Additive Attention) 层,计算每个时间步的权重,然后对 GRU 的输出进行加权平均。但是因为这个问题不是针对整个序列的问题,而是针对单个step的,因此省略对 output 的 sum。
不过Additive Attention在某些任务上会对特征的选择更加敏感,但如果特征间的关系较弱,模型可能会过分关注一些不重要的信息,导致预测偏差。
Ⅱ Dot-Product Attention
点积注意力机制是比较常见的注意力机制,其核心步骤是:
- 计算 query 和 key 的点积。
- 对点积结果进行 softmax 以生成注意力权重。
- 用注意力权重加权 value,得到最终的 context vector。
我采用它的原因是该任务更加依赖于局部上下文相关性,而不是全局信息,这种 Attention 相对擅长这个任务。加性注意力是基于一个单独的权重投影,而 Dot-Product Attention 会直接操作序列的隐藏状态。
虽然是个长序列的问题,但是我暂时还没打算引入masking填充。
Ⅲ Scaled Dot-Product Attention
Scaled Dot-Product Attention 是 Dot-Product Attention 修改版本。与之前的 Additive Attention 不同,Scaled Dot-Product Attention 利用了向量之间的点积计算和归一化机制来生成注意力权重。因为在流程上跟 Dot-Product Attention 比较相似,就不做流程图了。
此外,该注意力机制还是多头注意力机制的简化版本。
我个人是认为寻找红点的中心这个任务是非常依赖时间步之间的强相关性的,而 Scaled Dot-Product Attention 依赖于 query 和 key 的相似性,这是一个切入点。
Ⅳ Multi-Head Attention (self-attention)
实现多头注意力机制其实就是对上一个注意力机制的优化:
- 将输入分割成多个注意力头。
- 每个头单独计算 Scaled Dot-Product Attention。
- 将所有头的输出拼接并通过一个线性变换。
在初始的实验中,注意力机制效果非常差。在图片长度为100的时候,多头注意力机制能够比较精确地识别到红点的中心,但是当图片长度扩展到1080的时候,就算没有噪音,多头注意力机制也很难找到红点。
我个人觉得是因为在长注意力机制里,计算复杂度为 O ( n 2 ) O(n^2) O(n2),在传播梯度的时候比较困难,特别是在模型层数比较深的情况时;此外因为序列比较长,局部注意力也可能被稀释。
能想到的可能比较好用的缓解方法基本上是1. 放弃全局注意力采用局部窗口注意力和滑动窗口注意力;2. 把RNN层数拉上去或者在注意力层之间加FFNs来增加特征提取能力。我个人倾向于第一个想法。
此外,当input中的query,key,value都为同一个向量的时候,多头注意力机制变为自注意力机制。
Ⅴ Criss-Cross Attention
尝试这个注意力机制是我还是执着于要把精力放在局部的信息上,因此想做一个带有滑动窗口的注意力机制。之前尝试过带滑动窗口的多头注意力机制,感觉效果一般,就重新尝试这个。
Ⅵ Result
低噪声实验
实验 | train loss | val loss | test loss | test 完全准确样本 |
---|---|---|---|---|
Position embedding + GRU | 0.2899 | 0.3751 | 16.3403 | 5025.0/9000 (56%) |
Additive Attention | 368.9917 | 344.9136 | 11018.8885 | 291.0/9000 (3%) |
高噪声实验
实验 | train loss | val loss | test loss | test 完全准确样本 | 点1平均偏移量 | 点2平均偏移量 | 点3平均偏移量 |
---|---|---|---|---|---|---|---|
Position embedding + GRU | 664.6360 | 412.0233 | 2173.8701 | 119.0/3000 (4%) | / | / | / |
Additive Attention | 378.7690 | 23.7633 | 369.3527 | 403.0/4500 (9%) | 4.7889 | 3.2219 | 4.4258 |
Dot-Product Attention | 199.6322 | 14.1448 | 216.6723 | 544.0/4500 (12%) | 3.1844 | 3.1633 | 3.4102 |
Scaled Dot-Product Attention | 179.0531 | 13.0236 | 209.7830 | 517.0/4500 (11%) | 2.7206 | 3.5754 | 3.1143 |
Multi-Head Attention | 174.5019 | 18.1041 | 149.0297 | 645.0/4500 (14%) | 2.6598 | 3.2243 | 2.4309 |
Criss-Cross Attention | 192.4157 | 34.9980 | 205.0001 | 556.0/4500 (12%) | 3.1831 | 3.2340 | 3.1561 |
精确度感觉还是不是很充足:
4 代码
多通道色彩
class RedDotDataset(Dataset):
def __init__(self, csv_file, img_dir, transform=None):
self.data_frame = pd.read_csv(csv_file, dtype={'Stripe_Number': str}, encoding='utf-8-sig')
self.img_dir = img_dir
self.transform = transform or transforms.Compose([
transforms.ToTensor(), # Automatically converts and normalizes
transforms.Normalize((0.5,), (0.5,)) # Example normalization, adjust based on dataset
])
def __len__(self):
return len(self.data_frame)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, str(self.data_frame.iloc[idx, 0]) + '.jpg')
img_rgb = np.array(Image.open(img_path).convert('RGB'))
img_hsv = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2HSV)
img_lab = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2LAB)
img_l = Image.fromarray(img_rgb).convert('L')
height, width, _ = img_rgb.shape
r_channel = img_rgb[:, :, 0].astype('float32')
h_channel = img_hsv[:, :, 0].astype('float32')
def calculate_diff(channel_data):
diff_map = []
for offset in range(-3, 4):
if offset == 0:
continue
offset_data = np.zeros_like(channel_data)
for x in range(width):
new_x = min(max(0, x + offset), width - 1)
offset_data[:, x] = channel_data[:, new_x]
diff_map.append(offset_data - channel_data)
return np.stack(diff_map, axis=-1)
r_diff_map = calculate_diff(r_channel)
h_diff_map = calculate_diff(h_channel)
# [H, W, C] nampy array to [C, H, W] torch tensor
img_rgb = torch.from_numpy(np.array(img_rgb)).permute(2, 0, 1).float() / 255.0
img_hsv = torch.from_numpy(np.array(img_hsv)).permute(2, 0, 1).float() / 255.0
img_lab = torch.from_numpy(np.array(img_lab)).permute(2, 0, 1).float() / 255.0
img_l = torch.from_numpy(np.array(img_l)).unsqueeze(0).float() / 255.0
r_diff_map_n_channel = torch.from_numpy(r_diff_map).permute(2, 0, 1).float() / 255.0
h_diff_map_n_channel = torch.from_numpy(h_diff_map).permute(2, 0, 1).float() / 255.0
# zero_tensor = torch.zeros((1, 1, img_rgb.shape[2]), dtype=img_rgb.dtype)
combined_img = torch.cat((img_rgb, img_hsv, img_lab, img_l, r_diff_map_n_channel, h_diff_map_n_channel), dim=0)
labels = self.data_frame.iloc[idx, 1:].values.astype('float32')
# if self.transform:
# img = self.transform(img)
return combined_img, labels
位置编码
def generate_positional_encoding(self, seq_length, d_model):
pos_encoding = torch.zeros(seq_length, d_model)
for pos in range(seq_length):
for i in range(0, d_model, 2):
pos_encoding[pos, i] = math.sin(pos / (10000 ** ((2 * i) / d_model)))
if i + 1 < d_model:
pos_encoding[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1)) / d_model)))
return pos_encoding.unsqueeze(0) # Add batch dimension for broadcasting
Attention结构
class AdditiveAttention(nn.Module):
def __init__(self, hidden_size):
super(AdditiveAttention, self).__init__()
self.attn = nn.Linear(hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.rand(hidden_size))
def forward(self, x):
# x: (batch_size, seq_length, hidden_size * 2)
u = torch.tanh(self.attn(x)) # (batch_size, seq_length, hidden_size)
attn_weights = torch.matmul(u, self.v) # (batch_size, seq_length)
attn_weights = F.softmax(attn_weights, dim=1).unsqueeze(-1) # (batch_size, seq_length, 1)
output = x * attn_weights # (batch_size, seq_length, hidden_size * 2)
return output # Aggregate context vector (batch_size, hidden_size * 2)
class DotProductAttention(nn.Module):
def __init__(self, hidden_size):
super(DotProductAttention, self).__init__()
self.query_proj = nn.Linear(hidden_size * 2, hidden_size)
self.key_proj = nn.Linear(hidden_size * 2, hidden_size)
self.value_proj = nn.Linear(hidden_size * 2, hidden_size)
self.scale = math.sqrt(hidden_size)
def forward(self, x):
# x: (batch_size, seq_length, hidden_size * 2)
query = self.query_proj(x) # (batch_size, seq_length, hidden_size)
key = self.key_proj(x) # (batch_size, seq_length, hidden_size)
value = self.value_proj(x) # (batch_size, seq_length, hidden_size)
# Compute attention scores
scores = torch.matmul(query, key.transpose(-2, -1)) / self.scale # (batch_size, seq_length, seq_length)
attn_weights = F.softmax(scores, dim=-1) # (batch_size, seq_length, seq_length)
# Compute context vector
context = torch.matmul(attn_weights, value) # (batch_size, seq_length, hidden_size)
return context
class ScaledDotProductAttention(nn.Module):
def __init__(self, hidden_size):
super(ScaledDotProductAttention, self).__init__()
self.scale = math.sqrt(hidden_size)
def forward(self, x):
# x: (batch_size, seq_length, hidden_size * 2)
query = x # Use the GRU output directly as query, key, and value
key = x
value = x
# Compute attention scores
scores = torch.matmul(query, key.transpose(-2, -1)) / self.scale # (batch_size, seq_length, seq_length)
attn_weights = F.softmax(scores, dim=-1) # (batch_size, seq_length, seq_length)
# Compute context vector
context = torch.matmul(attn_weights, value) # (batch_size, seq_length, hidden_size * 2)
return context
# 偷懒了哈哈
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, num_heads):
super(MultiHeadAttention, self).__init__()
self.multihead_att = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads, batch_first=True)
def forward(self, query, key, value):
out, _ = self.multihead_att(query, key, value)
return out
class CrissCrossAttention(nn.Module):
def __init__(self, embed_size, window_size):
super(CrissCrossAttention, self).__init__()
self.window_size = window_size
self.linear_q = nn.Linear(embed_size, embed_size)
self.linear_k = nn.Linear(embed_size, embed_size)
self.linear_v = nn.Linear(embed_size, embed_size)
self.scale = embed_size ** 0.5
def forward(self, x):
# x: (batch_size, seq_length, embed_size)
batch_size, seq_length, embed_size = x.size()
query = self.linear_q(x)
key = self.linear_k(x)
value = self.linear_v(x)
# Sliding window
key = F.pad(key, (0, 0, self.window_size // 2, self.window_size // 2), mode="constant", value=0)
value = F.pad(value, (0, 0, self.window_size // 2, self.window_size // 2), mode="constant", value=0)
key_windows = key.unfold(dimension=1, size=self.window_size, step=1)
value_windows = value.unfold(dimension=1, size=self.window_size, step=1)
scores = torch.einsum("bqe,bqew->bqw", query, key_windows) / self.scale # (batch_size, seq_length, window_size)
attn_weights = F.softmax(scores, dim=-1)
# Weighted sum for context vectors
context = torch.einsum("bqw,bqew->bqe", attn_weights, value_windows)
return context
路过的大佬有什么建议 ball ball 在评论区打出来
都看到这儿了,不妨给个三连吧~

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。
更多推荐
所有评论(0)