-
Notifications
You must be signed in to change notification settings - Fork 104
/
classify_alexnet.py
76 lines (55 loc) · 1.99 KB
/
classify_alexnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#coding=utf-8
import numpy as np
import pickle
import os
import time
import sys
import shutil
import skimage
caffe_root = '/home/cscl/caffe-master/'
sys.path.insert(0, caffe_root + 'python')
import caffe
net_file = 'alexnet_deploy.prototxt'
caffe_model = 'models/alexnet_hccr.caffemodel'
mean_file = 'meanfiles/CASIA1.0_1.1_1.2_mean_108.npy'
unicode_index = np.loadtxt('util/unicode_index.txt', delimiter = ',',dtype = np.int) #7534
net = caffe.Net(net_file,caffe_model,caffe.TEST)
def get_crop_image(imagepath, img_name):
img=skimage.io.imread(imagepath + img_name,as_grey=True)
black_index = np.where(img < 255 )
min_x = min(black_index[0])
max_x = max(black_index[0])
min_y = min(black_index[1])
max_y = max(black_index[1])
#print(min_x,max_x,min_y,max_y)
image = caffe.io.load_image(imagepath+"//"+img_name)
return image[min_x:max_x, min_y:max_y,:]
def evaluate(imagepath, top_k):
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2,0,1))
transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))
transformer.set_raw_scale('data', 255)
rightcount=0
allcount=0
allimage=os.listdir(imagepath)
for img_name in allimage:
allcount = allcount + 1
label_truth = img_name.split('.')[0]
print "----------------------"
image = get_crop_image(imagepath,img_name)
net.blobs['data'].data[...] = transformer.preprocess('data',image)
out = net.forward()
label_index = net.blobs['prob'].data[0].flatten().argsort()[-1:-top_k-1:-1]
labels = unicode_index[label_index.astype(np.int)] # output unicode
#print 'Index: ',label_index
print 'Top-' + str(top_k) + ' Label: ',labels
print 'label_truth: ',label_truth
for i in range(0,top_k):
if labels[i] == int(label_truth):
rightcount=rightcount+1
break
print(rightcount,allcount,(float)(rightcount)/(float)(allcount))
if __name__=='__main__':
imagepath='images/'
top_k = 1;
evaluate(imagepath,top_k)