-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
415 lines (323 loc) · 14.8 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
import numpy as np
from flask import Flask, request, render_template, session, send_from_directory, Response, url_for, send_file
import pickle
import base64
from PIL import Image, ImageEnhance, ImageOps
import io
from flask_uploads import IMAGES, UploadSet, configure_uploads
import torch
import cv2
from flask_wtf import FlaskForm
from flask_wtf.file import FileField, FileRequired, FileAllowed
from wtforms import SubmitField
from werkzeug.utils import secure_filename
from segment.segmentAnything import better_cropped_mask, cropped_objects
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
# from lama_inpaint import inpaint_img_with_lama, build_lama_model, inpaint_img_with_builded_lama
import os
import replicate
from flask import Flask, request, jsonify
import openai
import requests
os.environ['REPLICATE_API_TOKEN'] = 'r8_U7fxcLt6Vd17mFdkLX4SAnvVimxb8Cn2drpdj'#'r8_W7GHOzDckeNyNAbaCUf8ts80TY71ve31LRUT6'
app = Flask(__name__, static_url_path='/static', static_folder='static')
app.secret_key = 'your_secret_key' # Set a secret key for session encryption
app.config['UPLOADED_PHOTO'] = 'uploads'
app.config['UPLOADED_PHOTOS_DEST'] = './uploads'
SEGMENTS_PATH = './segmented_images/'
photos = UploadSet('photos', IMAGES)
configure_uploads(app, photos)
# Load models when initializing the app
model = {}
# build the sam model
model_type="vit_h"
sam_ckpt="./pretrained_models/sam_vit_h_4b8939.pth"
model_sam = sam_model_registry[model_type](checkpoint=sam_ckpt)
print("sam_ckpt is loaded")
device = "cuda" if torch.cuda.is_available() else "cpu"
model_sam.to(device=device)
model['sam'] = SamPredictor(model_sam)
print("model_sam is loaded")
mask_generator = SamAutomaticMaskGenerator(model_sam)
class UploadForm(FlaskForm):
photo = FileField(
validators= [FileAllowed(photos, 'Only images are allowed'),
FileRequired('File field should not be empty')])
submit = SubmitField('Upload')
@app.route('/uploads/<filename>')
def get_file(filename):
return send_from_directory(app.config['UPLOADED_PHOTO'], filename)
@app.route('/backend-image-endpoint/<image_name>')
def serve_image(image_name):
image_path = f'./segmented_images/{image_name}.png'
return send_file(image_path, mimetype='image/png')
@app.route('/', methods=['GET', 'POST'])
def upload_image():
form = UploadForm()
file_url = None # Initialize file_url to None
if form.validate_on_submit():
filename = photos.save(form.photo.data)
file_url = url_for('get_file', filename = filename)
session['file_url'] = file_url # Store the file_url in the session
else:
session['file_url'] = None
return render_template('index.html', form=form, file_url=session.get('file_url'))
@app.route('/')
def home():
return render_template('index.html')
@app.route('/next-page')
def next_page():
file_url = session.get('file_url') # Retrieve the file_url from the session
return render_template('next-page.html', file_url=file_url)
@app.route('/third-page')
def third_page():
file_url = session.get('file_url') # Retrieve the file_url from the session
return render_template('third-page.html', file_url=file_url)
# @app.route('/forth-page')
# def forth_page():
# file_url = session.get('file_url') # Retrieve the file_url from the session
# return render_template('forth-page.html', file_url=file_url)
def apply_style_changes(image, grayscale=False, saturation=None, brightness=None, hue_rotate=None):
# Convert the image to grayscale if requested
if grayscale:
image = ImageOps.grayscale(image)
# Adjust image saturation if a value is provided
if saturation is not None:
saturation = (float(saturation) - 100.0) / 100.0 # Convert to float and scale to the range [-1, 1]
enhancer = ImageEnhance.Color(image)
image = enhancer.enhance(1.0 + saturation)
# Adjust image brightness if a value is provided
if brightness is not None:
brightness = (float(brightness) - 100.0) / 100.0 # Convert to float and scale to the range [-1, 1]
enhancer = ImageEnhance.Brightness(image)
image = enhancer.enhance(1.0 + brightness)
# Adjust image hue if a value is provided
if hue_rotate is not None:
hue_rotate = float(hue_rotate) * 360.0 / 300.0 # Convert to float and scale to the range [0, 360]
image = image.convert('RGB')
image = image.convert('HSV')
h, s, v = image.split()
np_h = np.array(h)
np_h = (np_h + hue_rotate) % 256
h = Image.fromarray(np_h.astype('uint8'), mode='L')
image = Image.merge('HSV', (h, s, v))
image = image.convert('RGB')
return image
@app.route('/save-images', methods=['POST'])
def save_images():
vector_data = request.json.get('vector') # Retrieve the vector data from the request
print("Received vector data:")
print(vector_data)
# Check if vector_data is not empty
if vector_data:
# Create the directory if it doesn't exist
if not os.path.exists('./edited_images'):
os.makedirs('./edited_images')
# Check if the vector_data contains at least one image object
if isinstance(vector_data, list) and len(vector_data) > 0:
# Counter to keep track of the image index
image_index = 0
# Iterate over each image object in the vector data
for image_object in vector_data:
if isinstance(image_object, dict):
# Retrieve the image data and style change parameters from the image object
image_data = image_object.get('data')
grayscale = image_object.get('grayscale')
saturation = image_object.get('saturation')
brightness = image_object.get('brightness')
hue_rotate = image_object.get('hueRotate')
# Check if all required parameters are present
if image_data is not None and grayscale is not None and saturation is not None and brightness is not None and hue_rotate is not None:
# Remove the prefix "data:image/png;base64," from the image data
image_data = image_data.replace('data:image/png;base64,', '')
# Decode the base64 image data and convert it to bytes
image_bytes = base64.b64decode(image_data)
# Create a PIL image object from the image data
image = Image.open(io.BytesIO(image_bytes))
# Apply style changes
image = apply_style_changes(image, grayscale=grayscale, saturation=saturation, brightness=brightness, hue_rotate=hue_rotate)
# Save the image to the server with the correct index
image_name = f'image_{image_index}.png'
image_path = os.path.join('./edited_images', image_name)
image.save(image_path)
# Increment the image index
image_index += 1
else:
print("Missing parameters for image object:")
else:
print("Invalid image object:")
else:
print("Invalid vector data: vector_data should be a non-empty list")
return jsonify({'success': False, 'message': 'Invalid vector data'})
return jsonify({'success': True})
else:
return jsonify({'success': False})
@app.route('/forth-page')
def forth_page():
output_image_url = session.get('output_image_url') # Retrieve the output image URL from the session
return render_template('forth-page.html', image_url=output_image_url)
@app.route('/adjust-contrast-server', methods=['POST'])
def adjust_contrast_server():
if 'imageData' in request.files:
image_file = request.files['imageData']
contrast = float(request.form.get('contrast', 1.0)) # Get the contrast value from the request
try:
# Read the image file and decode it
decoded_data = image_file.read()
# Open the image using PIL
img = Image.open(io.BytesIO(decoded_data))
# The image data is valid, proceed with processing
resized_image = img.resize((500, 500)) # Adjust the size as needed
# Adjust the contrast of the image
enhancer = ImageEnhance.Contrast(resized_image)
adjusted_image = enhancer.enhance(contrast)
# Convert the adjusted image back to base64 format
buffered = io.BytesIO()
adjusted_image.save(buffered, format='PNG')
encoded_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
return encoded_image
except Exception as e:
print("Error processing image:", e)
return "Error processing image"
else:
return "No image data received"
# @app.route('/upload-image', methods=['POST'])
# def upload_image():
# image_file = request.files['image']
# image_data = image_file.read()
# session['uploaded_image'] = base64.b64encode(image_data).decode('utf-8') # Store the base64-encoded image data in the session
# return base64.b64encode(image_data).decode('utf-8')
@app.route('/predict', methods=['GET'])
def predict():
# get uploaded image
file_url = session.get('file_url')
file_url_complete = '.' + file_url
image_file = cv2.imread(file_url_complete)
# image_file = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if image_file is None:
return "No image data found."
print("finally image_file is not None")
# resize the image
# scale_percent = 30 # percent of original size
# width = int(image_file.shape[1] * scale_percent / 100)
# height = int(image_file.shape[0] * scale_percent / 100)
# dim = (width, height)
# # resize image
# image_file = cv2.resize(image_file, dim, interpolation = cv2.INTER_AREA)
# print(image_file.shape)
masks = mask_generator.generate(image_file)
print("masks are generated")
# numMasks = len(masks) if len(masks) < 10 else 10
# masks = masks[:numMasks]
destination_path = './segmented_images/'
segment_map = np.zeros((image_file.shape[0], image_file.shape[1]))
segment_index = 0
# for i in range(masks):
for i in range(len(masks)):
image_file = cv2.imread(file_url_complete)
# image_file = cv2.resize(image_file, dim, interpolation = cv2.INTER_AREA)
segmentname = "segment" + str(segment_index)
# s = better_cropped_mask(masks, i, image_file)
s = cropped_objects(masks, i, image_file, segment_map)
if s is not False:
img, tmask = s
cv2.imwrite( destination_path + segmentname + ".png", img)
cv2.imwrite( destination_path + segmentname + "_tmask.png", tmask)
segment_index += 1
return "segments are generated"
@app.route('/dalle_edit1', methods=['GET'])
def edit_dalle1():
# = "a golden retriever on the sofa"
prompt = "black rug"
img_path = "./uploads/Dog.png"
mask_path = "./segmented_images/segment2_tmask.png"
openai.api_key = ""
response = openai.Image.create_edit(
image= open(img_path, "rb"),
mask= open(mask_path, "rb"),
prompt= prompt,
n=1,
size="512x512"
)
image_url = response['data'][0]['url']
print(image_url)
response = requests.get(image_url)
response.raise_for_status()
with open("./dalle_images/dalle_segment2.png", "wb") as file:
file.write(response.content)
print("Image downloaded successfully.")
dalle_img = cv2.imread("./dalle_images/dalle_segment2.png")
msk = cv2.imread("./segmented_images/segment2.png")
dalle_img = cv2.resize(dalle_img, (msk.shape[1], msk.shape[0]))
for x in range(msk.shape[0]):
for y in range(msk.shape[1]):
if(msk[x][y] == 0).all():
dalle_img[x][y][:] = 0
cv2.imwrite( "./edited_images/image_2.png", dalle_img)
print("segment cropped successfully.")
return "dalle edit is done"
@app.route('/dalle_edit2', methods=['GET'])
def edit_dalle2():
prompt = "a golden retriever"
img_path = "./uploads/Dog.png"
mask_path = "./segmented_images/segment3_tmask.png"
openai.api_key = ""
response = openai.Image.create_edit(
image= open(img_path, "rb"),
mask= open(mask_path, "rb"),
prompt= prompt,
n=1,
size="512x512"
)
image_url = response['data'][0]['url']
print(image_url)
response = requests.get(image_url)
response.raise_for_status()
with open("./dalle_images/dalle_segment3.png", "wb") as file:
file.write(response.content)
print("Image downloaded successfully.")
dalle_img = cv2.imread("./dalle_images/dalle_segment3.png")
msk = cv2.imread("./segmented_images/segment3.png")
dalle_img = cv2.resize(dalle_img, (msk.shape[1], msk.shape[0]))
for x in range(msk.shape[0]):
for y in range(msk.shape[1]):
if(msk[x][y] == 0).all():
dalle_img[x][y][:] = 0
cv2.imwrite( "./edited_images/image_3.png", dalle_img)
print("segment cropped successfully.")
return "dalle edit is done"
@app.route('/run-diffusion-model', methods=['POST'])
def run_diffusion_model():
prompt = request.get_json().get('prompt')
if not prompt:
return jsonify({'error': 'Missing prompt'}), 400
output = replicate.run(
"stability-ai/stable-diffusion:db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
input={"prompt": prompt}
)
session['output_image_url'] = output[0]
print(output[0])
return jsonify({'success': True})
@app.route('/combine-images', methods=['POST'])
def combine_images():
#replacing the black pixels with the corresponding pixels from the other image.
print("Combine images-----------------------------")
print("-------------------------------------------")
path = './edited_images'
images = [Image.open(os.path.join(path, i)).convert("RGBA") for i in os.listdir(path) if i.endswith(".png")]
# Convert the first image to an np.array
combined = np.array(images[0])
for img in images[1:]:
img_np = np.array(img)
# Create a mask of where the image is black
mask = np.all(img_np == [0, 0, 0, 255], axis=-1)
# Where the mask is True, replace with pixels from the combined image
img_np[mask] = combined[mask]
combined = img_np
print("shape of combined image in array format: ", combined.shape)
# Convert the combined np.array back to an image
combined_image = Image.fromarray(combined)
combined_image.save('static/combined.png')
return url_for('static', filename='static/combined.png')
if __name__ == '__main__':
app.run(debug=True) # Enable debug mode