From 665fb2b0c68f9e2ca4ddb62bcf4dccb10e9fae58 Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Wed, 19 Nov 2025 00:48:59 -0500 Subject: [PATCH 1/7] Reduce augmentation probabilities to fix underfitting --- XPointMLTest.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index 3027485..d56e05e 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -381,8 +381,8 @@ def _apply_augmentation(self, all_data, mask): return all_data, mask # 1. Random rotation (0, 90, 180, 270 degrees) - # 75% chance to apply rotation - if self.rng.random() < 0.75: + # 50% chance to apply rotation + if self.rng.random() < 0.50: k = self.rng.integers(1, 4) # 1, 2, or 3 (90°, 180°, 270°) all_data = torch.rot90(all_data, k=k, dims=(-2, -1)) mask = torch.rot90(mask, k=k, dims=(-2, -1)) @@ -397,9 +397,9 @@ def _apply_augmentation(self, all_data, mask): all_data = torch.flip(all_data, dims=(-2,)) mask = torch.flip(mask, dims=(-2,)) - # 4. Add Gaussian noise (30% chance) + # 4. Add Gaussian noise (10% chance) # Small noise helps prevent overfitting to exact pixel values - if self.rng.random() < 0.3: + if self.rng.random() < 0.1: noise_std = self.rng.uniform(0.005, 0.02) noise = torch.randn_like(all_data) * noise_std all_data = all_data + noise @@ -413,9 +413,9 @@ def _apply_augmentation(self, all_data, mask): mean = all_data[c].mean() all_data[c] = contrast * (all_data[c] - mean) + mean + brightness - # 6. Cutout/Random erasing (20% chance) + # 6. Cutout/Random erasing (5% chance) # Prevents model from relying too heavily on specific spatial features - if self.rng.random() < 0.2: + if self.rng.random() < 0.05: h, w = all_data.shape[-2:] cutout_size = int(min(h, w) * self.rng.uniform(0.1, 0.25)) if cutout_size > 0: From 9fbdff2fe0ae2a16cd78137e8e17068909be94dc Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Tue, 2 Dec 2025 12:44:18 -0500 Subject: [PATCH 2/7] Fix augmentation bug: apply brightness/contrast globally to preserve physical field relationships --- XPointMLTest.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index d56e05e..d50e355 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -404,14 +404,15 @@ def _apply_augmentation(self, all_data, mask): noise = torch.randn_like(all_data) * noise_std all_data = all_data + noise - # 5. Random brightness/contrast adjustment per channel (30% chance) - # Helps model become invariant to intensity variations + # 5. Random brightness/contrast adjustment (30% chance) + # CHANGED: Applied globally across channels to preserve physical relationships + # (e.g., keeping the derivative relationship between psi and B fields) if self.rng.random() < 0.3: - for c in range(all_data.shape[0]): - brightness = self.rng.uniform(-0.1, 0.1) - contrast = self.rng.uniform(0.9, 1.1) - mean = all_data[c].mean() - all_data[c] = contrast * (all_data[c] - mean) + mean + brightness + brightness = self.rng.uniform(-0.1, 0.1) + contrast = self.rng.uniform(0.9, 1.1) + # Apply same transformation to all channels + mean = all_data.mean(dim=(-2, -1), keepdim=True) + all_data = contrast * (all_data - mean) + mean + brightness # 6. Cutout/Random erasing (5% chance) # Prevents model from relying too heavily on specific spatial features From 7ce3e1488da0cc093d2c48c91a35f734a3f0d27d Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Wed, 24 Dec 2025 22:35:23 -0500 Subject: [PATCH 3/7] reverting percentages back to original for testing purposes --- XPointMLTest.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index d50e355..696462c 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -381,8 +381,8 @@ def _apply_augmentation(self, all_data, mask): return all_data, mask # 1. Random rotation (0, 90, 180, 270 degrees) - # 50% chance to apply rotation - if self.rng.random() < 0.50: + # 75% chance to apply rotation + if self.rng.random() < 0.75: k = self.rng.integers(1, 4) # 1, 2, or 3 (90°, 180°, 270°) all_data = torch.rot90(all_data, k=k, dims=(-2, -1)) mask = torch.rot90(mask, k=k, dims=(-2, -1)) @@ -397,9 +397,9 @@ def _apply_augmentation(self, all_data, mask): all_data = torch.flip(all_data, dims=(-2,)) mask = torch.flip(mask, dims=(-2,)) - # 4. Add Gaussian noise (10% chance) + # 4. Add Gaussian noise (30% chance) # Small noise helps prevent overfitting to exact pixel values - if self.rng.random() < 0.1: + if self.rng.random() < 0.3: noise_std = self.rng.uniform(0.005, 0.02) noise = torch.randn_like(all_data) * noise_std all_data = all_data + noise @@ -414,9 +414,9 @@ def _apply_augmentation(self, all_data, mask): mean = all_data.mean(dim=(-2, -1), keepdim=True) all_data = contrast * (all_data - mean) + mean + brightness - # 6. Cutout/Random erasing (5% chance) + # 6. Cutout/Random erasing (20% chance) # Prevents model from relying too heavily on specific spatial features - if self.rng.random() < 0.05: + if self.rng.random() < 0.20: h, w = all_data.shape[-2:] cutout_size = int(min(h, w) * self.rng.uniform(0.1, 0.25)) if cutout_size > 0: From 56c7448a377643713e35eb55147d5a7eb309e824 Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Tue, 27 Jan 2026 11:06:22 -0500 Subject: [PATCH 4/7] Fixed issue where on-the-fly augmentation was mutating caches base frames via in-place cutout --- XPointMLTest.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index 696462c..ae3fe34 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -452,7 +452,9 @@ def __getitem__(self, _): want_pos = (attempt / self.retries) < self.pos_ratio if has_pos == want_pos or attempt == self.retries - 1: - all_crop = self._crop(frame["all"], y0, x0) + # Clone to avoid in-place augmentation modifying cached base frames + all_crop = self._crop(frame["all"], y0, x0).clone() + crop_mask = crop_mask.clone() # Apply augmentation if enabled all_crop, crop_mask = self._apply_augmentation(all_crop, crop_mask) From 6c9eba335f01eb079918b9101f7413c1c4c144dd Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Fri, 13 Feb 2026 23:13:37 -0500 Subject: [PATCH 5/7] Add Optuna hyperparameter tuning with scheduler options and fixed validation crops --- XPointMLTest.py | 64 ++++- optuna_tuner.py | 614 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 673 insertions(+), 5 deletions(-) create mode 100644 optuna_tuner.py diff --git a/XPointMLTest.py b/XPointMLTest.py index ae3fe34..7a5eb72 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -362,6 +362,13 @@ def __len__(self): # give each full frame K random crops per epoch (K=32 for more samples) return len(self.base_ds) * 32 + def reset_rng(self, seed=None): + """Reset RNG for deterministic cropping (useful for fixed validation).""" + if seed is not None: + self.rng = np.random.default_rng(seed) + else: + self.rng = np.random.default_rng() + def _crop(self, arr, top, left): return arr[..., top:top+self.patch, left:left+self.patch] @@ -906,6 +913,17 @@ def parseCommandLineArgs(): choices=['float16', 'bfloat16'], help='data type for mixed precision (bfloat16 recommended)') parser.add_argument('--patience', type=int, default=15, help='patience for early stopping (default: 15)') + parser.add_argument('--early-stop-min-delta', type=float, default=0.0, + help='minimum improvement in validation loss to reset early stopping (default: 0.0)') + parser.add_argument('--scheduler', type=str, default='cosine', + choices=['cosine', 'plateau'], + help='learning rate scheduler type (cosine or plateau)') + parser.add_argument('--plateau-factor', type=float, default=0.5, + help='ReduceLROnPlateau factor (default: 0.5)') + parser.add_argument('--plateau-patience', type=int, default=5, + help='ReduceLROnPlateau patience in epochs (default: 5)') + parser.add_argument('--plateau-min-lr', type=float, default=1e-6, + help='ReduceLROnPlateau minimum learning rate (default: 1e-6)') parser.add_argument('--benchmark', action='store_true', help='enable performance benchmarking (tracks timing, throughput, GPU memory)') parser.add_argument('--benchmark-output', type=Path, default='./benchmark_results.json', @@ -914,6 +932,8 @@ def parseCommandLineArgs(): help='path to save evaluation metrics JSON file (default: ./evaluation_metrics.json)') parser.add_argument('--seed', type=int, default=None, help='random seed for reproducibility (default: None for non-deterministic)') + parser.add_argument('--fixed-val-crops', action=argparse.BooleanOptionalAction, default=False, + help='use deterministic validation crops each epoch (default: False)') parser.add_argument('--require-gpu', action='store_true', help='require GPU to be available, exit if not found') @@ -969,6 +989,22 @@ def checkCommandLineArgs(args): print(f"minTrainingLoss must be >= 0... exiting") sys.exit() + if args.early_stop_min_delta < 0: + print("early-stop-min-delta must be >= 0... exiting") + sys.exit() + + if args.plateau_factor <= 0 or args.plateau_factor >= 1: + print("plateau-factor must be in (0, 1)... exiting") + sys.exit() + + if args.plateau_patience < 0: + print("plateau-patience must be >= 0... exiting") + sys.exit() + + if args.plateau_min_lr < 0: + print("plateau-min-lr must be >= 0... exiting") + sys.exit() + if args.checkPointFrequency < 0: print(f"checkPointFrequency must be >= 0... exiting") sys.exit() @@ -1144,6 +1180,7 @@ def main(): print(f"number of training patches per epoch: {len(train_crop)}") print(f"number of validation patches per epoch: {len(val_crop)}") print(f"Data augmentation: ENABLED for training, DISABLED for validation") + print(f"Validation cropping: {'FIXED' if args.fixed_val_crops else 'RANDOM'} per epoch") if args.seed is not None: print(f"Random seed: {args.seed} (reproducible mode)") else: @@ -1192,8 +1229,19 @@ def main(): optimizer = optim.AdamW(model.parameters(), lr=args.learningRate, weight_decay=args.weightDecay) print(f"Optimizer: AdamW with learning_rate={args.learningRate}, weight_decay={args.weightDecay}") - # Learning rate scheduler with cosine annealing - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) + # Learning rate scheduler + if args.scheduler == 'plateau': + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode='min', + factor=args.plateau_factor, + patience=args.plateau_patience, + min_lr=args.plateau_min_lr + ) + print(f"Scheduler: ReduceLROnPlateau (factor={args.plateau_factor}, patience={args.plateau_patience}, min_lr={args.plateau_min_lr})") + else: + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) + print("Scheduler: CosineAnnealingLR") # --- AMP Setup (bfloat16 aware) --- use_amp = args.use_amp and torch.cuda.is_available() @@ -1216,7 +1264,7 @@ def main(): best_val_loss = float('inf') if os.path.exists(latest_checkpoint_path) and not args.smoke_test: - model, optimizer, start_epoch, train_loss, val_loss = load_model_checkpoint( + model, optimizer, start_epoch, train_loss, val_loss, scaler, best_val_loss = load_model_checkpoint( model, optimizer, latest_checkpoint_path ) print(f"Resuming training from epoch {start_epoch+1}") @@ -1233,6 +1281,9 @@ def main(): num_epochs = args.epochs for epoch in range(start_epoch, num_epochs): train_loss_epoch = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler, use_amp, amp_dtype, benchmark) + if args.fixed_val_crops: + val_seed = args.seed if args.seed is not None else 0 + val_crop.reset_rng(val_seed) val_loss_epoch = validate_one_epoch(model, val_loader, criterion, device, use_amp, amp_dtype) train_loss.append(train_loss_epoch) @@ -1249,10 +1300,13 @@ def main(): print(log_msg) # Learning rate scheduling - scheduler.step() + if args.scheduler == 'plateau': + scheduler.step(val_loss_epoch) + else: + scheduler.step() # Check for improvement - if val_loss[-1] < best_val_loss: + if val_loss[-1] < best_val_loss - args.early_stop_min_delta: best_val_loss = val_loss[-1] patience_counter = 0 print(f" New best validation loss: {best_val_loss:.6f}") diff --git a/optuna_tuner.py b/optuna_tuner.py new file mode 100644 index 0000000..7d59766 --- /dev/null +++ b/optuna_tuner.py @@ -0,0 +1,614 @@ +""" +Optuna hyperparameter tuner for XPointMLTest.py + +Usage: + python optuna_tuner.py \ + --paramFile /path/to/params.txt \ + --xptCacheDir /path/to/cache \ + --n-trials 50 \ + --study-name xpoint-tuning \ + --db sqlite:///optuna_xpoint.db + +The script wraps the existing training pipeline and searches over: + - Learning rate + - Weight decay + - Dropout rate + - Batch size + - Patch size + - Base channels (model capacity) + - Scheduler type + params + - Augmentation positive-patch ratio + +Results persist in a SQLite database so you can: + - Resume after SSH disconnects + - Analyze results with Optuna's built-in visualization + - Run multiple workers in parallel (each with their own process) +""" + +import argparse +import json +import os +import sys +import time +from pathlib import Path + +import numpy as np +import torch +import torch.optim as optim +from torch.amp import autocast, GradScaler +from torch.utils.data import DataLoader + +try: + import optuna + from optuna.exceptions import TrialPruned +except ImportError: + print("Optuna not installed. Install with:") + print(" pip install optuna --break-system-packages") + print("Optional visualization: pip install plotly --break-system-packages") + sys.exit(1) + +# Import from existing codebase +from XPointMLTest import ( + XPointDataset, + XPointPatchDataset, + UNet, + DiceLoss, + train_one_epoch, + validate_one_epoch, + set_seed, +) +from eval_metrics import evaluate_model_on_dataset +from ci_tests import SyntheticXPointDataset + + +def objective(trial, args): + """ + Optuna objective function. Trains the model with trial-suggested + hyperparameters and returns the best validation loss. + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # --- Seed: vary per trial for diversity, but keep reproducible --- + seed = args.seed + if seed is not None: + set_seed(seed + trial.number) + + # 1. Suggest ALL hyperparameters in one place + + # Model architecture + base_channels = trial.suggest_categorical("base_channels", [16, 32, 48, 64]) + dropout_rate = trial.suggest_float("dropout_rate", 0.05, 0.5) + + # Optimizer + lr = trial.suggest_float("learning_rate", 1e-5, 5e-3, log=True) + weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-2, log=True) + + # Data pipeline + batch_size = trial.suggest_categorical("batch_size", [8, 16, 32, 64]) + patch_size = trial.suggest_categorical("patch_size", [48, 64, 96]) + pos_ratio = trial.suggest_float("pos_ratio", 0.3, 0.7) + + # Scheduler + scheduler_name = trial.suggest_categorical("scheduler", ["cosine", "plateau"]) + + # 2. Load data (pre-loaded and cached on args to avoid repeated I/O) + + if args.smoke_test: + train_dataset = SyntheticXPointDataset(nframes=10, shape=(64, 64), nxpoints=3) + val_dataset = SyntheticXPointDataset( + nframes=2, shape=(64, 64), nxpoints=3, seed=123 + ) + else: + train_dataset = args._train_dataset + val_dataset = args._val_dataset + + train_crop = XPointPatchDataset( + train_dataset, + patch=patch_size, + pos_ratio=pos_ratio, + retries=30, + augment=True, + seed=seed, + ) + val_crop = XPointPatchDataset( + val_dataset, + patch=patch_size, + pos_ratio=0.5, + retries=30, + augment=False, + seed=seed, + ) + + train_loader = DataLoader( + train_crop, batch_size=batch_size, shuffle=True, num_workers=0 + ) + val_loader = DataLoader( + val_crop, batch_size=batch_size, shuffle=False, num_workers=0 + ) + + + # 3. Create model, optimizer, scheduler + + model = UNet( + input_channels=4, base_channels=base_channels, dropout_rate=dropout_rate + ).to(device) + + num_epochs = args.epochs + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) + + if scheduler_name == "plateau": + plateau_factor = trial.suggest_float("plateau_factor", 0.2, 0.8) + plateau_patience = trial.suggest_int("plateau_patience", 3, 15) + scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + mode="min", + factor=plateau_factor, + patience=plateau_patience, + min_lr=1e-6, + ) + else: + scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=num_epochs, eta_min=1e-6 + ) + + criterion = DiceLoss(smooth=1.0) + + # --- AMP setup --- + use_amp = args.use_amp and torch.cuda.is_available() + amp_dtype = ( + torch.bfloat16 + if args.amp_dtype == "bfloat16" and torch.cuda.is_bf16_supported() + else torch.float16 + ) + scaler = GradScaler(enabled=(use_amp and amp_dtype == torch.float16)) + + + # 4. Training loop with Optuna pruning + + best_val_loss = float("inf") + patience_counter = 0 + patience = args.patience + epochs_trained = 0 + + for epoch in range(num_epochs): + train_loss = train_one_epoch( + model, + train_loader, + criterion, + optimizer, + device, + scaler, + use_amp, + amp_dtype, + ) + + # Reset validation RNG for deterministic crops each epoch + if seed is not None: + val_crop.reset_rng(seed) + + val_loss = validate_one_epoch( + model, val_loader, criterion, device, use_amp, amp_dtype + ) + + # LR scheduling + if scheduler_name == "plateau": + scheduler.step(val_loss) + else: + scheduler.step() + + # Track best + if val_loss < best_val_loss: + best_val_loss = val_loss + patience_counter = 0 + else: + patience_counter += 1 + + epochs_trained = epoch + 1 + + # === Report intermediate value to Optuna for pruning === + trial.report(val_loss, epoch) + + if trial.should_prune(): + raise TrialPruned() + + # Early stopping + if patience_counter >= patience: + break + + + # 5. (Optional) Full-frame evaluation for richer metrics + + if args.eval_on_full_frames and not args.smoke_test: + model.eval() + evaluator = evaluate_model_on_dataset( + model, val_dataset, device, use_amp=use_amp, amp_dtype=amp_dtype + ) + global_metrics = evaluator.get_global_metrics() + + trial.set_user_attr("val_f1", global_metrics["f1_score"]) + trial.set_user_attr("val_iou", global_metrics["iou"]) + trial.set_user_attr("val_precision", global_metrics["precision"]) + trial.set_user_attr("val_recall", global_metrics["recall"]) + + trial.set_user_attr("best_val_loss", best_val_loss) + trial.set_user_attr("epochs_trained", epochs_trained) + + return best_val_loss + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Optuna hyperparameter tuning for X-point classifier" + ) + + # --- Optuna settings --- + parser.add_argument( + "--n-trials", + type=int, + default=50, + help="Number of Optuna trials (default: 50)", + ) + parser.add_argument( + "--study-name", + type=str, + default="xpoint-tuning", + help="Optuna study name (default: xpoint-tuning)", + ) + parser.add_argument( + "--db", + type=str, + default="sqlite:///optuna_xpoint.db", + help="Optuna storage URL (default: sqlite:///optuna_xpoint.db)", + ) + parser.add_argument( + "--timeout", + type=int, + default=None, + help="Stop after this many seconds (default: None, run all trials)", + ) + parser.add_argument( + "--pruner", + type=str, + default="median", + choices=["median", "hyperband", "none"], + help="Pruning strategy (default: median)", + ) + + # --- Data settings (same as XPointMLTest.py) --- + parser.add_argument("--paramFile", type=Path, default=None) + parser.add_argument("--xptCacheDir", type=Path, default=None) + parser.add_argument("--trainFrameFirst", type=int, default=1) + parser.add_argument("--trainFrameLast", type=int, default=140) + parser.add_argument("--validationFrameFirst", type=int, default=141) + parser.add_argument("--validationFrameLast", type=int, default=150) + + # --- Training settings (fixed across all trials) --- + parser.add_argument( + "--epochs", + type=int, + default=300, + help="Max epochs PER TRIAL (default: 300, lower than full training for speed)", + ) + parser.add_argument( + "--patience", + type=int, + default=30, + help="Early stopping patience per trial (default: 30)", + ) + parser.add_argument("--use-amp", action="store_true", help="Enable AMP") + parser.add_argument( + "--amp-dtype", + type=str, + default="bfloat16", + choices=["float16", "bfloat16"], + ) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--require-gpu", action="store_true") + + # --- Evaluation --- + parser.add_argument( + "--eval-on-full-frames", + action="store_true", + help="Run full-frame evaluation after each trial (slower but gives F1/IoU)", + ) + + # --- Output --- + parser.add_argument( + "--results-dir", + type=Path, + default="./optuna_results", + help="Directory for result files (default: ./optuna_results)", + ) + + # --- Testing --- + parser.add_argument( + "--smoke-test", + action="store_true", + help="Run 3 trials with synthetic data (no paramFile needed)", + ) + + return parser.parse_args() + + +def print_study_summary(study, results_dir): + """Print and save a summary of the Optuna study results.""" + + print("\n" + "=" * 70) + print("OPTUNA STUDY SUMMARY") + print("=" * 70) + + completed = [ + t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE + ] + pruned = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED] + failed = [t for t in study.trials if t.state == optuna.trial.TrialState.FAIL] + + print(f"\nStudy name: {study.study_name}") + print(f"Total trials: {len(study.trials)}") + print(f" Completed: {len(completed)}") + print(f" Pruned: {len(pruned)}") + print(f" Failed: {len(failed)}") + + if not completed: + print("\nNo completed trials to summarize.") + return + + best = study.best_trial + print(f"\nBest trial: #{best.number}") + print(f" Best validation loss: {best.value:.6f}") + print(f" Hyperparameters:") + for key, value in sorted(best.params.items()): + if isinstance(value, float): + print(f" {key:25s} {value:.6g}") + else: + print(f" {key:25s} {value}") + + if best.user_attrs: + print(f" Additional metrics:") + for key, value in sorted(best.user_attrs.items()): + if isinstance(value, float): + print(f" {key:25s} {value:.4f}") + else: + print(f" {key:25s} {value}") + + # --- Top 5 trials --- + sorted_trials = sorted(completed, key=lambda t: t.value) + print(f"\nTop 5 trials:") + print( + f" {'#':>4s} {'Val Loss':>10s} {'LR':>10s} {'WD':>10s} " + f"{'Drop':>6s} {'BS':>4s} {'Patch':>5s} {'Ch':>4s} {'Sched':>8s}" + ) + for t in sorted_trials[:5]: + p = t.params + print( + f" {t.number:4d} {t.value:10.6f} " + f"{p.get('learning_rate', 0):10.2e} " + f"{p.get('weight_decay', 0):10.2e} " + f"{p.get('dropout_rate', 0):6.3f} " + f"{p.get('batch_size', 0):4d} " + f"{p.get('patch_size', 0):5d} " + f"{p.get('base_channels', 0):4d} " + f"{p.get('scheduler', 'n/a'):>8s}" + ) + + print("=" * 70) + + # --- Save results to JSON --- + results_dir.mkdir(parents=True, exist_ok=True) + + results = { + "study_name": study.study_name, + "n_trials_total": len(study.trials), + "n_completed": len(completed), + "n_pruned": len(pruned), + "n_failed": len(failed), + "best_trial": { + "number": best.number, + "value": best.value, + "params": best.params, + "user_attrs": best.user_attrs, + }, + "all_completed_trials": [ + { + "number": t.number, + "value": t.value, + "params": t.params, + "user_attrs": t.user_attrs, + } + for t in sorted_trials + ], + } + + results_path = results_dir / "optuna_results.json" + with open(results_path, "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to: {results_path}") + + # --- Generate shell command for retraining with best params --- + p = best.params + cmd_lines = [ + "#!/bin/bash", + f"# Auto-generated from Optuna study: {study.study_name}", + f"# Best trial #{best.number} with val_loss={best.value:.6f}", + "", + "python -u ${rcRoot}/reconClassifier/XPointMLTest.py \\", + f" --learningRate {p.get('learning_rate', 5e-4):.6g} \\", + f" --weightDecay {p.get('weight_decay', 5e-5):.6g} \\", + f" --dropoutRate {p.get('dropout_rate', 0.15):.4g} \\", + f" --batchSize {p.get('batch_size', 64)} \\", + ] + + sched = p.get("scheduler", "cosine") + cmd_lines.append(f" --scheduler {sched} \\") + if sched == "plateau": + cmd_lines.append( + f" --plateau-factor {p.get('plateau_factor', 0.5):.4g} \\" + ) + cmd_lines.append( + f" --plateau-patience {p.get('plateau_patience', 5)} \\" + ) + + cmd_lines.extend( + [ + " --use-amp \\", + " --seed 42 \\", + " --require-gpu \\", + " --fixed-val-crops \\", + " --epochs 1200 \\", + " --patience 200 \\", + " --checkPointFrequency 200 \\", + " --paramFile=${PARAM_FILE} \\", + " --xptCacheDir=${CACHE_DIR}", + ] + ) + + cmd_path = results_dir / "retrain_best_params.sh" + with open(cmd_path, "w") as f: + f.write("\n".join(cmd_lines) + "\n") + os.chmod(cmd_path, 0o755) + + print(f"Retrain command saved to: {cmd_path}") + print(f"\nTo retrain with best hyperparameters:") + print(f" bash {cmd_path}") + print("=" * 70) + + +def main(): + args = parse_args() + + # --- Smoke test overrides --- + if args.smoke_test: + print("=" * 60) + print("SMOKE TEST MODE: 3 trials with synthetic data") + print("=" * 60) + args.n_trials = 3 + args.epochs = 5 + args.patience = 3 + args.eval_on_full_frames = False + + # --- Validate args --- + if not args.smoke_test: + if args.paramFile is None: + print("ERROR: --paramFile is required (or use --smoke-test)") + sys.exit(1) + if not args.paramFile.exists(): + print(f"ERROR: paramFile {args.paramFile} does not exist") + sys.exit(1) + + if args.require_gpu and not torch.cuda.is_available(): + print("ERROR: --require-gpu set but no CUDA device found") + sys.exit(1) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + if torch.cuda.is_available(): + print(f"GPU: {torch.cuda.get_device_name(0)}") + + + # Pre-load datasets ONCE (expensive I/O — shared across all trials) + + if not args.smoke_test: + print("\nLoading datasets (shared across all trials)...") + t0 = time.time() + train_dataset = XPointDataset( + args.paramFile, + range(args.trainFrameFirst, args.trainFrameLast), + xptCacheDir=args.xptCacheDir, + rotateAndReflect=False, + ) + val_dataset = XPointDataset( + args.paramFile, + range(args.validationFrameFirst, args.validationFrameLast), + xptCacheDir=args.xptCacheDir, + rotateAndReflect=False, + ) + print(f"Datasets loaded in {time.time() - t0:.1f}s") + print(f" Training frames: {len(train_dataset)}") + print(f" Validation frames: {len(val_dataset)}") + + # Attach to args for access in objective() + args._train_dataset = train_dataset + args._val_dataset = val_dataset + + + # Create Optuna study + + if args.pruner == "median": + pruner = optuna.pruners.MedianPruner( + n_startup_trials=5, n_warmup_steps=10, interval_steps=5 + ) + elif args.pruner == "hyperband": + pruner = optuna.pruners.HyperbandPruner( + min_resource=10, max_resource=args.epochs, reduction_factor=3 + ) + else: + pruner = optuna.pruners.NopPruner() + + study = optuna.create_study( + study_name=args.study_name, + storage=args.db, + load_if_exists=True, + direction="minimize", + pruner=pruner, + ) + + n_existing = len(study.trials) + if n_existing > 0: + print(f"\nResuming study '{args.study_name}' with {n_existing} existing trials") + if study.best_trial: + print(f"Current best val_loss: {study.best_trial.value:.6f}") + + + # Run optimization + + print(f"\nStarting Optuna optimization") + print(f" Trials: {args.n_trials}") + print(f" DB: {args.db}") + print(f" Pruner: {args.pruner}") + print(f" Max epochs: {args.epochs} per trial") + print(f" Patience: {args.patience} per trial") + if args.timeout: + print(f" Timeout: {args.timeout}s") + print() + + study.optimize( + lambda trial: objective(trial, args), + n_trials=args.n_trials, + timeout=args.timeout, + show_progress_bar=True, + ) + + + # Results + + print_study_summary(study, args.results_dir) + + # --- Try to generate visualization plots --- + try: + from optuna.visualization import ( + plot_param_importances, + plot_optimization_history, + plot_parallel_coordinate, + ) + + fig_dir = args.results_dir / "figures" + fig_dir.mkdir(parents=True, exist_ok=True) + + for name, plot_fn in [ + ("param_importances", plot_param_importances), + ("optimization_history", plot_optimization_history), + ("parallel_coordinate", plot_parallel_coordinate), + ]: + try: + fig = plot_fn(study) + fig.write_html(str(fig_dir / f"{name}.html")) + print(f" Saved: {fig_dir / name}.html") + except Exception as e: + print(f" Skipped {name}: {e}") + + except ImportError: + print("\nNote: pip install plotly for interactive visualizations") + + +if __name__ == "__main__": + main() \ No newline at end of file From ce112bff5b94bd5a9b4a89f44cef03ae586ae743 Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Fri, 10 Apr 2026 11:55:57 -0500 Subject: [PATCH 6/7] Added a bunch of new CLI args to XPointMLTest.py like baseChannels, posRatio, lossFunction, warmupEpochs, and swa so we can actually tune all the stuff that was hardcoded before, especially base_channels which was stuck at 32 while all the optuna tuners used 64. Also threw in a FocalDiceLoss class that combines focal and dice loss to help with the crazy class imbalance, and hooked up linear LR warmup and stochastic weight averaging with a custom BN update that works with our dict based dataloader. Created test_xpoint_transfer.py to evaluate our best PKPM trained model on the 5M and 10M datasets, which includes a monkey patch for the double component indexing bug in getData.py since we cant modify files outside reconClassifier. Then made build_transfer_cache.py to precompute and cache the xpoint finder results for all 150 frames of 5M and 10M data so we dont have to wait 20 minutes per frame every time we want to run the transfer evaluation. --- XPointMLTest.py | 146 +++++++++++++++++++-- build_transfer_cache.py | 128 +++++++++++++++++++ test_xpoint_transfer.py | 273 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 534 insertions(+), 13 deletions(-) create mode 100644 build_transfer_cache.py create mode 100644 test_xpoint_transfer.py diff --git a/XPointMLTest.py b/XPointMLTest.py index 7a5eb72..d613551 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -329,7 +329,8 @@ def load(self, fnum): class XPointPatchDataset(Dataset): """On‑the‑fly square crops with data augmentation, balancing positive / background patches.""" - def __init__(self, base_ds, patch=64, pos_ratio=0.5, retries=30, augment=False, seed=None): + def __init__(self, base_ds, patch=64, pos_ratio=0.5, retries=30, augment=False, seed=None, + noise_std_max=0.02, cutout_prob=0.20): """ Parameters: ----------- @@ -345,12 +346,18 @@ def __init__(self, base_ds, patch=64, pos_ratio=0.5, retries=30, augment=False, If True, apply on-the-fly data augmentation (use for training only) seed : int or None Random seed for reproducibility (None for non-deterministic) + noise_std_max : float + Maximum std for Gaussian noise augmentation (default: 0.02) + cutout_prob : float + Probability of applying cutout augmentation (default: 0.20) """ self.base_ds = base_ds self.patch = patch self.pos_ratio = pos_ratio self.retries = retries self.augment = augment + self.noise_std_max = noise_std_max + self.cutout_prob = cutout_prob # Initialize RNG with seed if provided if seed is not None: @@ -407,7 +414,7 @@ def _apply_augmentation(self, all_data, mask): # 4. Add Gaussian noise (30% chance) # Small noise helps prevent overfitting to exact pixel values if self.rng.random() < 0.3: - noise_std = self.rng.uniform(0.005, 0.02) + noise_std = self.rng.uniform(0.005, self.noise_std_max) noise = torch.randn_like(all_data) * noise_std all_data = all_data + noise @@ -421,9 +428,9 @@ def _apply_augmentation(self, all_data, mask): mean = all_data.mean(dim=(-2, -1), keepdim=True) all_data = contrast * (all_data - mean) + mean + brightness - # 6. Cutout/Random erasing (20% chance) + # 6. Cutout/Random erasing # Prevents model from relying too heavily on specific spatial features - if self.rng.random() < 0.20: + if self.rng.random() < self.cutout_prob: h, w = all_data.shape[-2:] cutout_size = int(min(h, w) * self.rng.uniform(0.1, 0.25)) if cutout_size > 0: @@ -628,6 +635,35 @@ def forward(self, inputs, targets): # Return Dice loss (1 - Dice coefficient) return 1.0 - dice + +class FocalDiceLoss(nn.Module): + """Combined Focal + Dice loss for extreme class imbalance. + + Focal loss downweights easy negatives so the model focuses on hard + positives near X-point boundaries. Combined with Dice loss + which directly optimizes region overlap. + """ + def __init__(self, alpha=0.75, gamma=2.0, dice_weight=0.5, smooth=1.0): + super().__init__() + self.alpha = alpha + self.gamma = gamma + self.dice_weight = dice_weight + self.dice = DiceLoss(smooth=smooth) + + def forward(self, inputs, targets): + # Focal loss + bce = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none') + p = torch.sigmoid(inputs) + pt = p * targets + (1 - p) * (1 - targets) + alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets) + focal = alpha_t * ((1 - pt) ** self.gamma) * bce + focal_loss = focal.mean() + + # Dice loss + dice_loss = self.dice(inputs, targets) + + return (1 - self.dice_weight) * focal_loss + self.dice_weight * dice_loss + # TRAIN & VALIDATION UTILS def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp, amp_dtype, benchmark=None): model.train() @@ -874,6 +910,25 @@ def parseCommandLineArgs(): help='specify the weight decay (L2 regularization) for optimizer') parser.add_argument('--dropoutRate', type=float, default=0.2, help='specify the dropout rate for regularization') + parser.add_argument('--baseChannels', type=int, default=64, + help='base number of channels in the UNet encoder (default: 64)') + parser.add_argument('--posRatio', type=float, default=0.5, + help='target ratio of patches containing X-points (default: 0.5)') + parser.add_argument('--lossFunction', type=str, default='dice', + choices=['dice', 'focal_dice'], + help='loss function: dice (default) or focal_dice (combined focal + dice)') + parser.add_argument('--focalAlpha', type=float, default=0.75, + help='focal loss alpha (class balance weight, default: 0.75)') + parser.add_argument('--focalGamma', type=float, default=2.0, + help='focal loss gamma (focusing parameter, default: 2.0)') + parser.add_argument('--focalDiceWeight', type=float, default=0.5, + help='weight of dice component in FocalDiceLoss (default: 0.5)') + parser.add_argument('--warmupEpochs', type=int, default=0, + help='number of linear warmup epochs before cosine decay (default: 0)') + parser.add_argument('--swa', action='store_true', + help='enable Stochastic Weight Averaging for better generalization') + parser.add_argument('--swaStart', type=float, default=0.75, + help='fraction of total epochs after which SWA begins (default: 0.75)') parser.add_argument('--batchSize', type=int, default=1, help='specify the batch size') parser.add_argument('--epochs', type=int, default=2000, @@ -1168,7 +1223,7 @@ def main(): xptCacheDir=args.xptCacheDir, rotateAndReflect=False) # Enable augmentation for training, disable for validation - train_crop = XPointPatchDataset(train_dataset, patch=64, pos_ratio=0.5, retries=30, + train_crop = XPointPatchDataset(train_dataset, patch=64, pos_ratio=args.posRatio, retries=30, augment=True, seed=args.seed) val_crop = XPointPatchDataset(val_dataset, patch=64, pos_ratio=0.5, retries=30, augment=False, seed=args.seed) @@ -1214,7 +1269,7 @@ def main(): benchmark.print_hardware_info() # Use the improved model - model = UNet(input_channels=4, base_channels=32, dropout_rate=args.dropoutRate).to(device) + model = UNet(input_channels=4, base_channels=args.baseChannels, dropout_rate=args.dropoutRate).to(device) # Count parameters total_params = sum(p.numel() for p in model.parameters()) @@ -1223,13 +1278,19 @@ def main(): print(f"Trainable parameters: {trainable_params:,}") print(f"Dropout rate: {args.dropoutRate}") - criterion = DiceLoss(smooth=1.0) + if args.lossFunction == 'focal_dice': + criterion = FocalDiceLoss(alpha=args.focalAlpha, gamma=args.focalGamma, + dice_weight=args.focalDiceWeight, smooth=1.0) + print(f"Loss function: FocalDiceLoss (alpha={args.focalAlpha}, gamma={args.focalGamma}, dice_weight={args.focalDiceWeight})") + else: + criterion = DiceLoss(smooth=1.0) + print("Loss function: DiceLoss") # Use AdamW optimizer with weight decay for better generalization optimizer = optim.AdamW(model.parameters(), lr=args.learningRate, weight_decay=args.weightDecay) print(f"Optimizer: AdamW with learning_rate={args.learningRate}, weight_decay={args.weightDecay}") - # Learning rate scheduler + # Learning rate scheduler (with optional warmup) if args.scheduler == 'plateau': scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, @@ -1240,8 +1301,32 @@ def main(): ) print(f"Scheduler: ReduceLROnPlateau (factor={args.plateau_factor}, patience={args.plateau_patience}, min_lr={args.plateau_min_lr})") else: - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) - print("Scheduler: CosineAnnealingLR") + if args.warmupEpochs > 0: + cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=args.epochs - args.warmupEpochs, eta_min=1e-6 + ) + warmup_scheduler = optim.lr_scheduler.LinearLR( + optimizer, start_factor=0.01, total_iters=args.warmupEpochs + ) + scheduler = optim.lr_scheduler.SequentialLR( + optimizer, + schedulers=[warmup_scheduler, cosine_scheduler], + milestones=[args.warmupEpochs] + ) + print(f"Scheduler: CosineAnnealingLR with {args.warmupEpochs}-epoch linear warmup") + else: + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6) + print("Scheduler: CosineAnnealingLR") + + # SWA setup + swa_model = None + swa_scheduler = None + swa_start_epoch = int(args.epochs * args.swaStart) + if args.swa: + from torch.optim.swa_utils import AveragedModel, SWALR + swa_model = AveragedModel(model) + swa_scheduler = SWALR(optimizer, swa_lr=args.learningRate * 0.1) + print(f"SWA: Enabled (starts at epoch {swa_start_epoch})") # --- AMP Setup (bfloat16 aware) --- use_amp = args.use_amp and torch.cuda.is_available() @@ -1300,7 +1385,10 @@ def main(): print(log_msg) # Learning rate scheduling - if args.scheduler == 'plateau': + if args.swa and epoch >= swa_start_epoch: + swa_model.update_parameters(model) + swa_scheduler.step() + elif args.scheduler == 'plateau': scheduler.step(val_loss_epoch) else: scheduler.step() @@ -1319,11 +1407,43 @@ def main(): if (epoch+1) % args.checkPointFrequency == 0: save_model_checkpoint(model, optimizer, train_loss, val_loss, epoch+1, checkpoint_dir, scaler, best_val_loss) - # Early stopping - if patience_counter >= args.patience: + # Early stopping (disabled during SWA phase) + if patience_counter >= args.patience and not (args.swa and epoch >= swa_start_epoch): print(f"Early stopping triggered after {epoch+1} epochs (patience={args.patience})") break + # Finalize SWA: update batch normalization statistics + if args.swa and swa_model is not None: + print("Updating SWA batch normalization statistics...") + # Custom BN update since our DataLoader yields dicts, not raw tensors + momenta = {} + for module in swa_model.modules(): + if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + module.running_mean = torch.zeros_like(module.running_mean) + module.running_var = torch.ones_like(module.running_var) + momenta[module] = module.momentum + module.momentum = None + module.num_batches_tracked *= 0 + swa_model.train() + with torch.no_grad(): + for batch in train_loader: + all_data = batch["all"].to(device) + with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp): + swa_model(all_data) + for module, momentum in momenta.items(): + module.momentum = momentum + # Save SWA model as best if it improves val loss + swa_model.eval() + swa_val_loss = validate_one_epoch(swa_model, val_loader, criterion, device, use_amp, amp_dtype) + print(f"SWA model validation loss: {swa_val_loss:.6f} (vs best: {best_val_loss:.6f})") + if swa_val_loss < best_val_loss: + print("SWA model is better - saving as best model") + # Extract the inner module for saving + torch.save(swa_model.module.state_dict(), os.path.join(checkpoint_dir, "best_model.pt")) + best_val_loss = swa_val_loss + else: + print("SWA model did not improve - keeping original best model") + plot_training_history(train_loss, val_loss, save_path='plots/training_history.png') print("time (s) to train model: " + str(timer()-t2)) diff --git a/build_transfer_cache.py b/build_transfer_cache.py new file mode 100644 index 0000000..b19db44 --- /dev/null +++ b/build_transfer_cache.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +""" +Build X-point cache for 5M/10M datasets. + +This pre-computes and caches the X-point finder results so that +test_xpoint_transfer.py can load all frames quickly. + +Usage: + python build_transfer_cache.py --dataset 5M --start 1 --end 75 + python build_transfer_cache.py --dataset 10M --start 76 --end 150 +""" + +import argparse +import sys +import time +from pathlib import Path + +import numpy as np + +# -- Monkey-patch: fix 5m/10m .gkyl component indexing bug in getData.py -- +import postgkyl as pg +_orig_GData_init = pg.data.GData.__init__ +def _fixed_GData_init(self, *args, comp=None, **kwargs): + _orig_GData_init(self, *args, **kwargs) +pg.data.GData.__init__ = _fixed_GData_init +# -- End monkey-patch -- + +RC_ROOT = Path(__file__).resolve().parent +sys.path.insert(0, str(RC_ROOT)) + +from XPointMLTest import XPointDataset + +EXTRACT_DIR = Path("/work/nvme/bfim/ssridhar6/mlReconnection2025") +CACHE_BASE = Path("/work/nvme/bfim/ssridhar6/mlReconnection2025/cache") + +DATASETS = { + "5M": { + "extract_subdir": EXTRACT_DIR / "5M", + "param_file": "rt_5M_2d_turb_local-params.txt", + }, + "10M": { + "extract_subdir": EXTRACT_DIR / "10M", + "param_file": "rt_10M_2d_turb_local-params.txt", + }, +} + + +def find_param_file(extract_dir, param_file_name): + for p in extract_dir.rglob(param_file_name): + return p + for p in extract_dir.rglob("*-params.txt"): + return p + raise FileNotFoundError(f"No params file found in {extract_dir}") + + +def discover_all_frames(extract_dir): + """Find all frame numbers from field files.""" + field_files = list(extract_dir.rglob("*-field_*.gkyl")) + frame_nums = set() + for f in field_files: + parts = f.stem.rsplit("_", 1) + if len(parts) == 2 and parts[1].isdigit(): + frame_nums.add(int(parts[1])) + return sorted(f for f in frame_nums if f > 0) + + +def main(): + parser = argparse.ArgumentParser(description="Build X-point cache for transfer datasets") + parser.add_argument("--dataset", required=True, choices=["5M", "10M"]) + parser.add_argument("--start", type=int, default=None, help="First frame index (inclusive)") + parser.add_argument("--end", type=int, default=None, help="Last frame index (inclusive)") + args = parser.parse_args() + + ds = DATASETS[args.dataset] + cache_dir = CACHE_BASE / args.dataset + cache_dir.mkdir(parents=True, exist_ok=True) + + param_path = find_param_file(ds["extract_subdir"], ds["param_file"]) + print(f"Dataset: {args.dataset}") + print(f"Param file: {param_path}") + print(f"Cache dir: {cache_dir}") + + all_frames = discover_all_frames(ds["extract_subdir"]) + print(f"Total frames available: {len(all_frames)} ({all_frames[0]}-{all_frames[-1]})") + + # Apply range filter + start = args.start if args.start is not None else all_frames[0] + end = args.end if args.end is not None else all_frames[-1] + frames = [f for f in all_frames if start <= f <= end] + print(f"Processing frames {start}-{end}: {len(frames)} frames") + + # Check which frames are already cached + from XPointMLTest import cachedPgkylDataExists + uncached = [f for f in frames if not cachedPgkylDataExists(cache_dir, f, "psi")] + cached = len(frames) - len(uncached) + print(f"Already cached: {cached}, need to compute: {len(uncached)}") + + if not uncached: + print("All frames already cached!") + return + + # Process frames one at a time, with progress and timing + total_time = 0 + for i, fnum in enumerate(uncached): + print(f"\n[{i+1}/{len(uncached)}] Frame {fnum}...", flush=True) + t0 = time.time() + + # XPointDataset will compute and cache for us + dataset = XPointDataset( + str(param_path), + [fnum], + xptCacheDir=cache_dir, + rotateAndReflect=False, + ) + + elapsed = time.time() - t0 + total_time += elapsed + avg = total_time / (i + 1) + remaining = avg * (len(uncached) - i - 1) + print(f" Done in {elapsed:.1f}s | Avg: {avg:.1f}s/frame | " + f"ETA: {remaining/3600:.1f}h remaining", flush=True) + + print(f"\nCache building complete! Total time: {total_time/3600:.1f}h") + print(f"Cached {len(uncached)} frames to {cache_dir}") + + +if __name__ == "__main__": + main() diff --git a/test_xpoint_transfer.py b/test_xpoint_transfer.py new file mode 100644 index 0000000..f059178 --- /dev/null +++ b/test_xpoint_transfer.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python3 +""" +Cross-domain inference: evaluate the best PKPM-trained model on 5M and 10M data. + +This script: + 1. Extracts 5M.tgz and 10M.tgz (if not already extracted) + 2. Loads the best model checkpoint + 3. Runs evaluate_model_on_dataset on each dataset + 4. Saves per-dataset evaluation metrics + +Usage (within SLURM job): + python -u reconClassifier/test_xpoint_transfer.py +""" + +import json +import os +import subprocess +import sys +import time +from pathlib import Path + +import numpy as np +import torch + +# -- Monkey-patch: fix 5m/10m .gkyl component indexing bug in getData.py -- +# The bug: getData.py passes comp=N to pg.data.GData which pre-selects the +# component, then tries data[..., N] which fails (array only has 1 element). +# Fix: drop the comp kwarg so GData returns all components. +import postgkyl as pg +_orig_GData_init = pg.data.GData.__init__ +def _fixed_GData_init(self, *args, comp=None, **kwargs): + _orig_GData_init(self, *args, **kwargs) +pg.data.GData.__init__ = _fixed_GData_init +# -- End monkey-patch -- + +# Add reconClassifier to path +RC_ROOT = Path(__file__).resolve().parent +sys.path.insert(0, str(RC_ROOT)) + +from XPointMLTest import XPointDataset, UNet +from eval_metrics import evaluate_model_on_dataset + +# ── Configuration ────────────────────────────────────────────────────── +SOURCE_DIR = Path("/work/nvme/bfim/cwsmith/mlReconnection2025") +EXTRACT_DIR = Path("/work/nvme/bfim/ssridhar6/mlReconnection2025") +CACHE_BASE = EXTRACT_DIR / "cache" +BEST_MODEL = Path.home() / "mlReconnection/testdir_2026-04-01-19-57-11/checkpoints/best_model.pt" +OUTPUT_DIR = Path.home() / "mlReconnection/transfer_eval_results" + +DATASETS = { + "5M": { + "tarball": SOURCE_DIR / "5M.tgz", + "extract_subdir": EXTRACT_DIR / "5M", + "param_file": "rt_5M_2d_turb_local-params.txt", + "tar_prefix": "./", # files are at root of tarball + }, + "10M": { + "tarball": SOURCE_DIR / "10M.tgz", + "extract_subdir": EXTRACT_DIR / "10M", + "param_file": "rt_10M_2d_turb_local-params.txt", + "tar_prefix": "10M/", # files are under 10M/ subdir + }, +} + +# Model config must match training +BASE_CHANNELS = 64 +DROPOUT_RATE = 0.055 # Trial #36 value (doesn't matter for eval, just architecture) + + +def extract_tarball(tarball_path, extract_dir, name): + """Extract tarball if not already done.""" + param_candidates = list(extract_dir.glob("*-params.txt")) + if param_candidates: + print(f" [{name}] Already extracted ({len(param_candidates)} param file(s) found)") + return + + extract_dir.mkdir(parents=True, exist_ok=True) + print(f" [{name}] Extracting {tarball_path} -> {extract_dir} ...") + t0 = time.time() + subprocess.run( + ["tar", "xzf", str(tarball_path), "-C", str(extract_dir), "--strip-components=0"], + check=True, + ) + elapsed = time.time() - t0 + print(f" [{name}] Extraction complete in {elapsed:.1f}s") + + +def discover_frames(extract_dir, param_file_name): + """Discover available frame numbers from field files.""" + # Find the param file + param_path = None + for p in extract_dir.rglob(param_file_name): + param_path = p + break + + if param_path is None: + # Try searching more broadly + for p in extract_dir.rglob("*-params.txt"): + param_path = p + break + + if param_path is None: + raise FileNotFoundError(f"No params file found in {extract_dir}") + + print(f" Found param file: {param_path}") + + # Discover frame numbers from field files in same directory + param_dir = param_path.parent + field_files = sorted(param_dir.glob("*-field_*.gkyl")) + frame_nums = set() + for f in field_files: + # Extract number from pattern like "...-field_42.gkyl" + stem = f.stem # e.g. "rt_5M_2d_turb_local-field_42" + parts = stem.rsplit("_", 1) + if len(parts) == 2 and parts[1].isdigit(): + frame_nums.add(int(parts[1])) + + frame_nums = sorted(frame_nums) + print(f" Found {len(frame_nums)} frames: {frame_nums[0]}-{frame_nums[-1]}") + + # Exclude frame 0 (often initial conditions, not interesting) + frame_nums = [f for f in frame_nums if f > 0] + + return param_path, frame_nums + + +def main(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + if torch.cuda.is_available(): + print(f"GPU: {torch.cuda.get_device_name(0)}") + + # ── Load model ───────────────────────────────────────────────────── + print(f"\nLoading model from {BEST_MODEL}") + model = UNet(input_channels=4, base_channels=BASE_CHANNELS, dropout_rate=DROPOUT_RATE).to(device) + state_dict = torch.load(str(BEST_MODEL), map_location=device, weights_only=False) + model.load_state_dict(state_dict) + model.eval() + total_params = sum(p.numel() for p in model.parameters()) + print(f"Model loaded: {total_params:,} parameters") + + # ── Setup output ─────────────────────────────────────────────────── + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + use_amp = torch.cuda.is_available() + amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16 + + all_results = {} + + # ── Process each dataset ─────────────────────────────────────────── + for ds_name, ds_config in DATASETS.items(): + print(f"\n{'='*70}") + print(f"EVALUATING ON {ds_name} DATA") + print(f"{'='*70}") + + # Extract + extract_tarball(ds_config["tarball"], ds_config["extract_subdir"], ds_name) + + # Discover frames + try: + param_path, frame_nums = discover_frames( + ds_config["extract_subdir"], ds_config["param_file"] + ) + except FileNotFoundError as e: + print(f" ERROR: {e}") + continue + + # Load dataset using cache (pre-built by build_transfer_cache.py) + cache_dir = CACHE_BASE / ds_name + if not cache_dir.is_dir(): + print(f" ERROR: Cache directory {cache_dir} not found.") + print(f" Run build_transfer_cache.py --dataset {ds_name} first.") + continue + + print(f" Loading {ds_name} dataset ({len(frame_nums)} frames, cache={cache_dir})...", flush=True) + t0 = time.time() + try: + dataset = XPointDataset( + str(param_path), frame_nums, + xptCacheDir=cache_dir, rotateAndReflect=False, + ) + except Exception as e: + print(f"\n ERROR loading dataset: {e}") + import traceback + traceback.print_exc() + continue + elapsed = time.time() - t0 + print(f" All {len(frame_nums)} frames loaded in {elapsed:.1f}s") + + # Check grid size + sample = dataset[0] + grid_shape = sample["all"].shape + print(f" Grid shape: {grid_shape} (channels, H, W)") + + # Evaluate + print(f" Running inference...") + t0 = time.time() + evaluator = evaluate_model_on_dataset( + model, dataset, device, + use_amp=use_amp, amp_dtype=amp_dtype, threshold=0.5, + ) + elapsed = time.time() - t0 + + evaluator.print_summary() + + # Save per-dataset results + output_file = OUTPUT_DIR / f"eval_{ds_name.lower()}.json" + evaluator.save_json(str(output_file)) + + metrics = evaluator.get_global_metrics() + metrics["inference_time_s"] = elapsed + metrics["num_frames"] = len(frame_nums) + metrics["grid_shape"] = list(grid_shape) + all_results[ds_name] = metrics + print(f" Inference time: {elapsed:.1f}s ({elapsed/len(frame_nums):.2f}s/frame)") + + # Save partial results incrementally in case of timeout + with open(OUTPUT_DIR / "transfer_summary.json", "w") as f: + json.dump(all_results, f, indent=2) + + # ── Also re-evaluate on original PKPM validation set for comparison ── + print(f"\n{'='*70}") + print(f"RE-EVALUATING ON PKPM VALIDATION (baseline comparison)") + print(f"{'='*70}") + + pkpm_param = "/work/nvme/bfim/cwsmith/mlReconnection2025/1024Res_v0/pkpm_2d_turb_p2-params.txt" + pkpm_cache = "/work/nvme/bfim/cwsmith/mlReconnection2025/1024Res_v0/cache04082025" + pkpm_val_frames = list(range(141, 150)) + + print(f" Loading PKPM validation ({len(pkpm_val_frames)} frames)...") + t0 = time.time() + pkpm_dataset = XPointDataset( + pkpm_param, pkpm_val_frames, + xptCacheDir=Path(pkpm_cache), + rotateAndReflect=False, + ) + print(f" Loaded in {time.time()-t0:.1f}s") + + t0 = time.time() + pkpm_evaluator = evaluate_model_on_dataset( + model, pkpm_dataset, device, + use_amp=use_amp, amp_dtype=amp_dtype, threshold=0.5, + ) + elapsed = time.time() - t0 + pkpm_evaluator.print_summary() + pkpm_evaluator.save_json(str(OUTPUT_DIR / "eval_pkpm_val.json")) + + pkpm_metrics = pkpm_evaluator.get_global_metrics() + pkpm_metrics["inference_time_s"] = elapsed + pkpm_metrics["num_frames"] = len(pkpm_val_frames) + all_results["PKPM_val"] = pkpm_metrics + + # ── Summary comparison ───────────────────────────────────────────── + print(f"\n{'='*70}") + print("CROSS-DOMAIN TRANSFER SUMMARY") + print(f"{'='*70}") + print(f"{'Dataset':>12s} {'F1':>7s} {'Prec':>7s} {'Rec':>7s} {'IoU':>7s} {'Frames':>6s} {'Grid':>12s}") + print("-" * 70) + for name, m in all_results.items(): + grid_str = "x".join(str(x) for x in m.get("grid_shape", [])) if "grid_shape" in m else "N/A" + print(f"{name:>12s} {m['f1_score']:7.4f} {m['precision']:7.4f} " + f"{m['recall']:7.4f} {m['iou']:7.4f} {m.get('num_frames','?'):>6} {grid_str:>12s}") + print(f"{'='*70}") + + # Save combined summary + summary_path = OUTPUT_DIR / "transfer_summary.json" + with open(summary_path, "w") as f: + json.dump(all_results, f, indent=2) + print(f"\nCombined summary saved to: {summary_path}") + + +if __name__ == "__main__": + main() From 76d0f79e0acf4b390a55e11a7b4513c6a1b26556 Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Thu, 7 May 2026 06:14:31 -0500 Subject: [PATCH 7/7] Add parallel cache build with env-var path overrides (RC_EXTRACT_DIR, RC_CACHE_BASE) for ramdisk staging, and repoint transfer eval to the production checkpoint testdir_2026-04-02-13-23-05. XPointMLTest.py now profiles getPgkylData stages and reuses precomputed second derivatives as the Hessian for getXOPoints. --- XPointMLTest.py | 19 +++++++++- build_transfer_cache.py | 84 +++++++++++++++++++++++++++++------------ test_xpoint_transfer.py | 2 +- 3 files changed, 78 insertions(+), 27 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index d613551..c635303 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -150,23 +150,38 @@ def getPgkylData(paramFile, frameNumber, verbosity): params["polyOrderOverride"] = 0 #Override default dg interpolation and interpolate to given number of points constrcutBandJ = 1 #Read vector potential + import time as _time + _t0 = _time.time() var = gkData.gkData(str(paramFile),frameNumber,'psi',params).compactRead() + _t1 = _time.time() + print(f" [PROFILE] compactRead: {_t1-_t0:.1f}s") psi = var.data coords = var.coords axesNorm = var.d[ var.speciesFileIndex.index('ion') ] if verbosity > 0: print(f"psi shape: {psi.shape}, min={psi.min()}, max={psi.max()}") #Construct B and J (first and second derivatives) + _t2 = _time.time() [df_dx,df_dy,df_dz] = auxFuncs.genGradient(psi,var.dx) [d2f_dxdx,d2f_dxdy,d2f_dxdz] = auxFuncs.genGradient(df_dx,var.dx) [d2f_dydx,d2f_dydy,d2f_dydz] = auxFuncs.genGradient(df_dy,var.dx) + _t3 = _time.time() + print(f" [PROFILE] 3x genGradient: {_t3-_t2:.1f}s") bx = df_dy by = -df_dx jz = -(d2f_dxdx + d2f_dydy) / var.mu0 - del df_dx,df_dy,df_dz,d2f_dxdx,d2f_dxdy,d2f_dxdz,d2f_dydx,d2f_dydy,d2f_dydz + #Precompute Hessian from already-computed derivatives (avoid redundant gradient calls) + Hess = np.array([d2f_dxdx, d2f_dxdy, d2f_dxdy, d2f_dydy]) + del df_dz,d2f_dxdz,d2f_dydz,d2f_dxdx,d2f_dxdy,d2f_dydx,d2f_dydy,df_dx,df_dy #Indicies of critical points, X points, and O points (max and min) + _t4 = _time.time() critPoints = auxFuncs.getCritPoints(psi) - [xpts, optsMax, optsMin] = auxFuncs.getXOPoints(psi, critPoints) + _t5 = _time.time() + print(f" [PROFILE] getCritPoints: {_t5-_t4:.1f}s") + [xpts, optsMax, optsMin] = auxFuncs.getXOPoints(psi, critPoints, hessian=Hess) + _t6 = _time.time() + print(f" [PROFILE] getXOPoints: {_t6-_t5:.1f}s") + print(f" [PROFILE] TOTAL: {_t6-_t0:.1f}s") return [var.filenameBase, axesNorm, critPoints, xpts, optsMax, optsMin, coords, psi, bx, by, jz] def cachedPgkylDataExists(cacheDir, frameNumber, fieldName): diff --git a/build_transfer_cache.py b/build_transfer_cache.py index b19db44..7463799 100644 --- a/build_transfer_cache.py +++ b/build_transfer_cache.py @@ -11,6 +11,8 @@ """ import argparse +import multiprocessing as mp +import os import sys import time from pathlib import Path @@ -30,8 +32,14 @@ def _fixed_GData_init(self, *args, comp=None, **kwargs): from XPointMLTest import XPointDataset -EXTRACT_DIR = Path("/work/nvme/bfim/ssridhar6/mlReconnection2025") -CACHE_BASE = Path("/work/nvme/bfim/ssridhar6/mlReconnection2025/cache") +EXTRACT_DIR = Path(os.environ.get( + "RC_EXTRACT_DIR", + "/work/nvme/bfim/ssridhar6/mlReconnection2025", +)) +CACHE_BASE = Path(os.environ.get( + "RC_CACHE_BASE", + "/work/nvme/bfim/ssridhar6/mlReconnection2025/cache", +)) DATASETS = { "5M": { @@ -64,11 +72,27 @@ def discover_all_frames(extract_dir): return sorted(f for f in frame_nums if f > 0) +def _process_frame(task): + """Worker function: process a single frame. Returns (frame_num, elapsed).""" + param_path, fnum, cache_dir = task + t0 = time.time() + dataset = XPointDataset( + str(param_path), + [fnum], + xptCacheDir=cache_dir, + rotateAndReflect=False, + ) + elapsed = time.time() - t0 + return fnum, elapsed + + def main(): parser = argparse.ArgumentParser(description="Build X-point cache for transfer datasets") parser.add_argument("--dataset", required=True, choices=["5M", "10M"]) parser.add_argument("--start", type=int, default=None, help="First frame index (inclusive)") parser.add_argument("--end", type=int, default=None, help="Last frame index (inclusive)") + parser.add_argument("--workers", type=int, default=1, + help="Number of parallel workers (default: 1)") args = parser.parse_args() ds = DATASETS[args.dataset] @@ -79,6 +103,7 @@ def main(): print(f"Dataset: {args.dataset}") print(f"Param file: {param_path}") print(f"Cache dir: {cache_dir}") + print(f"Workers: {args.workers}") all_frames = discover_all_frames(ds["extract_subdir"]) print(f"Total frames available: {len(all_frames)} ({all_frames[0]}-{all_frames[-1]})") @@ -99,28 +124,39 @@ def main(): print("All frames already cached!") return - # Process frames one at a time, with progress and timing - total_time = 0 - for i, fnum in enumerate(uncached): - print(f"\n[{i+1}/{len(uncached)}] Frame {fnum}...", flush=True) - t0 = time.time() - - # XPointDataset will compute and cache for us - dataset = XPointDataset( - str(param_path), - [fnum], - xptCacheDir=cache_dir, - rotateAndReflect=False, - ) - - elapsed = time.time() - t0 - total_time += elapsed - avg = total_time / (i + 1) - remaining = avg * (len(uncached) - i - 1) - print(f" Done in {elapsed:.1f}s | Avg: {avg:.1f}s/frame | " - f"ETA: {remaining/3600:.1f}h remaining", flush=True) - - print(f"\nCache building complete! Total time: {total_time/3600:.1f}h") + tasks = [(param_path, fnum, cache_dir) for fnum in uncached] + + if args.workers <= 1: + # Sequential mode (original behavior) + total_time = 0 + for i, task in enumerate(tasks): + fnum = task[1] + print(f"\n[{i+1}/{len(uncached)}] Frame {fnum}...", flush=True) + _, elapsed = _process_frame(task) + total_time += elapsed + avg = total_time / (i + 1) + remaining = avg * (len(uncached) - i - 1) + print(f" Done in {elapsed:.1f}s | Avg: {avg:.1f}s/frame | " + f"ETA: {remaining/3600:.1f}h remaining", flush=True) + else: + # Parallel mode + print(f"\nStarting parallel processing with {args.workers} workers...", flush=True) + total_time = 0 + completed = 0 + wall_start = time.time() + with mp.Pool(processes=args.workers) as pool: + for fnum, elapsed in pool.imap_unordered(_process_frame, tasks): + completed += 1 + total_time += elapsed + wall_elapsed = time.time() - wall_start + avg_wall = wall_elapsed / completed + remaining = avg_wall * (len(uncached) - completed) + print(f"[{completed}/{len(uncached)}] Frame {fnum} done in {elapsed:.1f}s | " + f"Wall avg: {avg_wall:.1f}s/frame | " + f"ETA: {remaining/60:.1f}m remaining", flush=True) + total_time = time.time() - wall_start + + print(f"\nCache building complete! Wall time: {total_time/60:.1f}m ({total_time/3600:.1f}h)") print(f"Cached {len(uncached)} frames to {cache_dir}") diff --git a/test_xpoint_transfer.py b/test_xpoint_transfer.py index f059178..5175995 100644 --- a/test_xpoint_transfer.py +++ b/test_xpoint_transfer.py @@ -44,7 +44,7 @@ def _fixed_GData_init(self, *args, comp=None, **kwargs): SOURCE_DIR = Path("/work/nvme/bfim/cwsmith/mlReconnection2025") EXTRACT_DIR = Path("/work/nvme/bfim/ssridhar6/mlReconnection2025") CACHE_BASE = EXTRACT_DIR / "cache" -BEST_MODEL = Path.home() / "mlReconnection/testdir_2026-04-01-19-57-11/checkpoints/best_model.pt" +BEST_MODEL = Path.home() / "mlReconnection/testdir_2026-04-02-13-23-05/checkpoints/best_model.pt" OUTPUT_DIR = Path.home() / "mlReconnection/transfer_eval_results" DATASETS = {