import os
import numpy as np
import codecs
import SimpleITK as sitk
import pandas as pd
import torch

'''

dice.txt context

---len(num_images)
---path of ground truth of image_001
---path of seg_mask of image_001
---path of ground truth of image_002
---path of seg_mask of image_002
    ...
    ...
    ...
'''
def readlines(file):
    """
    read lines by removing '\n' in the end of line
    :param file: a text file
    :return: a list of line strings
    """
    fp = codecs.open(file, 'r', encoding='utf-8')
    linelist = fp.readlines()
    fp.close()
    for i in range(len(linelist)):
        linelist[i] = linelist[i].rstrip('\n') # cancel '\n' per line
    return linelist

def read_test_txt(imlist_file):
    '''
    :param imlist_file: image list file path
    :return: image list path divided into two list
    '''
    lines = readlines(imlist_file)
    num_cases = int(lines[0])

    if (len(lines) - 1) < (num_cases * 2):
        raise ValueError('too few lines in imlist file')
    im_list, seg_list = [], []
    for i in range(num_cases):
        im_path, seg_path = lines[1 + i * 2].strip(), lines[2 + i * 2].strip()
        assert os.path.exists(im_path), 'image not exist: {}'.format(im_path)
        assert os.path.exists(seg_path), 'mask not exist: {}'.format(seg_path)
        im_list.append(im_path)
        seg_list.append(seg_path)

    return im_list, seg_list

def cal_dice(input_tensor, target, num_class, epsilon=1e-6):
    '''
    :params input_tensor:   the result of segmentation
    :params target:         ground true mask
    :params num_class:      label number
    :params epsilon         avoid dividezero arguments
    :return:                each class dice score
    '''
    dice_score = []
    for i in range(1, num_class):
        input_i = (input_tensor == i) * 1
        target_i = (target == i) * 1
        input_i = input_i.view(-1)
        target_i = target_i.view(-1)
        # compute dice score
        intersect = torch.sum(input_i * target_i, 0)
        input_area = torch.sum(input_i, 0)
        target_area = torch.sum(target_i, 0)
        sum_area = input_area + target_area + 2 * epsilon

        dice_score_i = 2 * intersect.float() / sum_area.float()
        dice_score.append(dice_score_i)
        print('class = {}, dice = {}'.format(i, dice_score_i))

    return dice_score

def val(input_path, results_csv):
    if input_path.endswith('txt'):
        gt_list, pre_list = read_test_txt(input_path)
    else:
        raise ValueError('image test_list must either be a txt file or a csv file')
    
    # dice_score_record = pd.DataFrame(columns = ['case_name', 'left_testis', 'right_testis']) # 2 labels
    dice_score_record = pd.DataFrame(columns = ['case_name', 'tumor'])
    for gt_path, pre_path in zip(gt_list, pre_list):
        print('{}: {}'.format(gt_path, pre_path))

        gt_mask = sitk.ReadImage(gt_path)
        pre_mask = sitk.ReadImage(pre_path)
        case_name = pre_path.split('/')[5] # need to change according to where is case_name 
        print(case_name)
        gt_mask_np = sitk.GetArrayFromImage(gt_mask).astype(float)
        pre_mask_np = sitk.GetArrayFromImage(pre_mask).astype(float)
        num_label = np.unique(gt_mask_np)
        num_class = len(num_label)
        # get tensor
        gt_mask = torch.from_numpy(gt_mask_np)
        gt_mask = torch.unsqueeze(gt_mask, 0)
        gt_mask = gt_mask.float()
        pre_mask = torch.from_numpy(pre_mask_np)
        pre_mask = torch.unsqueeze(pre_mask, 0)
        pre_mask = pre_mask.float()
        # gt_mask.append(gt_mask)
        dice_score = cal_dice(pre_mask, gt_mask, num_class)
        if num_class == 3:
            df = pd.DataFrame({
                'case_name':case_name,
                'left_testis': dice_score[0].item(),
                'right_testis': dice_score[1].item()
            },index=[0]) 
        if num_class == 2:
            df = pd.DataFrame({
                'case_name': case_name,
                'tumor':dice_score[0].item()
            },index=[0])
        dice_score_record = dice_score_record.append(df)
	dice_score_record.to_csv(results_csv, index=None)

input_path = '/home/xxx/06_datalist/dice.txt'
results_csv = '/home/xxx/06_datalist/dice.csv'
val(input_path, results_csv)
Logo

GitCode 天启AI是一款由 GitCode 团队打造的智能助手,基于先进的LLM(大语言模型)与多智能体 Agent 技术构建,致力于为用户提供高效、智能、多模态的创作与开发支持。它不仅支持自然语言对话,还具备处理文件、生成 PPT、撰写分析报告、开发 Web 应用等多项能力,真正做到“一句话,让 Al帮你完成复杂任务”。

更多推荐