| | """ |
| | MobiusNet Trainer with TensorBoard, SafeTensors, and HuggingFace Upload |
| | ======================================================================= |
| | """ |
| |
|
| | import os |
| | import re |
| | import json |
| | import math |
| | import shutil |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch import Tensor |
| | from typing import Tuple, Optional, Dict, Any |
| | from torchvision import datasets, transforms |
| | from torch.utils.data import DataLoader |
| | from torch.utils.tensorboard import SummaryWriter |
| | from tqdm.auto import tqdm |
| | from datetime import datetime |
| | from pathlib import Path |
| | from safetensors.torch import save_file as save_safetensors, load_file as load_safetensors |
| | from huggingface_hub import HfApi, login |
| |
|
| | |
| | try: |
| | from google.colab import userdata |
| | token = userdata.get('HF_TOKEN') |
| | os.environ['HF_TOKEN'] = token |
| | login(token=token) |
| | print("Logged in to HuggingFace via Colab") |
| | except: |
| | |
| | pass |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | print(f"Device: {device}") |
| |
|
| | |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| | torch.set_float32_matmul_precision('high') |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class MobiusLens(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | layer_idx: int, |
| | total_layers: int, |
| | scale_range: Tuple[float, float] = (1.0, 9.0), |
| | ): |
| | super().__init__() |
| | |
| | self.dim = dim |
| | self.layer_idx = layer_idx |
| | self.total_layers = total_layers |
| | self.t = layer_idx / max(total_layers - 1, 1) |
| | |
| | scale_span = scale_range[1] - scale_range[0] |
| | step = scale_span / max(total_layers, 1) |
| | scale_low = scale_range[0] + self.t * scale_span |
| | scale_high = scale_low + step |
| | |
| | self.register_buffer('scales', torch.tensor([scale_low, scale_high])) |
| | |
| | self.twist_in_angle = nn.Parameter(torch.tensor(self.t * math.pi)) |
| | self.twist_in_proj = nn.Linear(dim, dim, bias=False) |
| | nn.init.orthogonal_(self.twist_in_proj.weight) |
| | |
| | self.omega = nn.Parameter(torch.tensor(math.pi)) |
| | self.alpha = nn.Parameter(torch.tensor(1.5)) |
| | |
| | self.phase_l = nn.Parameter(torch.zeros(2)) |
| | self.drift_l = nn.Parameter(torch.ones(2)) |
| | self.phase_m = nn.Parameter(torch.zeros(2)) |
| | self.drift_m = nn.Parameter(torch.zeros(2)) |
| | self.phase_r = nn.Parameter(torch.zeros(2)) |
| | self.drift_r = nn.Parameter(-torch.ones(2)) |
| | |
| | self.accum_weights = nn.Parameter(torch.tensor([0.4, 0.2, 0.4])) |
| | self.xor_weight = nn.Parameter(torch.tensor(0.7)) |
| | |
| | self.gate_norm = nn.LayerNorm(dim) |
| | |
| | self.twist_out_angle = nn.Parameter(torch.tensor(-self.t * math.pi)) |
| | self.twist_out_proj = nn.Linear(dim, dim, bias=False) |
| | nn.init.orthogonal_(self.twist_out_proj.weight) |
| | |
| | def _twist_in(self, x: Tensor) -> Tensor: |
| | cos_t = torch.cos(self.twist_in_angle) |
| | sin_t = torch.sin(self.twist_in_angle) |
| | return x * cos_t + self.twist_in_proj(x) * sin_t |
| | |
| | def _center_lens(self, x: Tensor) -> Tensor: |
| | x_norm = torch.tanh(x) |
| | t = x_norm.abs().mean(dim=-1, keepdim=True).unsqueeze(-2) |
| | |
| | x_exp = x_norm.unsqueeze(-2) |
| | s = self.scales.view(-1, 1) |
| | |
| | def wave(phase, drift): |
| | a = self.alpha.abs() + 0.1 |
| | pos = s * self.omega * (x_exp + drift.view(-1, 1) * t) + phase.view(-1, 1) |
| | return torch.exp(-a * torch.sin(pos).pow(2)).prod(dim=-2) |
| | |
| | L = wave(self.phase_l, self.drift_l) |
| | M = wave(self.phase_m, self.drift_m) |
| | R = wave(self.phase_r, self.drift_r) |
| | |
| | w = torch.softmax(self.accum_weights, dim=0) |
| | xor_w = torch.sigmoid(self.xor_weight) |
| | |
| | xor_comp = (L + R - 2 * L * R).abs() |
| | and_comp = L * R |
| | lr = xor_w * xor_comp + (1 - xor_w) * and_comp |
| | |
| | gate = w[0] * L + w[1] * M + w[2] * R |
| | gate = gate * (0.5 + 0.5 * lr) |
| | gate = torch.sigmoid(self.gate_norm(gate)) |
| | |
| | return x * gate |
| | |
| | def _twist_out(self, x: Tensor) -> Tensor: |
| | cos_t = torch.cos(self.twist_out_angle) |
| | sin_t = torch.sin(self.twist_out_angle) |
| | return x * cos_t + self.twist_out_proj(x) * sin_t |
| | |
| | def forward(self, x: Tensor) -> Tensor: |
| | return self._twist_out(self._center_lens(self._twist_in(x))) |
| | |
| | def get_lens_stats(self) -> Dict[str, float]: |
| | """Return lens parameters for logging.""" |
| | return { |
| | 'omega': self.omega.item(), |
| | 'alpha': self.alpha.item(), |
| | 'twist_in_angle': self.twist_in_angle.item(), |
| | 'twist_out_angle': self.twist_out_angle.item(), |
| | 'xor_weight': torch.sigmoid(self.xor_weight).item(), |
| | 'accum_weights_l': torch.softmax(self.accum_weights, dim=0)[0].item(), |
| | 'accum_weights_m': torch.softmax(self.accum_weights, dim=0)[1].item(), |
| | 'accum_weights_r': torch.softmax(self.accum_weights, dim=0)[2].item(), |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class MobiusConvBlock(nn.Module): |
| | def __init__( |
| | self, |
| | channels: int, |
| | layer_idx: int, |
| | total_layers: int, |
| | scale_range: Tuple[float, float] = (1.0, 9.0), |
| | reduction: float = 0.5, |
| | ): |
| | super().__init__() |
| | |
| | self.conv = nn.Sequential( |
| | nn.Conv2d(channels, channels, 3, padding=1, groups=channels, bias=False), |
| | nn.Conv2d(channels, channels, 1, bias=False), |
| | nn.BatchNorm2d(channels), |
| | ) |
| | |
| | self.lens = MobiusLens(channels, layer_idx, total_layers, scale_range) |
| | |
| | third = channels // 3 |
| | which_third = layer_idx % 3 |
| | mask = torch.ones(channels) |
| | start = which_third * third |
| | end = start + third + (channels % 3 if which_third == 2 else 0) |
| | mask[start:end] = reduction |
| | self.register_buffer('thirds_mask', mask.view(1, -1, 1, 1)) |
| | |
| | self.residual_weight = nn.Parameter(torch.tensor(0.9)) |
| | |
| | def forward(self, x: Tensor) -> Tensor: |
| | identity = x |
| | |
| | h = self.conv(x) |
| | B, D, H, W = h.shape |
| | h = h.permute(0, 2, 3, 1) |
| | h = self.lens(h) |
| | h = h.permute(0, 3, 1, 2) |
| | h = h * self.thirds_mask |
| | |
| | rw = torch.sigmoid(self.residual_weight) |
| | return rw * identity + (1 - rw) * h |
| | |
| | def get_residual_weight(self) -> float: |
| | return torch.sigmoid(self.residual_weight).item() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class MobiusNet(nn.Module): |
| | def __init__( |
| | self, |
| | in_chans: int = 3, |
| | num_classes: int = 200, |
| | channels: Tuple[int, ...] = (64, 128, 256, 512), |
| | depths: Tuple[int, ...] = (2, 2, 2, 2), |
| | scale_range: Tuple[float, float] = (0.5, 2.5), |
| | use_integrator: bool = True, |
| | ): |
| | super().__init__() |
| | |
| | num_stages = len(depths) |
| | total_layers = sum(depths) |
| | |
| | self.total_layers = total_layers |
| | self.scale_range = scale_range |
| | self.channels = tuple(channels) |
| | self.depths = tuple(depths) |
| | self.num_stages = num_stages |
| | self.use_integrator = use_integrator |
| | self.num_classes = num_classes |
| | self.in_chans = in_chans |
| | |
| | channels = list(channels) |
| | while len(channels) < num_stages: |
| | channels.append(channels[-1]) |
| | |
| | self.stem = nn.Sequential( |
| | nn.Conv2d(in_chans, channels[0], 3, stride=1, padding=1, bias=False), |
| | nn.BatchNorm2d(channels[0]), |
| | ) |
| | |
| | layer_idx = 0 |
| | self.stages = nn.ModuleList() |
| | self.downsamples = nn.ModuleList() |
| | |
| | for stage_idx in range(num_stages): |
| | ch = channels[stage_idx] |
| | |
| | stage = nn.ModuleList() |
| | for _ in range(depths[stage_idx]): |
| | stage.append(MobiusConvBlock(ch, layer_idx, total_layers, scale_range)) |
| | layer_idx += 1 |
| | self.stages.append(stage) |
| | |
| | if stage_idx < num_stages - 1: |
| | ch_next = channels[stage_idx + 1] |
| | self.downsamples.append(nn.Sequential( |
| | nn.Conv2d(ch, ch_next, 3, stride=2, padding=1, bias=False), |
| | nn.BatchNorm2d(ch_next), |
| | )) |
| | |
| | final_ch = channels[num_stages - 1] |
| | if use_integrator: |
| | self.integrator = nn.Sequential( |
| | nn.Conv2d(final_ch, final_ch, 3, padding=1, bias=False), |
| | nn.BatchNorm2d(final_ch), |
| | nn.GELU(), |
| | ) |
| | else: |
| | self.integrator = nn.Identity() |
| | |
| | self.pool = nn.AdaptiveAvgPool2d(1) |
| | self.head = nn.Linear(final_ch, num_classes) |
| | |
| | def forward(self, x: Tensor) -> Tensor: |
| | x = self.stem(x) |
| | |
| | for i, stage in enumerate(self.stages): |
| | for block in stage: |
| | x = block(x) |
| | if i < len(self.downsamples): |
| | x = self.downsamples[i](x) |
| | |
| | x = self.integrator(x) |
| | return self.head(self.pool(x).flatten(1)) |
| | |
| | def get_config(self) -> Dict[str, Any]: |
| | """Return model configuration for saving.""" |
| | return { |
| | 'in_chans': self.in_chans, |
| | 'num_classes': self.num_classes, |
| | 'channels': self.channels, |
| | 'depths': self.depths, |
| | 'scale_range': self.scale_range, |
| | 'use_integrator': self.use_integrator, |
| | 'total_layers': self.total_layers, |
| | 'num_stages': self.num_stages, |
| | } |
| | |
| | def get_all_lens_stats(self) -> Dict[str, Dict[str, float]]: |
| | """Return stats from all lenses for logging.""" |
| | stats = {} |
| | layer_idx = 0 |
| | for stage_idx, stage in enumerate(self.stages): |
| | for block_idx, block in enumerate(stage): |
| | key = f"stage{stage_idx}_block{block_idx}" |
| | stats[key] = block.lens.get_lens_stats() |
| | stats[key]['residual_weight'] = block.get_residual_weight() |
| | layer_idx += 1 |
| | return stats |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def get_tiny_imagenet_loaders(data_dir='./data/tiny-imagenet-200', batch_size=128): |
| | train_dir = os.path.join(data_dir, 'train') |
| | val_dir = os.path.join(data_dir, 'val') |
| | |
| | val_images_dir = os.path.join(val_dir, 'images') |
| | if os.path.exists(val_images_dir): |
| | print("Reorganizing validation folder...") |
| | reorganize_val_folder(val_dir) |
| | |
| | train_transform = transforms.Compose([ |
| | transforms.RandomCrop(64, padding=8), |
| | transforms.RandomHorizontalFlip(), |
| | transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET), |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| | ]) |
| | |
| | val_transform = transforms.Compose([ |
| | transforms.ToTensor(), |
| | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| | ]) |
| | |
| | train_dataset = datasets.ImageFolder(train_dir, transform=train_transform) |
| | val_dataset = datasets.ImageFolder(val_dir, transform=val_transform) |
| | |
| | train_loader = DataLoader( |
| | train_dataset, batch_size=batch_size, shuffle=True, |
| | num_workers=8, pin_memory=True, persistent_workers=True |
| | ) |
| | val_loader = DataLoader( |
| | val_dataset, batch_size=256, shuffle=False, |
| | num_workers=4, pin_memory=True, persistent_workers=True |
| | ) |
| | |
| | return train_loader, val_loader |
| |
|
| |
|
| | def reorganize_val_folder(val_dir): |
| | """Reorganize Tiny ImageNet val folder into class subfolders.""" |
| | val_images_dir = os.path.join(val_dir, 'images') |
| | val_annotations = os.path.join(val_dir, 'val_annotations.txt') |
| | |
| | if not os.path.exists(val_images_dir): |
| | return |
| | |
| | with open(val_annotations, 'r') as f: |
| | for line in f: |
| | parts = line.strip().split('\t') |
| | img_name, class_id = parts[0], parts[1] |
| | |
| | class_dir = os.path.join(val_dir, class_id) |
| | os.makedirs(class_dir, exist_ok=True) |
| | |
| | src = os.path.join(val_images_dir, img_name) |
| | dst = os.path.join(class_dir, img_name) |
| | |
| | if os.path.exists(src): |
| | shutil.move(src, dst) |
| | |
| | if os.path.exists(val_images_dir): |
| | shutil.rmtree(val_images_dir) |
| | if os.path.exists(val_annotations): |
| | os.remove(val_annotations) |
| | |
| | print("Validation folder reorganized.") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | PRESETS = { |
| | 'mobius_tiny_s': { |
| | 'channels': (64, 128, 256), |
| | 'depths': (2, 2, 2), |
| | 'scale_range': (0.5, 2.5), |
| | }, |
| | 'mobius_tiny_m': { |
| | 'channels': (64, 128, 256, 512, 768), |
| | 'depths': (2, 2, 4, 2, 2), |
| | 'scale_range': (0.25, 2.75), |
| | }, |
| | 'mobius_tiny_l': { |
| | 'channels': (96, 192, 384, 768), |
| | 'depths': (3, 3, 3, 3), |
| | 'scale_range': (0.5, 3.5), |
| | }, |
| | 'mobius_base': { |
| | 'channels': (128, 256, 512, 768, 1024), |
| | 'depths': (2, 2, 2, 2, 2), |
| | 'scale_range': (0.25, 2.75), |
| | }, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class CheckpointManager: |
| | def __init__( |
| | self, |
| | base_dir: str, |
| | variant_name: str, |
| | dataset_name: str, |
| | hf_repo: str = "AbstractPhil/mobiusnet", |
| | upload_every_n_epochs: int = 10, |
| | save_every_n_epochs: int = 10, |
| | timestamp: Optional[str] = None, |
| | ): |
| | self.timestamp = timestamp or datetime.now().strftime("%Y%m%d_%H%M%S") |
| | self.variant_name = variant_name |
| | self.dataset_name = dataset_name |
| | self.hf_repo = hf_repo |
| | self.upload_every_n_epochs = upload_every_n_epochs |
| | self.save_every_n_epochs = save_every_n_epochs |
| | |
| | |
| | self.run_name = f"{variant_name}_{dataset_name}" |
| | self.run_dir = Path(base_dir) / "checkpoints" / self.run_name / self.timestamp |
| | self.checkpoints_dir = self.run_dir / "checkpoints" |
| | self.tensorboard_dir = self.run_dir / "tensorboard" |
| | |
| | |
| | self.checkpoints_dir.mkdir(parents=True, exist_ok=True) |
| | self.tensorboard_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | self.writer = SummaryWriter(log_dir=str(self.tensorboard_dir)) |
| | |
| | |
| | self.hf_api = HfApi() |
| | self.uploaded_files = set() |
| | |
| | |
| | self.best_acc = 0.0 |
| | self.best_epoch = 0 |
| | self.best_changed_since_upload = False |
| | |
| | print(f"Checkpoint directory: {self.run_dir}") |
| | |
| | @staticmethod |
| | def extract_timestamp(checkpoint_path: str) -> Optional[str]: |
| | """Extract timestamp from checkpoint path.""" |
| | |
| | match = re.search(r'(\d{8}_\d{6})', checkpoint_path) |
| | if match: |
| | return match.group(1) |
| | return None |
| | |
| | def save_config(self, config: Dict[str, Any], training_config: Dict[str, Any]): |
| | """Save model and training configuration.""" |
| | full_config = { |
| | 'model': config, |
| | 'training': training_config, |
| | 'timestamp': self.timestamp, |
| | 'variant_name': self.variant_name, |
| | 'dataset_name': self.dataset_name, |
| | } |
| | |
| | config_path = self.run_dir / "config.json" |
| | with open(config_path, 'w') as f: |
| | json.dump(full_config, f, indent=2) |
| | |
| | return config_path |
| | |
| | def save_checkpoint( |
| | self, |
| | model: nn.Module, |
| | optimizer: torch.optim.Optimizer, |
| | scheduler: Any, |
| | epoch: int, |
| | train_acc: float, |
| | val_acc: float, |
| | train_loss: float, |
| | is_best: bool = False, |
| | ): |
| | """Save checkpoint every N epochs, always save best (overwriting).""" |
| | |
| | |
| | raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
| | |
| | |
| | checkpoint = { |
| | 'epoch': epoch, |
| | 'train_acc': train_acc, |
| | 'val_acc': val_acc, |
| | 'train_loss': train_loss, |
| | 'best_acc': self.best_acc, |
| | 'optimizer_state_dict': optimizer.state_dict(), |
| | 'scheduler_state_dict': scheduler.state_dict(), |
| | } |
| | |
| | |
| | if epoch % self.save_every_n_epochs == 0: |
| | epoch_pt_path = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.pt" |
| | torch.save({**checkpoint, 'model_state_dict': raw_model.state_dict()}, epoch_pt_path) |
| | |
| | epoch_st_path = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.safetensors" |
| | save_safetensors(raw_model.state_dict(), str(epoch_st_path)) |
| | |
| | |
| | if is_best: |
| | self.best_acc = val_acc |
| | self.best_epoch = epoch |
| | self.best_changed_since_upload = True |
| | |
| | |
| | best_pt_path = self.checkpoints_dir / "best_model.pt" |
| | torch.save({**checkpoint, 'model_state_dict': raw_model.state_dict()}, best_pt_path) |
| | |
| | |
| | best_st_path = self.checkpoints_dir / "best_model.safetensors" |
| | save_safetensors(raw_model.state_dict(), str(best_st_path)) |
| | |
| | |
| | acc_path = self.run_dir / "best_accuracy.json" |
| | with open(acc_path, 'w') as f: |
| | json.dump({ |
| | 'best_acc': val_acc, |
| | 'best_epoch': epoch, |
| | 'train_acc': train_acc, |
| | 'train_loss': train_loss, |
| | }, f, indent=2) |
| | |
| | def save_final(self, model: nn.Module, final_acc: float, final_epoch: int): |
| | """Save final model.""" |
| | raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
| | |
| | |
| | final_st_path = self.checkpoints_dir / "final_model.safetensors" |
| | save_safetensors(raw_model.state_dict(), str(final_st_path)) |
| | |
| | |
| | final_pt_path = self.checkpoints_dir / "final_model.pt" |
| | torch.save({ |
| | 'model_state_dict': raw_model.state_dict(), |
| | 'final_acc': final_acc, |
| | 'final_epoch': final_epoch, |
| | 'best_acc': self.best_acc, |
| | 'best_epoch': self.best_epoch, |
| | }, final_pt_path) |
| | |
| | |
| | acc_path = self.run_dir / "final_accuracy.json" |
| | with open(acc_path, 'w') as f: |
| | json.dump({ |
| | 'final_acc': final_acc, |
| | 'final_epoch': final_epoch, |
| | 'best_acc': self.best_acc, |
| | 'best_epoch': self.best_epoch, |
| | }, f, indent=2) |
| | |
| | return final_st_path, final_pt_path |
| | |
| | def log_scalars(self, epoch: int, scalars: Dict[str, float], prefix: str = ""): |
| | """Log scalars to TensorBoard.""" |
| | for name, value in scalars.items(): |
| | tag = f"{prefix}/{name}" if prefix else name |
| | self.writer.add_scalar(tag, value, epoch) |
| | |
| | def log_lens_stats(self, epoch: int, model: nn.Module): |
| | """Log lens statistics to TensorBoard.""" |
| | raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
| | stats = raw_model.get_all_lens_stats() |
| | |
| | for block_name, block_stats in stats.items(): |
| | for stat_name, value in block_stats.items(): |
| | self.writer.add_scalar(f"lens/{block_name}/{stat_name}", value, epoch) |
| | |
| | def log_histograms(self, epoch: int, model: nn.Module): |
| | """Log weight histograms to TensorBoard.""" |
| | raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
| | |
| | for name, param in raw_model.named_parameters(): |
| | if param.requires_grad: |
| | self.writer.add_histogram(f"weights/{name}", param.data, epoch) |
| | if param.grad is not None: |
| | self.writer.add_histogram(f"gradients/{name}", param.grad, epoch) |
| | |
| | def upload_to_hf(self, epoch: int, force: bool = False): |
| | """Upload checkpoint every N epochs. Best uploads only on upload epochs if changed.""" |
| | if not force and epoch % self.upload_every_n_epochs != 0: |
| | return |
| | |
| | try: |
| | hf_base_path = f"checkpoints/{self.run_name}/{self.timestamp}" |
| | |
| | files_to_upload = [] |
| | |
| | |
| | config_path = self.run_dir / "config.json" |
| | if config_path.exists(): |
| | files_to_upload.append(config_path) |
| | |
| | |
| | if epoch % self.save_every_n_epochs == 0: |
| | ckpt_st = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.safetensors" |
| | ckpt_pt = self.checkpoints_dir / f"checkpoint_epoch_{epoch:04d}.pt" |
| | if ckpt_st.exists(): |
| | files_to_upload.append(ckpt_st) |
| | if ckpt_pt.exists(): |
| | files_to_upload.append(ckpt_pt) |
| | |
| | |
| | if self.best_changed_since_upload: |
| | best_files = [ |
| | self.checkpoints_dir / "best_model.safetensors", |
| | self.checkpoints_dir / "best_model.pt", |
| | self.run_dir / "best_accuracy.json", |
| | ] |
| | for f in best_files: |
| | if f.exists(): |
| | files_to_upload.append(f) |
| | self.best_changed_since_upload = False |
| | |
| | |
| | for local_path in files_to_upload: |
| | rel_path = local_path.relative_to(self.run_dir) |
| | hf_path = f"{hf_base_path}/{rel_path}" |
| | |
| | try: |
| | self.hf_api.upload_file( |
| | path_or_fileobj=str(local_path), |
| | path_in_repo=hf_path, |
| | repo_id=self.hf_repo, |
| | repo_type="model", |
| | ) |
| | print(f"Uploaded: {hf_path}") |
| | except Exception as e: |
| | print(f"Failed to upload {rel_path}: {e}") |
| | |
| | except Exception as e: |
| | print(f"HuggingFace upload error: {e}") |
| | |
| | def close(self): |
| | """Close TensorBoard writer.""" |
| | self.writer.close() |
| | |
| | @staticmethod |
| | def load_checkpoint( |
| | checkpoint_path: str, |
| | model: nn.Module, |
| | optimizer: Optional[torch.optim.Optimizer] = None, |
| | scheduler: Optional[Any] = None, |
| | hf_repo: str = "AbstractPhil/mobiusnet", |
| | device: torch.device = torch.device('cpu'), |
| | ) -> Dict[str, Any]: |
| | """ |
| | Load checkpoint from local path or HuggingFace repo. |
| | |
| | Args: |
| | checkpoint_path: Either: |
| | - Local file path to .pt checkpoint |
| | - Local directory containing checkpoints |
| | - HuggingFace path like "checkpoints/variant_dataset/timestamp" |
| | model: Model to load weights into |
| | optimizer: Optional optimizer to restore state |
| | scheduler: Optional scheduler to restore state |
| | hf_repo: HuggingFace repo ID |
| | device: Device to load tensors to |
| | |
| | Returns: |
| | Dict with checkpoint info (epoch, best_acc, etc.) |
| | """ |
| | from huggingface_hub import hf_hub_download, list_repo_files |
| | |
| | checkpoint_file = None |
| | |
| | |
| | if os.path.isfile(checkpoint_path): |
| | checkpoint_file = checkpoint_path |
| | |
| | |
| | elif os.path.isdir(checkpoint_path): |
| | |
| | best_path = os.path.join(checkpoint_path, "checkpoints", "best_model.pt") |
| | if os.path.exists(best_path): |
| | checkpoint_file = best_path |
| | else: |
| | |
| | ckpt_dir = os.path.join(checkpoint_path, "checkpoints") |
| | if os.path.isdir(ckpt_dir): |
| | pt_files = sorted([f for f in os.listdir(ckpt_dir) if f.startswith("checkpoint_epoch_") and f.endswith(".pt")]) |
| | if pt_files: |
| | checkpoint_file = os.path.join(ckpt_dir, pt_files[-1]) |
| | |
| | |
| | if checkpoint_file is None: |
| | print(f"Attempting to download from HuggingFace: {hf_repo}/{checkpoint_path}") |
| | try: |
| | |
| | if not checkpoint_path.endswith(".pt"): |
| | |
| | try: |
| | checkpoint_file = hf_hub_download( |
| | repo_id=hf_repo, |
| | filename=f"{checkpoint_path}/checkpoints/best_model.pt", |
| | repo_type="model", |
| | ) |
| | print(f"Downloaded best_model.pt from {hf_repo}") |
| | except: |
| | |
| | files = list_repo_files(repo_id=hf_repo, repo_type="model") |
| | ckpt_files = sorted([f for f in files if checkpoint_path in f and f.endswith(".pt") and "checkpoint_epoch_" in f]) |
| | if ckpt_files: |
| | checkpoint_file = hf_hub_download( |
| | repo_id=hf_repo, |
| | filename=ckpt_files[-1], |
| | repo_type="model", |
| | ) |
| | print(f"Downloaded {ckpt_files[-1]} from {hf_repo}") |
| | else: |
| | |
| | checkpoint_file = hf_hub_download( |
| | repo_id=hf_repo, |
| | filename=checkpoint_path, |
| | repo_type="model", |
| | ) |
| | print(f"Downloaded {checkpoint_path} from {hf_repo}") |
| | except Exception as e: |
| | raise FileNotFoundError(f"Could not find or download checkpoint: {checkpoint_path}. Error: {e}") |
| | |
| | if checkpoint_file is None: |
| | raise FileNotFoundError(f"Could not find checkpoint: {checkpoint_path}") |
| | |
| | print(f"Loading checkpoint from: {checkpoint_file}") |
| | checkpoint = torch.load(checkpoint_file, map_location=device, weights_only=False) |
| | |
| | |
| | raw_model = model._orig_mod if hasattr(model, '_orig_mod') else model |
| | raw_model.load_state_dict(checkpoint['model_state_dict']) |
| | print(f"Loaded model weights") |
| | |
| | |
| | if optimizer is not None and 'optimizer_state_dict' in checkpoint: |
| | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| | print(f"Loaded optimizer state") |
| | |
| | |
| | if scheduler is not None and 'scheduler_state_dict' in checkpoint: |
| | scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
| | print(f"Loaded scheduler state") |
| | |
| | info = { |
| | 'epoch': checkpoint.get('epoch', 0), |
| | 'best_acc': checkpoint.get('best_acc', 0.0), |
| | 'train_acc': checkpoint.get('train_acc', 0.0), |
| | 'val_acc': checkpoint.get('val_acc', 0.0), |
| | 'train_loss': checkpoint.get('train_loss', 0.0), |
| | } |
| | |
| | print(f"Resuming from epoch {info['epoch']} (best_acc: {info['best_acc']:.4f})") |
| | |
| | return info |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def train_tiny_imagenet( |
| | preset: str = 'mobius_tiny_m', |
| | epochs: int = 100, |
| | lr: float = 1e-3, |
| | batch_size: int = 128, |
| | use_integrator: bool = True, |
| | data_dir: str = './data/tiny-imagenet-200', |
| | output_dir: str = './outputs', |
| | hf_repo: str = "AbstractPhil/mobiusnet", |
| | save_every_n_epochs: int = 10, |
| | upload_every_n_epochs: int = 10, |
| | log_histograms_every: int = 10, |
| | use_compile: bool = True, |
| | continue_from: Optional[str] = None, |
| | ): |
| | """ |
| | Train MobiusNet on Tiny ImageNet. |
| | |
| | Args: |
| | preset: Model preset name |
| | epochs: Total epochs to train |
| | lr: Learning rate |
| | batch_size: Batch size |
| | use_integrator: Whether to use integrator layer |
| | data_dir: Path to Tiny ImageNet data |
| | output_dir: Output directory for checkpoints |
| | hf_repo: HuggingFace repo for uploads/downloads |
| | save_every_n_epochs: Save checkpoint every N epochs |
| | upload_every_n_epochs: Upload to HF every N epochs |
| | log_histograms_every: Log weight histograms every N epochs |
| | use_compile: Whether to use torch.compile |
| | continue_from: Resume from checkpoint. Can be: |
| | - Local .pt file path |
| | - Local checkpoint directory |
| | - HuggingFace path (e.g., "checkpoints/mobius_base_tiny_imagenet/20240101_120000") |
| | """ |
| | config = PRESETS[preset] |
| | dataset_name = "tiny_imagenet" |
| | |
| | print("=" * 70) |
| | print(f"MÖBIUS NET - {preset.upper()} - TINY IMAGENET") |
| | print("=" * 70) |
| | print(f"Device: {device}") |
| | print(f"Channels: {config['channels']}") |
| | print(f"Depths: {config['depths']}") |
| | print(f"Scale range: {config['scale_range']}") |
| | print(f"Integrator: {use_integrator}") |
| | if continue_from: |
| | print(f"Continuing from: {continue_from}") |
| | print() |
| | |
| | |
| | resume_timestamp = None |
| | if continue_from: |
| | resume_timestamp = CheckpointManager.extract_timestamp(continue_from) |
| | if resume_timestamp: |
| | print(f"Using original timestamp: {resume_timestamp}") |
| | |
| | |
| | ckpt_manager = CheckpointManager( |
| | base_dir=output_dir, |
| | variant_name=preset, |
| | dataset_name=dataset_name, |
| | hf_repo=hf_repo, |
| | upload_every_n_epochs=upload_every_n_epochs, |
| | save_every_n_epochs=save_every_n_epochs, |
| | timestamp=resume_timestamp, |
| | ) |
| | |
| | |
| | train_loader, val_loader = get_tiny_imagenet_loaders(data_dir, batch_size) |
| | |
| | |
| | model = MobiusNet( |
| | in_chans=3, |
| | num_classes=200, |
| | use_integrator=use_integrator, |
| | **config |
| | ).to(device) |
| | |
| | total_params = sum(p.numel() for p in model.parameters()) |
| | print(f"Total params: {total_params:,}") |
| | print() |
| | |
| | |
| | training_config = { |
| | 'epochs': epochs, |
| | 'lr': lr, |
| | 'batch_size': batch_size, |
| | 'optimizer': 'AdamW', |
| | 'weight_decay': 0.05, |
| | 'scheduler': 'CosineAnnealingLR', |
| | 'total_params': total_params, |
| | } |
| | ckpt_manager.save_config(model.get_config(), training_config) |
| | |
| | |
| | if use_compile: |
| | model = torch.compile(model, mode='reduce-overhead') |
| | |
| | |
| | optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05) |
| | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) |
| | |
| | |
| | start_epoch = 1 |
| | best_acc = 0.0 |
| | |
| | if continue_from: |
| | ckpt_info = CheckpointManager.load_checkpoint( |
| | checkpoint_path=continue_from, |
| | model=model, |
| | optimizer=optimizer, |
| | scheduler=scheduler, |
| | hf_repo=hf_repo, |
| | device=device, |
| | ) |
| | start_epoch = ckpt_info['epoch'] + 1 |
| | best_acc = ckpt_info['best_acc'] |
| | ckpt_manager.best_acc = best_acc |
| | ckpt_manager.best_epoch = ckpt_info['epoch'] |
| | print(f"Resuming training from epoch {start_epoch}") |
| | |
| | for epoch in range(start_epoch, epochs + 1): |
| | |
| | model.train() |
| | train_loss, train_correct, train_total = 0, 0, 0 |
| | |
| | pbar = tqdm(train_loader, desc=f"Epoch {epoch:3d}") |
| | for x, y in pbar: |
| | x, y = x.to(device), y.to(device) |
| | |
| | optimizer.zero_grad() |
| | logits = model(x) |
| | loss = F.cross_entropy(logits, y) |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| | optimizer.step() |
| | |
| | train_loss += loss.item() * x.size(0) |
| | train_correct += (logits.argmax(1) == y).sum().item() |
| | train_total += x.size(0) |
| | |
| | pbar.set_postfix(loss=f"{loss.item():.4f}") |
| | |
| | scheduler.step() |
| | |
| | |
| | model.eval() |
| | val_correct, val_total = 0, 0 |
| | with torch.no_grad(): |
| | for x, y in val_loader: |
| | x, y = x.to(device), y.to(device) |
| | logits = model(x) |
| | val_correct += (logits.argmax(1) == y).sum().item() |
| | val_total += x.size(0) |
| | |
| | |
| | train_acc = train_correct / train_total |
| | val_acc = val_correct / val_total |
| | avg_loss = train_loss / train_total |
| | current_lr = scheduler.get_last_lr()[0] |
| | |
| | is_best = val_acc > best_acc |
| | if is_best: |
| | best_acc = val_acc |
| | |
| | marker = " ★" if is_best else "" |
| | print(f"Epoch {epoch:3d} | Loss: {avg_loss:.4f} | " |
| | f"Train: {train_acc:.4f} | Val: {val_acc:.4f} | Best: {best_acc:.4f}{marker}") |
| | |
| | |
| | ckpt_manager.log_scalars(epoch, { |
| | 'loss': avg_loss, |
| | 'train_acc': train_acc, |
| | 'val_acc': val_acc, |
| | 'best_acc': best_acc, |
| | 'learning_rate': current_lr, |
| | }, prefix="train") |
| | |
| | |
| | ckpt_manager.log_lens_stats(epoch, model) |
| | |
| | |
| | if epoch % log_histograms_every == 0: |
| | ckpt_manager.log_histograms(epoch, model) |
| | |
| | |
| | ckpt_manager.save_checkpoint( |
| | model=model, |
| | optimizer=optimizer, |
| | scheduler=scheduler, |
| | epoch=epoch, |
| | train_acc=train_acc, |
| | val_acc=val_acc, |
| | train_loss=avg_loss, |
| | is_best=is_best, |
| | ) |
| | |
| | |
| | ckpt_manager.upload_to_hf(epoch) |
| | |
| | |
| | ckpt_manager.save_final(model, val_acc, epochs) |
| | |
| | |
| | ckpt_manager.upload_to_hf(epochs, force=True) |
| | ckpt_manager.close() |
| | |
| | print() |
| | print("=" * 70) |
| | print("FINAL RESULTS") |
| | print("=" * 70) |
| | print(f"Preset: {preset}") |
| | print(f"Best accuracy: {best_acc:.4f}") |
| | print(f"Total params: {total_params:,}") |
| | print(f"Checkpoints: {ckpt_manager.run_dir}") |
| | print("=" * 70) |
| | |
| | return model, best_acc |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == '__main__': |
| | model, best_acc = train_tiny_imagenet( |
| | preset='mobius_base', |
| | epochs=200, |
| | lr=3e-4, |
| | batch_size=128, |
| | use_integrator=True, |
| | data_dir='./data/tiny-imagenet-200', |
| | output_dir='./outputs', |
| | hf_repo='AbstractPhil/mobiusnet', |
| | save_every_n_epochs=10, |
| | upload_every_n_epochs=10, |
| | log_histograms_every=10, |
| | use_compile=True, |
| | continue_from='/content/outputs/checkpoints/mobius_base_tiny_imagenet/20260110_132436/checkpoints/best_model.pt', |
| | |
| | |
| | |
| | |
| | ) |