-
Notifications
You must be signed in to change notification settings - Fork 0
/
kNN.py
40 lines (30 loc) · 1018 Bytes
/
kNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt
def classify0(intX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(intX,(dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDisIndex = distances.argsort()
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDisIndex[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0)+1
sortedClassCount = sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0]
def createDataSet():
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = ['A','A','B','B']
return group,labels
def main():
dataSet ,labels = createDataSet()
print(classify0([3,3],dataSet,labels,3))
fig = plt.figure()
ax = fig.add_subplot(111)
ax.scatter(dataSet[:,0],dataSet[:,1])
plt.show()
if __name__ == '__main__':
main()