import torch.nn.functional as F


# 将(Batch_Size, Num_Class, Height, Weight)裁剪掉指定的Height和Weight
def crop_tensor_by_height_width(tensor, height_crop, width_crop):
    assert len(tensor.shape) == 4, '输入的tensor应为4维'
    assert height_crop > 0 and width_crop > 0, 'crop应该大于0'
    height_extra = 0
    width_extra = 0
    if height_crop % 2 != 0:
        height_extra = 1
    if width_crop % 2 != 0:
        width_extra = 1
    # 计算截取下界
    lower_bound_height_crop = height_crop // 2
    lower_bound_width_crop = width_crop // 2
    # 获取原始的高度和宽度
    original_height, original_width = tensor.shape[2], tensor.shape[3]
    # 计算截取上界
    upper_width_height_crop = original_height - height_crop // 2 - height_extra
    upper_width_width_crop = original_width - width_crop // 2 - width_extra
    # 同时裁剪高度和宽度
    return tensor[:, :, lower_bound_height_crop:upper_width_height_crop, lower_bound_width_crop:upper_width_width_crop]


def crop_or_pad_tensor_by_height_width(tensor, height_crop, width_crop, pad_value=0):
    '''
    裁剪或扩展Tensor在高度(仅底部)和宽度(仅右侧)维度上的最后一个像素。
    正数表示扩展(用0填充),负数表示裁剪。

    参数:
    tensor (torch.Tensor): 输入的4维张量,形状为 (batch_size, channels, height, width)
    height_crop (int): 高度方向上底部要裁剪或扩展的像素数量,默认为1
    width_crop (int): 宽度方向上右侧要裁剪或扩展的像素数量,默认为1
    pad_value (float or int): 填充时使用的值,默认为0

    返回:
    cropped_or_padded_tensor (torch.Tensor): 裁剪或扩展后的张量
    '''
    assert len(tensor.shape) == 4, '输入的tensor应为4维'

    # 获取原始的高度和宽度
    original_height, original_width = tensor.shape[2], tensor.shape[3]

    # 计算需要裁剪的数量(正值代表不裁剪,负值时代表裁剪)
    height_to_remove_from_bottom = min(original_height, -height_crop) if height_crop < 0 else 0
    width_to_remove_from_right = min(original_width, -width_crop) if width_crop < 0 else 0

    # 计算需要填充的数量(正值代表填充,负值代表不填充)
    pad_bottom = abs(height_crop) if height_crop > 0 else 0
    pad_right = abs(width_crop) if width_crop > 0 else 0

    # 先填充,再裁剪
    padded_tensor = F.pad(tensor, pad=(0, pad_right, 0, pad_bottom), mode='constant', value=pad_value)

    # 在高度和宽度维度上进行裁剪(如果需要)
    if height_to_remove_from_bottom > 0 and width_to_remove_from_right > 0:
        # 同时裁剪高度和宽度
        cropped_or_padded_tensor = padded_tensor[:, :, :-height_to_remove_from_bottom, :-width_to_remove_from_right]
    elif height_to_remove_from_bottom > 0:
        # 只裁剪高度
        cropped_or_padded_tensor = padded_tensor[:, :, :-height_to_remove_from_bottom, :]
    elif width_to_remove_from_right > 0:
        # 只裁剪宽度
        cropped_or_padded_tensor = padded_tensor[:, :, :, :-width_to_remove_from_right]
    else:
        # 不裁剪任何维度
        cropped_or_padded_tensor = padded_tensor

    return cropped_or_padded_tensor


def crop_or_pad_tensor_by_depth_height_width(tensor, depth_crop, height_crop, width_crop, pad_value=0):
    '''
    裁剪或扩展Tensor在深度(仅最后一个)、高度(仅底部)和宽度(仅右侧)维度上的最后一个像素。
    正数表示扩展(用0填充),负数表示裁剪。

    参数:
    tensor (torch.Tensor): 输入的5维张量,形状为 (batch_size, channels, depth, height, width)
    depth_crop (int): 深度方向上最后一个要裁剪或扩展的数量,默认为1
    height_crop (int): 高度方向上底部要裁剪或扩展的像素数量,默认为1
    width_crop (int): 宽度方向上右侧要裁剪或扩展的像素数量,默认为1
    pad_value (float or int): 填充时使用的值,默认为0

    返回:
    cropped_or_padded_tensor (torch.Tensor): 裁剪或扩展后的张量
    '''
    assert len(tensor.shape) == 5, '输入的tensor应为5维'

    # 获取原始的深度、高度和宽度
    original_depth, original_height, original_width = tensor.shape[2], tensor.shape[3], tensor.shape[4]

    # 计算需要裁剪的数量(正值代表不裁剪,负值时代表裁剪)
    depth_to_remove_from_end = min(original_depth, -depth_crop) if depth_crop < 0 else 0
    height_to_remove_from_bottom = min(original_height, -height_crop) if height_crop < 0 else 0
    width_to_remove_from_right = min(original_width, -width_crop) if width_crop < 0 else 0

    # 计算需要填充的数量(正值代表填充,负值代表不填充)
    pad_depth = abs(depth_crop) if depth_crop > 0 else 0
    pad_bottom = abs(height_crop) if height_crop > 0 else 0
    pad_right = abs(width_crop) if width_crop > 0 else 0

    # 先填充,再裁剪
    padded_tensor = F.pad(tensor, pad=(0, pad_right, 0, pad_bottom, 0, pad_depth), mode='constant', value=pad_value)

    # 在深度、高度和宽度维度上进行裁剪(如果需要)
    if depth_to_remove_from_end > 0 and height_to_remove_from_bottom > 0 and width_to_remove_from_right > 0:
        # 同时裁剪深度、高度和宽度
        cropped_or_padded_tensor = padded_tensor[:, :, :-depth_to_remove_from_end, :-height_to_remove_from_bottom,
                                   :-width_to_remove_from_right]
    elif depth_to_remove_from_end > 0 and height_to_remove_from_bottom > 0:
        # 只裁剪深度和高度
        cropped_or_padded_tensor = padded_tensor[:, :, :-depth_to_remove_from_end, :-height_to_remove_from_bottom, :]
    elif depth_to_remove_from_end > 0 and width_to_remove_from_right > 0:
        # 只裁剪深度和宽度
        cropped_or_padded_tensor = padded_tensor[:, :, :-depth_to_remove_from_end, :, :-width_to_remove_from_right]
    elif height_to_remove_from_bottom > 0 and width_to_remove_from_right > 0:
        # 只裁剪高度和宽度
        cropped_or_padded_tensor = padded_tensor[:, :, :, :-height_to_remove_from_bottom, :-width_to_remove_from_right]
    elif depth_to_remove_from_end > 0:
        # 只裁剪深度
        cropped_or_padded_tensor = padded_tensor[:, :, :-depth_to_remove_from_end, :, :]
    elif height_to_remove_from_bottom > 0:
        # 只裁剪高度
        cropped_or_padded_tensor = padded_tensor[:, :, :, :-height_to_remove_from_bottom, :]
    elif width_to_remove_from_right > 0:
        # 只裁剪宽度
        cropped_or_padded_tensor = padded_tensor[:, :, :, :, :-width_to_remove_from_right]
    else:
        # 不裁剪任何维度
        cropped_or_padded_tensor = padded_tensor

    return cropped_or_padded_tensor

Logo

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。

更多推荐