裁剪或填充张量(Tensor)(四维与五维)(Python代码)
【代码】裁剪或填充张量(Tensor)(四维与五维)(Python代码)
·
import torch.nn.functional as F
# 将(Batch_Size, Num_Class, Height, Weight)裁剪掉指定的Height和Weight
def crop_tensor_by_height_width(tensor, height_crop, width_crop):
assert len(tensor.shape) == 4, '输入的tensor应为4维'
assert height_crop > 0 and width_crop > 0, 'crop应该大于0'
height_extra = 0
width_extra = 0
if height_crop % 2 != 0:
height_extra = 1
if width_crop % 2 != 0:
width_extra = 1
# 计算截取下界
lower_bound_height_crop = height_crop // 2
lower_bound_width_crop = width_crop // 2
# 获取原始的高度和宽度
original_height, original_width = tensor.shape[2], tensor.shape[3]
# 计算截取上界
upper_width_height_crop = original_height - height_crop // 2 - height_extra
upper_width_width_crop = original_width - width_crop // 2 - width_extra
# 同时裁剪高度和宽度
return tensor[:, :, lower_bound_height_crop:upper_width_height_crop, lower_bound_width_crop:upper_width_width_crop]
def crop_or_pad_tensor_by_height_width(tensor, height_crop, width_crop, pad_value=0):
'''
裁剪或扩展Tensor在高度(仅底部)和宽度(仅右侧)维度上的最后一个像素。
正数表示扩展(用0填充),负数表示裁剪。
参数:
tensor (torch.Tensor): 输入的4维张量,形状为 (batch_size, channels, height, width)
height_crop (int): 高度方向上底部要裁剪或扩展的像素数量,默认为1
width_crop (int): 宽度方向上右侧要裁剪或扩展的像素数量,默认为1
pad_value (float or int): 填充时使用的值,默认为0
返回:
cropped_or_padded_tensor (torch.Tensor): 裁剪或扩展后的张量
'''
assert len(tensor.shape) == 4, '输入的tensor应为4维'
# 获取原始的高度和宽度
original_height, original_width = tensor.shape[2], tensor.shape[3]
# 计算需要裁剪的数量(正值代表不裁剪,负值时代表裁剪)
height_to_remove_from_bottom = min(original_height, -height_crop) if height_crop < 0 else 0
width_to_remove_from_right = min(original_width, -width_crop) if width_crop < 0 else 0
# 计算需要填充的数量(正值代表填充,负值代表不填充)
pad_bottom = abs(height_crop) if height_crop > 0 else 0
pad_right = abs(width_crop) if width_crop > 0 else 0
# 先填充,再裁剪
padded_tensor = F.pad(tensor, pad=(0, pad_right, 0, pad_bottom), mode='constant', value=pad_value)
# 在高度和宽度维度上进行裁剪(如果需要)
if height_to_remove_from_bottom > 0 and width_to_remove_from_right > 0:
# 同时裁剪高度和宽度
cropped_or_padded_tensor = padded_tensor[:, :, :-height_to_remove_from_bottom, :-width_to_remove_from_right]
elif height_to_remove_from_bottom > 0:
# 只裁剪高度
cropped_or_padded_tensor = padded_tensor[:, :, :-height_to_remove_from_bottom, :]
elif width_to_remove_from_right > 0:
# 只裁剪宽度
cropped_or_padded_tensor = padded_tensor[:, :, :, :-width_to_remove_from_right]
else:
# 不裁剪任何维度
cropped_or_padded_tensor = padded_tensor
return cropped_or_padded_tensor
def crop_or_pad_tensor_by_depth_height_width(tensor, depth_crop, height_crop, width_crop, pad_value=0):
'''
裁剪或扩展Tensor在深度(仅最后一个)、高度(仅底部)和宽度(仅右侧)维度上的最后一个像素。
正数表示扩展(用0填充),负数表示裁剪。
参数:
tensor (torch.Tensor): 输入的5维张量,形状为 (batch_size, channels, depth, height, width)
depth_crop (int): 深度方向上最后一个要裁剪或扩展的数量,默认为1
height_crop (int): 高度方向上底部要裁剪或扩展的像素数量,默认为1
width_crop (int): 宽度方向上右侧要裁剪或扩展的像素数量,默认为1
pad_value (float or int): 填充时使用的值,默认为0
返回:
cropped_or_padded_tensor (torch.Tensor): 裁剪或扩展后的张量
'''
assert len(tensor.shape) == 5, '输入的tensor应为5维'
# 获取原始的深度、高度和宽度
original_depth, original_height, original_width = tensor.shape[2], tensor.shape[3], tensor.shape[4]
# 计算需要裁剪的数量(正值代表不裁剪,负值时代表裁剪)
depth_to_remove_from_end = min(original_depth, -depth_crop) if depth_crop < 0 else 0
height_to_remove_from_bottom = min(original_height, -height_crop) if height_crop < 0 else 0
width_to_remove_from_right = min(original_width, -width_crop) if width_crop < 0 else 0
# 计算需要填充的数量(正值代表填充,负值代表不填充)
pad_depth = abs(depth_crop) if depth_crop > 0 else 0
pad_bottom = abs(height_crop) if height_crop > 0 else 0
pad_right = abs(width_crop) if width_crop > 0 else 0
# 先填充,再裁剪
padded_tensor = F.pad(tensor, pad=(0, pad_right, 0, pad_bottom, 0, pad_depth), mode='constant', value=pad_value)
# 在深度、高度和宽度维度上进行裁剪(如果需要)
if depth_to_remove_from_end > 0 and height_to_remove_from_bottom > 0 and width_to_remove_from_right > 0:
# 同时裁剪深度、高度和宽度
cropped_or_padded_tensor = padded_tensor[:, :, :-depth_to_remove_from_end, :-height_to_remove_from_bottom,
:-width_to_remove_from_right]
elif depth_to_remove_from_end > 0 and height_to_remove_from_bottom > 0:
# 只裁剪深度和高度
cropped_or_padded_tensor = padded_tensor[:, :, :-depth_to_remove_from_end, :-height_to_remove_from_bottom, :]
elif depth_to_remove_from_end > 0 and width_to_remove_from_right > 0:
# 只裁剪深度和宽度
cropped_or_padded_tensor = padded_tensor[:, :, :-depth_to_remove_from_end, :, :-width_to_remove_from_right]
elif height_to_remove_from_bottom > 0 and width_to_remove_from_right > 0:
# 只裁剪高度和宽度
cropped_or_padded_tensor = padded_tensor[:, :, :, :-height_to_remove_from_bottom, :-width_to_remove_from_right]
elif depth_to_remove_from_end > 0:
# 只裁剪深度
cropped_or_padded_tensor = padded_tensor[:, :, :-depth_to_remove_from_end, :, :]
elif height_to_remove_from_bottom > 0:
# 只裁剪高度
cropped_or_padded_tensor = padded_tensor[:, :, :, :-height_to_remove_from_bottom, :]
elif width_to_remove_from_right > 0:
# 只裁剪宽度
cropped_or_padded_tensor = padded_tensor[:, :, :, :, :-width_to_remove_from_right]
else:
# 不裁剪任何维度
cropped_or_padded_tensor = padded_tensor
return cropped_or_padded_tensor

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。
更多推荐
所有评论(0)