diff --git a/XPointMLTest.py b/XPointMLTest.py index 3027485..c635303 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -150,23 +150,38 @@ def getPgkylData(paramFile, frameNumber, verbosity): params["polyOrderOverride"] = 0 #Override default dg interpolation and interpolate to given number of points constrcutBandJ = 1 #Read vector potential + import time as _time + _t0 = _time.time() var = gkData.gkData(str(paramFile),frameNumber,'psi',params).compactRead() + _t1 = _time.time() + print(f" [PROFILE] compactRead: {_t1-_t0:.1f}s") psi = var.data coords = var.coords axesNorm = var.d[ var.speciesFileIndex.index('ion') ] if verbosity > 0: print(f"psi shape: {psi.shape}, min={psi.min()}, max={psi.max()}") #Construct B and J (first and second derivatives) + _t2 = _time.time() [df_dx,df_dy,df_dz] = auxFuncs.genGradient(psi,var.dx) [d2f_dxdx,d2f_dxdy,d2f_dxdz] = auxFuncs.genGradient(df_dx,var.dx) [d2f_dydx,d2f_dydy,d2f_dydz] = auxFuncs.genGradient(df_dy,var.dx) + _t3 = _time.time() + print(f" [PROFILE] 3x genGradient: {_t3-_t2:.1f}s") bx = df_dy by = -df_dx jz = -(d2f_dxdx + d2f_dydy) / var.mu0 - del df_dx,df_dy,df_dz,d2f_dxdx,d2f_dxdy,d2f_dxdz,d2f_dydx,d2f_dydy,d2f_dydz + #Precompute Hessian from already-computed derivatives (avoid redundant gradient calls) + Hess = np.array([d2f_dxdx, d2f_dxdy, d2f_dxdy, d2f_dydy]) + del df_dz,d2f_dxdz,d2f_dydz,d2f_dxdx,d2f_dxdy,d2f_dydx,d2f_dydy,df_dx,df_dy #Indicies of critical points, X points, and O points (max and min) + _t4 = _time.time() critPoints = auxFuncs.getCritPoints(psi) - [xpts, optsMax, optsMin] = auxFuncs.getXOPoints(psi, critPoints) + _t5 = _time.time() + print(f" [PROFILE] getCritPoints: {_t5-_t4:.1f}s") + [xpts, optsMax, optsMin] = auxFuncs.getXOPoints(psi, critPoints, hessian=Hess) + _t6 = _time.time() + print(f" [PROFILE] getXOPoints: {_t6-_t5:.1f}s") + print(f" [PROFILE] TOTAL: {_t6-_t0:.1f}s") return [var.filenameBase, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] def cachedPgkylDataExists(cacheDir, frameNumber, fieldName): @@ -329,7 +344,8 @@ def load(self, fnum): class XPointPatchDataset(Dataset): """On‑the‑fly square crops with data augmentation, balancing positive / background patches.""" - def __init__(self, base_ds, patch=64, pos_ratio=0.5, retries=30, augment=False, seed=None): + def __init__(self, base_ds, patch=64, pos_ratio=0.5, retries=30, augment=False, seed=None, + noise_std_max=0.02, cutout_prob=0.20): """ Parameters: ----------- @@ -345,12 +361,18 @@ def __init__(self, base_ds, patch=64, pos_ratio=0.5, retries=30, augment=False, If True, apply on-the-fly data augmentation (use for training only) seed : int or None Random seed for reproducibility (None for non-deterministic) + noise_std_max : float + Maximum std for Gaussian noise augmentation (default: 0.02) + cutout_prob : float + Probability of applying cutout augmentation (default: 0.20) """ self.base_ds = base_ds self.patch = patch self.pos_ratio = pos_ratio self.retries = retries self.augment = augment + self.noise_std_max = noise_std_max + self.cutout_prob = cutout_prob # Initialize RNG with seed if provided if seed is not None: @@ -362,6 +384,13 @@ def __len__(self): # give each full frame K random crops per epoch (K=32 for more samples) return len(self.base_ds) * 32 + def reset_rng(self, seed=None): + """Reset RNG for deterministic cropping (useful for fixed validation).""" + if seed is not None: + self.rng = np.random.default_rng(seed) + else: + self.rng = np.random.default_rng() + def _crop(self, arr, top, left): return arr[..., top:top+self.patch, left:left+self.patch] @@ -400,22 +429,23 @@ def _apply_augmentation(self, all_data, mask): # 4. Add Gaussian noise (30% chance) # Small noise helps prevent overfitting to exact pixel values if self.rng.random() < 0.3: - noise_std = self.rng.uniform(0.005, 0.02) + noise_std = self.rng.uniform(0.005, self.noise_std_max) noise = torch.randn_like(all_data) * noise_std all_data = all_data + noise - # 5. Random brightness/contrast adjustment per channel (30% chance) - # Helps model become invariant to intensity variations + # 5. Random brightness/contrast adjustment (30% chance) + # CHANGED: Applied globally across channels to preserve physical relationships + # (e.g., keeping the derivative relationship between psi and B fields) if self.rng.random() < 0.3: - for c in range(all_data.shape[0]): - brightness = self.rng.uniform(-0.1, 0.1) - contrast = self.rng.uniform(0.9, 1.1) - mean = all_data[c].mean() - all_data[c] = contrast * (all_data[c] - mean) + mean + brightness + brightness = self.rng.uniform(-0.1, 0.1) + contrast = self.rng.uniform(0.9, 1.1) + # Apply same transformation to all channels + mean = all_data.mean(dim=(-2, -1), keepdim=True) + all_data = contrast * (all_data - mean) + mean + brightness - # 6. Cutout/Random erasing (20% chance) + # 6. Cutout/Random erasing # Prevents model from relying too heavily on specific spatial features - if self.rng.random() < 0.2: + if self.rng.random() < self.cutout_prob: h, w = all_data.shape[-2:] cutout_size = int(min(h, w) * self.rng.uniform(0.1, 0.25)) if cutout_size > 0: @@ -451,7 +481,9 @@ def __getitem__(self, _): want_pos = (attempt / self.retries) < self.pos_ratio if has_pos == want_pos or attempt == self.retries - 1: - all_crop = self._crop(frame["all"], y0, x0) + # Clone to avoid in-place augmentation modifying cached base frames + all_crop = self._crop(frame["all"], y0, x0).clone() + crop_mask = crop_mask.clone() # Apply augmentation if enabled all_crop, crop_mask = self._apply_augmentation(all_crop, crop_mask) @@ -618,6 +650,35 @@ def forward(self, inputs, targets): # Return Dice loss (1 - Dice coefficient) return 1.0 - dice + +class FocalDiceLoss(nn.Module): + """Combined Focal + Dice loss for extreme class imbalance. + + Focal loss downweights easy negatives so the model focuses on hard + positives near X-point boundaries. Combined with Dice loss + which directly optimizes region overlap. + """ + def __init__(self, alpha=0.75, gamma=2.0, dice_weight=0.5, smooth=1.0): + super().__init__() + self.alpha = alpha + self.gamma = gamma + self.dice_weight = dice_weight + self.dice = DiceLoss(smooth=smooth) + + def forward(self, inputs, targets): + # Focal loss + bce = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') + p = torch.sigmoid(inputs) + pt = p * targets + (1 - p) * (1 - targets) + alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) + focal = alpha_t * ((1 - pt) ** self.gamma) * bce + focal_loss = focal.mean() + + # Dice loss + dice_loss = self.dice(inputs, targets) + + return (1 - self.dice_weight) * focal_loss + self.dice_weight * dice_loss + # TRAIN & VALIDATION UTILS def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp, amp_dtype, benchmark=None): model.train() @@ -864,6 +925,25 @@ def parseCommandLineArgs(): help='specify the weight decay (L2 regularization) for optimizer') parser.add_argument('--dropoutRate', type=float, default=0.2, help='specify the dropout rate for regularization') + parser.add_argument('--baseChannels', type=int, default=64, + help='base number of channels in the UNet encoder (default: 64)') + parser.add_argument('--posRatio', type=float, default=0.5, + help='target ratio of patches containing X-points (default: 0.5)') + parser.add_argument('--lossFunction', type=str, default='dice', + choices=['dice', 'focal_dice'], + help='loss function: dice (default) or focal_dice (combined focal + dice)') + parser.add_argument('--focalAlpha', type=float, default=0.75, + help='focal loss alpha (class balance weight, default: 0.75)') + parser.add_argument('--focalGamma', type=float, default=2.0, + help='focal loss gamma (focusing parameter, default: 2.0)') + parser.add_argument('--focalDiceWeight', type=float, default=0.5, + help='weight of dice component in FocalDiceLoss (default: 0.5)') + parser.add_argument('--warmupEpochs', type=int, default=0, + help='number of linear warmup epochs before cosine decay (default: 0)') + parser.add_argument('--swa', action='store_true', + help='enable Stochastic Weight Averaging for better generalization') + parser.add_argument('--swaStart', type=float, default=0.75, + help='fraction of total epochs after which SWA begins (default: 0.75)') parser.add_argument('--batchSize', type=int, default=1, help='specify the batch size') parser.add_argument('--epochs', type=int, default=2000, @@ -903,6 +983,17 @@ def parseCommandLineArgs(): choices=['float16', 'bfloat16'], help='data type for mixed precision (bfloat16 recommended)') parser.add_argument('--patience', type=int, default=15, help='patience for early stopping (default: 15)') + parser.add_argument('--early-stop-min-delta', type=float, default=0.0, + help='minimum improvement in validation loss to reset early stopping (default: 0.0)') + parser.add_argument('--scheduler', type=str, default='cosine', + choices=['cosine', 'plateau'], + help='learning rate scheduler type (cosine or plateau)') + parser.add_argument('--plateau-factor', type=float, default=0.5, + help='ReduceLROnPlateau factor (default: 0.5)') + parser.add_argument('--plateau-patience', type=int, default=5, + help='ReduceLROnPlateau patience in epochs (default: 5)') + parser.add_argument('--plateau-min-lr', type=float, default=1e-6, + help='ReduceLROnPlateau minimum learning rate (default: 1e-6)') parser.add_argument('--benchmark', action='store_true', help='enable performance benchmarking (tracks timing, throughput, GPU memory)') parser.add_argument('--benchmark-output', type=Path, default='./benchmark_results.json', @@ -911,6 +1002,8 @@ def parseCommandLineArgs(): help='path to save evaluation metrics JSON file (default: ./evaluation_metrics.json)') parser.add_argument('--seed', type=int, default=None, help='random seed for reproducibility (default: None for non-deterministic)') + parser.add_argument('--fixed-val-crops', action=argparse.BooleanOptionalAction, default=False, + help='use deterministic validation crops each epoch (default: False)') parser.add_argument('--require-gpu', action='store_true', help='require GPU to be available, exit if not found') @@ -966,6 +1059,22 @@ def checkCommandLineArgs(args): print(f"minTrainingLoss must be >= 0... exiting") sys.exit() + if args.early_stop_min_delta < 0: + print("early-stop-min-delta must be >= 0... exiting") + sys.exit() + + if args.plateau_factor <= 0 or args.plateau_factor >= 1: + print("plateau-factor must be in (0, 1)... exiting") + sys.exit() + + if args.plateau_patience < 0: + print("plateau-patience must be >= 0... exiting") + sys.exit() + + if args.plateau_min_lr < 0: + print("plateau-min-lr must be >= 0... exiting") + sys.exit() + if args.checkPointFrequency < 0: print(f"checkPointFrequency must be >= 0... exiting") sys.exit() @@ -1129,7 +1238,7 @@ def main(): xptCacheDir=args.xptCacheDir, rotateAndReflect=False) # Enable augmentation for training, disable for validation - train_crop = XPointPatchDataset(train_dataset, patch=64, pos_ratio=0.5, retries=30, + train_crop = XPointPatchDataset(train_dataset, patch=64, pos_ratio=args.posRatio, retries=30, augment=True, seed=args.seed) val_crop = XPointPatchDataset(val_dataset, patch=64, pos_ratio=0.5, retries=30, augment=False, seed=args.seed) @@ -1141,6 +1250,7 @@ def main(): print(f"number of training patches per epoch: {len(train_crop)}") print(f"number of validation patches per epoch: {len(val_crop)}") print(f"Data augmentation: ENABLED for training, DISABLED for validation") + print(f"Validation cropping: {'FIXED' if args.fixed_val_crops else 'RANDOM'} per epoch") if args.seed is not None: print(f"Random seed: {args.seed} (reproducible mode)") else: @@ -1174,7 +1284,7 @@ def main(): benchmark.print_hardware_info() # Use the improved model - model = UNet(input_channels=4, base_channels=32, dropout_rate=args.dropoutRate).to(device) + model = UNet(input_channels=4, base_channels=args.baseChannels, dropout_rate=args.dropoutRate).to(device) # Count parameters total_params = sum(p.numel() for p in model.parameters()) @@ -1183,14 +1293,55 @@ def main(): print(f"Trainable parameters: {trainable_params:,}") print(f"Dropout rate: {args.dropoutRate}") - criterion = DiceLoss(smooth=1.0) + if args.lossFunction == 'focal_dice': + criterion = FocalDiceLoss(alpha=args.focalAlpha, gamma=args.focalGamma, + dice_weight=args.focalDiceWeight, smooth=1.0) + print(f"Loss function: FocalDiceLoss (alpha={args.focalAlpha}, gamma={args.focalGamma}, dice_weight={args.focalDiceWeight})") + else: + criterion = DiceLoss(smooth=1.0) + print("Loss function: DiceLoss") # Use AdamW optimizer with weight decay for better generalization optimizer = optim.AdamW(model.parameters(), lr=args.learningRate, weight_decay=args.weightDecay) print(f"Optimizer: AdamW with learning_rate={args.learningRate}, weight_decay={args.weightDecay}") - # Learning rate scheduler with cosine annealing - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) + # Learning rate scheduler (with optional warmup) + if args.scheduler == 'plateau': + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode='min', + factor=args.plateau_factor, + patience=args.plateau_patience, + min_lr=args.plateau_min_lr + ) + print(f"Scheduler: ReduceLROnPlateau (factor={args.plateau_factor}, patience={args.plateau_patience}, min_lr={args.plateau_min_lr})") + else: + if args.warmupEpochs > 0: + cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=args.epochs - args.warmupEpochs, eta_min=1e-6 + ) + warmup_scheduler = optim.lr_scheduler.LinearLR( + optimizer, start_factor=0.01, total_iters=args.warmupEpochs + ) + scheduler = optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_scheduler, cosine_scheduler], + milestones=[args.warmupEpochs] + ) + print(f"Scheduler: CosineAnnealingLR with {args.warmupEpochs}-epoch linear warmup") + else: + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) + print("Scheduler: CosineAnnealingLR") + + # SWA setup + swa_model = None + swa_scheduler = None + swa_start_epoch = int(args.epochs * args.swaStart) + if args.swa: + from torch.optim.swa_utils import AveragedModel, SWALR + swa_model = AveragedModel(model) + swa_scheduler = SWALR(optimizer, swa_lr=args.learningRate * 0.1) + print(f"SWA: Enabled (starts at epoch {swa_start_epoch})") # --- AMP Setup (bfloat16 aware) --- use_amp = args.use_amp and torch.cuda.is_available() @@ -1213,7 +1364,7 @@ def main(): best_val_loss = float('inf') if os.path.exists(latest_checkpoint_path) and not args.smoke_test: - model, optimizer, start_epoch, train_loss, val_loss = load_model_checkpoint( + model, optimizer, start_epoch, train_loss, val_loss, scaler, best_val_loss = load_model_checkpoint( model, optimizer, latest_checkpoint_path ) print(f"Resuming training from epoch {start_epoch+1}") @@ -1230,6 +1381,9 @@ def main(): num_epochs = args.epochs for epoch in range(start_epoch, num_epochs): train_loss_epoch = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler, use_amp, amp_dtype, benchmark) + if args.fixed_val_crops: + val_seed = args.seed if args.seed is not None else 0 + val_crop.reset_rng(val_seed) val_loss_epoch = validate_one_epoch(model, val_loader, criterion, device, use_amp, amp_dtype) train_loss.append(train_loss_epoch) @@ -1246,10 +1400,16 @@ def main(): print(log_msg) # Learning rate scheduling - scheduler.step() + if args.swa and epoch >= swa_start_epoch: + swa_model.update_parameters(model) + swa_scheduler.step() + elif args.scheduler == 'plateau': + scheduler.step(val_loss_epoch) + else: + scheduler.step() # Check for improvement - if val_loss[-1] < best_val_loss: + if val_loss[-1] < best_val_loss - args.early_stop_min_delta: best_val_loss = val_loss[-1] patience_counter = 0 print(f" New best validation loss: {best_val_loss:.6f}") @@ -1262,11 +1422,43 @@ def main(): if (epoch+1) % args.checkPointFrequency == 0: save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch+1, checkpoint_dir, scaler, best_val_loss) - # Early stopping - if patience_counter >= args.patience: + # Early stopping (disabled during SWA phase) + if patience_counter >= args.patience and not (args.swa and epoch >= swa_start_epoch): print(f"Early stopping triggered after {epoch+1} epochs (patience={args.patience})") break + # Finalize SWA: update batch normalization statistics + if args.swa and swa_model is not None: + print("Updating SWA batch normalization statistics...") + # Custom BN update since our DataLoader yields dicts, not raw tensors + momenta = {} + for module in swa_model.modules(): + if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + module.running_mean = torch.zeros_like(module.running_mean) + module.running_var = torch.ones_like(module.running_var) + momenta[module] = module.momentum + module.momentum = None + module.num_batches_tracked *= 0 + swa_model.train() + with torch.no_grad(): + for batch in train_loader: + all_data = batch["all"].to(device) + with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp): + swa_model(all_data) + for module, momentum in momenta.items(): + module.momentum = momentum + # Save SWA model as best if it improves val loss + swa_model.eval() + swa_val_loss = validate_one_epoch(swa_model, val_loader, criterion, device, use_amp, amp_dtype) + print(f"SWA model validation loss: {swa_val_loss:.6f} (vs best: {best_val_loss:.6f})") + if swa_val_loss < best_val_loss: + print("SWA model is better - saving as best model") + # Extract the inner module for saving + torch.save(swa_model.module.state_dict(), os.path.join(checkpoint_dir, "best_model.pt")) + best_val_loss = swa_val_loss + else: + print("SWA model did not improve - keeping original best model") + plot_training_history(train_loss, val_loss, save_path='plots/training_history.png') print("time (s) to train model: " + str(timer()-t2)) diff --git a/build_transfer_cache.py b/build_transfer_cache.py new file mode 100644 index 0000000..7463799 --- /dev/null +++ b/build_transfer_cache.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +""" +Build X-point cache for 5M/10M datasets. + +This pre-computes and caches the X-point finder results so that +test_xpoint_transfer.py can load all frames quickly. + +Usage: + python build_transfer_cache.py --dataset 5M --start 1 --end 75 + python build_transfer_cache.py --dataset 10M --start 76 --end 150 +""" + +import argparse +import multiprocessing as mp +import os +import sys +import time +from pathlib import Path + +import numpy as np + +# -- Monkey-patch: fix 5m/10m .gkyl component indexing bug in getData.py -- +import postgkyl as pg +_orig_GData_init = pg.data.GData.__init__ +def _fixed_GData_init(self, *args, comp=None, **kwargs): + _orig_GData_init(self, *args, **kwargs) +pg.data.GData.__init__ = _fixed_GData_init +# -- End monkey-patch -- + +RC_ROOT = Path(__file__).resolve().parent +sys.path.insert(0, str(RC_ROOT)) + +from XPointMLTest import XPointDataset + +EXTRACT_DIR = Path(os.environ.get( + "RC_EXTRACT_DIR", + "/work/nvme/bfim/ssridhar6/mlReconnection2025", +)) +CACHE_BASE = Path(os.environ.get( + "RC_CACHE_BASE", + "/work/nvme/bfim/ssridhar6/mlReconnection2025/cache", +)) + +DATASETS = { + "5M": { + "extract_subdir": EXTRACT_DIR / "5M", + "param_file": "rt_5M_2d_turb_local-params.txt", + }, + "10M": { + "extract_subdir": EXTRACT_DIR / "10M", + "param_file": "rt_10M_2d_turb_local-params.txt", + }, +} + + +def find_param_file(extract_dir, param_file_name): + for p in extract_dir.rglob(param_file_name): + return p + for p in extract_dir.rglob("*-params.txt"): + return p + raise FileNotFoundError(f"No params file found in {extract_dir}") + + +def discover_all_frames(extract_dir): + """Find all frame numbers from field files.""" + field_files = list(extract_dir.rglob("*-field_*.gkyl")) + frame_nums = set() + for f in field_files: + parts = f.stem.rsplit("_", 1) + if len(parts) == 2 and parts[1].isdigit(): + frame_nums.add(int(parts[1])) + return sorted(f for f in frame_nums if f > 0) + + +def _process_frame(task): + """Worker function: process a single frame. Returns (frame_num, elapsed).""" + param_path, fnum, cache_dir = task + t0 = time.time() + dataset = XPointDataset( + str(param_path), + [fnum], + xptCacheDir=cache_dir, + rotateAndReflect=False, + ) + elapsed = time.time() - t0 + return fnum, elapsed + + +def main(): + parser = argparse.ArgumentParser(description="Build X-point cache for transfer datasets") + parser.add_argument("--dataset", required=True, choices=["5M", "10M"]) + parser.add_argument("--start", type=int, default=None, help="First frame index (inclusive)") + parser.add_argument("--end", type=int, default=None, help="Last frame index (inclusive)") + parser.add_argument("--workers", type=int, default=1, + help="Number of parallel workers (default: 1)") + args = parser.parse_args() + + ds = DATASETS[args.dataset] + cache_dir = CACHE_BASE / args.dataset + cache_dir.mkdir(parents=True, exist_ok=True) + + param_path = find_param_file(ds["extract_subdir"], ds["param_file"]) + print(f"Dataset: {args.dataset}") + print(f"Param file: {param_path}") + print(f"Cache dir: {cache_dir}") + print(f"Workers: {args.workers}") + + all_frames = discover_all_frames(ds["extract_subdir"]) + print(f"Total frames available: {len(all_frames)} ({all_frames[0]}-{all_frames[-1]})") + + # Apply range filter + start = args.start if args.start is not None else all_frames[0] + end = args.end if args.end is not None else all_frames[-1] + frames = [f for f in all_frames if start <= f <= end] + print(f"Processing frames {start}-{end}: {len(frames)} frames") + + # Check which frames are already cached + from XPointMLTest import cachedPgkylDataExists + uncached = [f for f in frames if not cachedPgkylDataExists(cache_dir, f, "psi")] + cached = len(frames) - len(uncached) + print(f"Already cached: {cached}, need to compute: {len(uncached)}") + + if not uncached: + print("All frames already cached!") + return + + tasks = [(param_path, fnum, cache_dir) for fnum in uncached] + + if args.workers <= 1: + # Sequential mode (original behavior) + total_time = 0 + for i, task in enumerate(tasks): + fnum = task[1] + print(f"\n[{i+1}/{len(uncached)}] Frame {fnum}...", flush=True) + _, elapsed = _process_frame(task) + total_time += elapsed + avg = total_time / (i + 1) + remaining = avg * (len(uncached) - i - 1) + print(f" Done in {elapsed:.1f}s | Avg: {avg:.1f}s/frame | " + f"ETA: {remaining/3600:.1f}h remaining", flush=True) + else: + # Parallel mode + print(f"\nStarting parallel processing with {args.workers} workers...", flush=True) + total_time = 0 + completed = 0 + wall_start = time.time() + with mp.Pool(processes=args.workers) as pool: + for fnum, elapsed in pool.imap_unordered(_process_frame, tasks): + completed += 1 + total_time += elapsed + wall_elapsed = time.time() - wall_start + avg_wall = wall_elapsed / completed + remaining = avg_wall * (len(uncached) - completed) + print(f"[{completed}/{len(uncached)}] Frame {fnum} done in {elapsed:.1f}s | " + f"Wall avg: {avg_wall:.1f}s/frame | " + f"ETA: {remaining/60:.1f}m remaining", flush=True) + total_time = time.time() - wall_start + + print(f"\nCache building complete! Wall time: {total_time/60:.1f}m ({total_time/3600:.1f}h)") + print(f"Cached {len(uncached)} frames to {cache_dir}") + + +if __name__ == "__main__": + main() diff --git a/optuna_tuner.py b/optuna_tuner.py new file mode 100644 index 0000000..7d59766 --- /dev/null +++ b/optuna_tuner.py @@ -0,0 +1,614 @@ +""" +Optuna hyperparameter tuner for XPointMLTest.py + +Usage: + python optuna_tuner.py \ + --paramFile /path/to/params.txt \ + --xptCacheDir /path/to/cache \ + --n-trials 50 \ + --study-name xpoint-tuning \ + --db sqlite:///optuna_xpoint.db + +The script wraps the existing training pipeline and searches over: + - Learning rate + - Weight decay + - Dropout rate + - Batch size + - Patch size + - Base channels (model capacity) + - Scheduler type + params + - Augmentation positive-patch ratio + +Results persist in a SQLite database so you can: + - Resume after SSH disconnects + - Analyze results with Optuna's built-in visualization + - Run multiple workers in parallel (each with their own process) +""" + +import argparse +import json +import os +import sys +import time +from pathlib import Path + +import numpy as np +import torch +import torch.optim as optim +from torch.amp import autocast, GradScaler +from torch.utils.data import DataLoader + +try: + import optuna + from optuna.exceptions import TrialPruned +except ImportError: + print("Optuna not installed. Install with:") + print(" pip install optuna --break-system-packages") + print("Optional visualization: pip install plotly --break-system-packages") + sys.exit(1) + +# Import from existing codebase +from XPointMLTest import ( + XPointDataset, + XPointPatchDataset, + UNet, + DiceLoss, + train_one_epoch, + validate_one_epoch, + set_seed, +) +from eval_metrics import evaluate_model_on_dataset +from ci_tests import SyntheticXPointDataset + + +def objective(trial, args): + """ + Optuna objective function. Trains the model with trial-suggested + hyperparameters and returns the best validation loss. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # --- Seed: vary per trial for diversity, but keep reproducible --- + seed = args.seed + if seed is not None: + set_seed(seed + trial.number) + + # 1. Suggest ALL hyperparameters in one place + + # Model architecture + base_channels = trial.suggest_categorical("base_channels", [16, 32, 48, 64]) + dropout_rate = trial.suggest_float("dropout_rate", 0.05, 0.5) + + # Optimizer + lr = trial.suggest_float("learning_rate", 1e-5, 5e-3, log=True) + weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True) + + # Data pipeline + batch_size = trial.suggest_categorical("batch_size", [8, 16, 32, 64]) + patch_size = trial.suggest_categorical("patch_size", [48, 64, 96]) + pos_ratio = trial.suggest_float("pos_ratio", 0.3, 0.7) + + # Scheduler + scheduler_name = trial.suggest_categorical("scheduler", ["cosine", "plateau"]) + + # 2. Load data (pre-loaded and cached on args to avoid repeated I/O) + + if args.smoke_test: + train_dataset = SyntheticXPointDataset(nframes=10, shape=(64, 64), nxpoints=3) + val_dataset = SyntheticXPointDataset( + nframes=2, shape=(64, 64), nxpoints=3, seed=123 + ) + else: + train_dataset = args._train_dataset + val_dataset = args._val_dataset + + train_crop = XPointPatchDataset( + train_dataset, + patch=patch_size, + pos_ratio=pos_ratio, + retries=30, + augment=True, + seed=seed, + ) + val_crop = XPointPatchDataset( + val_dataset, + patch=patch_size, + pos_ratio=0.5, + retries=30, + augment=False, + seed=seed, + ) + + train_loader = DataLoader( + train_crop, batch_size=batch_size, shuffle=True, num_workers=0 + ) + val_loader = DataLoader( + val_crop, batch_size=batch_size, shuffle=False, num_workers=0 + ) + + + # 3. Create model, optimizer, scheduler + + model = UNet( + input_channels=4, base_channels=base_channels, dropout_rate=dropout_rate + ).to(device) + + num_epochs = args.epochs + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + + if scheduler_name == "plateau": + plateau_factor = trial.suggest_float("plateau_factor", 0.2, 0.8) + plateau_patience = trial.suggest_int("plateau_patience", 3, 15) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode="min", + factor=plateau_factor, + patience=plateau_patience, + min_lr=1e-6, + ) + else: + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=num_epochs, eta_min=1e-6 + ) + + criterion = DiceLoss(smooth=1.0) + + # --- AMP setup --- + use_amp = args.use_amp and torch.cuda.is_available() + amp_dtype = ( + torch.bfloat16 + if args.amp_dtype == "bfloat16" and torch.cuda.is_bf16_supported() + else torch.float16 + ) + scaler = GradScaler(enabled=(use_amp and amp_dtype == torch.float16)) + + + # 4. Training loop with Optuna pruning + + best_val_loss = float("inf") + patience_counter = 0 + patience = args.patience + epochs_trained = 0 + + for epoch in range(num_epochs): + train_loss = train_one_epoch( + model, + train_loader, + criterion, + optimizer, + device, + scaler, + use_amp, + amp_dtype, + ) + + # Reset validation RNG for deterministic crops each epoch + if seed is not None: + val_crop.reset_rng(seed) + + val_loss = validate_one_epoch( + model, val_loader, criterion, device, use_amp, amp_dtype + ) + + # LR scheduling + if scheduler_name == "plateau": + scheduler.step(val_loss) + else: + scheduler.step() + + # Track best + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + else: + patience_counter += 1 + + epochs_trained = epoch + 1 + + # === Report intermediate value to Optuna for pruning === + trial.report(val_loss, epoch) + + if trial.should_prune(): + raise TrialPruned() + + # Early stopping + if patience_counter >= patience: + break + + + # 5. (Optional) Full-frame evaluation for richer metrics + + if args.eval_on_full_frames and not args.smoke_test: + model.eval() + evaluator = evaluate_model_on_dataset( + model, val_dataset, device, use_amp=use_amp, amp_dtype=amp_dtype + ) + global_metrics = evaluator.get_global_metrics() + + trial.set_user_attr("val_f1", global_metrics["f1_score"]) + trial.set_user_attr("val_iou", global_metrics["iou"]) + trial.set_user_attr("val_precision", global_metrics["precision"]) + trial.set_user_attr("val_recall", global_metrics["recall"]) + + trial.set_user_attr("best_val_loss", best_val_loss) + trial.set_user_attr("epochs_trained", epochs_trained) + + return best_val_loss + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Optuna hyperparameter tuning for X-point classifier" + ) + + # --- Optuna settings --- + parser.add_argument( + "--n-trials", + type=int, + default=50, + help="Number of Optuna trials (default: 50)", + ) + parser.add_argument( + "--study-name", + type=str, + default="xpoint-tuning", + help="Optuna study name (default: xpoint-tuning)", + ) + parser.add_argument( + "--db", + type=str, + default="sqlite:///optuna_xpoint.db", + help="Optuna storage URL (default: sqlite:///optuna_xpoint.db)", + ) + parser.add_argument( + "--timeout", + type=int, + default=None, + help="Stop after this many seconds (default: None, run all trials)", + ) + parser.add_argument( + "--pruner", + type=str, + default="median", + choices=["median", "hyperband", "none"], + help="Pruning strategy (default: median)", + ) + + # --- Data settings (same as XPointMLTest.py) --- + parser.add_argument("--paramFile", type=Path, default=None) + parser.add_argument("--xptCacheDir", type=Path, default=None) + parser.add_argument("--trainFrameFirst", type=int, default=1) + parser.add_argument("--trainFrameLast", type=int, default=140) + parser.add_argument("--validationFrameFirst", type=int, default=141) + parser.add_argument("--validationFrameLast", type=int, default=150) + + # --- Training settings (fixed across all trials) --- + parser.add_argument( + "--epochs", + type=int, + default=300, + help="Max epochs PER TRIAL (default: 300, lower than full training for speed)", + ) + parser.add_argument( + "--patience", + type=int, + default=30, + help="Early stopping patience per trial (default: 30)", + ) + parser.add_argument("--use-amp", action="store_true", help="Enable AMP") + parser.add_argument( + "--amp-dtype", + type=str, + default="bfloat16", + choices=["float16", "bfloat16"], + ) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--require-gpu", action="store_true") + + # --- Evaluation --- + parser.add_argument( + "--eval-on-full-frames", + action="store_true", + help="Run full-frame evaluation after each trial (slower but gives F1/IoU)", + ) + + # --- Output --- + parser.add_argument( + "--results-dir", + type=Path, + default="./optuna_results", + help="Directory for result files (default: ./optuna_results)", + ) + + # --- Testing --- + parser.add_argument( + "--smoke-test", + action="store_true", + help="Run 3 trials with synthetic data (no paramFile needed)", + ) + + return parser.parse_args() + + +def print_study_summary(study, results_dir): + """Print and save a summary of the Optuna study results.""" + + print("\n" + "=" * 70) + print("OPTUNA STUDY SUMMARY") + print("=" * 70) + + completed = [ + t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE + ] + pruned = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED] + failed = [t for t in study.trials if t.state == optuna.trial.TrialState.FAIL] + + print(f"\nStudy name: {study.study_name}") + print(f"Total trials: {len(study.trials)}") + print(f" Completed: {len(completed)}") + print(f" Pruned: {len(pruned)}") + print(f" Failed: {len(failed)}") + + if not completed: + print("\nNo completed trials to summarize.") + return + + best = study.best_trial + print(f"\nBest trial: #{best.number}") + print(f" Best validation loss: {best.value:.6f}") + print(f" Hyperparameters:") + for key, value in sorted(best.params.items()): + if isinstance(value, float): + print(f" {key:25s} {value:.6g}") + else: + print(f" {key:25s} {value}") + + if best.user_attrs: + print(f" Additional metrics:") + for key, value in sorted(best.user_attrs.items()): + if isinstance(value, float): + print(f" {key:25s} {value:.4f}") + else: + print(f" {key:25s} {value}") + + # --- Top 5 trials --- + sorted_trials = sorted(completed, key=lambda t: t.value) + print(f"\nTop 5 trials:") + print( + f" {'#':>4s} {'Val Loss':>10s} {'LR':>10s} {'WD':>10s} " + f"{'Drop':>6s} {'BS':>4s} {'Patch':>5s} {'Ch':>4s} {'Sched':>8s}" + ) + for t in sorted_trials[:5]: + p = t.params + print( + f" {t.number:4d} {t.value:10.6f} " + f"{p.get('learning_rate', 0):10.2e} " + f"{p.get('weight_decay', 0):10.2e} " + f"{p.get('dropout_rate', 0):6.3f} " + f"{p.get('batch_size', 0):4d} " + f"{p.get('patch_size', 0):5d} " + f"{p.get('base_channels', 0):4d} " + f"{p.get('scheduler', 'n/a'):>8s}" + ) + + print("=" * 70) + + # --- Save results to JSON --- + results_dir.mkdir(parents=True, exist_ok=True) + + results = { + "study_name": study.study_name, + "n_trials_total": len(study.trials), + "n_completed": len(completed), + "n_pruned": len(pruned), + "n_failed": len(failed), + "best_trial": { + "number": best.number, + "value": best.value, + "params": best.params, + "user_attrs": best.user_attrs, + }, + "all_completed_trials": [ + { + "number": t.number, + "value": t.value, + "params": t.params, + "user_attrs": t.user_attrs, + } + for t in sorted_trials + ], + } + + results_path = results_dir / "optuna_results.json" + with open(results_path, "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to: {results_path}") + + # --- Generate shell command for retraining with best params --- + p = best.params + cmd_lines = [ + "#!/bin/bash", + f"# Auto-generated from Optuna study: {study.study_name}", + f"# Best trial #{best.number} with val_loss={best.value:.6f}", + "", + "python -u ${rcRoot}/reconClassifier/XPointMLTest.py \\", + f" --learningRate {p.get('learning_rate', 5e-4):.6g} \\", + f" --weightDecay {p.get('weight_decay', 5e-5):.6g} \\", + f" --dropoutRate {p.get('dropout_rate', 0.15):.4g} \\", + f" --batchSize {p.get('batch_size', 64)} \\", + ] + + sched = p.get("scheduler", "cosine") + cmd_lines.append(f" --scheduler {sched} \\") + if sched == "plateau": + cmd_lines.append( + f" --plateau-factor {p.get('plateau_factor', 0.5):.4g} \\" + ) + cmd_lines.append( + f" --plateau-patience {p.get('plateau_patience', 5)} \\" + ) + + cmd_lines.extend( + [ + " --use-amp \\", + " --seed 42 \\", + " --require-gpu \\", + " --fixed-val-crops \\", + " --epochs 1200 \\", + " --patience 200 \\", + " --checkPointFrequency 200 \\", + " --paramFile=${PARAM_FILE} \\", + " --xptCacheDir=${CACHE_DIR}", + ] + ) + + cmd_path = results_dir / "retrain_best_params.sh" + with open(cmd_path, "w") as f: + f.write("\n".join(cmd_lines) + "\n") + os.chmod(cmd_path, 0o755) + + print(f"Retrain command saved to: {cmd_path}") + print(f"\nTo retrain with best hyperparameters:") + print(f" bash {cmd_path}") + print("=" * 70) + + +def main(): + args = parse_args() + + # --- Smoke test overrides --- + if args.smoke_test: + print("=" * 60) + print("SMOKE TEST MODE: 3 trials with synthetic data") + print("=" * 60) + args.n_trials = 3 + args.epochs = 5 + args.patience = 3 + args.eval_on_full_frames = False + + # --- Validate args --- + if not args.smoke_test: + if args.paramFile is None: + print("ERROR: --paramFile is required (or use --smoke-test)") + sys.exit(1) + if not args.paramFile.exists(): + print(f"ERROR: paramFile {args.paramFile} does not exist") + sys.exit(1) + + if args.require_gpu and not torch.cuda.is_available(): + print("ERROR: --require-gpu set but no CUDA device found") + sys.exit(1) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + if torch.cuda.is_available(): + print(f"GPU: {torch.cuda.get_device_name(0)}") + + + # Pre-load datasets ONCE (expensive I/O — shared across all trials) + + if not args.smoke_test: + print("\nLoading datasets (shared across all trials)...") + t0 = time.time() + train_dataset = XPointDataset( + args.paramFile, + range(args.trainFrameFirst, args.trainFrameLast), + xptCacheDir=args.xptCacheDir, + rotateAndReflect=False, + ) + val_dataset = XPointDataset( + args.paramFile, + range(args.validationFrameFirst, args.validationFrameLast), + xptCacheDir=args.xptCacheDir, + rotateAndReflect=False, + ) + print(f"Datasets loaded in {time.time() - t0:.1f}s") + print(f" Training frames: {len(train_dataset)}") + print(f" Validation frames: {len(val_dataset)}") + + # Attach to args for access in objective() + args._train_dataset = train_dataset + args._val_dataset = val_dataset + + + # Create Optuna study + + if args.pruner == "median": + pruner = optuna.pruners.MedianPruner( + n_startup_trials=5, n_warmup_steps=10, interval_steps=5 + ) + elif args.pruner == "hyperband": + pruner = optuna.pruners.HyperbandPruner( + min_resource=10, max_resource=args.epochs, reduction_factor=3 + ) + else: + pruner = optuna.pruners.NopPruner() + + study = optuna.create_study( + study_name=args.study_name, + storage=args.db, + load_if_exists=True, + direction="minimize", + pruner=pruner, + ) + + n_existing = len(study.trials) + if n_existing > 0: + print(f"\nResuming study '{args.study_name}' with {n_existing} existing trials") + if study.best_trial: + print(f"Current best val_loss: {study.best_trial.value:.6f}") + + + # Run optimization + + print(f"\nStarting Optuna optimization") + print(f" Trials: {args.n_trials}") + print(f" DB: {args.db}") + print(f" Pruner: {args.pruner}") + print(f" Max epochs: {args.epochs} per trial") + print(f" Patience: {args.patience} per trial") + if args.timeout: + print(f" Timeout: {args.timeout}s") + print() + + study.optimize( + lambda trial: objective(trial, args), + n_trials=args.n_trials, + timeout=args.timeout, + show_progress_bar=True, + ) + + + # Results + + print_study_summary(study, args.results_dir) + + # --- Try to generate visualization plots --- + try: + from optuna.visualization import ( + plot_param_importances, + plot_optimization_history, + plot_parallel_coordinate, + ) + + fig_dir = args.results_dir / "figures" + fig_dir.mkdir(parents=True, exist_ok=True) + + for name, plot_fn in [ + ("param_importances", plot_param_importances), + ("optimization_history", plot_optimization_history), + ("parallel_coordinate", plot_parallel_coordinate), + ]: + try: + fig = plot_fn(study) + fig.write_html(str(fig_dir / f"{name}.html")) + print(f" Saved: {fig_dir / name}.html") + except Exception as e: + print(f" Skipped {name}: {e}") + + except ImportError: + print("\nNote: pip install plotly for interactive visualizations") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/test_xpoint_transfer.py b/test_xpoint_transfer.py new file mode 100644 index 0000000..5175995 --- /dev/null +++ b/test_xpoint_transfer.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python3 +""" +Cross-domain inference: evaluate the best PKPM-trained model on 5M and 10M data. + +This script: + 1. Extracts 5M.tgz and 10M.tgz (if not already extracted) + 2. Loads the best model checkpoint + 3. Runs evaluate_model_on_dataset on each dataset + 4. Saves per-dataset evaluation metrics + +Usage (within SLURM job): + python -u reconClassifier/test_xpoint_transfer.py +""" + +import json +import os +import subprocess +import sys +import time +from pathlib import Path + +import numpy as np +import torch + +# -- Monkey-patch: fix 5m/10m .gkyl component indexing bug in getData.py -- +# The bug: getData.py passes comp=N to pg.data.GData which pre-selects the +# component, then tries data[..., N] which fails (array only has 1 element). +# Fix: drop the comp kwarg so GData returns all components. +import postgkyl as pg +_orig_GData_init = pg.data.GData.__init__ +def _fixed_GData_init(self, *args, comp=None, **kwargs): + _orig_GData_init(self, *args, **kwargs) +pg.data.GData.__init__ = _fixed_GData_init +# -- End monkey-patch -- + +# Add reconClassifier to path +RC_ROOT = Path(__file__).resolve().parent +sys.path.insert(0, str(RC_ROOT)) + +from XPointMLTest import XPointDataset, UNet +from eval_metrics import evaluate_model_on_dataset + +# ── Configuration ────────────────────────────────────────────────────── +SOURCE_DIR = Path("/work/nvme/bfim/cwsmith/mlReconnection2025") +EXTRACT_DIR = Path("/work/nvme/bfim/ssridhar6/mlReconnection2025") +CACHE_BASE = EXTRACT_DIR / "cache" +BEST_MODEL = Path.home() / "mlReconnection/testdir_2026-04-02-13-23-05/checkpoints/best_model.pt" +OUTPUT_DIR = Path.home() / "mlReconnection/transfer_eval_results" + +DATASETS = { + "5M": { + "tarball": SOURCE_DIR / "5M.tgz", + "extract_subdir": EXTRACT_DIR / "5M", + "param_file": "rt_5M_2d_turb_local-params.txt", + "tar_prefix": "./", # files are at root of tarball + }, + "10M": { + "tarball": SOURCE_DIR / "10M.tgz", + "extract_subdir": EXTRACT_DIR / "10M", + "param_file": "rt_10M_2d_turb_local-params.txt", + "tar_prefix": "10M/", # files are under 10M/ subdir + }, +} + +# Model config must match training +BASE_CHANNELS = 64 +DROPOUT_RATE = 0.055 # Trial #36 value (doesn't matter for eval, just architecture) + + +def extract_tarball(tarball_path, extract_dir, name): + """Extract tarball if not already done.""" + param_candidates = list(extract_dir.glob("*-params.txt")) + if param_candidates: + print(f" [{name}] Already extracted ({len(param_candidates)} param file(s) found)") + return + + extract_dir.mkdir(parents=True, exist_ok=True) + print(f" [{name}] Extracting {tarball_path} -> {extract_dir} ...") + t0 = time.time() + subprocess.run( + ["tar", "xzf", str(tarball_path), "-C", str(extract_dir), "--strip-components=0"], + check=True, + ) + elapsed = time.time() - t0 + print(f" [{name}] Extraction complete in {elapsed:.1f}s") + + +def discover_frames(extract_dir, param_file_name): + """Discover available frame numbers from field files.""" + # Find the param file + param_path = None + for p in extract_dir.rglob(param_file_name): + param_path = p + break + + if param_path is None: + # Try searching more broadly + for p in extract_dir.rglob("*-params.txt"): + param_path = p + break + + if param_path is None: + raise FileNotFoundError(f"No params file found in {extract_dir}") + + print(f" Found param file: {param_path}") + + # Discover frame numbers from field files in same directory + param_dir = param_path.parent + field_files = sorted(param_dir.glob("*-field_*.gkyl")) + frame_nums = set() + for f in field_files: + # Extract number from pattern like "...-field_42.gkyl" + stem = f.stem # e.g. "rt_5M_2d_turb_local-field_42" + parts = stem.rsplit("_", 1) + if len(parts) == 2 and parts[1].isdigit(): + frame_nums.add(int(parts[1])) + + frame_nums = sorted(frame_nums) + print(f" Found {len(frame_nums)} frames: {frame_nums[0]}-{frame_nums[-1]}") + + # Exclude frame 0 (often initial conditions, not interesting) + frame_nums = [f for f in frame_nums if f > 0] + + return param_path, frame_nums + + +def main(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + if torch.cuda.is_available(): + print(f"GPU: {torch.cuda.get_device_name(0)}") + + # ── Load model ───────────────────────────────────────────────────── + print(f"\nLoading model from {BEST_MODEL}") + model = UNet(input_channels=4, base_channels=BASE_CHANNELS, dropout_rate=DROPOUT_RATE).to(device) + state_dict = torch.load(str(BEST_MODEL), map_location=device, weights_only=False) + model.load_state_dict(state_dict) + model.eval() + total_params = sum(p.numel() for p in model.parameters()) + print(f"Model loaded: {total_params:,} parameters") + + # ── Setup output ─────────────────────────────────────────────────── + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + use_amp = torch.cuda.is_available() + amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16 + + all_results = {} + + # ── Process each dataset ─────────────────────────────────────────── + for ds_name, ds_config in DATASETS.items(): + print(f"\n{'='*70}") + print(f"EVALUATING ON {ds_name} DATA") + print(f"{'='*70}") + + # Extract + extract_tarball(ds_config["tarball"], ds_config["extract_subdir"], ds_name) + + # Discover frames + try: + param_path, frame_nums = discover_frames( + ds_config["extract_subdir"], ds_config["param_file"] + ) + except FileNotFoundError as e: + print(f" ERROR: {e}") + continue + + # Load dataset using cache (pre-built by build_transfer_cache.py) + cache_dir = CACHE_BASE / ds_name + if not cache_dir.is_dir(): + print(f" ERROR: Cache directory {cache_dir} not found.") + print(f" Run build_transfer_cache.py --dataset {ds_name} first.") + continue + + print(f" Loading {ds_name} dataset ({len(frame_nums)} frames, cache={cache_dir})...", flush=True) + t0 = time.time() + try: + dataset = XPointDataset( + str(param_path), frame_nums, + xptCacheDir=cache_dir, rotateAndReflect=False, + ) + except Exception as e: + print(f"\n ERROR loading dataset: {e}") + import traceback + traceback.print_exc() + continue + elapsed = time.time() - t0 + print(f" All {len(frame_nums)} frames loaded in {elapsed:.1f}s") + + # Check grid size + sample = dataset[0] + grid_shape = sample["all"].shape + print(f" Grid shape: {grid_shape} (channels, H, W)") + + # Evaluate + print(f" Running inference...") + t0 = time.time() + evaluator = evaluate_model_on_dataset( + model, dataset, device, + use_amp=use_amp, amp_dtype=amp_dtype, threshold=0.5, + ) + elapsed = time.time() - t0 + + evaluator.print_summary() + + # Save per-dataset results + output_file = OUTPUT_DIR / f"eval_{ds_name.lower()}.json" + evaluator.save_json(str(output_file)) + + metrics = evaluator.get_global_metrics() + metrics["inference_time_s"] = elapsed + metrics["num_frames"] = len(frame_nums) + metrics["grid_shape"] = list(grid_shape) + all_results[ds_name] = metrics + print(f" Inference time: {elapsed:.1f}s ({elapsed/len(frame_nums):.2f}s/frame)") + + # Save partial results incrementally in case of timeout + with open(OUTPUT_DIR / "transfer_summary.json", "w") as f: + json.dump(all_results, f, indent=2) + + # ── Also re-evaluate on original PKPM validation set for comparison ── + print(f"\n{'='*70}") + print(f"RE-EVALUATING ON PKPM VALIDATION (baseline comparison)") + print(f"{'='*70}") + + pkpm_param = "/work/nvme/bfim/cwsmith/mlReconnection2025/1024Res_v0/pkpm_2d_turb_p2-params.txt" + pkpm_cache = "/work/nvme/bfim/cwsmith/mlReconnection2025/1024Res_v0/cache04082025" + pkpm_val_frames = list(range(141, 150)) + + print(f" Loading PKPM validation ({len(pkpm_val_frames)} frames)...") + t0 = time.time() + pkpm_dataset = XPointDataset( + pkpm_param, pkpm_val_frames, + xptCacheDir=Path(pkpm_cache), + rotateAndReflect=False, + ) + print(f" Loaded in {time.time()-t0:.1f}s") + + t0 = time.time() + pkpm_evaluator = evaluate_model_on_dataset( + model, pkpm_dataset, device, + use_amp=use_amp, amp_dtype=amp_dtype, threshold=0.5, + ) + elapsed = time.time() - t0 + pkpm_evaluator.print_summary() + pkpm_evaluator.save_json(str(OUTPUT_DIR / "eval_pkpm_val.json")) + + pkpm_metrics = pkpm_evaluator.get_global_metrics() + pkpm_metrics["inference_time_s"] = elapsed + pkpm_metrics["num_frames"] = len(pkpm_val_frames) + all_results["PKPM_val"] = pkpm_metrics + + # ── Summary comparison ───────────────────────────────────────────── + print(f"\n{'='*70}") + print("CROSS-DOMAIN TRANSFER SUMMARY") + print(f"{'='*70}") + print(f"{'Dataset':>12s} {'F1':>7s} {'Prec':>7s} {'Rec':>7s} {'IoU':>7s} {'Frames':>6s} {'Grid':>12s}") + print("-" * 70) + for name, m in all_results.items(): + grid_str = "x".join(str(x) for x in m.get("grid_shape", [])) if "grid_shape" in m else "N/A" + print(f"{name:>12s} {m['f1_score']:7.4f} {m['precision']:7.4f} " + f"{m['recall']:7.4f} {m['iou']:7.4f} {m.get('num_frames','?'):>6} {grid_str:>12s}") + print(f"{'='*70}") + + # Save combined summary + summary_path = OUTPUT_DIR / "transfer_summary.json" + with open(summary_path, "w") as f: + json.dump(all_results, f, indent=2) + print(f"\nCombined summary saved to: {summary_path}") + + +if __name__ == "__main__": + main()