import torch import torch.nn.functional as F from tqdm import tqdm import torch.nn as nn from .loss_func import BCEFocalLoss,SeLossWithBCEFocalLoss,FocalLoss,SE_FocalLoss,FeatureSeLossWithBCEFocalLoss def eval_net(net, dataset,loader, device,alpha=None,cfg=1): """Evaluation without the densecrf with the dice coefficient""" num_class = cfg['num_class'] num_task = cfg['num_task'] gamma = cfg['gamma'] if 'gamma' in cfg.keys() else 1 sigmoid_choice = cfg['sigmoid_choice'] if 'sigmoid_choice' in cfg.keys() else False if net.num_class>1: criterion = nn.CrossEntropyLoss() if alpha is not None: criterion = FocalLoss(cfg['num_class'],alpha=alpha,gamma=gamma,size_average=False,sigmoid_choice=True) else: if alpha is not None: criterion = BCEFocalLoss(alpha=alpha,gamma=gamma,reduction='sum') net.eval() mask_type = torch.float32 n_val = len(dataset) # the number of batch tot = 0 count = 0 with tqdm(total=int(n_val/loader.batch_size)+1, desc='Validation round', unit='batch', leave=False) as pbar: for batch in loader: imgs = batch[0][0] label = batch[1][0] imgs = imgs.to(device=device,dtype=torch.float32) if num_task==1: label = label.to(device=device,dtype=torch.float32) else: label = [val.to(device=device,dtype=torch.float32) for val in label] if cfg['num_class']>2: label = label.to(device=device,dtype=torch.long) with torch.no_grad(): prediction = net(imgs) if num_task==1: if cfg['num_class']>2: prediction = prediction else: if type(prediction) == list: prediction = prediction[0] prediction = prediction.squeeze(-1) try: single_loss = criterion([prediction,label]).item() tot += single_loss except: continue else: for x,y in zip(prediction,label): x = x.squeeze(-1) try: tot += criterion([x,y]).item except: continue count += 1 pbar.update() net.train() print ('n_val is',tot,n_val,count) return tot / n_val def eval_netSE(net, dataset, loader, base_cfg=None, cfg=1, criterion=None,reg_loss=None): """Evaluation without the densecrf with the dice coefficient""" lambda_val = cfg['lambda_val'] if 'lambda_val' in cfg.keys() else 1 gamma = base_cfg.lr_gamma if hasattr(base_cfg, 'lr_gamma') else 1 num_task = cfg['num_task'] prob_diff = cfg['prob_diff'] if 'prob_diff' in cfg.keys() else False feature_sim = True if 'Se_choice' in cfg.keys() and cfg['Se_choice']==1 else False batch_size = base_cfg.batch_size if cfg['num_class'] > 1: label_dtype = torch.long else: label_dtype = torch.float32 net.eval() mask_type = torch.float32 print('in test len(dataset): ', len(dataset)) n_val = int(round(len(dataset)/batch_size)) # the number of batch print('n_val: ', n_val) tot = 0 with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar: batch_count = 0 for batch in zip(loader): batch_count += 1 imgs, labels, label_diff = batch[0] # label_diff = [single_label_diff.type(label_dtype).cuda() for single_label_diff in label_diff] imgs = [single_set_img.type(torch.float32).cuda() for single_set_img in imgs] if reg_loss: cls_labels = [single_label[0].type(label_stype).cuda() for single_label in labels] reg_labels = [single_label[1].type(torch.float32).cuda() for single_label in labels] else: cls_labels = [single_label.type(label_dtype).cuda() for single_label in labels] with torch.no_grad(): prediction = net(imgs,labels=cls_labels) if reg_loss: logits = [pred[0].squeeze(-1) for pred in prediction[:len(labels)]] reg_outputs = [pred[1].squeeze(-1) for pred in prediction[:len(labels)]] cons = prediction[-1][0].squeeze(-1) else: print('prediction: ', prediction[0].shape) print(prediction) print('len(labels): ', len(labels)) print('cls_label: ', cls_labels) logits = [pred.squeeze(-1) for pred in prediction[:len(labels)]] cons = prediction[-1].squeeze(-1) loss_input = [logits]+[cls_labels]+[label_diff] loss = criterion(loss_input) print(torch.cuda.device_count()) print('loss: ', loss) if reg_loss: reg_loss_val = 0 for sample_idx in range(len(reg_outputs)): reg_loss_val_current = reg_loss(reg_outputs[sample_idx],reg_labels[sample_idx]) if sample_idx==0: reg_loss_val = reg_loss_val_current else: reg_loss_val += reg_loss_val_current loss += 0.05*reg_loss_val tot += loss.item() pbar.update() net.train() val_loss = tot / n_val return val_loss def eval_netFSE(net, dataset,loader,loader_se, device,cfg=1,criterion=None): """Evaluation without the densecrf with the dice coefficient""" lambda_val = cfg['lambda_val'] if 'lambda_val' in cfg.keys() else 1 gamma = cfg['gamma'] if 'gamma' in cfg.keys() else 1 num_task = cfg['num_task'] if net.num_class>1: label_dtype = torch.long else: label_dtype = torch.float32 net.eval() mask_type = torch.float32 n_val = len(dataset) # the number of batch tot = 0 with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar: for batch,batch_se in zip(loader,loader_se): imgs,label = batch imgs_se,label_se = batch_se label_diff = (label==label_se) imgs = imgs.cuda() imgs_se = imgs_se.cuda() label = label.cuda() label = label.to(device=device,dtype=label_dtype) label_se = label_se.cuda() label_se = label_se.to(device=device,dtype=label_dtype) imgs = imgs.to(device=device,dtype=torch.float32) imgs_se = imgs_se.to(device=device,dtype=torch.float32) label_diff = label_diff.cuda() label_diff = label_diff.to(device=device,dtype=label_dtype) with torch.no_grad(): prediction = net(imgs,imgs_se) if num_task==1: tot += criterion(prediction[0].squeeze(-1),prediction[1].squeeze(-1),prediction[2],prediction[3],prediction[4].squeeze(-1),label,label_se,label_diff).item() else: for x,y in zip(prediction,label): x = x.squeeze(-1) tot += criterion(x,y).item() pbar.update() net.train() return tot / n_val