-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprobe.py
More file actions
100 lines (86 loc) · 4.04 KB
/
probe.py
File metadata and controls
100 lines (86 loc) · 4.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import time
import numpy as np
import torch
from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model import LassoCV
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader
import warnings
from model import ResNet18ImageEncoder
from dataset_probe import CityLensDownstream, eval_transform
warnings.filterwarnings("ignore", category=ConvergenceWarning)
@torch.no_grad()
def extract_features(encoder, ds, batch_size=128, workers=4, device="cuda"):
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=workers, pin_memory=True)
feats, ys = [], []
encoder.eval()
for imgs, y in dl:
imgs = imgs.to(device, non_blocking=True)
feats.append(encoder.features(imgs).cpu().numpy())
ys.append(y.numpy())
return np.concatenate(feats), np.concatenate(ys)
def run_task(encoder, task, args, device):
ds = CityLensDownstream(task, transform=eval_transform(), split="all", target="reference")
t0 = time.time()
X, y = extract_features(encoder, ds, args.batch_size, args.workers, device)
t_feat = time.time() - t0
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2, random_state=42)
scaler = StandardScaler().fit(X_tr)
X_tr_s = scaler.transform(X_tr)
X_te_s = scaler.transform(X_te)
reg = LassoCV(cv=5, random_state=42, max_iter=10000, precompute=False).fit(X_tr_s, y_tr)
pred = reg.predict(X_te_s)
mse = mean_squared_error(y_te, pred)
r2 = r2_score(y_te, pred)
yrange = y_te.max() - y_te.min()
nrmse = np.sqrt(mse) / yrange if yrange > 0 else float("nan")
print(f" n={len(ds):5d} X={X.shape[1]:4d} extract={t_feat:.1f}s alpha={reg.alpha_:.4g} MSE={mse:.4g} R2={r2:.4f} nRMSE={nrmse:.4f}")
return r2, mse, nrmse, reg.alpha_
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", type=str, default="imagenet")
ap.add_argument("--task", type=str, required=True, choices=CityLensDownstream.TASKS + ["all"])
ap.add_argument("--batch-size", type=int, default=128)
ap.add_argument("--workers", type=int, default=4)
ap.add_argument("--out", type=str, default="final_result.txt")
ap.add_argument("--append", action="store_true")
ap.add_argument("--exp-name", type=str, default="")
ap.add_argument("--exp-desc", type=str, default="")
args = ap.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
encoder = ResNet18ImageEncoder(pretrained=(args.ckpt == "imagenet")).to(device)
if args.ckpt != "imagenet":
state = torch.load(args.ckpt, map_location=device)
image_state = state.get("image_state")
if image_state is None:
image_state = {k.replace("cap.image.", ""): v for k, v in state["model_state"].items() if k.startswith("cap.image.")}
missing, unexpected = encoder.load_state_dict(image_state, strict=False)
print(f"loaded ResNet18 encoder missing={len(missing)} unexpected={len(unexpected)}")
else:
print("using ImageNet ResNet18 baseline")
tasks = CityLensDownstream.TASKS if args.task == "all" else [args.task]
results = []
for task in tasks:
print(f"\n=== {task} ===")
results.append((task, *run_task(encoder, task, args, device)))
lines = []
if args.exp_name:
lines.append(f"Experiment: {args.exp_name}")
if args.exp_desc:
lines.append(f"Ablation: {args.exp_desc}")
lines.append(f"Checkpoint: {args.ckpt}")
lines.append("=== summary (R^2) ===")
lines.append(f"{'task':14s} {'R2':>8s} {'nRMSE':>8s} {'alpha':>10s}")
for task, r2, _mse, nrmse, alpha in results:
lines.append(f"{task:14s} {r2:>8.4f} {nrmse:>8.4f} {alpha:>10.4g}")
if len(results) > 1:
lines.append(f"mean R2 = {np.mean([r[1] for r in results]):.4f}")
text = "\n".join(lines) + "\n"
print(text)
with open(args.out, "a" if args.append else "w") as f:
f.write(text + "\n")
if __name__ == "__main__":
main()