import torch import torch.nn.functional as F from tqdm import tqdm import torch.nn as nn from .loss_func import BCEFocalLoss,SeLossWithBCEFocalLoss,FocalLoss,SE_FocalLoss,FeatureSeLossWithBCEFocalLoss def eval_net(net, dataset,loader, device,alpha=None,cfg=1): """Evaluation without the densecrf with the dice coefficient""" num_class = cfg['num_class'] num_task = cfg['num_task'] gamma = cfg['gamma'] if 'gamma' in cfg.keys() else 1 sigmoid_choice = cfg['sigmoid_choice'] if 'sigmoid_choice' in cfg.keys() else False if net.num_class>1: criterion = nn.CrossEntropyLoss() if alpha is not None: criterion = FocalLoss(cfg['num_class'],alpha=alpha,gamma=gamma,size_average=False,sigmoid_choice=True) else: if alpha is not None: criterion = BCEFocalLoss(alpha=alpha,gamma=gamma,reduction='sum') net.eval() mask_type = torch.float32 n_val = len(dataset) # the number of batch tot = 0 count = 0 with tqdm(total=int(n_val/loader.batch_size)+1, desc='Validation round', unit='batch', leave=False) as pbar: for batch in loader: imgs = batch[0][0] label = batch[1][0] imgs = imgs.to(device=device,dtype=torch.float32) if num_task==1: label = label.to(device=device,dtype=torch.float32) else: label = [val.to(device=device,dtype=torch.float32) for val in label] if cfg['num_class']>2: label = label.to(device=device,dtype=torch.long) with torch.no_grad(): prediction = net(imgs) if num_task==1: if cfg['num_class']>2: prediction = prediction else: if type(prediction) == list: prediction = prediction[0] prediction = prediction.squeeze(-1) try: single_loss = criterion([prediction,label]).item() tot += single_loss except: continue else: for x,y in zip(prediction,label): x = x.squeeze(-1) try: tot += criterion([x,y]).item except: continue count += 1 pbar.update() net.train() print ('n_val is',tot,n_val,count) return tot / n_val def eval_netSE(net,dataset,loader, device,cfg=1,criterion=None,reg_loss=None): """Evaluation without the densecrf with the dice coefficient""" lambda_val = cfg['lambda_val'] if 'lambda_val' in cfg.keys() else 1 gamma = cfg['gamma'] if 'gamma' in cfg.keys() else 1 num_task = cfg['num_task'] prob_diff = cfg['prob_diff'] if 'prob_diff' in cfg.keys() else False feature_sim = True if 'Se_choice' in cfg.keys() and cfg['Se_choice']==1 else False batch_size = cfg['batch_size'] if net.num_class>1: label_dtype = torch.long else: label_dtype = torch.float32 net.eval() mask_type = torch.float32 n_val = int(round(len(dataset)/batch_size)) # the number of batch tot = 0 with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar: batch_count = 0 for batch in zip(loader): batch_count += 1 imgs,labels,label_diff = batch[0] label_diff = [single_label_diff.cuda().to(device=device,dtype=label_dtype) for single_label_diff in label_diff] imgs = [single_set_img.cuda().to(device=device,dtype=torch.float32) for single_set_img in imgs] if reg_loss: cls_labels = [single_label[0].cuda().to(device=device,dtype=label_dtype) for single_label in labels] reg_labels = [single_label[1].cuda().to(device=device,dtype=torch.float32) for single_label in labels] else: cls_labels = [single_label.cuda().to(device=device,dtype=label_dtype) for single_label in labels] with torch.no_grad(): prediction = net(imgs,labels=cls_labels) if reg_loss: logits = [pred[0].squeeze(-1) for pred in prediction[:len(labels)]] reg_outputs = [pred[1].squeeze(-1) for pred in prediction[:len(labels)]] cons = prediction[-1][0].squeeze(-1) else: logits = [pred.squeeze(-1) for pred in prediction[:len(labels)]] cons = prediction[-1].squeeze(-1) loss_input = [logits]+[cls_labels]+[label_diff] loss = criterion(loss_input) if reg_loss: reg_loss_val = 0 for sample_idx in range(len(reg_outputs)):