Skip to content

Commit

Permalink
added stability support
Browse files Browse the repository at this point in the history
  • Loading branch information
surcyf123 committed Jan 6, 2024
1 parent 9f26b2a commit 37f446e
Show file tree
Hide file tree
Showing 11 changed files with 359 additions and 167 deletions.
64 changes: 37 additions & 27 deletions miner/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,28 @@
import copy
import json
import os
import io
import base64
import boto3
import pathlib
import threading
import time
import requests
import traceback
import requests
import anthropic
from abc import ABC, abstractmethod
from collections import deque
from functools import partial
from typing import Tuple
from stability_sdk import client

import bittensor as bt
import wandb
from config import check_config, get_config
from openai import AsyncOpenAI, OpenAI
from PIL import Image
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
from anthropic_bedrock import AsyncAnthropicBedrock, HUMAN_PROMPT, AI_PROMPT, AnthropicBedrock

import template
Expand All @@ -34,6 +41,12 @@
if not OpenAI.api_key:
raise ValueError("Please set the OPENAI_API_KEY environment variable.")

stability_api = client.StabilityInference(
key=os.environ['STABILITY_KEY'],
verbose=True,
engine="stable-diffusion-xl-1024-v1-0"
)

api_key = os.environ.get("ANTHROPIC_API_KEY")

bedrock_client = AsyncAnthropicBedrock(
Expand Down Expand Up @@ -392,14 +405,19 @@ async def images(self, synapse: ImageResponse) -> ImageResponse:
model = synapse.model
messages = synapse.messages
size = synapse.size
width, height = self.size.split('x')
width = int(width)
height = int(height)
width = synapse.width
height = synapse.height
quality = synapse.quality
style = synapse.style
seed = synapse.seed
steps = synapse.steps
image_revised_prompt = None
cfg_scale = synapse.cfg_scale
sampler = synapse.sampler
samples = synapse.samples
image_data = {}

bt.logging.debug(f"data = {provider, model, messages, size, width, height, quality, style, seed, steps, image_revised_prompt, cfg_scale, sampler, samples}")

if provider == "OpenAI":
meta = await client.images.generate(
Expand All @@ -411,44 +429,36 @@ async def images(self, synapse: ImageResponse) -> ImageResponse:
)
image_url = meta.data[0].url
image_revised_prompt = meta.data[0].revised_prompt
image_data["url"] = image_url
image_data["image_revised_prompt"] = image_revised_prompt
bt.logging.info(f"returning image response of {image_url}")

elif provider == "Stability":
bt.logging.debug(f"calling stability for {messages, seed, steps, cfg_scale, width, height, samples, sampler}")

meta = stability_api.generate(
prompt=messages,
seed=steps,
seed=seed,
steps=steps,
cfg_scale=8.0,
cfg_scale=cfg_scale,
width=width,
height=height,
samples=1,
sampler=generation.SAMPLER_K_DPMPP_2M,
samples=samples,
# sampler=sampler
)
# Process and upload the image
for artifact in meta.artifacts:
if artifact.finish_reason == generation.FILTER:
bt.logging.error("Safety filters activated, prompt could not be processed.")
elif artifact.type == generation.ARTIFACT_IMAGE:
img = Image.open(io.BytesIO(artifact.binary))
img_buffer = io.BytesIO()
img.save(img_buffer, format="PNG")
img_buffer.seek(0)
# Upload to file.io
response = requests.post('https://file.io', files={'file': img_buffer})
if response.status_code == 200:
image_url = response.json()['link']
else:
raise Exception("Failed to upload the image.")
b64s = []
for image in meta:
for artifact in image.artifacts:
b64s.append(base64.b64encode(artifact.binary).decode())

image_data["b64s"] = b64s
bt.logging.info(f"returning image response to {messages}")

else:
bt.logging.error(f"Unknown provider: {provider}")

image_data = {
"url": image_url,
"revised_prompt": image_revised_prompt,
}

synapse.completion = image_data
bt.logging.info(f"returning image response of {synapse.completion}")
return synapse

except Exception as exc:
Expand Down
2 changes: 1 addition & 1 deletion state.json

Large diffs are not rendered by default.

60 changes: 46 additions & 14 deletions template/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class IsAlive( bt.Synapse ):

class ImageResponse(bt.Synapse):
""" A class to represent the response for an image-related request. """
# https://platform.stability.ai/docs/api-reference#tag/v1generation/operation/textToImage

completion: Optional[Dict] = pydantic.Field(
None,
Expand All @@ -41,37 +42,68 @@ class ImageResponse(bt.Synapse):
)

seed: int = pydantic.Field(
...,
default=1234,
title="Seed",
description="The seed that which to generate the image with"
)

samples: int = pydantic.Field(
default=1,
title="Samples",
description="The number of samples to generate"
)

cfg_scale: float = pydantic.Field(
default=8.0,
title="cfg_scale",
description="The cfg_scale to use for image generation"
)

# (Available Samplers: ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_dpmpp_2s_ancestral, k_lms, k_dpmpp_2m, k_dpmpp_sde)
sampler: str = pydantic.Field(
default="",
title="Sampler",
description="The sampler to use for image generation"
)

steps: int = pydantic.Field(
...,
default=30,
title="Seed",
description="The steps to take in generating the image"
)

model: str = pydantic.Field(
...,
default="dall-e-2",
title="Model",
description="The model used for generating the image."
)

style: str = pydantic.Field(
...,
default="vivid",
title="Style",
description="The style of the image."
)

size: str = pydantic.Field(
...,
title="Size",
default="1024x1024",
title="The size of the image, used for Openai generation. Options are 1024x1024, 1792x1024, 1024x1792 for dalle3",
description="The size of the image."
)

height: int = pydantic.Field(
default=1024,
title="Height used for non Openai images",
description="height"
)

width: int = pydantic.Field(
default=1024,
title="Width used for non Openai images",
description="width"
)

quality: str = pydantic.Field(
...,
default="standard",
title="Quality",
description="The quality of the image."
)
Expand All @@ -96,7 +128,7 @@ class Embeddings( bt.Synapse):
)

model: str = pydantic.Field(
"text-embedding-ada-002",
default="text-embedding-ada-002",
title="Model",
description="The model used for generating embeddings."
)
Expand Down Expand Up @@ -127,7 +159,7 @@ class StreamPrompting(bt.StreamingSynapse):
)

seed: int = pydantic.Field(
"",
default="1234",
title="Seed",
description="Seed for text generation. This attribute is immutable and cannot be updated.",
)
Expand All @@ -140,28 +172,28 @@ class StreamPrompting(bt.StreamingSynapse):
)

max_tokens: int = pydantic.Field(
2048,
default=2048,
title="Max Tokens",
description="Max tokens for text generation. "
"This attribute is immutable and cannot be updated.",
)

top_p: float = pydantic.Field(
0.001,
defalt=0.001,
title="Max Tokens",
description="Max tokens for text generation. "
"This attribute is immutable and cannot be updated.",
)

top_k: int = pydantic.Field(
1,
default=1,
title="Max Tokens",
description="Max tokens for text generation. "
"This attribute is immutable and cannot be updated.",
)

completion: str = pydantic.Field(
"",
None,
title="Completion",
description="Completion status of the current StreamPrompting object. "
"This attribute is mutable and can be updated.",
Expand All @@ -174,7 +206,7 @@ class StreamPrompting(bt.StreamingSynapse):
)

model: str = pydantic.Field(
"",
default="gpt-3.5-turbo",
title="model",
description="The model to use when calling provider for your response.",
)
Expand Down
18 changes: 16 additions & 2 deletions template/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
import numpy as np
from numpy.linalg import norm
import bittensor as bt
from template import utils
from PIL import Image
from scipy.spatial.distance import cosine
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import CLIPProcessor, CLIPModel


# ==== TEXT ====

def calculate_text_similarity(text1: str, text2: str):
Expand Down Expand Up @@ -151,7 +151,7 @@ def calculate_image_similarity(image, description, max_length: int = 77):
# Calculate cosine similarity
return torch.cosine_similarity(image_embedding, text_embedding, dim=1).item()

async def image_score(uid, url, desired_size, description, weight, similarity_threshold=0.23) -> float:
async def dalle_score(uid, url, desired_size, description, weight, similarity_threshold=0.23) -> float:
"""Calculate the image score based on similarity and size asynchronously."""

if not re.match(url_regex, url):
Expand Down Expand Up @@ -188,6 +188,20 @@ async def image_score(uid, url, desired_size, description, weight, similarity_th
return 0



# IMAGES ---- DETERMINISTIC

async def deterministic_score(uid: int, syn, weight: float):
vali_b64s = await utils.call_stability(syn.messages, syn.seed, syn.steps, syn.cfg_scale, syn.width, syn.height, syn.samples, syn.sampler)

for miner_b64, vali_b64 in zip(syn.completion["b64s"], vali_b64s):
if miner_b64[:50] != vali_b64[:50]:
return 0

return weight



# ==== Embeddings =====

async def embeddings_score(openai_answer: list, response: list, weight: float, threshold: float = .95) -> float:
Expand Down
Loading

0 comments on commit 37f446e

Please sign in to comment.