-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patheval.py
More file actions
36 lines (30 loc) · 1.11 KB
/
eval.py
File metadata and controls
36 lines (30 loc) · 1.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import os
import argparse
from dataloader import valid_dataloader, valid_dataset
from config import cfg
from model import SimpleConv, DenseConv
from train import test
if __name__ == "__main__":
parser = argparse.ArgumentParser("locate weight")
parser.add_argument("--weight",
type=str,
default="./weights/model1/dense121_2019_12_7_10.pth")
args = parser.parse_args()
model = DenseConv(cfg.NUM_CLASSES)
if args.weight is not "":
model.load_state_dict(torch.load(args.weight))
if torch.cuda.is_available():
model = model.cuda()
# model.eval()
# sum_correct = 0
# for image, label in valid_dataloader:
# if torch.cuda.is_available():
# image, label = image.cuda(), label.cuda()
# out = model(image)
# prediction = torch.max(out, 1)[1]
# correct = (prediction == label).sum()
# sum_correct += correct
# test_acc = sum_correct.float() / len(valid_dataset)
test_acc = test(model)
print("Accuracy of %s is %.3f" % (os.path.basename(args.weight), test_acc))