| import os, sys |
| import re |
| import torch |
| import argparse |
| import yaml |
| import pandas as pd |
| import numpy as np |
| from glob import glob |
| from queue import Queue |
| from loguru import logger |
| from threading import Thread |
| from torch_geometric.data import Data, HeteroData |
| import torch.distributed as dist |
| import random |
| import subprocess |
| import time |
| from torch.utils.tensorboard import SummaryWriter |
| from datetime import datetime |
|
|
|
|
| |
|
|
| class AverageMeter(object): |
| """Computes and stores the average and current value""" |
|
|
| def __init__(self, length=0): |
| self.length = length |
| self.reset() |
|
|
| def reset(self): |
| if self.length > 0: |
| self.history = [] |
| else: |
| self.count = 0 |
| self.sum = 0.0 |
| self.val = 0.0 |
| self.avg = 0.0 |
|
|
| def update(self, val, num=1): |
| if self.length > 0: |
| |
| assert num == 1 |
| self.history.append(val) |
| if len(self.history) > self.length: |
| del self.history[0] |
|
|
| self.val = self.history[-1] |
| self.avg = np.mean(self.history) |
| else: |
| self.val = val |
| self.sum += val * num |
| self.count += num |
| self.avg = self.sum / self.count |
|
|
|
|
| class AVGMeter(): |
| def __init__(self): |
| self.value = 0 |
| self.cnt = 0 |
|
|
| def update(self, v_new): |
| self.value += v_new |
| self.cnt += 1 |
|
|
| def agg(self): |
| return self.value / self.cnt |
|
|
| def reset(self): |
| self.value = 0 |
| self.cnt = 0 |
|
|
|
|
| class Reporter(): |
| def __init__(self, cfg, log_dir) -> None: |
| print("="*20, cfg['log_path']) |
| self.writer = SummaryWriter(log_dir) |
| self.cfg = cfg |
|
|
| def record(self, value_dict, epoch): |
| for key in value_dict: |
| if isinstance(value_dict[key], AVGMeter): |
| self.writer.add_scalar(key, value_dict[key].agg(), epoch) |
| else: |
| self.writer.add_scalar(key, value_dict[key], epoch) |
|
|
| def close(self): |
| self.writer.close() |
|
|
|
|
| class Timer: |
| def __init__(self, rest_epochs): |
| self.elapsed_time = None |
| self.rest_epochs = rest_epochs |
| self.eta = None |
|
|
| def __enter__(self): |
| self.start_time = time.time() |
| return self |
|
|
| def __exit__(self, exc_type, exc_value, traceback): |
| self.elapsed_time = time.time() - self.start_time |
| |
| self.eta = round((self.rest_epochs * self.elapsed_time) / 3600, 2) |
|
|
|
|
|
|
| |
| def get_argparse(): |
| str2bool = lambda x: x.lower() == 'true' |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--config', type=str, default='./configs/default.yaml') |
| parser.add_argument('--distributed', default=False, action='store_true') |
| parser.add_argument('--local-rank', default=0, type=int, help='node rank for distributed training') |
| parser.add_argument("--seed", type=int, default=2024) |
| parser.add_argument("--ngpus", type=int, default=1) |
| args = parser.parse_args() |
| return args |
|
|
| def count_parameters(model): |
| total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| return total_params / 1_000_000 |
|
|
| def model_info(model, verbose=False, img_size=640): |
| |
| n_p = sum(x.numel() for x in model.parameters()) |
| n_g = sum(x.numel() for x in model.parameters() if x.requires_grad) |
| if verbose: |
| print('%5s %40s %9s %12s %20s %10s %10s' % ('layer', 'name', 'gradient', 'parameters', 'shape', 'mu', 'sigma')) |
| for i, (name, p) in enumerate(model.named_parameters()): |
| name = name.replace('module_list.', '') |
| print('%5g %40s %9s %12g %20s %10.3g %10.3g' % |
| (i, name, p.requires_grad, p.numel(), list(p.shape), p.mean(), p.std())) |
|
|
| try: |
| from thop import profile |
| flops = profile(deepcopy(model), inputs=(torch.zeros(1, 3, img_size, img_size),), verbose=False)[0] / 1E9 * 2 |
| img_size = img_size if isinstance(img_size, list) else [img_size, img_size] |
| fs = ', %.9f GFLOPS' % (flops) |
| except (ImportError, Exception): |
| fs = '' |
|
|
| logger.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}") |
|
|
| def get_cfg(): |
| args = get_argparse() |
|
|
| with open(args.config, 'r') as file: |
| cfg = yaml.safe_load(file) |
|
|
| for key, value in vars(args).items(): |
| if value is not None: |
| cfg[key] = value |
|
|
| cfg['log_path'] = os.path.join(cfg['log_path'], os.path.basename(args.config)[:-5]) |
|
|
| metadata = (cfg['data']['meta']['node'], |
| list(map(tuple, cfg['data']['meta']['edge']))) |
| return cfg, metadata |
|
|
|
|
| def init_seeds(seed=0): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
|
|
|
|
| def set_random_seed(seed, deterministic=False): |
| """Set random seed.""" |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
| if deterministic: |
| torch.backends.cudnn.enabled = True |
| torch.backends.cudnn.benchmark = False |
| torch.backends.cudnn.deterministic = True |
| else: |
| torch.backends.cudnn.enabled = True |
| torch.backends.cudnn.benchmark = True |
|
|
|
|
| def get_world_size(): |
| if not dist.is_available(): |
| return 1 |
| if not dist.is_initialized(): |
| return 1 |
| return dist.get_world_size() |
|
|
|
|
| def get_rank(): |
| if not dist.is_available(): |
| return 0 |
| if not dist.is_initialized(): |
| return 0 |
| return dist.get_rank() |
|
|
|
|
| def is_main_process(): |
| return get_rank() == 0 |
|
|
| |
|
|
|
|
| logs = set() |
|
|
|
|
| def time_str(fmt=None): |
| if fmt is None: |
| fmt = '%Y-%m-%d_%H:%M:%S' |
| return datetime.today().strftime(fmt) |
|
|
|
|
| def setup_default_logging(save_path, flag_multigpus=False, l_level='INFO'): |
|
|
| if flag_multigpus: |
| rank = dist.get_rank() |
| if rank != 0: |
| return |
|
|
| tmp_timestr = time_str(fmt='%Y_%m_%d_%H_%M_%S') |
| logger.add( |
| os.path.join(save_path, f'{tmp_timestr}.log'), |
| |
| level=l_level, |
| |
| format='{level}|{time:YYYY-MM-DD HH:mm:ss}: {message}', |
| |
| |
| enqueue=True, |
| encoding='utf-8', |
| ) |
| return tmp_timestr |
|
|
|
|
|
|
| def world_info_from_env(): |
| local_rank = 0 |
| for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): |
| if v in os.environ: |
| local_rank = int(os.environ[v]) |
| break |
| global_rank = 0 |
| for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): |
| if v in os.environ: |
| global_rank = int(os.environ[v]) |
| break |
| world_size = 1 |
| for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): |
| if v in os.environ: |
| world_size = int(os.environ[v]) |
| break |
|
|
| return local_rank, global_rank, world_size |
|
|
|
|
| def setup_distributed(backend="nccl", port=None): |
| """AdaHessian Optimizer |
| Lifted from https://github.com/BIGBALLON/distribuuuu/blob/master/distribuuuu/utils.py |
| Originally licensed MIT, Copyright (c) 2020 Wei Li |
| """ |
| num_gpus = torch.cuda.device_count() |
| |
| if "SLURM_JOB_ID" in os.environ and "ZHENSALLOC" not in os.environ: |
| _, rank, world_size = world_info_from_env() |
| node_list = os.environ["SLURM_NODELIST"] |
| addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") |
| |
| if port is not None: |
| os.environ["MASTER_PORT"] = str(port) |
| elif "MASTER_PORT" not in os.environ: |
| os.environ["MASTER_PORT"] = "10685" |
| if "MASTER_ADDR" not in os.environ: |
| os.environ["MASTER_ADDR"] = addr |
| os.environ["WORLD_SIZE"] = str(world_size) |
| os.environ["LOCAL_RANK"] = str(rank % num_gpus) |
| os.environ["RANK"] = str(rank) |
| else: |
| rank = int(os.environ["RANK"]) |
| world_size = int(os.environ["WORLD_SIZE"]) |
|
|
|
|
| torch.cuda.set_device(rank % num_gpus) |
|
|
| dist.init_process_group( |
| backend=backend, |
| world_size=world_size, |
| rank=rank, |
| ) |
|
|
| return rank, world_size |
|
|
|
|
|
|
|
|
| |
| def setup_default_logging_wt_dir(save_path, flag_multigpus=False, l_level='INFO'): |
|
|
| if flag_multigpus: |
| rank = dist.get_rank() |
| if rank != 0: |
| return |
|
|
| tmp_timestr = time_str(fmt='%Y_%m_%d_%H_%M_%S') |
| new_log_path = os.path.join(save_path, tmp_timestr) |
| os.makedirs(new_log_path, exist_ok=True) |
| logger.add( |
| os.path.join(new_log_path, f'{tmp_timestr}.log'), |
| |
| level=l_level, |
| |
| format='{level}|{time:YYYY-MM-DD HH:mm:ss}: {message}', |
| |
| |
| enqueue=True, |
| encoding='utf-8', |
| ) |
| return tmp_timestr |
|
|
|
|
| |
|
|
| def seed_worker(worker_id): |
| cur_seed = np.random.get_state()[1][0] |
| cur_seed += worker_id |
| np.random.seed(cur_seed) |
| random.seed(cur_seed) |
|
|