-
Notifications
You must be signed in to change notification settings - Fork 0
/
DNABert_visualize.py
183 lines (135 loc) · 5.58 KB
/
DNABert_visualize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import argparse
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
import pandas as pd
import shutil
import os
# Initialize the argument parser
parser = argparse.ArgumentParser(description='Process input file')
parser.add_argument('-n', '--data_name', type=str, required=True, help='Name of cage directory that was used initially')
parser.add_argument('-user', '--user_name', type=str, required=True, help='User name on bridges2')
# Parse the arguments
args = parser.parse_args()
DEEPTSS_DIR="/ocean/projects/bio230007p/"+args.user_name+"/DeepTSS/data/"+args.data_name+"/OUT"
# High confidence
dt_path_high = DEEPTSS_DIR+"/high_prob/dev.tsv"
# Low confidence
dt_path_low = DEEPTSS_DIR+"/low_prob/dev.tsv"
with open(dt_path_high, 'r') as file:
# Use the NumPy genfromtxt function to load the TSV file
data = np.genfromtxt(file, delimiter='\t')
# Extract the second column -- labels
deeptss_pred_high = data[1:, 1].astype(int)
with open(dt_path_low, 'r') as file:
# Use the NumPy genfromtxt function to load the TSV file
data = np.genfromtxt(file, delimiter='\t')
# Extract the second column -- labels
deeptss_pred_low = data[1:, 1].astype(int)
DNABERT_PATH="/ocean/projects/bio230007p/"+args.user_name+"/DNABert"
OUT_PATH="/ocean/projects/bio230007p/"+args.user_name+"/DNABert/OUT/"+args.data_name
if not os.path.exists(OUT_PATH):
os.makedirs(OUT_PATH)
os.makedirs(OUT_PATH+"/high")
os.makedirs(OUT_PATH+"/low")
high_path = OUT_PATH+"/high"
low_path = OUT_PATH+"/low"
high_tmp_path = DNABERT_PATH+"/OUT/high"
# get a list of all the files in the source folder
high_files = os.listdir(high_tmp_path)
# iterate over each file and copy it to the destination folder
for file_name in high_files:
src_file = os.path.join(high_tmp_path, file_name)
shutil.copy(src_file, high_path)
low_tmp_path = DNABERT_PATH+"/OUT/low"
low_files = os.listdir(low_tmp_path)
for file_name in low_files:
src_file = os.path.join(low_tmp_path, file_name)
shutil.copy(src_file, low_path)
# Load the prediction results and attention weights
pred_results_high = np.load(high_path+'/preds.npy')
pred_results_low = np.load(low_path+'/preds.npy')
atten_high = np.load(high_path+'/atten.npy')
atten_low = np.load(low_path+'/atten.npy')
print(deeptss_pred_high.shape, pred_results_high.shape)
print(deeptss_pred_low.shape, pred_results_low.shape)
print(atten_high.shape)
print(deeptss_pred_high[0])
print(pred_results_high[0])
# Visualize the prediction results
# Create confusion matrix
confusion_matrix = np.zeros((2,2))
for i in range(len(deeptss_pred_high)):
confusion_matrix[deeptss_pred_high[i], pred_results_high[i]] += 1
# Plot confusion matrix using Seaborn
plt.figure()
sns.set()
sns.heatmap(confusion_matrix, annot=True, cmap='Blues')
plt.xlabel('DNABert Predictions')
plt.ylabel('DeepTSS Predictions')
plt.savefig(high_path+'/confusion.png')
confusion_matrix = np.zeros((2,2))
for i in range(len(deeptss_pred_low)):
confusion_matrix[deeptss_pred_low[i], pred_results_low[i]] += 1
# Plot confusion matrix using Seaborn
plt.figure()
sns.set()
sns.heatmap(confusion_matrix, annot=True, cmap='Blues')
plt.xlabel('DNABert Predictions')
plt.ylabel('DeepTSS Predictions')
plt.savefig(low_path+'/confusion.png')
# Calculate metrics using scikit-learn
accuracy = accuracy_score(deeptss_pred_high, pred_results_high)
recall = recall_score(deeptss_pred_high, pred_results_high)
precision = precision_score(deeptss_pred_high, pred_results_high)
f1 = f1_score(deeptss_pred_high, pred_results_high)
# Create a dictionary to store the metrics
metrics_dict = {'Accuracy': accuracy, 'Recall': recall, 'Precision': precision, 'F1 Score': f1}
# Convert the dictionary to a pandas DataFrame
metrics_df = pd.DataFrame(metrics_dict, index=[0])
# Melt the DataFrame to long format
metrics_long = pd.melt(metrics_df, var_name='Metric', value_name='Score')
# Plot the bar plot using Seaborn
plt.figure()
sns.set_style('whitegrid')
sns.barplot(x='Metric', y='Score', data=metrics_long, palette='Blues_d')
plt.title('Metrics Summary')
plt.ylim(0, 1)
plt.savefig(high_path+'/bar.png')
# Calculate metrics using scikit-learn
accuracy = accuracy_score(deeptss_pred_low, pred_results_low)
recall = recall_score(deeptss_pred_low, pred_results_low)
precision = precision_score(deeptss_pred_low, pred_results_low)
f1 = f1_score(deeptss_pred_low, pred_results_low)
# Create a dictionary to store the metrics
metrics_dict = {'Accuracy': accuracy, 'Recall': recall, 'Precision': precision, 'F1 Score': f1}
# Convert the dictionary to a pandas DataFrame
metrics_df = pd.DataFrame(metrics_dict, index=[0])
# Melt the DataFrame to long format
metrics_long = pd.melt(metrics_df, var_name='Metric', value_name='Score')
# Plot the bar plot using Seaborn
plt.figure()
sns.set_style('whitegrid')
sns.barplot(x='Metric', y='Score', data=metrics_long, palette='Blues_d')
plt.title('Metrics Summary')
plt.ylim(0, 1)
plt.savefig(low_path+'/bar.png')
ave = np.sum(atten_high)/atten_high.shape[1]
# Visualize the attention weights
plt.figure()
sns.set()
ax = sns.heatmap(atten_high, cmap='YlGnBu', vmin=0)
plt.xlabel('Input sequence position')
plt.ylabel('Output sequence position')
plt.title('Attention weights')
plt.savefig(high_path+'/attention_weights.png')
ave = np.sum(atten_low)/atten_low.shape[1]
# Visualize the attention weights
plt.figure()
sns.set()
ax = sns.heatmap(atten_low, cmap='YlGnBu', vmin=0)
plt.xlabel('Input sequence position')
plt.ylabel('Output sequence position')
plt.title('Attention weights')
plt.savefig(low_path+'/attention_weights.png')