-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDecisionTree.py
63 lines (45 loc) · 1.85 KB
/
DecisionTree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import pandas as pd
import numpy as np
from sklearn import metrics
from sklearn.preprocessing import LabelEncoder
import random
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
fullData = pd.read_csv('/mnt/RTC/PredictionData/data003.csv')
#ID_col = ['StudentID']
#target_col = ['Performance Type']
#cat_cols = ['Gender','Program','Cohort','Student Type','Section','High School']
print fullData.isnull().any()
ID_col = ['StudentID']
target_col = ['High School']
cat_cols = []
num_cols = ['Class12PercentGrade','Sem Avg']
other_col=['Type'] #Test and Train Data set identifier
analysisData = fullData[num_cols+cat_cols+target_col]
#Impute numerical missing values with mean
analysisData[num_cols] = analysisData[num_cols].fillna(analysisData[num_cols].mean(),inplace=True)
#Impute categorical missing values with -9999
analysisData[cat_cols] = analysisData[cat_cols].fillna(value = -9999)
#create label encoders for categorical features
for var in cat_cols:
number = LabelEncoder()
analysisData[var] = number.fit_transform(analysisData[var].astype('str'))
#Target variable is also a categorical so convert it
#fullData["Rising"] = number.fit_transform(fullData["Rising"].astype('str'))
features=num_cols+cat_cols
X = analysisData[features]
Y = analysisData[target_col]
targets = pd.Series(analysisData['High School'].get_values())
from sklearn import tree
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, Y)
import graphviz
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=features,
class_names=targets,
filled=True, rounded=True,
special_characters=True)
graph = graphviz.Source(dot_data)
graph.render("/mnt/RTC/outfile")
print "done"