import sys import os from .model_utils import get_block sys.path.append(os.path.dirname(os.path.abspath(__file__))) from ..utils.RWconfig import LoadJson # os.environ["CUDA_VISIBLE_DEVICES"]="2" import torch # from torchsummary import summary def build_model(model_name,cfg): func = get_block(model_name) if "ResNet" in model_name : cfg['block'] = get_block(cfg['block']) return func(**cfg) if __name__ == "__main__": cfg_path = '../files/modelParams/model_parameters.json' model_name = "FSe_ResNet" cfg = LoadJson(cfg_path)[model_name] print ('cfg is ',cfg) model = build_model(model_name,cfg) # summary(model.cuda(),(1,48,48,48)) total_params = sum(p.numel() for p in model.parameters()) print(f'{total_params:,} total parameters.') # summary(model,(1,24,24,24),device='cpu')