-
Notifications
You must be signed in to change notification settings - Fork 0
/
utility.py
130 lines (113 loc) · 4.2 KB
/
utility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#----------------------------------------------------------------------
# This file is a collection of utility functions that are used
# to load and display clustering results.
#----------------------------------------------------------------------
from sklearn_extra.cluster import KMedoids
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pandas.core.frame import DataFrame
import seaborn as sns
from time import time
from pyclustering.cluster.kmedoids import kmedoids
def centroid_to_index(centroids, df):
return [df[(df['x']==p[0]) & (df['y']==p[1])].index[0] for p in centroids]
def get_pyclustering_clustering(file_name, k, initial_centroids):
'''
Given a file name and the number of clusters, loads the
2-D data from the file and performs KMedoids clustering
using pyclustering.
'''
df = get_df_from_txt(file_name)
index_of_centroids = centroid_to_index(initial_centroids, df)
start = time()
_km = kmedoids(df.values, initial_index_medoids=centroid_to_index(initial_centroids, df))
_km.process()
end = time()
medoid_indexes = _km.get_medoids()
print(f"Time taken for clustering: {end-start} seconds")
df['c'] = _km.get_clusters()
return df, df.iloc[medoid_indexes].values
def get_df_from_txt(file_name):
'''
Use this function to load the 2-D data
generated by points_generator.py
'''
d = pd.read_csv(file_name, sep=' ', header=None)
d.columns = ['x', 'y']
return d
def get_df_from_generated_csv(file_name):
'''
Deprecated. I had written to load CSV files.
Currently not using this directly.
'''
d = pd.read_csv(file_name)
return d
def get_centroids_from_csv(file_name):
'''
Currently not using this directly.
'''
return pd.read_csv(file_name).values
def get_our_clustering(dirname):
'''
Given a directory path, load clustering results from
it. This is in line with the format that our C++ code
uses to store results after clustering.
'''
clusters = pd.read_csv(dirname + '/clusters.csv')
centroids = pd.read_csv(dirname + '/centroids.csv').values
return clusters, centroids
def get_sklearn_clustering(file_name, k, init='heuristic'):
'''
Given a file name and the number of clusters, loads the
2-D data from the file and performs KMedoids clustering
using scikit learn.
'''
df = get_df_from_txt(file_name)
start = time()
_km = KMedoids(n_clusters=k, init=init).fit(df[['x','y']])
end = time()
print(f"Time taken for clustering: {end-start} seconds")
df['c'] = _km.labels_
return df, _km.cluster_centers_
def _cluster_counting(df: DataFrame):
return len(set(df.c))
def show_cluster(file_name: str=None, img_title: str = "",
x_label: str = "", y_label: str = "", df=None):
'''Takes the name of a CSV file with points characterized\n
by x,y,c (i.e. their coordinates and respective clusters).\n
It displays all the points coloured according to their cluster.\n
Optionally you can display a title and axis labels\n
(which are empty by default)
Currently not using this.
'''
# After clustering
plt.figure()
if df is None:
df = pd.read_csv(file_name)
# call this func here to count the occurrences of the clusters
# in order for getting the correct number of clusters for scatterplot
k = _cluster_counting(df)
# passing as arguments the number of occurrences of the clusters
sns.scatterplot(x = df.x, y = df.y,
hue = df.c,
palette = sns.color_palette("hls", n_colors = k) )
plt.xlabel(x_label)
plt.ylabel(y_label)
plt.title(img_title)
plt.show()
def compare_cluster_plots(df1, df2, titles=[None, None]):
'''
Use this function to compare two clusterings, mostly comparing
our results with scikit-learns's.
'''
dfs = [df1, df2]
for i in range(2):
plt.subplot(1,2,i+1)
k = _cluster_counting(dfs[i])
sns.scatterplot(x = dfs[i].x, y = dfs[i].y,
hue = dfs[i].c,
palette = sns.color_palette("hls", n_colors = k) )
plt.title(titles[i])
plt.legend(bbox_to_anchor=(0.1,0.9))
plt.tight_layout()