import os import sys import torch import math import warnings import random import traceback import numpy as np import horovod.torch as hvd import pandas as pd import torch.backends.cudnn as cudnn from tqdm import tqdm from apex import amp from tensorboardX import SummaryWriter from torch.utils.data import DataLoader from torch.autograd import Variable from .BaseDetector import BaseDetection3D, Metric from ..data import CenterReader from ..data.collate import train_collate_center, test_collate class CenterDetection3D(BaseDetection3D): def __init__(self, cfg, mode = "train"): super(CenterDetection3D, self).__init__(cfg, mode) self.DataReader = CenterReader def _create_dataloader(self): self.train_dataset = self.DataReader(self.cfg, mode='train') self.val_dataset = self.DataReader(self.cfg, mode='val') # Horovod: limit # of CPU threads to be used per worker. torch.set_num_threads(4) kwargs = {'num_workers': self.cfg.DATA.DATA_LOADER.NUM_WORKERS, 'pin_memory': True} # # When supported, use 'forkserver' to spawn dataloader workers instead of 'fork' to prevent # issues with Infiniband implementations that are not fork-safe # if (kwargs.get('num_workers', 0) > 0 and hasattr(mp, '_supports_context') and # mp._supports_context and 'forkserver' in mp.get_all_start_methods()): # kwargs['multiprocessing_context'] = 'forkserver' # Partition dataset among workers using DistributedSampler self.train_sampler = self.DistributedSampler(self.train_dataset, num_replicas=hvd.size(), rank=hvd.rank()) self.val_sampler = self.DistributedSampler(self.val_dataset, num_replicas=hvd.size(), rank=hvd.rank()) self.train_loader = DataLoader(self.train_dataset, batch_size=self.cfg.TRAINING.BATCH_SIZE, collate_fn=train_collate_center, sampler=self.train_sampler, **kwargs) self.val_loader = DataLoader(self.val_dataset, batch_size=self.cfg.TRAINING.BATCH_SIZE, collate_fn=train_collate_center, sampler=self.val_sampler, **kwargs) def train(self, epoch, verbose): self.train_sampler.set_epoch(epoch) self.model.set_mode('train') hm_loss = Metric('hm_loss') reg_loss = Metric('reg_loss') off_loss = Metric('off_loss') total_loss = Metric('train_loss') with tqdm(total=len(self.train_loader), desc='Train Epoch #{}'.format(epoch), disable=not verbose) as t: for j, (input_data, gaussian_hm,center_idxs, bboxes_diameters, reg_offset, reg_mask) in enumerate(self.train_loader): input_data = Variable(input_data).cuda() gaussian_hm = Variable(gaussian_hm).cuda() center_idxs = Variable(center_idxs).cuda() bboxes_diameters = Variable(bboxes_diameters).cuda() reg_offset = Variable(reg_offset).cuda() reg_mask = Variable(reg_mask).cuda() if self.cfg.TRAINING.SHEDULER.LR_SHEDULE: self.lr_shedule(epoch, j) self.optimizer.zero_grad() rpn_logits, rpn_deltas, rpn_offsets = self.model(input_data) loss, [hm_loss_, reg_loss_, off_loss_] = self.model.loss(rpn_logits, rpn_deltas, rpn_offsets, gaussian_hm, center_idxs, bboxes_diameters,reg_offset,reg_mask) if self.cfg.TRAINING.AMP : with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() self.optimizer.synchronize() with self.optimizer.skip_synchronize(): # torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), max_norm = 1.0, norm_type=2) self.optimizer.step() else: loss.backward() self.optimizer.step() total_loss.update(loss) hm_loss.update(hm_loss_) reg_loss.update(reg_loss_) off_loss.update(off_loss_) t.set_postfix({'total_loss': total_loss.avg.item(), 'hm_loss': hm_loss.avg.item(), 'reg_loss': reg_loss.avg.item(), 'off_loss': off_loss.avg.item()}) t.update(1) if self.train_writer: print('Train Epoch %d, loss %f' % (epoch, total_loss.avg.item())) print('hm_loss %f, reg_loss %f, off_loss %f' % (hm_loss.avg.item(), reg_loss.avg.item(),off_loss.avg.item())) # Write to tensorboard self.train_writer.add_scalar('loss', total_loss.avg, epoch) self.train_writer.add_scalar('hm_loss', hm_loss.avg, epoch) self.train_writer.add_scalar('reg_loss', reg_loss.avg, epoch) self.train_writer.add_scalar('off_loss', off_loss.avg, epoch) torch.cuda.empty_cache() def validate(self,epoch, verbose): # self.val_sampler.set_epoch(epoch) self.model.set_mode('valid') hm_loss = Metric('hm_loss') reg_loss = Metric('reg_loss') off_loss = Metric('off_loss') total_loss = Metric('train_loss') with tqdm(total=len(self.val_loader), desc='Validate Epoch #{}'.format(epoch), disable=not verbose) as t: for j, (input_data, gaussian_hm, center_idxs, bboxes_diameters,reg_offset,reg_mask) in enumerate(self.val_loader): with torch.no_grad(): input_data = Variable(input_data).cuda() gaussian_hm = Variable(gaussian_hm).cuda() center_idxs = Variable(center_idxs).cuda() bboxes_diameters = Variable(bboxes_diameters).cuda() reg_offset = Variable(reg_offset).cuda() reg_mask = Variable(reg_mask).cuda() rpn_logits, rpn_deltas, rpn_offsets = self.model(input_data) loss, [hm_loss_, reg_loss_, off_loss_] = self.model.loss(rpn_logits, rpn_deltas, rpn_offsets, gaussian_hm, center_idxs, bboxes_diameters,reg_offset,reg_mask) # if torch.is_nonzero(rpn_reg_loss_): total_loss.update(loss) hm_loss.update(hm_loss_) reg_loss.update(reg_loss_) off_loss.update(off_loss_) if self.val_writer: print('Val Epoch %d, loss %f' % (epoch, total_loss.avg.item())) print('hm_loss %f, reg_loss %f, off_loss %f' % (hm_loss.avg.item(), reg_loss.avg.item(),off_loss.avg.item())) # Write to tensorboard self.val_writer.add_scalar('loss', total_loss.avg, epoch) self.val_writer.add_scalar('hm_loss', hm_loss.avg, epoch) self.val_writer.add_scalar('reg_loss', reg_loss.avg, epoch) self.val_writer.add_scalar('off_loss', off_loss.avg, epoch) torch.cuda.empty_cache() def do_test(self): initial_checkpoint = self.cfg.TESTING.WEIGHT save_dir = self.cfg.TESTING.SAVER_DIR if initial_checkpoint: print('[Loading model from %s]' % initial_checkpoint) checkpoint = torch.load(initial_checkpoint) self.model.load_state_dict(checkpoint['state_dict']) epoch = checkpoint['epoch'] else: print('No model weight file specified') return self.model.set_mode('test') # model_path = self.cfg.DEPLOY.TORCHSCRIPT_SAVE_PATH # print('[Loading torchscript from %s]' % model_path) # torchscript_model = torch.jit.load(model_path) self.test_dataset = self.DataReader(self.cfg, mode='test') if not os.path.exists(save_dir): os.makedirs(save_dir) res_dir = os.path.join(save_dir, str(epoch)) if not os.path.exists(res_dir): os.makedirs(res_dir) print('Total # of eval data %d' % (len(self.test_dataset))) for i, (input, image) in enumerate(self.test_dataset): # if i == 10: # break try: D, H, W = image.shape pid = self.test_dataset.sample_bboxes[i].get_field("filename") pid = pid.split('/')[-2].replace('.nii.gz', '') print('[%d] Predicting %s' % (i, pid), image.shape) with torch.no_grad(): input = input.cuda().unsqueeze(0) # self.model = self.model.half() # input = input.half() detections = self.model.forward(input) # image = np.expand_dims(image, 0) # image = np.expand_dims(image, 0) # image = torch.from_numpy(image).cuda() # detections = torchscript_model(image.half()) detections = detections.cpu().numpy() if len(detections): detections = detections[:, 1:] np.save(os.path.join(res_dir, '%s_detections.npy' % (pid)), detections) # Clear gpu memory torch.cuda.empty_cache() except Exception as e: torch.cuda.empty_cache() traceback.print_exc() self.npy2csv('detections',res_dir)