diff --git a/Cortex.t.egg-info/PKG-INFO b/Cortex.t.egg-info/PKG-INFO
new file mode 100644
index 00000000..bd0205ce
--- /dev/null
+++ b/Cortex.t.egg-info/PKG-INFO
@@ -0,0 +1,136 @@
+Metadata-Version: 2.1
+Name: Cortex.t
+Version: 3.1.6
+Summary: Decentralized APIs for synthetic data generation
+Home-page: https://github.com/corcel-api/cortex.t
+Author: Fish
+License: MIT
+Classifier: Development Status :: 3 - Alpha
+Classifier: Intended Audience :: Developers
+Classifier: License :: OSI Approved :: MIT License
+Classifier: Programming Language :: Python :: 3
+Classifier: Topic :: Software Development
+Requires-Python: >=3.8
+Description-Content-Type: text/markdown
+License-File: LICENSE
+Requires-Dist: aiohttp==3.*
+Requires-Dist: bittensor==6.*
+Requires-Dist: datasets==2.*
+Requires-Dist: envparse==0.2.0
+Requires-Dist: openai==1.*,>=1.3.2
+Requires-Dist: Pillow==10.*
+Requires-Dist: requests==2.*
+Requires-Dist: scikit-learn==1.*
+Requires-Dist: torch==2.*
+Requires-Dist: transformers==4.*
+Requires-Dist: wandb
+Requires-Dist: anthropic
+Requires-Dist: stability-sdk
+Requires-Dist: boto3
+Requires-Dist: anthropic_bedrock
+Requires-Dist: pyOpenSSL
+
+
+
+# **Cortex.t Subnet**
+[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
+---
+
+---
+- [Introduction](#introduction)
+- [Setup](#setup)
+- [Mining](#mining)
+- [Validating](#validating)
+- [License](#license)
+
+
+## Introduction
+
+**IMPORTANT**: If you are new to Bittensor, please checkout the [Bittensor Website](https://bittensor.com/) before proceeding to the [Setup](#setup) section.
+
+Introducing Bittensor Subnet 18 (Cortex.t): A Pioneering Platform for AI Development and Synthetic Data Generation.
+
+Cortex.t stands at the forefront of artificial intelligence, offering a dual-purpose solution that caters to the needs of app developers and innovators in the AI space. This platform is meticulously designed to deliver reliable, high-quality text and image responses through API usage, utilising the decentralised Bittensor network. It serves as a cornerstone for creating a fair, transparent, and manipulation-free environment for the incentivised production of intelligence (mining) and generation and fulfilment of diverse user prompts.
+
+Our initiative is a leap forward in redefining the reward system for text and image prompting with a commitment to providing stability and reassurance to developers. By focusing on the value delivered to clients, we alleviate the concerns of data inconsistencies that often plague app development. The quality of Cortex.t is seamlessly integrated within the Bittensor network, allowing developers to harness the power of multiple subnets and modalities by building directly onto an existing validator, or through an API key from [Corcel](https://corcel.io).
+
+Cortex.t is also a transformative platform leveraging advanced AI models to generate synthetic prompt-response pairs. This novel method yields a comprehensive dataset of interactions, archived in wandb [wandb.ai/cortex-t/synthetic-QA](https://wandb.ai/cortex-t/synthetic-QA). The process involves recycling model outputs back into the system, using a prompt evolution and data augmentation strategy similar to Microsoft's approach in developing WizardLM. This enables the distillation of sophisticated AI models into smaller, yet efficient counterparts, mirroring the performance of their larger predecessors. Ultimately, Cortex.t democratizes access to high-end AI technology, encouraging innovation and customization.
+
+By leveraging synthetic data, Cortex.t circumvents the traditional challenges of data collection and curation, accelerating the development of AI models that are both robust and adaptable. This platform is your gateway to AI mastery, offering the unique opportunity to train your models with data that reflects the depth and versatility of the parent model. With SynthPairPro, you're not just collecting data; you're capturing intelligence, providing a path to creating AI models that mirror the advanced understanding and response capabilities of their predecessors.
+
+Join us at Cortex.t, your bridge to AI excellence, and democratise access to top-level AI capabilities. Be part of the AI revolution and stay at the forefront of innovation with SynthPairPro – Synthesizing Intelligence, Empowering the Future!
+
+
+## Development
+
+### Testing
+
+install `nox` (`pip install nox`) and run `nox -s test`.
+
+## Setup
+
+### Before you proceed
+Before you proceed with the installation of the subnet, note the following:
+
+**IMPORTANT**: We **strongly recommend** before proceeding that you test both subtensor and OpenAI API keys. Ensure you are running Subtensor locally to minimize chances of outages and improve the latency/connection.
+
+After exporting your OpenAI API key to your bash profile, test the streaming service for both the gpt-3.5-turbo and gpt-4 engines using ```./neurons/test_openai.py```. Neither the miner or the validator will function without a valid and working [OpenAI API key](https://platform.openai.com/).
+
+**IMPORTANT:** Make sure you are aware of the minimum compute requirements for cortex.t. See the [Minimum compute YAML configuration](./min_compute.yml).
+Note that this subnet requires very little compute. The main functionality is api calls, so we outsource the compute to openai. The cost for mining and validating on this subnet comes from api calls, not from compute. Please be aware of your API costs and monitor accordingly.
+
+A high tier key is required for both mining and validations so it is important if you do not have one to work your way up slowly by running a single miner or small numbers of miners whilst payiing attention to your usage and limits.
+
+
+### Installation
+
+Download the repository, navigate to the folder and then install the necessary requirements with the following chained command.
+
+```git clone https://github.com/corcel-api/cortex.t.git && cd cortex.t && pip install -e .```
+
+Prior to proceeding, ensure you have a registered hotkey on subnet 18 mainnet. If not, run the command `btcli s register --netuid 18 --wallet.name [wallet_name] --wallet.hotkey [wallet.hotkey]`.
+
+We recommend using [direnv](https://direnv.net). After installing it, copy `envrc.example` to `.envrc` and substitute
+all env vars with values appropriate for your accounts. After making changes to `.envrc` run `direnv allow` and start a
+new terminal tab.
+
+## Mining
+
+You can launch your miners via pm2 using the following command.
+
+`pm2 start ./miner/miner.py --interpreter python3 -- --netuid 18 --subtensor.network
--wallet.name --wallet.hotkey --axon.port `
+
+
+## Validating
+
+You can launch your validator via pm2 using the following command.
+
+`pm2 start ./validators/validator.py --interpreter python3 -- --netuid 18 --subtensor.network --wallet.name --wallet.hotkey `
+
+
+## Logging
+
+As cortex.t supports streaming natively, you do not (and should not) enable `logging.trace` or `logging.debug` as all of the important information is already output to `logging.info` which is set as default.
+
+---
+
+## License
+This repository is licensed under the MIT License.
+```text
+# The MIT License (MIT)
+# Copyright © 2023 Yuma Rao
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
+# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
+# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
+# the Software.
+
+# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
+# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+```
diff --git a/Cortex.t.egg-info/SOURCES.txt b/Cortex.t.egg-info/SOURCES.txt
new file mode 100644
index 00000000..8e060877
--- /dev/null
+++ b/Cortex.t.egg-info/SOURCES.txt
@@ -0,0 +1,14 @@
+LICENSE
+README.md
+setup.py
+Cortex.t.egg-info/PKG-INFO
+Cortex.t.egg-info/SOURCES.txt
+Cortex.t.egg-info/dependency_links.txt
+Cortex.t.egg-info/requires.txt
+Cortex.t.egg-info/top_level.txt
+base/__init__.py
+template/__init__.py
+template/protocol.py
+template/reward.py
+template/utils.py
+test_base/__init__.py
\ No newline at end of file
diff --git a/Cortex.t.egg-info/dependency_links.txt b/Cortex.t.egg-info/dependency_links.txt
new file mode 100644
index 00000000..8b137891
--- /dev/null
+++ b/Cortex.t.egg-info/dependency_links.txt
@@ -0,0 +1 @@
+
diff --git a/Cortex.t.egg-info/requires.txt b/Cortex.t.egg-info/requires.txt
new file mode 100644
index 00000000..7435c4fa
--- /dev/null
+++ b/Cortex.t.egg-info/requires.txt
@@ -0,0 +1,16 @@
+aiohttp==3.*
+bittensor==6.*
+datasets==2.*
+envparse==0.2.0
+openai==1.*,>=1.3.2
+Pillow==10.*
+requests==2.*
+scikit-learn==1.*
+torch==2.*
+transformers==4.*
+wandb
+anthropic
+stability-sdk
+boto3
+anthropic_bedrock
+pyOpenSSL
diff --git a/Cortex.t.egg-info/top_level.txt b/Cortex.t.egg-info/top_level.txt
new file mode 100644
index 00000000..6976312d
--- /dev/null
+++ b/Cortex.t.egg-info/top_level.txt
@@ -0,0 +1,3 @@
+base
+template
+test_base
diff --git a/api.py b/api.py
new file mode 100644
index 00000000..b3333558
--- /dev/null
+++ b/api.py
@@ -0,0 +1,86 @@
+import bittensor as bt
+import asyncio
+import json
+import traceback
+from template.protocol import StreamPrompting, TextPrompting, ImageResponse
+
+# Assuming initial setup remains the same
+wallet = bt.wallet( name="validator", hotkey="default" )
+axon = bt.axon(wallet=wallet)
+dendrite = bt.dendrite(wallet=wallet)
+subtensor = bt.subtensor( network = "test")
+metagraph = subtensor.metagraph(netuid = 24 )
+
+# StreamPrompting variables
+question = [{"role": "user", "content": "quick question"}]
+vali_uid = 1
+target_uid = 3
+provider = "OpenAI"
+model = "gpt-3.5-turbo"
+seed = 1234
+temperature = 0.5
+max_tokens = 2048
+top_p = 0.8
+top_k = 1000
+timeout = 3
+streaming = True
+
+synapse = StreamPrompting(
+ messages=question,
+ uid=target_uid,
+ provider=provider,
+ model=model,
+ seed=seed,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ top_p=top_p,
+ top_k=top_k,
+ timeout=timeout,
+ streaming=streaming,
+)
+
+# ImageResponse variables
+messages = "a thick white cloud over a river"
+
+synapse = ImageResponse(
+ messages=messages
+)
+
+print("messages", messages)
+bt.trace()
+response = dendrite.query(metagraph.axons[vali_uid], synapse, deserialize=False, timeout=synapse.timeout)
+print('completion:', response.completion)
+
+# async def query_miner(synapse):
+# try:
+# axon = metagraph.axons[vali_uid]
+# responses = dendrite.query(
+# axons=[axon],
+# synapse=synapse,
+# deserialize=False,
+# timeout=timeout,
+# streaming=streaming,
+# )
+# return await handle_response(responses)
+# except Exception as e:
+# print(f"Exception during query: {traceback.format_exc()}")
+# return None
+
+# async def handle_response(responses):
+# full_response = ""
+# try:
+# for resp in responses:
+# async for chunk in resp:
+# if isinstance(chunk, str):
+# full_response += chunk
+# bt.logging.info(chunk)
+# except Exception as e:
+# print(f"Error processing response for uid {e}")
+# return full_response
+
+# async def main():
+# response = await query_miner(synapse)
+# bt.logging.info(f"full_response = {response}")
+
+# if __name__ == "__main__":
+# asyncio.run(main())
diff --git a/template/__init__.py b/cortext/__init__.py
similarity index 99%
rename from template/__init__.py
rename to cortext/__init__.py
index bf12c693..99d85b68 100644
--- a/template/__init__.py
+++ b/cortext/__init__.py
@@ -19,7 +19,7 @@
# version must stay on line 22
-__version__ = "3.2.2"
+__version__ = "3.2.3"
version_split = __version__.split(".")
__spec_version__ = (
(1000 * int(version_split[0]))
@@ -51,6 +51,7 @@
# must have the test_key whitelisted to avoid a global blacklist
testnet_key = ["5EhEZN6soubtKJm8RN7ANx9FGZ2JezxBUFxr45cdsHtDp3Uk"]
test_key = ["5DcRHcCwD33YsHfj4PX5j2evWLniR1wSWeNmpf5RXaspQT6t"]
+VALIDATOR_API_WHITELIST = ["5EhEZN6soubtKJm8RN7ANx9FGZ2JezxBUFxr45cdsHtDp3Uk"]
valid_validators = [
'5FFApaS75bv5pJHfAp2FVLBj9ZaXuFDjEypsaBNc1wCfe52v',
'5EhvL1FVkQPpMjZX4MAADcW42i3xPSF1KiCpuaxTYVr28sux',
diff --git a/template/protocol.py b/cortext/protocol.py
similarity index 66%
rename from template/protocol.py
rename to cortext/protocol.py
index 5d7f4816..653c8b7b 100644
--- a/template/protocol.py
+++ b/cortext/protocol.py
@@ -107,6 +107,18 @@ class ImageResponse(bt.Synapse):
title="Quality",
description="The quality of the image."
)
+
+ uid: int = pydantic.Field(
+ default=3,
+ title="uid",
+ description="The UID to send the synapse to",
+ )
+
+ timeout: int = pydantic.Field(
+ default=60,
+ title="timeout",
+ description="The timeout for the dendrite of the synapse",
+ )
required_hash_fields: List[str] = pydantic.Field(
["messages"],
@@ -138,6 +150,18 @@ class Embeddings( bt.Synapse):
title="Embeddings",
description="The resulting list of embeddings, each corresponding to an input text."
)
+
+ uid: int = pydantic.Field(
+ default=60,
+ title="uid",
+ description="The UID to send the synapse to",
+ )
+
+ timeout: int = pydantic.Field(
+ default=60,
+ title="timeout",
+ description="The timeout for the dendrite of the synapse",
+ )
@@ -204,7 +228,8 @@ class StreamPrompting(bt.StreamingSynapse):
provider: str = pydantic.Field(
default="OpenAI",
title="Provider",
- description="The provider to use when calling for your response."
+ description="The provider to use when calling for your response. "
+ "Options: OpenAI, Anthropic, Gemini",
)
model: str = pydantic.Field(
@@ -213,6 +238,24 @@ class StreamPrompting(bt.StreamingSynapse):
description="The model to use when calling provider for your response.",
)
+ uid: int = pydantic.Field(
+ default=3,
+ title="uid",
+ description="The UID to send the streaming synapse to",
+ )
+
+ timeout: int = pydantic.Field(
+ default=60,
+ title="timeout",
+ description="The timeout for the dendrite of the streaming synapse",
+ )
+
+ streaming: bool = pydantic.Field(
+ default=True,
+ title="streaming",
+ description="whether to stream the output",
+ )
+
async def process_streaming_response(self, response: StreamingResponse) -> AsyncIterator[str]:
if self.completion is None:
self.completion = ""
@@ -249,3 +292,95 @@ def extract_info(prefix: str) -> dict[str, str]:
"messages": self.messages,
"completion": self.completion,
}
+
+
+class TextPrompting(bt.Synapse):
+
+ messages: List[Dict[str, str]] = pydantic.Field(
+ ...,
+ title="Messages",
+ description="A list of messages in the StreamPrompting scenario, "
+ "each containing a role and content. Immutable.",
+ allow_mutation=False,
+ )
+
+ required_hash_fields: List[str] = pydantic.Field(
+ ["messages"],
+ title="Required Hash Fields",
+ description="A list of required fields for the hash.",
+ allow_mutation=False,
+ )
+
+ seed: int = pydantic.Field(
+ default="1234",
+ title="Seed",
+ description="Seed for text generation. This attribute is immutable and cannot be updated.",
+ )
+
+ temperature: float = pydantic.Field(
+ default=0.0001,
+ title="Temperature",
+ description="Temperature for text generation. "
+ "This attribute is immutable and cannot be updated.",
+ )
+
+ max_tokens: int = pydantic.Field(
+ default=2048,
+ title="Max Tokens",
+ description="Max tokens for text generation. "
+ "This attribute is immutable and cannot be updated.",
+ )
+
+ top_p: float = pydantic.Field(
+ default=0.001,
+ title="Top_p",
+ description="Top_p for text generation. The sampler will pick one of "
+ "the top p percent tokens in the logit distirbution. "
+ "This attribute is immutable and cannot be updated.",
+ )
+
+ top_k: int = pydantic.Field(
+ default=1,
+ title="Top_k",
+ description="Top_k for text generation. Sampler will pick one of "
+ "the k most probablistic tokens in the logit distribtion. "
+ "This attribute is immutable and cannot be updated.",
+ )
+
+ completion: str = pydantic.Field(
+ None,
+ title="Completion",
+ description="Completion status of the current StreamPrompting object. "
+ "This attribute is mutable and can be updated.",
+ )
+
+ provider: str = pydantic.Field(
+ default="OpenAI",
+ title="Provider",
+ description="The provider to use when calling for your response. "
+ "Options: OpenAI, Anthropic, Gemini",
+ )
+
+ model: str = pydantic.Field(
+ default="gpt-3.5-turbo",
+ title="model",
+ description="The model to use when calling provider for your response.",
+ )
+
+ uid: int = pydantic.Field(
+ default=3,
+ title="uid",
+ description="The UID to send the streaming synapse to",
+ )
+
+ timeout: int = pydantic.Field(
+ default=60,
+ title="timeout",
+ description="The timeout for the dendrite of the streaming synapse",
+ )
+
+ streaming: bool = pydantic.Field(
+ default=True,
+ title="streaming",
+ description="whether to stream the output",
+ )
\ No newline at end of file
diff --git a/template/reward.py b/cortext/reward.py
similarity index 96%
rename from template/reward.py
rename to cortext/reward.py
index 61b1bfe8..0c069430 100644
--- a/template/reward.py
+++ b/cortext/reward.py
@@ -31,10 +31,11 @@
import logging
import aiohttp
import requests
+import traceback
import numpy as np
from numpy.linalg import norm
import bittensor as bt
-from template import utils
+from cortext import utils
from PIL import Image
from scipy.spatial.distance import cosine
from sklearn.metrics.pairwise import cosine_similarity
@@ -54,28 +55,28 @@ def calculate_text_similarity(text1: str, text2: str):
# Calculate the Cosine Similarity
similarity = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
- # bt.debug(f"Similarity: {similarity}")
+ bt.logging.debug(f"Similarity: {similarity}")
return similarity
except Exception as e:
- bt.logging.error(f"Error in calculate_text_similarity: {e}")
+ bt.logging.error(f"Error in calculate_text_similarity: {traceback.format_exc()}")
raise
-async def api_score(api_answer: str, response: str, weight: float) -> float:
+async def api_score(api_answer: str, response: str, weight: float, temperature: float, provider: str) -> float:
try:
loop = asyncio.get_running_loop()
similarity = await loop.run_in_executor(None, calculate_text_similarity, api_answer, response)
- bt.logging.debug(f"Similarity obtained: {similarity}")
words_in_response = len(response.split())
words_in_api = len(api_answer.split())
- # Answer must be within 10% the true answer's length
- word_count_threshold = words_in_api * 0.10
+ # Answer must be within some percent of the true answer's length
+ word_count_threshold = words_in_api * 0.20
# Check if the word count difference is within the threshold and similarity
if abs(words_in_response - words_in_api) <= word_count_threshold:
min_similarity = max(1 - 0.001 * (words_in_response - 1), 0.82)
- score = weight if similarity >= min_similarity else 0
+ # score = weight if similarity >= min_similarity else 0
+ score = weight * similarity
else:
score = 0
diff --git a/template/utils.py b/cortext/utils.py
similarity index 82%
rename from template/utils.py
rename to cortext/utils.py
index 0ff45668..82b119da 100644
--- a/template/utils.py
+++ b/cortext/utils.py
@@ -15,16 +15,18 @@
from PIL import Image
import traceback
import anthropic
-from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
+from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT, AsyncAnthropic
from typing import Optional
from stability_sdk import client as stability_client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
+import google.generativeai as genai
+
import bittensor as bt
import requests
import wandb
-import template
+import cortext
from . import client
@@ -33,32 +35,58 @@
from anthropic_bedrock import AsyncAnthropicBedrock
# Set up the API connection
-stability_api = stability_client.StabilityInference(
- key=os.environ['STABILITY_KEY'],
- verbose=True,
- engine="stable-diffusion-xl-1024-v1-0"
-)
+# stability_api = stability_client.StabilityInference(
+# key=os.environ['STABILITY_API_KEY'],
+# verbose=True,
+# engine="stable-diffusion-xl-1024-v1-0"
+# )
+
+claude_key = os.environ.get("ANTHROPIC_API_KEY")
+if not claude_key:
+ raise ValueError("claude api key not found in environment variables. Go to https://console.anthropic.com/settings/keys to get one. Then set it as ANTHROPIC_API_KEY in your .env")
+
+claude_client = AsyncAnthropic()
+claude_client.api_key = claude_key
+
+# google_key=os.environ.get('GOOGLE_API_KEY')
+# if not google_key:
+# raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
+
+# genai.configure(api_key=google_key)
+bedrock_client = AsyncAnthropicBedrock()
-def load_state_from_file(filename: str = "state.json"):
+def load_state_from_file(filename: str):
+ load_success = False
+
+ # Check if the file exists
if os.path.exists(filename):
with open(filename, "r") as file:
- bt.logging.debug("loaded previous state")
- return json.load(file)
- else:
+ try:
+ # Attempt to load JSON from the file
+ bt.logging.debug("loaded previous state")
+ state = json.load(file)
+ load_success = True # Set flag to true as the operation was successful
+ return state
+ except Exception as e: # Catch specific exceptions for better error handling
+ bt.logging.error(f"error loading state, deleting and resetting it. Error: {e}")
+ os.remove(filename) # Delete if error
+
+ # If the file does not exist or there was an error
+ if not load_success:
bt.logging.debug("initialized new global state")
+ # Return the default state structure
return {
"text": {"themes": None, "questions": None, "theme_counter": 0, "question_counter": 0},
"images": {"themes": None, "questions": None, "theme_counter": 0, "question_counter": 0}
}
-state = load_state_from_file()
-
+state = None
-def get_state():
+def get_state(path):
global state
- if state is None:
- load_state_from_file()
+ if not state:
+ state = load_state_from_file(path)
return state
@@ -68,12 +96,13 @@ def save_state_to_file(state, filename="state.json"):
json.dump(state, file)
+
def get_validators_with_runs_in_all_projects():
api = wandb.Api()
validators_runs = {project: set() for project in projects}
# Retrieve runs for each project and store validator UIDs
- for project in template.PROJECT_NAMES:
+ for project in cortext.PROJECT_NAMES:
runs = api.runs(f"cortex-t/{project}")
for run in runs:
if run.config['type'] == 'validator':
@@ -87,11 +116,11 @@ async def get_list(list_type, num_questions_needed, theme=None):
prompts_in_question = {'text_questions': 10, 'images_questions': 20}
list_type_mapping = {
"text_questions": {
- "default": template.INSTRUCT_DEFAULT_QUESTIONS,
+ "default": cortext.INSTRUCT_DEFAULT_QUESTIONS,
"prompt": "placeholder"
},
"images_questions": {
- "default": template.IMAGE_DEFAULT_QUESTIONS,
+ "default": cortext.IMAGE_DEFAULT_QUESTIONS,
"prompt": f"Provide a python-formatted list of {prompts_in_question[list_type]} creative and detailed scenarios for image generation, each inspired by the theme '{theme}'. The scenarios should be diverse, thoughtful, and possibly out-of-the-box interpretations related to '{theme}'. Each element in the list should be a concise, but a vividly descriptive situation designed to inspire visually rich stories. Format these elements as comma-separated, quote-encapsulated strings in a single Python list."
}
}
@@ -176,8 +205,8 @@ async def update_counters_and_get_new_list(category, item_type, num_questions_ne
async def get_items(category, item_type, theme=None):
if item_type == "themes":
if category == "images":
- return template.IMAGE_THEMES
- return template.INSTRUCT_DEFAULT_THEMES
+ return cortext.IMAGE_THEMES
+ return cortext.INSTRUCT_DEFAULT_THEMES
else:
# Never fail here, retry until valid list is found
while True:
@@ -360,7 +389,29 @@ async def call_openai(messages, temperature, model, seed=1234, max_tokens=2048,
return None
+async def call_gemini(messages, temperature, model, max_tokens, top_p, top_k):
+ print(f"Calling Gemini. Temperature = {temperature}, Model = {model}, Messages = {messages}")
+ try:
+ model = genai.GenerativeModel(model)
+ response = model.generate_content(
+ str(messages),
+ stream=False,
+ generation_config=genai.types.GenerationConfig(
+ # candidate_count=1,
+ # stop_sequences=['x'],
+ temperature=temperature,
+ # max_output_tokens=max_tokens,
+ top_p=top_p,
+ top_k=top_k,
+ # seed=seed,
+ )
+ )
+ print(f"validator response is {response.text}")
+ return response.text
+ except:
+ print(f"error in call_gemini {traceback.format_exc()}")
+
# anthropic = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
@@ -389,9 +440,8 @@ async def call_openai(messages, temperature, model, seed=1234, max_tokens=2048,
async def call_anthropic(prompt, temperature, model, max_tokens=2048, top_p=1, top_k=10000):
try:
- client = AsyncAnthropicBedrock()
bt.logging.debug(f"Calling Anthropic. Model = {model}, Prompt = {prompt}, Temperature = {temperature}, Max Tokens = {max_tokens}")
- completion = await client.completions.create(
+ completion = await bedrock_client.completions.create(
model=model,
max_tokens_to_sample=max_tokens,
temperature=temperature,
@@ -406,6 +456,28 @@ async def call_anthropic(prompt, temperature, model, max_tokens=2048, top_p=1, t
bt.logging.error(f"Error when calling Anthropic: {traceback.format_exc()}")
await asyncio.sleep(0.5)
+async def call_claude(messages, temperature, model, max_tokens, top_p, top_k):
+ system_prompt = None
+ filtered_messages = []
+ for message in messages:
+ if message["role"] == "system":
+ system_prompt = message["content"]
+ else:
+ filtered_messages.append(message)
+
+ kwargs = {
+ "max_tokens": max_tokens,
+ "messages": filtered_messages,
+ "model": model,
+ }
+
+ if system_prompt:
+ kwargs["system"] = system_prompt
+
+ message = await claude_client.messages.create(**kwargs)
+ bt.logging.debug(f"validator response is {message.content[0].text}")
+ return message.content[0].text
+
async def call_stability(prompt, seed, steps, cfg_scale, width, height, samples, sampler):
# bt.logging.info(f"calling stability for {prompt, seed, steps, cfg_scale, width, height, samples, sampler}")
bt.logging.info(f"calling stability for {prompt[:50]}...")
@@ -431,7 +503,7 @@ async def call_stability(prompt, seed, steps, cfg_scale, width, height, samples,
# Github unauthorized rate limit of requests per hour is 60. Authorized is 5000.
def get_version(line_number: int = 22) -> Optional[str]:
- url = "https://api.github.com/repos/corcel-api/cortex.t/contents/template/__init__.py"
+ url = "https://api.github.com/repos/corcel-api/cortex.t/contents/cortext/__init__.py"
response = requests.get(url, timeout=10)
if not response.ok:
bt.logging.error("github api call failed")
diff --git a/miner/claude_miner.py b/miner/claude_miner.py
index e2530d5e..3cbaed5f 100644
--- a/miner/claude_miner.py
+++ b/miner/claude_miner.py
@@ -29,9 +29,9 @@
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
from anthropic_bedrock import AsyncAnthropicBedrock, HUMAN_PROMPT, AI_PROMPT, AnthropicBedrock
-import template
-from template.protocol import Embeddings, ImageResponse, IsAlive, StreamPrompting
-from template.utils import get_version
+import cortext
+from cortext.protocol import Embeddings, ImageResponse, IsAlive, StreamPrompting
+from cortext.utils import get_version
import sys
from starlette.types import Send
@@ -157,7 +157,7 @@ def base_blacklist(self, synapse, blacklist_amt = 20000) -> Tuple[bool, str]:
hotkey = synapse.dendrite.hotkey
synapse_type = type(synapse).__name__
- if hotkey in template.WHITELISTED_KEYS:
+ if hotkey in cortext.WHITELISTED_KEYS:
return False, f"accepting {synapse_type} request from {hotkey}"
if hotkey not in valid_hotkeys:
@@ -168,7 +168,7 @@ def base_blacklist(self, synapse, blacklist_amt = 20000) -> Tuple[bool, str]:
if _axon.hotkey == hotkey:
break
- if uid is None and template.ALLOW_NON_REGISTERED is False:
+ if uid is None and cortext.ALLOW_NON_REGISTERED is False:
return True, f"Blacklisted a non registered hotkey's {synapse_type} request from {hotkey}"
# check the stake
@@ -177,7 +177,7 @@ def base_blacklist(self, synapse, blacklist_amt = 20000) -> Tuple[bool, str]:
if tao < blacklist_amt:
return True, f"Blacklisted a low stake {synapse_type} request: {tao} < {blacklist_amt} from {hotkey}"
- time_window = template.MIN_REQUEST_PERIOD * 60
+ time_window = cortext.MIN_REQUEST_PERIOD * 60
current_time = time.time()
if hotkey not in self.request_timestamps:
@@ -188,12 +188,12 @@ def base_blacklist(self, synapse, blacklist_amt = 20000) -> Tuple[bool, str]:
self.request_timestamps[hotkey].popleft()
# Check if the number of requests exceeds the limit
- if len(self.request_timestamps[hotkey]) >= template.MAX_REQUESTS:
+ if len(self.request_timestamps[hotkey]) >= cortext.MAX_REQUESTS:
return (
True,
f"Request frequency for {hotkey} exceeded: "
- f"{len(self.request_timestamps[hotkey])} requests in {template.MIN_REQUEST_PERIOD} minutes. "
- f"Limit is {template.MAX_REQUESTS} requests."
+ f"{len(self.request_timestamps[hotkey])} requests in {cortext.MIN_REQUEST_PERIOD} minutes. "
+ f"Limit is {cortext.MAX_REQUESTS} requests."
)
self.request_timestamps[hotkey].append(current_time)
@@ -205,22 +205,22 @@ def base_blacklist(self, synapse, blacklist_amt = 20000) -> Tuple[bool, str]:
def blacklist_prompt( self, synapse: StreamPrompting ) -> Tuple[bool, str]:
- blacklist = self.base_blacklist(synapse, template.PROMPT_BLACKLIST_STAKE)
+ blacklist = self.base_blacklist(synapse, cortext.PROMPT_BLACKLIST_STAKE)
bt.logging.info(blacklist[1])
return blacklist
def blacklist_is_alive( self, synapse: IsAlive ) -> Tuple[bool, str]:
- blacklist = self.base_blacklist(synapse, template.ISALIVE_BLACKLIST_STAKE)
+ blacklist = self.base_blacklist(synapse, cortext.ISALIVE_BLACKLIST_STAKE)
bt.logging.debug(blacklist[1])
return blacklist
def blacklist_images( self, synapse: ImageResponse ) -> Tuple[bool, str]:
- blacklist = self.base_blacklist(synapse, template.IMAGE_BLACKLIST_STAKE)
+ blacklist = self.base_blacklist(synapse, cortext.IMAGE_BLACKLIST_STAKE)
bt.logging.info(blacklist[1])
return blacklist
def blacklist_embeddings( self, synapse: Embeddings ) -> Tuple[bool, str]:
- blacklist = self.base_blacklist(synapse, template.EMBEDDING_BLACKLIST_STAKE)
+ blacklist = self.base_blacklist(synapse, cortext.EMBEDDING_BLACKLIST_STAKE)
bt.logging.info(blacklist[1])
return blacklist
@@ -351,7 +351,7 @@ def __exit__(self, exc_type, exc_value, traceback):
self.stop_run_thread()
-class StreamingTemplateMiner(StreamMiner):
+class StreamingcortextMiner(StreamMiner):
def config(self) -> bt.config:
parser = argparse.ArgumentParser(description="Streaming Miner Configs")
self.add_args(parser)
@@ -584,7 +584,7 @@ def get_valid_hotkeys(config):
while True:
metagraph = subtensor.metagraph(18)
try:
- runs = api.runs(f"cortex-t/{template.PROJECT_NAME}")
+ runs = api.runs(f"cortex-t/{cortext.PROJECT_NAME}")
latest_version = get_version()
for run in runs:
if run.state == "running":
@@ -628,6 +628,6 @@ def get_valid_hotkeys(config):
if __name__ == "__main__":
- with StreamingTemplateMiner():
+ with StreamingcortextMiner():
while True:
time.sleep(1)
diff --git a/miner/miner.py b/miner/miner.py
index 927912fa..ca7c7ef8 100644
--- a/miner/miner.py
+++ b/miner/miner.py
@@ -1,84 +1,107 @@
-import base # noqa
-
import argparse
import asyncio
+import base64
import copy
import json
import os
-import io
-import base64
-import boto3
import pathlib
+import requests
import threading
import time
-import requests
import traceback
-import requests
import anthropic
-from abc import ABC, abstractmethod
from collections import deque
from functools import partial
from typing import Tuple
-from stability_sdk import client
import bittensor as bt
+import google.generativeai as genai
import wandb
+from PIL import Image
+from stability_sdk import client
from config import check_config, get_config
from openai import AsyncOpenAI, OpenAI
+from anthropic import AsyncAnthropic
+from stability_sdk import client as stability_client
from PIL import Image
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
from anthropic_bedrock import AsyncAnthropicBedrock, HUMAN_PROMPT, AI_PROMPT, AnthropicBedrock
-import template
-from template.protocol import Embeddings, ImageResponse, IsAlive, StreamPrompting
-from template.utils import get_version
+import cortext
+from cortext.protocol import Embeddings, ImageResponse, IsAlive, StreamPrompting, TextPrompting
+from cortext.utils import get_version
import sys
from starlette.types import Send
+# Set up api keys from .env file and initialze clients
+
+# OpenAI
OpenAI.api_key = os.environ.get("OPENAI_API_KEY")
if not OpenAI.api_key:
raise ValueError("Please set the OPENAI_API_KEY environment variable.")
-stability_api = client.StabilityInference(
- key=os.environ['STABILITY_KEY'],
+client = AsyncOpenAI(timeout=60.0)
+
+# Stability
+stability_key = os.environ.get("STABILITY_API_KEY")
+if not stability_key:
+ raise ValueError("Please set the STABILITY_KEY environment variable.")
+
+claude_key = os.environ.get("ANTHROPIC_API_KEY")
+if not claude_key:
+ raise ValueError("claude api key not found in environment variables. Go to https://console.anthropic.com/settings/keys to get one. Then set it as ANTHROPIC_API_KEY in your .env")
+
+claude_client = AsyncAnthropic()
+claude_client.api_key = claude_key
+
+stability_api = stability_client.StabilityInference(
+ key=stability_key,
verbose=True,
- engine="stable-diffusion-xl-1024-v1-0"
)
+# Anthropic
+# Only if using the official claude for access instead of aws bedrock
api_key = os.environ.get("ANTHROPIC_API_KEY")
+anthropic_client = anthropic.Anthropic()
+anthropic_client.api_key = api_key
+# For AWS bedrock (default)
bedrock_client = AsyncAnthropicBedrock(
# default is 10 minutes
# more granular timeout options: timeout=httpx.Timeout(60.0, read=5.0, write=10.0, connect=2.0),
timeout=60.0,
)
-
anthropic_client = anthropic.Anthropic()
-anthropic_client.api_key = api_key
+# For google/gemini
+google_key=os.environ.get('GOOGLE_API_KEY')
+if not google_key:
+ raise ValueError("Please set the GOOGLE_API_KEY environment variable.")
+
+genai.configure(api_key=google_key)
+
+
+# Wandb
netrc_path = pathlib.Path.home() / ".netrc"
wandb_api_key = os.getenv("WANDB_API_KEY")
-
bt.logging.info("WANDB_API_KEY is set")
bt.logging.info("~/.netrc exists:", netrc_path.exists())
if not wandb_api_key and not netrc_path.exists():
raise ValueError("Please log in to wandb using `wandb login` or set the WANDB_API_KEY environment variable.")
-client = AsyncOpenAI(timeout=60.0)
valid_hotkeys = []
-
-class StreamMiner(ABC):
+class StreamMiner():
def __init__(self, config=None, axon=None, wallet=None, subtensor=None):
bt.logging.info("starting stream miner")
base_config = copy.deepcopy(config or get_config())
self.config = self.config()
self.config.merge(base_config)
check_config(StreamMiner, self.config)
- bt.logging.info(self.config) # TODO: duplicate print?
+ bt.logging.info(self.config)
self.prompt_cache: dict[str, Tuple[str, int]] = {}
self.request_timestamps = {}
@@ -104,8 +127,8 @@ def __init__(self, config=None, axon=None, wallet=None, subtensor=None):
if self.wallet.hotkey.ss58_address not in self.metagraph.hotkeys:
bt.logging.error(
- f"\nYour validator: {self.wallet} if not registered to chain connection: {self.subtensor} "
- f"\nRun btcli register and try again. "
+ f"\nYour miner: {self.wallet} is not registered to this subnet"
+ f"\nRun btcli recycle_register --netuid 18 and try again. "
)
sys.exit()
else:
@@ -119,19 +142,21 @@ def __init__(self, config=None, axon=None, wallet=None, subtensor=None):
self.axon = axon or bt.axon(wallet=self.wallet, port=self.config.axon.port)
# Attach determiners which functions are called when servicing a request.
bt.logging.info("Attaching forward function to axon.")
- print(f"Attaching forward function to axon. {self._prompt}")
+ print(f"Attaching forward function to axon. {self.prompt}")
self.axon.attach(
- forward_fn=self._prompt,
+ forward_fn=self.prompt,
blacklist_fn=self.blacklist_prompt,
).attach(
- forward_fn=self._is_alive,
+ forward_fn=self.is_alive,
blacklist_fn=self.blacklist_is_alive,
).attach(
- forward_fn=self._images,
+ forward_fn=self.images,
blacklist_fn=self.blacklist_images,
).attach(
- forward_fn=self._embeddings,
+ forward_fn=self.embeddings,
blacklist_fn=self.blacklist_embeddings,
+ ).attach(
+ forward_fn=self.text,
)
bt.logging.info(f"Axon created: {self.axon}")
@@ -143,20 +168,21 @@ def __init__(self, config=None, axon=None, wallet=None, subtensor=None):
self.request_timestamps: dict = {}
thread = threading.Thread(target=get_valid_hotkeys, args=(self.config,))
# thread.start()
+
+ def text(self, synapse: TextPrompting) -> TextPrompting:
+ synapse.completion = "completed by miner"
+ return synapse
- @abstractmethod
def config(self) -> bt.config:
- ...
-
- def _prompt(self, synapse: StreamPrompting) -> StreamPrompting:
- return self.prompt(synapse)
+ parser = argparse.ArgumentParser(description="Streaming Miner Configs")
+ return bt.config(parser)
def base_blacklist(self, synapse, blacklist_amt = 20000) -> Tuple[bool, str]:
try:
hotkey = synapse.dendrite.hotkey
synapse_type = type(synapse).__name__
- if hotkey in template.WHITELISTED_KEYS:
+ if hotkey in cortext.WHITELISTED_KEYS:
return False, f"accepting {synapse_type} request from {hotkey}"
if hotkey not in valid_hotkeys:
@@ -167,16 +193,16 @@ def base_blacklist(self, synapse, blacklist_amt = 20000) -> Tuple[bool, str]:
if _axon.hotkey == hotkey:
break
- if uid is None and template.ALLOW_NON_REGISTERED is False:
+ if uid is None and cortext.ALLOW_NON_REGISTERED is False:
return True, f"Blacklisted a non registered hotkey's {synapse_type} request from {hotkey}"
# check the stake
- tao = self.metagraph.neurons[uid].stake.tao
+ tao = self.metagraph.neurons[uid].S
# metagraph.neurons[uid].S
if tao < blacklist_amt:
return True, f"Blacklisted a low stake {synapse_type} request: {tao} < {blacklist_amt} from {hotkey}"
- time_window = template.MIN_REQUEST_PERIOD * 60
+ time_window = cortext.MIN_REQUEST_PERIOD * 60
current_time = time.time()
if hotkey not in self.request_timestamps:
@@ -187,12 +213,12 @@ def base_blacklist(self, synapse, blacklist_amt = 20000) -> Tuple[bool, str]:
self.request_timestamps[hotkey].popleft()
# Check if the number of requests exceeds the limit
- if len(self.request_timestamps[hotkey]) >= template.MAX_REQUESTS:
+ if len(self.request_timestamps[hotkey]) >= cortext.MAX_REQUESTS:
return (
True,
f"Request frequency for {hotkey} exceeded: "
- f"{len(self.request_timestamps[hotkey])} requests in {template.MIN_REQUEST_PERIOD} minutes. "
- f"Limit is {template.MAX_REQUESTS} requests."
+ f"{len(self.request_timestamps[hotkey])} requests in {cortext.MIN_REQUEST_PERIOD} minutes. "
+ f"Limit is {cortext.MAX_REQUESTS} requests."
)
self.request_timestamps[hotkey].append(current_time)
@@ -204,56 +230,25 @@ def base_blacklist(self, synapse, blacklist_amt = 20000) -> Tuple[bool, str]:
def blacklist_prompt( self, synapse: StreamPrompting ) -> Tuple[bool, str]:
- blacklist = self.base_blacklist(synapse, template.PROMPT_BLACKLIST_STAKE)
+ blacklist = self.base_blacklist(synapse, cortext.PROMPT_BLACKLIST_STAKE)
bt.logging.info(blacklist[1])
return blacklist
def blacklist_is_alive( self, synapse: IsAlive ) -> Tuple[bool, str]:
- blacklist = self.base_blacklist(synapse, template.ISALIVE_BLACKLIST_STAKE)
+ blacklist = self.base_blacklist(synapse, cortext.ISALIVE_BLACKLIST_STAKE)
bt.logging.debug(blacklist[1])
return blacklist
def blacklist_images( self, synapse: ImageResponse ) -> Tuple[bool, str]:
- blacklist = self.base_blacklist(synapse, template.IMAGE_BLACKLIST_STAKE)
+ blacklist = self.base_blacklist(synapse, cortext.IMAGE_BLACKLIST_STAKE)
bt.logging.info(blacklist[1])
return blacklist
def blacklist_embeddings( self, synapse: Embeddings ) -> Tuple[bool, str]:
- blacklist = self.base_blacklist(synapse, template.EMBEDDING_BLACKLIST_STAKE)
+ blacklist = self.base_blacklist(synapse, cortext.EMBEDDING_BLACKLIST_STAKE)
bt.logging.info(blacklist[1])
return blacklist
- @classmethod
- @abstractmethod
- def add_args(cls, parser: argparse.ArgumentParser):
- ...
-
- def _prompt(self, synapse: StreamPrompting) -> StreamPrompting:
- return self.prompt(synapse)
-
- async def _images(self, synapse: ImageResponse) -> ImageResponse:
- return await self.images(synapse)
-
- async def _embeddings(self, synapse: Embeddings) -> Embeddings:
- return await self.embeddings(synapse)
-
- def _is_alive(self, synapse: IsAlive) -> IsAlive:
- bt.logging.info("answered to be active")
- synapse.completion = "True"
- return synapse
-
- @abstractmethod
- def prompt(self, synapse: StreamPrompting) -> StreamPrompting:
- ...
-
- @abstractmethod
- def images(self, synapse: ImageResponse) -> ImageResponse:
- ...
-
- @abstractmethod
- def embeddings(self, synapse: Embeddings) -> Embeddings:
- ...
-
def run(self):
if not self.subtensor.is_hotkey_registered(
netuid=self.config.netuid,
@@ -286,7 +281,7 @@ def run(self):
current_block - self.last_epoch_block
< self.config.miner.blocks_per_epoch
):
- # --- Wait for next bloc.
+ # --- Wait for next block.
time.sleep(1)
current_block = self.subtensor.get_current_block()
# --- Check if we should exit.
@@ -349,122 +344,7 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
self.stop_run_thread()
-
-class StreamingTemplateMiner(StreamMiner):
- def config(self) -> bt.config:
- parser = argparse.ArgumentParser(description="Streaming Miner Configs")
- self.add_args(parser)
- return bt.config(parser)
-
- def add_args(cls, parser: argparse.ArgumentParser):
- pass
-
- async def embeddings(self, synapse: Embeddings) -> Embeddings:
- bt.logging.info(f"entered embeddings processing for embeddings of len {len(synapse.texts)}")
-
- async def get_embeddings_in_batch(texts, model, batch_size=10):
- batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
- tasks = []
- for batch in batches:
- filtered_batch = [text for text in batch if text.strip()]
- if filtered_batch:
- task = asyncio.create_task(client.embeddings.create(
- input=filtered_batch, model=model, encoding_format='float'
- ))
- tasks.append(task)
- else:
- bt.logging.info("Skipped an empty batch.")
-
- all_embeddings = []
- results = await asyncio.gather(*tasks, return_exceptions=True)
- for result in results:
- if isinstance(result, Exception):
- bt.logging.error(f"Error in processing batch: {result}")
- else:
- batch_embeddings = [item.embedding for item in result.data]
- all_embeddings.extend(batch_embeddings)
- return all_embeddings
-
- try:
- texts = synapse.texts
- model = synapse.model
- batched_embeddings = await get_embeddings_in_batch(texts, model)
- synapse.embeddings = batched_embeddings
- # synapse.embeddings = [np.array(embed) for embed in batched_embeddings]
- bt.logging.info(f"synapse response is {synapse.embeddings[0][:10]}")
- return synapse
- except Exception:
- bt.logging.error(f"Exception in embeddings function: {traceback.format_exc()}")
-
-
- async def images(self, synapse: ImageResponse) -> ImageResponse:
- bt.logging.info(f"received image request: {synapse}")
- try:
- # Extract necessary information from synapse
- provider = synapse.provider
- model = synapse.model
- messages = synapse.messages
- size = synapse.size
- width = synapse.width
- height = synapse.height
- quality = synapse.quality
- style = synapse.style
- seed = synapse.seed
- steps = synapse.steps
- image_revised_prompt = None
- cfg_scale = synapse.cfg_scale
- sampler = synapse.sampler
- samples = synapse.samples
- image_data = {}
-
- bt.logging.debug(f"data = {provider, model, messages, size, width, height, quality, style, seed, steps, image_revised_prompt, cfg_scale, sampler, samples}")
-
- if provider == "OpenAI":
- meta = await client.images.generate(
- model=model,
- prompt=messages,
- size=size,
- quality=quality,
- style=style,
- )
- image_url = meta.data[0].url
- image_revised_prompt = meta.data[0].revised_prompt
- image_data["url"] = image_url
- image_data["image_revised_prompt"] = image_revised_prompt
- bt.logging.info(f"returning image response of {image_url}")
-
- elif provider == "Stability":
- bt.logging.debug(f"calling stability for {messages, seed, steps, cfg_scale, width, height, samples, sampler}")
-
- meta = stability_api.generate(
- prompt=messages,
- seed=seed,
- steps=steps,
- cfg_scale=cfg_scale,
- width=width,
- height=height,
- samples=samples,
- # sampler=sampler
- )
- # Process and upload the image
- b64s = []
- for image in meta:
- for artifact in image.artifacts:
- b64s.append(base64.b64encode(artifact.binary).decode())
-
- image_data["b64s"] = b64s
- bt.logging.info(f"returning image response to {messages}")
-
- else:
- bt.logging.error(f"Unknown provider: {provider}")
-
- synapse.completion = image_data
- return synapse
-
- except Exception as exc:
- bt.logging.error(f"error in images: {exc}\n{traceback.format_exc()}")
-
- def prompt(self, synapse: StreamPrompting) -> StreamPrompting:
+ async def prompt(self, synapse: StreamPrompting) -> StreamPrompting:
bt.logging.info(f"started processing for synapse {synapse}")
async def _prompt(synapse, send: Send):
@@ -478,23 +358,17 @@ async def _prompt(synapse, send: Send):
top_p = synapse.top_p
top_k = synapse.top_k
- if provider == "OpenAI":
- ### Store the args in a dictionary to be passed to OpenAI
- args = dict(
+ if provider == "OpenAI":
+ # Test seeds + higher temperature
+ response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
stream=True,
+ seed=seed,
max_tokens=max_tokens
)
-
- ### If -1 is specified, this indicates that the user does not want a seed added to the request
- if seed != -1:
- args["seed"] = seed
-
- response = await client.chat.completions.create(**args)
-
buffer = []
n = 1
async for chunk in response:
@@ -523,33 +397,6 @@ async def _prompt(synapse, send: Send):
)
bt.logging.info(f"Streamed tokens: {joined_buffer}")
- # # for official claude users, comment out the other elif
- # elif provider == "Anthropic":
- # models = ["anthropic.claude-v2:1", "anthropic.claude-instant-v1", "anthropic.claude-v1", "anthropic.claude-v2"]
- # if model == models[0]: model = "claude-2.1"
- # if model == models[1]: model = "claude-instant-1.2"
- # if model == models[2]: model = "claude-instant-1.2"
- # if model == models[3]: model = "claude-2.0"
-
- # with anthropic_client.beta.messages.stream(
- # max_tokens=max_tokens,
- # messages=messages,
- # model=model,
- # temperature=temperature,
- # top_p=top_p,
- # top_k=top_k,
- # ) as stream:
- # for text in stream.text_stream:
- # await send(
- # {
- # "type": "http.response.body",
- # "body": text.encode("utf-8"),
- # "more_body": True,
- # }
- # )
- # bt.logging.info(f"Streamed text: {text}")
-
- # For amazon bedrock users, comment out the other elif
elif provider == "Anthropic":
stream = await bedrock_client.completions.create(
prompt=f"\n\nHuman: {messages}\n\nAssistant:",
@@ -575,6 +422,68 @@ async def _prompt(synapse, send: Send):
# Send final message to close the stream
await send({"type": "http.response.body", "body": b'', "more_body": False})
+ elif provider == "Claude":
+ system_prompt = None
+ filtered_messages = []
+ for message in messages:
+ if message["role"] == "system":
+ system_prompt = message["content"]
+ else:
+ filtered_messages.append(message)
+
+ stream_kwargs = {
+ "max_tokens": max_tokens,
+ "messages": filtered_messages,
+ "model": model,
+ }
+
+ if system_prompt:
+ stream_kwargs["system"] = system_prompt
+
+ completion = claude_client.messages.stream(**stream_kwargs)
+ async with completion as stream:
+ async for text in stream.text_stream:
+ await send(
+ {
+ "type": "http.response.body",
+ "body": text.encode("utf-8"),
+ "more_body": True,
+ }
+ )
+ bt.logging.info(f"Streamed text: {text}")
+
+ # Send final message to close the stream
+ await send({"type": "http.response.body", "body": b'', "more_body": False})
+
+ elif provider == "Gemini":
+ model = genai.GenerativeModel(model)
+ stream = model.generate_content(
+ str(messages),
+ stream=True,
+ generation_config=genai.types.GenerationConfig(
+ # candidate_count=1,
+ # stop_sequences=['x'],
+ temperature=temperature,
+ # max_output_tokens=max_tokens,
+ top_p=top_p,
+ top_k=top_k,
+ # seed=seed,
+ )
+ )
+ for chunk in stream:
+ for part in chunk.candidates[0].content.parts:
+ await send(
+ {
+ "type": "http.response.body",
+ "body": chunk.text.encode("utf-8"),
+ "more_body": True,
+ }
+ )
+ bt.logging.info(f"Streamed text: {chunk.text}")
+
+ # Send final message to close the stream
+ await send({"type": "http.response.body", "body": b'', "more_body": False})
+
else:
bt.logging.error(f"Unknown provider: {provider}")
@@ -584,6 +493,116 @@ async def _prompt(synapse, send: Send):
token_streamer = partial(_prompt, synapse)
return synapse.create_streaming_response(token_streamer)
+ async def images(self, synapse: ImageResponse) -> ImageResponse:
+ bt.logging.info(f"received image request: {synapse}")
+ try:
+ # Extract necessary information from synapse
+ provider = synapse.provider
+ model = synapse.model
+ messages = synapse.messages
+ size = synapse.size
+ width = synapse.width
+ height = synapse.height
+ quality = synapse.quality
+ style = synapse.style
+ seed = synapse.seed
+ steps = synapse.steps
+ image_revised_prompt = None
+ cfg_scale = synapse.cfg_scale
+ sampler = synapse.sampler
+ samples = synapse.samples
+ image_data = {}
+
+ bt.logging.debug(f"data = {provider, model, messages, size, width, height, quality, style, seed, steps, image_revised_prompt, cfg_scale, sampler, samples}")
+
+ if provider == "OpenAI":
+ meta = await client.images.generate(
+ model=model,
+ prompt=messages,
+ size=size,
+ quality=quality,
+ style=style,
+ )
+ image_url = meta.data[0].url
+ image_revised_prompt = meta.data[0].revised_prompt
+ image_data["url"] = image_url
+ image_data["image_revised_prompt"] = image_revised_prompt
+ bt.logging.info(f"returning image response of {image_url}")
+
+ elif provider == "Stability":
+ bt.logging.debug(f"calling stability for {messages, seed, steps, cfg_scale, width, height, samples, sampler}")
+
+ meta = stability_api.generate(
+ prompt=messages,
+ seed=seed,
+ steps=steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ samples=samples,
+ # sampler=sampler
+ )
+ # Process and upload the image
+ b64s = []
+ for image in meta:
+ for artifact in image.artifacts:
+ b64s.append(base64.b64encode(artifact.binary).decode())
+
+ image_data["b64s"] = b64s
+ bt.logging.info(f"returning image response to {messages}")
+
+ else:
+ bt.logging.error(f"Unknown provider: {provider}")
+
+ synapse.completion = image_data
+ return synapse
+
+ except Exception as exc:
+ bt.logging.error(f"error in images: {exc}\n{traceback.format_exc()}")
+
+ async def embeddings(self, synapse: Embeddings) -> Embeddings:
+ bt.logging.info(f"entered embeddings processing for embeddings of len {len(synapse.texts)}")
+
+ async def get_embeddings_in_batch(texts, model, batch_size=10):
+ batches = [texts[i:i + batch_size] for i in range(0, len(texts), batch_size)]
+ tasks = []
+ for batch in batches:
+ filtered_batch = [text for text in batch if text.strip()]
+ if filtered_batch:
+ task = asyncio.create_task(client.embeddings.create(
+ input=filtered_batch, model=model, encoding_format='float'
+ ))
+ tasks.append(task)
+ else:
+ bt.logging.info("Skipped an empty batch.")
+
+ all_embeddings = []
+ results = await asyncio.gather(*tasks, return_exceptions=True)
+ for result in results:
+ if isinstance(result, Exception):
+ bt.logging.error(f"Error in processing batch: {result}")
+ else:
+ batch_embeddings = [item.embedding for item in result.data]
+ all_embeddings.extend(batch_embeddings)
+ return all_embeddings
+
+ try:
+ texts = synapse.texts
+ model = synapse.model
+ batched_embeddings = await get_embeddings_in_batch(texts, model)
+ synapse.embeddings = batched_embeddings
+ # synapse.embeddings = [np.array(embed) for embed in batched_embeddings]
+ bt.logging.info(f"synapse response is {synapse.embeddings[0][:10]}")
+ return synapse
+ except Exception:
+ bt.logging.error(f"Exception in embeddings function: {traceback.format_exc()}")
+
+ async def is_alive(self, synapse: IsAlive) -> IsAlive:
+ bt.logging.debug("answered to be active")
+ synapse.completion = "True"
+ return synapse
+
+
def get_valid_hotkeys(config):
global valid_hotkeys
api = wandb.Api()
@@ -591,7 +610,7 @@ def get_valid_hotkeys(config):
while True:
metagraph = subtensor.metagraph(18)
try:
- runs = api.runs(f"cortex-t/{template.PROJECT_NAME}")
+ runs = api.runs(f"cortex-t/{cortext.PROJECT_NAME}")
latest_version = get_version()
for run in runs:
if run.state == "running":
@@ -635,6 +654,8 @@ def get_valid_hotkeys(config):
if __name__ == "__main__":
- with StreamingTemplateMiner():
+ with StreamMiner():
while True:
time.sleep(1)
+
+
diff --git a/miner/prices.ipynb b/miner/prices.ipynb
new file mode 100644
index 00000000..8fc1e2d1
--- /dev/null
+++ b/miner/prices.ipynb
@@ -0,0 +1,84 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import copy\n",
+ "import torch\n",
+ "class ConstantTao:\n",
+ " def __init__( self, n, initial_tao ):\n",
+ " self.n = n\n",
+ " self.tao_reserve = [ initial_tao/n for i in range(self.n) ] \n",
+ " self.alpha_reserve = [ initial_tao for i in range(self.n) ]\n",
+ " self.alpha_outstanding = [ initial_tao for i in range(self.n) ]\n",
+ " self.k = [ self.tao_reserve[i] * self.alpha_reserve[i] for i in range(n) ]\n",
+ "\n",
+ " def __repr__( self ):\n",
+ " return self.__str__()\n",
+ "\n",
+ " def __str__( self ):\n",
+ " return f\"tao_reserve: {self.tao_reserve}\\ntao_prices: {self.tao_prices()}\\nalpha_prices: {self.alpha_prices()}\\nalpha_reserve: {self.alpha_reserve}\\nalpha_outstanding: {self.alpha_outstanding}\\nk: {self.k}\"\n",
+ " \n",
+ " def __state__dict__( self ):\n",
+ " return { \n",
+ " \"tao_reserve\": copy.deepcopy( self.tao_reserve ),\n",
+ " \"tao_prices\": copy.deepcopy( self.tao_prices() ) ,\n",
+ " \"alpha_prices\": copy.deepcopy( self.alpha_prices() ),\n",
+ " \"alpha_reserve\": copy.deepcopy( self.alpha_reserve ),\n",
+ " \"alpha_outstanding\":copy.deepcopy( self.alpha_outstanding ),\n",
+ " \"k\": copy.deepcopy( self.k ),\n",
+ " 'sum_marketcap': sum( self.marketcaps() ),\n",
+ " 'sum_price': sum( self.alpha_prices() ),\n",
+ " }\n",
+ "\n",
+ " def tao_prices( self ):\n",
+ " \"\"\" Calculate and return the current price of tao based on reserves. \"\"\"\n",
+ " return [ self.alpha_reserve[i] / self.tao_reserve[i] for i in range(len(self.tao_reserve)) ]\n",
+ "\n",
+ " def alpha_prices( self ):\n",
+ " \"\"\" Calculate and return the current price of alpha based on reserves. \"\"\"\n",
+ " return [ self.tao_reserve[i] / self.alpha_reserve[i] for i in range(len(self.tao_reserve)) ]\n",
+ " \n",
+ " def marketcaps( self ):\n",
+ " \"\"\" Calculate and return the market capitalization of alpha. \"\"\"\n",
+ " return [ (self.alpha_outstanding[i] + self.alpha_reserve[i]) * self.alpha_prices()[i] for i in range(len(self.tao_reserve)) ] \n",
+ "\n",
+ " def block_step( self ):\n",
+ " for i in range(self.n):\n",
+ " ratio_i = self.alpha_prices()[i] / sum( self.alpha_prices() )\n",
+ " self.tao_reserve[i] += 7200 * ratio_i\n",
+ " self.alpha_reserve[i] += 7200\n",
+ " self.alpha_outstanding[i] += 7200\n",
+ " self.k[i] = self.tao_reserve[i] * self.alpha_reserve[i]\n",
+ "\n",
+ " def swap( self, idx_a, idx_b, amount ):\n",
+ "\n",
+ " new_alpha_reserve_a = self.alpha_reserve[idx_a] + amount\n",
+ " new_tao_reserve_a = self.k[idx_a] / new_alpha_reserve_a\n",
+ " tao_bought_a = self.tao_reserve[idx_a] - new_tao_reserve_a\n",
+ " self.alpha_outstanding[idx_a] = self.alpha_outstanding[idx_a] - amount\n",
+ " self.alpha_reserve[idx_a] = new_alpha_reserve_a\n",
+ " self.tao_reserve[idx_a] = new_tao_reserve_a\n",
+ " self.k[idx_a] = self.tao_reserve[idx_a] * self.alpha_reserve[idx_a]\n",
+ "\n",
+ " new_tao_reserve_b = self.tao_reserve[idx_b] + tao_bought_a\n",
+ " new_alpha_reserve_b = self.k[idx_b] / new_tao_reserve_b\n",
+ " alpha_bought_b = self.alpha_reserve[idx_b] - new_alpha_reserve_b\n",
+ " self.alpha_outstanding[idx_b] = self.alpha_outstanding[idx_b] + alpha_bought_b\n",
+ " self.alpha_reserve[idx_b] = new_alpha_reserve_b\n",
+ " self.tao_reserve[idx_b] = new_tao_reserve_b\n",
+ " self.k[idx_b] = self.tao_reserve[idx_b] * self.alpha_reserve[idx_b]"
+ ]
+ }
+ ],
+ "metadata": {
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/miner/test_miners.py b/miner/test_miners.py
index ab0b7a9d..a742d8ce 100644
--- a/miner/test_miners.py
+++ b/miner/test_miners.py
@@ -9,6 +9,8 @@
for i in range(num_miners):
miner_number = start_num + i
port = base_port + i
+ # export PYTHONPATH='/home/ec2-user/cortex.t/'
+ # export PYTHONPATH='/home/ubuntu/cortex.t/'
# Construct the PM2 start command
# command = f"pm2 start miner.py --interpreter python3 --name {wallet_name}:{miner_number} -- --wallet.name {wallet_name} --wallet.hotkey {miner_number} --subtensor.network finney --netuid 18 --axon.port {port*wallet_name} --logging.debug"
diff --git a/requirements.txt b/requirements.txt
index 2b824226..b7083412 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,8 +9,9 @@ scikit-learn==1.*
torch==2.*
transformers==4.*
wandb
-anthropic
+anthropic==0.19.2
stability-sdk
boto3
anthropic_bedrock
pyOpenSSL
+google-generativeai
\ No newline at end of file
diff --git a/setup.py b/setup.py
index f76ec9b2..8050aca6 100644
--- a/setup.py
+++ b/setup.py
@@ -63,34 +63,24 @@ def read_requirements(path):
version_string = version_match.group(1)
setup(
- name="bittensor_subnet_template", # TODO(developer): Change this value to your module subnet name.
+ name="Cortex.t",
version=version_string,
- description="bittensor_subnet_template", # TODO(developer): Change this value to your module subnet description.
+ description="Decentralized APIs for synthetic data generation",
long_description=long_description,
long_description_content_type="text/markdown",
- url="https://github.com/opentensor/bittensor-subnet-template", # TODO(developer): Change this url to your module subnet github url.
- author="bittensor.com", # TODO(developer): Change this value to your module subnet author name.
+ url="https://github.com/corcel-api/cortex.t",
+ author="Fish",
packages=find_packages(),
include_package_data=True,
- author_email="", # TODO(developer): Change this value to your module subnet author email.
+ author_discord="p383_54249",
license="MIT",
python_requires=">=3.8",
install_requires=requirements,
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
- "Topic :: Software Development :: Build Tools",
- # Pick your license as you wish
"License :: OSI Approved :: MIT License",
- "Programming Language :: Python :: 3 :: Only",
- "Programming Language :: Python :: 3.8",
- "Programming Language :: Python :: 3.9",
- "Programming Language :: Python :: 3.10",
- "Topic :: Scientific/Engineering",
- "Topic :: Scientific/Engineering :: Mathematics",
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Programming Language :: Python :: 3",
"Topic :: Software Development",
- "Topic :: Software Development :: Libraries",
- "Topic :: Software Development :: Libraries :: Python Modules",
],
-)
+)
\ No newline at end of file
diff --git a/start_validator.py b/start_validator.py
index 870b44cc..e30f823e 100644
--- a/start_validator.py
+++ b/start_validator.py
@@ -1,12 +1,12 @@
import argparse
import time
import subprocess
-import template
-from template.utils import get_version, send_discord_alert
+import cortext
+from cortext.utils import get_version, send_discord_alert
default_address = "wss://bittensor-finney.api.onfinality.io/public-ws"
webhook_url = ""
-current_version = template.__version__
+current_version = cortext.__version__
def update_and_restart(pm2_name, wallet_name, wallet_hotkey, address, autoupdate):
global current_version
diff --git a/test_scripts/image.jpg b/test_scripts/image.jpg
new file mode 100644
index 00000000..7f7e53e5
Binary files /dev/null and b/test_scripts/image.jpg differ
diff --git a/test_scripts/images/resize.py b/test_scripts/images/resize.py
new file mode 100644
index 00000000..ec7e4a22
--- /dev/null
+++ b/test_scripts/images/resize.py
@@ -0,0 +1,25 @@
+from PIL import Image
+
+# Path to the image
+image_path = '/Users/colemiller/cortex.t/test_scripts/images/john2.png'
+
+# Desired dimensions
+desired_width = 576
+desired_height = 1024
+
+# Open the image
+with Image.open(image_path) as img:
+ # Get current dimensions
+ current_width, current_height = img.size
+
+ # Check if the image dimensions are already the desired ones
+ if (current_width, current_height) != (desired_width, desired_height):
+ # Resize the image
+ resized_img = img.resize((desired_width, desired_height))
+
+ # Save the resized image, overwriting the original
+ # You can also save to a different path if you don't want to overwrite
+ resized_img.save(image_path)
+ print(f"Image resized to {desired_width}x{desired_height} and saved.")
+ else:
+ print("Image is already at the desired dimensions.")
\ No newline at end of file
diff --git a/test_scripts/test_image_limit.py b/test_scripts/openai/test_image_limit.py
similarity index 100%
rename from test_scripts/test_image_limit.py
rename to test_scripts/openai/test_image_limit.py
diff --git a/test_scripts/output.mp4 b/test_scripts/output.mp4
new file mode 100644
index 00000000..9e3c8974
Binary files /dev/null and b/test_scripts/output.mp4 differ
diff --git a/test_scripts/output_audio.mp3 b/test_scripts/output_audio.mp3
new file mode 100644
index 00000000..f9979756
Binary files /dev/null and b/test_scripts/output_audio.mp3 differ
diff --git a/test_scripts/similarity_scores.txt b/test_scripts/similarity_scores.txt
deleted file mode 100644
index 7537c9e5..00000000
--- a/test_scripts/similarity_scores.txt
+++ /dev/null
@@ -1,170 +0,0 @@
-37|validator | 2023-12-04 07:05:00.024 | DEBUG | similarity for len 424 / 527: 0.8348887551785548, min_similarity is 0.8308
-37|validator | 2023-12-04 07:05:00.026 | DEBUG | similarity for len 349 / 453: 0.8988223301048972, min_similarity is 0.8608
-37|validator | 2023-12-04 07:05:00.026 | DEBUG | similarity for len 381 / 482: 0.801390301909809, min_similarity is 0.848
-37|validator | 2023-12-04 07:05:00.026 | DEBUG | similarity for len 436 / 454: 0.8604189463153747, min_similarity is 0.826
-37|validator | 2023-12-04 07:05:00.026 | DEBUG | similarity for len 377 / 523: 0.8878858165573009, min_similarity is 0.8496
-37|validator | 2023-12-04 07:05:00.027 | DEBUG | similarity for len 252 / 489: 0.8791063013974827, min_similarity is 0.8996
-37|validator | 2023-12-04 07:05:00.027 | DEBUG | similarity for len 369 / 394: 0.8644138943747834, min_similarity is 0.8528
-37|validator | 2023-12-04 07:05:00.027 | DEBUG | similarity for len 441 / 491: 0.897172288189965, min_similarity is 0.824
-37|validator | 2023-12-04 07:05:00.027 | DEBUG | similarity for len 384 / 515: 0.947877042606638, min_similarity is 0.8468
-37|validator | 2023-12-04 07:05:00.027 | DEBUG | similarity for len 402 / 391: 0.8195583902377683, min_similarity is 0.8396
-37|validator | 2023-12-04 07:05:00.027 | DEBUG | similarity for len 416 / 494: 0.8739556114132238, min_similarity is 0.834
-37|validator | 2023-12-04 07:05:00.028 | DEBUG | similarity for len 406 / 448: 0.8599688014831323, min_similarity is 0.838
-37|validator | 2023-12-04 07:05:00.028 | DEBUG | similarity for len 156 / 133: 0.9782055368400787, min_similarity is 0.938
-37|validator | 2023-12-04 07:05:00.028 | DEBUG | similarity for len 271 / 402: 0.8643106572933723, min_similarity is 0.892
-37|validator | 2023-12-04 07:05:00.028 | DEBUG | similarity for len 348 / 448: 0.8813479252905758, min_similarity is 0.8612
-37|validator | 2023-12-04 07:05:00.028 | DEBUG | similarity for len 271 / 206: 0.8640870470087569, min_similarity is 0.892
-37|validator | 2023-12-04 07:05:00.028 | DEBUG | similarity for len 151 / 151: 0.8029687758552327, min_similarity is 0.94
-37|validator | 2023-12-04 07:05:00.029 | DEBUG | similarity for len 146 / 223: 0.8430573460866714, min_similarity is 0.942
-37|validator | 2023-12-04 07:05:00.029 | DEBUG | similarity for len 273 / 311: 0.9149266062377501, min_similarity is 0.8912
-37|validator | 2023-12-04 07:05:00.029 | DEBUG | similarity for len 100 / 94: 0.9789637702085429, min_similarity is 0.9604
-37|validator | 2023-12-04 07:05:00.029 | DEBUG | similarity for len 329 / 364: 0.7864026390923641, min_similarity is 0.8688
-37|validator | 2023-12-04 07:05:00.029 | DEBUG | similarity for len 250 / 302: 0.8915839553933722, min_similarity is 0.9004
-37|validator | 2023-12-04 07:05:00.029 | DEBUG | similarity for len 220 / 187: 0.9664028210234707, min_similarity is 0.9124
-37|validator | 2023-12-04 07:05:00.030 | DEBUG | similarity for len 408 / 568: 0.9209515407747406, min_similarity is 0.8371999999999999
-37|validator | 2023-12-04 07:05:00.030 | DEBUG | similarity for len 420 / 573: 0.9171097096543241, min_similarity is 0.8324
-37|validator | 2023-12-04 07:05:00.030 | DEBUG | similarity for len 444 / 491: 0.8504123361976201, min_similarity is 0.8228
-37|validator | 2023-12-04 07:05:00.030 | DEBUG | similarity for len 341 / 380: 0.8492254259920571, min_similarity is 0.864
-37|validator | 2023-12-04 07:05:00.032 | DEBUG | similarity for len 221 / 244: 0.9015911842280643, min_similarity is 0.912
-37|validator | 2023-12-04 07:05:00.032 | DEBUG | similarity for len 299 / 266: 0.7017688544463613, min_similarity is 0.8808
-37|validator | 2023-12-04 07:05:00.033 | DEBUG | similarity for len 428 / 509: 0.8767644491010379, min_similarity is 0.8291999999999999
-37|validator | 2023-12-04 07:05:00.034 | DEBUG | similarity for len 387 / 618: 0.9002587068045835, min_similarity is 0.8456
-37|validator | 2023-12-04 07:05:00.034 | DEBUG | similarity for len 113 / 100: 0.9598220672070237, min_similarity is 0.9552
-37|validator | 2023-12-04 07:05:00.036 | DEBUG | similarity for len 254 / 510: 0.8038721586143719, min_similarity is 0.8988
-37|validator | 2023-12-04 07:05:00.037 | DEBUG | similarity for len 393 / 383: 0.9702856192260609, min_similarity is 0.8432
-[34m2023-12-04 07:07:07.197[0m | [34m[1m DEBUG [0m | similarity for len 133 / 142: 0.9456450792631665, min_similarity is 0.9472
-[34m2023-12-04 07:07:07.198[0m | [34m[1m DEBUG [0m | similarity for len 46 / 46: 1.0000000000000002, min_similarity is 0.982
-[34m2023-12-04 07:07:07.198[0m | [34m[1m DEBUG [0m | similarity for len 300 / 635: 0.8166930701610383, min_similarity is 0.8804
-[34m2023-12-04 07:07:07.198[0m | [34m[1m DEBUG [0m | similarity for len 187 / 162: 0.8873059067310434, min_similarity is 0.9256
-[34m2023-12-04 07:07:07.198[0m | [34m[1m DEBUG [0m | similarity for len 326 / 547: 0.855668079852997, min_similarity is 0.87
-[34m2023-12-04 07:07:07.199[0m | [34m[1m DEBUG [0m | similarity for len 347 / 563: 0.9062882267596397, min_similarity is 0.8616
-[34m2023-12-04 07:07:07.199[0m | [34m[1m DEBUG [0m | similarity for len 412 / 488: 0.8964386942213084, min_similarity is 0.8356
-[34m2023-12-04 07:07:07.199[0m | [34m[1m DEBUG [0m | similarity for len 433 / 553: 0.8988548994310929, min_similarity is 0.8271999999999999
-[34m2023-12-04 07:07:07.199[0m | [34m[1m DEBUG [0m | similarity for len 357 / 597: 0.7941443824119507, min_similarity is 0.8576
-[34m2023-12-04 07:07:07.199[0m | [34m[1m DEBUG [0m | similarity for len 261 / 446: 0.8402605740901087, min_similarity is 0.896
-[34m2023-12-04 07:07:07.199[0m | [34m[1m DEBUG [0m | similarity for len 441 / 547: 0.8752214419470239, min_similarity is 0.824
-[34m2023-12-04 07:07:07.200[0m | [34m[1m DEBUG [0m | similarity for len 432 / 571: 0.8546276203456687, min_similarity is 0.8276
-[34m2023-12-04 07:07:07.200[0m | [34m[1m DEBUG [0m | similarity for len 312 / 485: 0.8424351056990441, min_similarity is 0.8755999999999999
-[34m2023-12-04 07:07:07.200[0m | [34m[1m DEBUG [0m | similarity for len 425 / 650: 0.7499217218108999, min_similarity is 0.8304
-[34m2023-12-04 07:07:07.200[0m | [34m[1m DEBUG [0m | similarity for len 442 / 441: 0.8707375994286568, min_similarity is 0.8236
-[34m2023-12-04 07:07:07.200[0m | [34m[1m DEBUG [0m | similarity for len 392 / 558: 0.8790639493132009, min_similarity is 0.8436
-[34m2023-12-04 07:07:07.201[0m | [34m[1m DEBUG [0m | similarity for len 363 / 454: 0.8784721848906064, min_similarity is 0.8552
-[34m2023-12-04 07:07:07.201[0m | [34m[1m DEBUG [0m | similarity for len 472 / 505: 0.8731979224929621, min_similarity is 0.8116
-[34m2023-12-04 07:07:07.201[0m | [34m[1m DEBUG [0m | similarity for len 472 / 477: 0.8775053046256505, min_similarity is 0.8116
-[34m2023-12-04 07:07:07.222[0m | [34m[1m DEBUG [0m | similarity for len 347 / 443: 0.8970733668078488, min_similarity is 0.8616
-[34m2023-12-04 07:07:07.223[0m | [34m[1m DEBUG [0m | similarity for len 421 / 569: 0.8741328768524377, min_similarity is 0.832
-[34m2023-12-04 07:07:07.223[0m | [34m[1m DEBUG [0m | similarity for len 331 / 575: 0.8942829692678564, min_similarity is 0.868
-[34m2023-12-04 07:07:07.223[0m | [34m[1m DEBUG [0m | similarity for len 332 / 524: 0.7927488139203644, min_similarity is 0.8675999999999999
-[34m2023-12-04 07:07:07.223[0m | [34m[1m DEBUG [0m | similarity for len 414 / 411: 0.9157027527883319, min_similarity is 0.8348
-[34m2023-12-04 07:07:07.224[0m | [34m[1m DEBUG [0m | similarity for len 439 / 519: 0.878402419239963, min_similarity is 0.8248
-[34m2023-12-04 07:07:07.224[0m | [34m[1m DEBUG [0m | similarity for len 340 / 391: 0.8589968551471235, min_similarity is 0.8644000000000001
-[34m2023-12-04 07:07:07.224[0m | [34m[1m DEBUG [0m | similarity for len 359 / 438: 0.928736303258002, min_similarity is 0.8568
-[34m2023-12-04 07:07:07.224[0m | [34m[1m DEBUG [0m | similarity for len 370 / 516: 0.9035310064291979, min_similarity is 0.8524
-[34m2023-12-04 07:07:07.224[0m | [34m[1m DEBUG [0m | similarity for len 443 / 453: 0.8755516521459867, min_similarity is 0.8231999999999999
-[34m2023-12-04 07:07:07.225[0m | [34m[1m DEBUG [0m | similarity for len 272 / 613: 0.8830747188134481, min_similarity is 0.8916
-[34m2023-12-04 07:07:07.225[0m | [34m[1m DEBUG [0m | similarity for len 352 / 523: 0.8629848653330461, min_similarity is 0.8596
-[34m2023-12-04 07:07:07.225[0m | [34m[1m DEBUG [0m | similarity for len 433 / 465: 0.8848320734912007, min_similarity is 0.8271999999999999
-[34m2023-12-04 07:07:07.225[0m | [34m[1m DEBUG [0m | similarity for len 265 / 571: 0.8316107764982431, min_similarity is 0.8944
-[34m2023-12-04 07:07:07.226[0m | [34m[1m DEBUG [0m | similarity for len 430 / 518: 0.8822189904242547, min_similarity is 0.8284
-[34m2023-12-04 07:09:00.540[0m | [34m[1m DEBUG [0m | similarity for len 408 / 516: 0.9026657497245366, min_similarity is 0.8371999999999999
-[34m2023-12-04 07:09:00.542[0m | [34m[1m DEBUG [0m | similarity for len 395 / 542: 0.8902425085628194, min_similarity is 0.8424
-[34m2023-12-04 07:09:00.542[0m | [34m[1m DEBUG [0m | similarity for len 382 / 479: 0.9252559539143139, min_similarity is 0.8476
-[34m2023-12-04 07:09:00.543[0m | [34m[1m DEBUG [0m | similarity for len 264 / 241: 0.9183828782915812, min_similarity is 0.8948
-[34m2023-12-04 07:09:00.543[0m | [34m[1m DEBUG [0m | similarity for len 364 / 311: 0.8604862588500631, min_similarity is 0.8548
-[34m2023-12-04 07:09:00.552[0m | [34m[1m DEBUG [0m | similarity for len 372 / 487: 0.9001760829142862, min_similarity is 0.8516
-[34m2023-12-04 07:09:00.552[0m | [34m[1m DEBUG [0m | similarity for len 419 / 563: 0.9357056864178165, min_similarity is 0.8328
-[34m2023-12-04 07:09:00.563[0m | [34m[1m DEBUG [0m | similarity for len 370 / 441: 0.9190552797911058, min_similarity is 0.8524
-[34m2023-12-04 07:09:00.564[0m | [34m[1m DEBUG [0m | similarity for len 334 / 404: 0.8860575381708318, min_similarity is 0.8668
-[34m2023-12-04 07:09:00.574[0m | [34m[1m DEBUG [0m | similarity for len 416 / 526: 0.9120951729768038, min_similarity is 0.834
-[34m2023-12-04 07:09:00.579[0m | [34m[1m DEBUG [0m | similarity for len 389 / 455: 0.9130805640189076, min_similarity is 0.8448
-[34m2023-12-04 07:09:00.586[0m | [34m[1m DEBUG [0m | similarity for len 397 / 381: 0.8722625448850428, min_similarity is 0.8416
-[34m2023-12-04 07:09:00.586[0m | [34m[1m DEBUG [0m | similarity for len 402 / 482: 0.8195858959432366, min_similarity is 0.8396
-[34m2023-12-04 07:09:00.586[0m | [34m[1m DEBUG [0m | similarity for len 411 / 414: 0.8044381952958233, min_similarity is 0.836
-[34m2023-12-04 07:09:00.616[0m | [34m[1m DEBUG [0m | similarity for len 385 / 574: 0.898494766617453, min_similarity is 0.8464
-[34m2023-12-04 07:09:00.633[0m | [34m[1m DEBUG [0m | similarity for len 425 / 423: 0.9057496072894863, min_similarity is 0.8304
-[34m2023-12-04 07:09:00.635[0m | [34m[1m DEBUG [0m | similarity for len 326 / 516: 0.7851048241445269, min_similarity is 0.87
-[34m2023-12-04 07:09:00.636[0m | [34m[1m DEBUG [0m | similarity for len 431 / 532: 0.910434224618457, min_similarity is 0.828
-[34m2023-12-04 07:09:00.636[0m | [34m[1m DEBUG [0m | similarity for len 382 / 394: 0.8936932198473263, min_similarity is 0.8476
-[34m2023-12-04 07:09:00.636[0m | [34m[1m DEBUG [0m | similarity for len 315 / 244: 0.8105005052271032, min_similarity is 0.8744
-[34m2023-12-04 07:09:00.636[0m | [34m[1m DEBUG [0m | similarity for len 501 / 514: 0.8683480318281338, min_similarity is 0.8
-[34m2023-12-04 07:09:00.636[0m | [34m[1m DEBUG [0m | similarity for len 406 / 449: 0.8560873361725055, min_similarity is 0.838
-[34m2023-12-04 07:09:00.636[0m | [34m[1m DEBUG [0m | similarity for len 391 / 464: 0.8524293311905217, min_similarity is 0.844
-[34m2023-12-04 07:09:00.637[0m | [34m[1m DEBUG [0m | similarity for len 455 / 447: 0.9160848093802318, min_similarity is 0.8184
-[34m2023-12-04 07:09:00.637[0m | [34m[1m DEBUG [0m | similarity for len 430 / 437: 0.8889208884965547, min_similarity is 0.8284
-[34m2023-12-04 07:09:00.637[0m | [34m[1m DEBUG [0m | similarity for len 454 / 534: 0.9014348805568247, min_similarity is 0.8188
-[34m2023-12-04 07:09:00.637[0m | [34m[1m DEBUG [0m | similarity for len 450 / 522: 0.8456626292049759, min_similarity is 0.8204
-[34m2023-12-04 07:09:00.637[0m | [34m[1m DEBUG [0m | similarity for len 519 / 568: 0.9302128424838387, min_similarity is 0.7928
-[34m2023-12-04 07:09:00.638[0m | [34m[1m DEBUG [0m | similarity for len 428 / 465: 0.9258163314589227, min_similarity is 0.8291999999999999
-[34m2023-12-04 07:09:00.640[0m | [34m[1m DEBUG [0m | similarity for len 412 / 634: 0.8655592290845852, min_similarity is 0.8356
-[34m2023-12-04 07:09:00.641[0m | [34m[1m DEBUG [0m | similarity for len 425 / 492: 0.9027414249184775, min_similarity is 0.8304
-[34m2023-12-04 07:09:00.641[0m | [34m[1m DEBUG [0m | similarity for len 413 / 552: 0.8576624690320486, min_similarity is 0.8351999999999999
-[34m2023-12-04 07:09:00.641[0m | [34m[1m DEBUG [0m | similarity for len 353 / 462: 0.837044839027798, min_similarity is 0.8592
-[34m2023-12-04 07:09:00.641[0m | [34m[1m DEBUG [0m | similarity for len 353 / 370: 0.8511992622056048, min_similarity is 0.8592
-[34m2023-12-04 07:12:11.803[0m | [34m[1m DEBUG [0m | similarity for len 327 / 492: 0.8210902218838608, min_similarity is 0.8695999999999999
-[34m2023-12-04 07:12:11.814[0m | [34m[1m DEBUG [0m | similarity for len 117 / 119: 1.0, min_similarity is 0.9536
-[34m2023-12-04 07:12:11.830[0m | [34m[1m DEBUG [0m | similarity for len 254 / 314: 0.9198537786612923, min_similarity is 0.8988
-[34m2023-12-04 07:12:11.830[0m | [34m[1m DEBUG [0m | similarity for len 235 / 258: 0.9161452262707519, min_similarity is 0.9064
-[34m2023-12-04 07:12:11.831[0m | [34m[1m DEBUG [0m | similarity for len 162 / 161: 0.8846683830398849, min_similarity is 0.9356
-[34m2023-12-04 07:12:11.831[0m | [34m[1m DEBUG [0m | similarity for len 140 / 134: 0.9862248335593387, min_similarity is 0.9444
-[34m2023-12-04 07:12:11.831[0m | [34m[1m DEBUG [0m | similarity for len 102 / 102: 1.0, min_similarity is 0.9596
-[34m2023-12-04 07:12:11.831[0m | [34m[1m DEBUG [0m | similarity for len 102 / 100: 0.9465435974117447, min_similarity is 0.9596
-[34m2023-12-04 07:12:11.838[0m | [34m[1m DEBUG [0m | similarity for len 93 / 72: 0.7879272824544888, min_similarity is 0.9632000000000001
-[34m2023-12-04 07:12:11.851[0m | [34m[1m DEBUG [0m | similarity for len 325 / 437: 0.8499581012576504, min_similarity is 0.8704000000000001
-[34m2023-12-04 07:12:11.851[0m | [34m[1m DEBUG [0m | similarity for len 97 / 97: 1.0, min_similarity is 0.9616
-[34m2023-12-04 07:12:11.851[0m | [34m[1m DEBUG [0m | similarity for len 65 / 65: 1.0, min_similarity is 0.9744
-[34m2023-12-04 07:12:11.851[0m | [34m[1m DEBUG [0m | similarity for len 82 / 82: 1.0000000000000002, min_similarity is 0.9676
-[34m2023-12-04 07:12:11.852[0m | [34m[1m DEBUG [0m | similarity for len 412 / 452: 0.9117337654308568, min_similarity is 0.8356
-[34m2023-12-04 07:12:11.852[0m | [34m[1m DEBUG [0m | similarity for len 318 / 537: 0.8970702505323582, min_similarity is 0.8732
-[34m2023-12-04 07:12:11.852[0m | [34m[1m DEBUG [0m | similarity for len 192 / 170: 0.9498894295280911, min_similarity is 0.9236
-[34m2023-12-04 07:12:11.910[0m | [34m[1m DEBUG [0m | similarity for len 242 / 330: 0.9531040319665327, min_similarity is 0.9036
-[34m2023-12-04 07:12:11.910[0m | [34m[1m DEBUG [0m | similarity for len 191 / 169: 0.9445400906349057, min_similarity is 0.924
-[34m2023-12-04 07:12:11.910[0m | [34m[1m DEBUG [0m | similarity for len 271 / 319: 0.8852580014918673, min_similarity is 0.892
-[34m2023-12-04 07:12:11.911[0m | [34m[1m DEBUG [0m | similarity for len 396 / 340: 0.7993440751149328, min_similarity is 0.842
-[34m2023-12-04 07:12:11.911[0m | [34m[1m DEBUG [0m | similarity for len 258 / 479: 0.7868447558105361, min_similarity is 0.8972
-[34m2023-12-04 07:12:11.911[0m | [34m[1m DEBUG [0m | similarity for len 120 / 131: 0.9844267470478563, min_similarity is 0.9524
-[34m2023-12-04 07:12:11.911[0m | [34m[1m DEBUG [0m | similarity for len 265 / 239: 0.8354785885521149, min_similarity is 0.8944
-[34m2023-12-04 07:12:11.912[0m | [34m[1m DEBUG [0m | similarity for len 305 / 306: 0.9041736092002, min_similarity is 0.8784
-[34m2023-12-04 07:12:11.912[0m | [34m[1m DEBUG [0m | similarity for len 416 / 523: 0.9017647819845214, min_similarity is 0.834
-[34m2023-12-04 07:12:11.912[0m | [34m[1m DEBUG [0m | similarity for len 301 / 493: 0.858746088711172, min_similarity is 0.88
-[34m2023-12-04 07:12:11.912[0m | [34m[1m DEBUG [0m | similarity for len 410 / 404: 0.8339476046343121, min_similarity is 0.8364
-[34m2023-12-04 07:12:11.920[0m | [34m[1m DEBUG [0m | similarity for len 323 / 524: 0.8082541185763535, min_similarity is 0.8712
-[34m2023-12-04 07:12:11.921[0m | [34m[1m DEBUG [0m | similarity for len 428 / 531: 0.909743085823103, min_similarity is 0.8291999999999999
-[34m2023-12-04 07:12:11.922[0m | [34m[1m DEBUG [0m | similarity for len 221 / 301: 0.8477461400364986, min_similarity is 0.912
-[34m2023-12-04 07:12:11.923[0m | [34m[1m DEBUG [0m | similarity for len 222 / 280: 0.8015209129673697, min_similarity is 0.9116
-[34m2023-12-04 07:12:11.923[0m | [34m[1m DEBUG [0m | similarity for len 303 / 447: 0.8555695775109581, min_similarity is 0.8792
-[34m2023-12-04 07:12:11.924[0m | [34m[1m DEBUG [0m | similarity for len 114 / 134: 0.9682744392443813, min_similarity is 0.9548
-[34m2023-12-04 07:12:11.925[0m | [34m[1m DEBUG [0m | similarity for len 332 / 471: 0.896784614051202, min_similarity is 0.8675999999999999
-[34m2023-12-04 07:14:05.141[0m | [34m[1m DEBUG [0m | similarity for len 118 / 107: 0.9721084698136468, min_similarity is 0.9532
-[34m2023-12-04 07:14:05.155[0m | [34m[1m DEBUG [0m | similarity for len 341 / 289: 0.8395419124000327, min_similarity is 0.864
-[34m2023-12-04 07:14:05.161[0m | [34m[1m DEBUG [0m | similarity for len 414 / 528: 0.8738201379537269, min_similarity is 0.8348
-[34m2023-12-04 07:14:05.161[0m | [34m[1m DEBUG [0m | similarity for len 383 / 352: 0.9535246700629569, min_similarity is 0.8472
-[34m2023-12-04 07:14:05.166[0m | [34m[1m DEBUG [0m | similarity for len 454 / 420: 0.8766894534872591, min_similarity is 0.8188
-[34m2023-12-04 07:14:05.170[0m | [34m[1m DEBUG [0m | similarity for len 400 / 527: 0.9007529299862785, min_similarity is 0.8404
-[34m2023-12-04 07:14:05.176[0m | [34m[1m DEBUG [0m | similarity for len 356 / 357: 0.929817849360712, min_similarity is 0.858
-[34m2023-12-04 07:14:05.196[0m | [34m[1m DEBUG [0m | similarity for len 435 / 489: 0.8665380247636415, min_similarity is 0.8264
-[34m2023-12-04 07:14:05.196[0m | [34m[1m DEBUG [0m | similarity for len 386 / 610: 0.8972524833774908, min_similarity is 0.846
-[34m2023-12-04 07:14:05.197[0m | [34m[1m DEBUG [0m | similarity for len 367 / 334: 0.8875614907516767, min_similarity is 0.8536
-[34m2023-12-04 07:14:05.197[0m | [34m[1m DEBUG [0m | similarity for len 296 / 296: 0.9806761558585092, min_similarity is 0.882
-[34m2023-12-04 07:14:05.197[0m | [34m[1m DEBUG [0m | similarity for len 475 / 551: 0.9515211461349795, min_similarity is 0.8104
-[34m2023-12-04 07:14:05.197[0m | [34m[1m DEBUG [0m | similarity for len 190 / 235: 0.8805907438293246, min_similarity is 0.9244
-[34m2023-12-04 07:14:05.197[0m | [34m[1m DEBUG [0m | similarity for len 483 / 506: 0.9246491166635572, min_similarity is 0.8072
-[34m2023-12-04 07:14:05.232[0m | [34m[1m DEBUG [0m | similarity for len 409 / 437: 0.9379699742663268, min_similarity is 0.8368
-[34m2023-12-04 07:14:05.233[0m | [34m[1m DEBUG [0m | similarity for len 520 / 533: 0.8277056324765645, min_similarity is 0.7924
-[34m2023-12-04 07:14:05.233[0m | [34m[1m DEBUG [0m | similarity for len 377 / 391: 0.9341598524903717, min_similarity is 0.8496
-[34m2023-12-04 07:14:05.233[0m | [34m[1m DEBUG [0m | similarity for len 379 / 345: 0.9763621016490883, min_similarity is 0.8488
-[34m2023-12-04 07:14:05.234[0m | [34m[1m DEBUG [0m | similarity for len 443 / 407: 0.7494222769483958, min_similarity is 0.8231999999999999
-[34m2023-12-04 07:14:05.234[0m | [34m[1m DEBUG [0m | similarity for len 533 / 532: 0.8341223699059029, min_similarity is 0.7872
-[34m2023-12-04 07:14:05.234[0m | [34m[1m DEBUG [0m | similarity for len 489 / 566: 0.8660415276505383, min_similarity is 0.8048
-[34m2023-12-04 07:14:05.234[0m | [34m[1m DEBUG [0m | similarity for len 282 / 263: 0.926473134868225, min_similarity is 0.8876
-[34m2023-12-04 07:14:05.234[0m | [34m[1m DEBUG [0m | similarity for len 411 / 559: 0.7351401953252933, min_similarity is 0.836
-[34m2023-12-04 07:14:05.235[0m | [34m[1m DEBUG [0m | similarity for len 410 / 540: 0.847495634183809, min_similarity is 0.8364
-[34m2023-12-04 07:14:05.235[0m | [34m[1m DEBUG [0m | similarity for len 413 / 445: 0.7881123716220488, min_similarity is 0.8351999999999999
-[34m2023-12-04 07:14:05.242[0m | [34m[1m DEBUG [0m | similarity for len 415 / 552: 0.8904059110087488, min_similarity is 0.8344
-[34m2023-12-04 07:14:05.242[0m | [34m[1m DEBUG [0m | similarity for len 334 / 408: 0.933066719422742, min_similarity is 0.8668
-[34m2023-12-04 07:14:05.243[0m | [34m[1m DEBUG [0m | similarity for len 404 / 347: 0.9044975424230605, min_similarity is 0.8388
-[34m2023-12-04 07:14:05.243[0m | [34m[1m DEBUG [0m | similarity for len 408 / 476: 0.976520948059826, min_similarity is 0.8371999999999999
-[34m2023-12-04 07:14:05.248[0m | [34m[1m DEBUG [0m | similarity for len 477 / 591: 0.918575764385639, min_similarity is 0.8096
-[34m2023-12-04 07:14:05.248[0m | [34m[1m DEBUG [0m | similarity for len 457 / 509: 0.9126626035218209, min_similarity is 0.8176
-[34m2023-12-04 07:14:05.249[0m | [34m[1m DEBUG [0m | similarity for len 564 / 502: 0.8928567370112782, min_similarity is 0.7747999999999999
-[34m2023-12-04 07:14:05.249[0m | [34m[1m DEBUG [0m | similarity for len 470 / 478: 0.904943956671997, min_similarity is 0.8124
-[34m2023-12-04 07:14:05.249[0m | [34m[1m DEBUG [0m | similarity for len 458 / 542: 0.8656841578253535, min_similarity is 0.8171999999999999
\ No newline at end of file
diff --git a/test_scripts/compare_stability.py b/test_scripts/stability/compare_stability.py
similarity index 100%
rename from test_scripts/compare_stability.py
rename to test_scripts/stability/compare_stability.py
diff --git a/test_scripts/stability/get_balance.py b/test_scripts/stability/get_balance.py
new file mode 100644
index 00000000..96bfbad2
--- /dev/null
+++ b/test_scripts/stability/get_balance.py
@@ -0,0 +1,19 @@
+import os
+import requests
+
+api_host = os.getenv('API_HOST', 'https://api.stability.ai')
+url = f"{api_host}/v1/user/balance"
+
+api_key = os.getenv("STABILITY_API_KEY")
+if api_key is None:
+ raise Exception("Missing Stability API key.")
+
+response = requests.get(url, headers={
+ "Authorization": f"Bearer {api_key}"
+})
+
+if response.status_code != 200:
+ raise Exception("Non-200 response: " + str(response.text))
+
+payload = response.json()
+print(payload)
diff --git a/test_scripts/stability/get_engines.py b/test_scripts/stability/get_engines.py
new file mode 100644
index 00000000..74797baa
--- /dev/null
+++ b/test_scripts/stability/get_engines.py
@@ -0,0 +1,25 @@
+import os
+import requests
+
+# image model list
+# ['esrgan-v1-x2plus', 'stable-diffusion-xl-1024-v0-9', 'stable-diffusion-xl-1024-v1-0', 'stable-diffusion-v1-6', 'stable-diffusion-512-v2-1', 'stable-diffusion-xl-beta-v2-2-2']
+
+api_host = os.getenv('API_HOST', 'https://api.stability.ai')
+url = f"{api_host}/v1/engines/list"
+
+api_key = os.getenv("STABILITY_API_KEY")
+if api_key is None:
+ raise Exception("Missing Stability API key.")
+
+response = requests.get(url, headers={
+ "Authorization": f"Bearer {api_key}"
+})
+
+if response.status_code != 200:
+ raise Exception("Non-200 response: " + str(response.text))
+
+# Do something with the payload...
+payload = response.json()
+print(payload)
+engine_ids = [engine['id'] for engine in payload]
+print(engine_ids)
\ No newline at end of file
diff --git a/test_scripts/test_embeddings.py b/test_scripts/t2e/test_embeddings.py
similarity index 100%
rename from test_scripts/test_embeddings.py
rename to test_scripts/t2e/test_embeddings.py
diff --git a/test_scripts/t2i/...i2i/upscale.py b/test_scripts/t2i/...i2i/upscale.py
new file mode 100644
index 00000000..6100d02b
--- /dev/null
+++ b/test_scripts/t2i/...i2i/upscale.py
@@ -0,0 +1,16 @@
+
+
+# https://platform.stability.ai/docs/features/image-upscaling#Python
+# engine = esrgan-v1-x2plus
+
+# input image limit: 1024 x 1024
+# output image limit: 2048 x 2048
+
+# example:
+
+# img = Image.open('/img2upscale.png')
+
+# answers = stability_api.upscale(
+# init_image=img, # Pass our image to the API and call the upscaling process.
+# # width=1024, # Optional parameter to specify the desired output width.
+# )
\ No newline at end of file
diff --git a/test_scripts/test_images.py b/test_scripts/t2i/dalle3.py
similarity index 100%
rename from test_scripts/test_images.py
rename to test_scripts/t2i/dalle3.py
diff --git a/test_scripts/t2i/download_image.py b/test_scripts/t2i/download_image.py
new file mode 100644
index 00000000..87259b73
--- /dev/null
+++ b/test_scripts/t2i/download_image.py
@@ -0,0 +1,17 @@
+import requests
+
+# URL of the image
+url = 'https://file.io/a9znfUZ9uKvR'
+
+# Send a GET request to the URL
+response = requests.get(url)
+
+# Check if the request was successful
+if response.status_code == 200:
+ # Open a file in binary write mode
+ with open('image.jpg', 'wb') as file:
+ # Write the content of the response to the file
+ file.write(response.content)
+ print("Image downloaded and saved as image.jpg")
+else:
+ print("Failed to download the image. Status code:", response.status_code)
\ No newline at end of file
diff --git a/test_scripts/t2i/image0.jpg b/test_scripts/t2i/image0.jpg
new file mode 100644
index 00000000..058955dd
Binary files /dev/null and b/test_scripts/t2i/image0.jpg differ
diff --git a/test_scripts/t2i/image1.jpg b/test_scripts/t2i/image1.jpg
new file mode 100644
index 00000000..1248a37b
Binary files /dev/null and b/test_scripts/t2i/image1.jpg differ
diff --git a/test_scripts/t2i/image2.jpg b/test_scripts/t2i/image2.jpg
new file mode 100644
index 00000000..a986318f
Binary files /dev/null and b/test_scripts/t2i/image2.jpg differ
diff --git a/test_scripts/t2i/image3.jpg b/test_scripts/t2i/image3.jpg
new file mode 100644
index 00000000..5fc6ce68
Binary files /dev/null and b/test_scripts/t2i/image3.jpg differ
diff --git a/test_scripts/t2i/image4.jpg b/test_scripts/t2i/image4.jpg
new file mode 100644
index 00000000..dbdd12e9
Binary files /dev/null and b/test_scripts/t2i/image4.jpg differ
diff --git a/test_scripts/t2i/stability.py b/test_scripts/t2i/stability.py
new file mode 100644
index 00000000..086a1b9c
--- /dev/null
+++ b/test_scripts/t2i/stability.py
@@ -0,0 +1,162 @@
+import os
+import io
+import requests
+from PIL import Image
+from stability_sdk import client
+import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
+
+engine_id = "stable-diffusion-v1-6"
+api_host = os.getenv('API_HOST', 'https://api.stability.ai')
+api_key = os.getenv("STABILITY_API_KEY")
+
+
+# Set up the API connection
+stability_api = client.StabilityInference(
+ key=os.environ['STABILITY_API_KEY'],
+ verbose=True,
+ engine="stable-diffusion-xl-1024-v1-0"
+)
+
+# https://platform.stability.ai/docs/api-reference#tag/v1generation/operation/textToImage
+
+
+# engine_ids
+# ['esrgan-v1-x2plus', 'stable-diffusion-xl-1024-v0-9', 'stable-diffusion-xl-1024-v1-0', 'stable-diffusion-v1-6', 'stable-diffusion-512-v2-1', 'stable-diffusion-xl-beta-v2-2-2']
+
+# height
+# multiple of 64 >= 128
+# default: 512 or 1024 depending on the model
+
+# width
+# multiple of 64 >= 128
+# default: 512 or 1024 depending on the model
+
+# Engine-specific dimension validation:
+# SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
+# SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
+# SDXL v1.0: same as SDXL v0.9
+# SD v1.6: must be between 320x320 and 1536x1536
+
+# text_prompts
+# a list of prompts to use for generation
+
+# weight
+# positive to prompt the image, negative to prompt the image away from the text
+# total possible range is [-10, 10] but we recommend staying within the range of [-2, 2].
+# example:
+# prompt= [generation.Prompt(text="beautiful night sky above japanese town, anime style",parameters=generation.PromptParameters(weight=1)),
+ # generation.Prompt(text="clouds",parameters=generation.PromptParameters(weight=-1))],
+# this will not have clouds in the output
+
+# cfg_scale
+# [ 0 .. 35 ]
+# default = 7
+# description: How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
+
+# clip_guidance_preset
+# FAST_BLUE FAST_GREEN NONE SIMPLE SLOW SLOWER SLOWEST
+# default = NONE
+
+
+# seed
+# [ 0 .. 4294967295 ]
+# default = 0
+# If a seed is provided, the resulting generated image will be deterministic.
+
+
+# samplers
+# ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_dpmpp_2s_ancestral, k_lms, k_dpmpp_2m, k_dpmpp_sde
+# default = will auto pick an appropriate one
+
+# samples
+# [ 1 .. 10 ]
+# default = 1
+# number of images to generate
+
+
+# steps
+# [ 10 .. 50 ]
+# default = 30
+# Number of diffusion steps to run.
+
+# style_preset
+# 3d-model analog-film anime cinematic comic-book digital-art enhance fantasy-art isometric line-art low-poly modeling-compound neon-punk origami photographic pixel-art tile-texture
+# Pass in a style preset to guide the image model towards a particular style.
+
+
+# # Generate the image
+# answers = stability_api.generate(
+# prompt="expansive landscape rolling greens with gargantuan yggdrasil, intricate world-spanning roots towering under a blue alien sky, masterful, ghibli",
+# seed=9283409,
+# steps=30,
+# cfg_scale=8.0,
+# width=1024,
+# height=1024,
+# samples=1,
+# sampler=generation.SAMPLER_K_DPMPP_2M
+# )
+
+# # Process and upload the image
+# for resp in answers:
+# for artifact in resp.artifacts:
+# if artifact.finish_reason == generation.FILTER:
+# print("Safety filters activated, prompt could not be processed.")
+# elif artifact.type == generation.ARTIFACT_IMAGE:
+# img = Image.open(io.BytesIO(artifact.binary))
+# img_buffer = io.BytesIO()
+# img.save(img_buffer, format="PNG")
+# img_buffer.seek(0)
+
+# # Upload to file.io
+# response = requests.post('https://file.io', files={'file': img_buffer})
+# if response.status_code == 200:
+# print(f"Image uploaded successfully. URL: {response.json()['link']}")
+# else:
+# print("Failed to upload the image.")
+
+prompt1 = 'royal king with all his riches'
+weight1 = 1.5
+prompt2 = 'a quick brown fox jumped over the lazy dog'
+weight2 = 1
+prompt3 = "crown"
+weight3 = -2
+
+seed = 0
+steps = 5
+cfg_scale = 17
+width = 1024
+height = 1024
+samples = 2
+
+meta = stability_api.generate(
+ prompt=
+ [generation.Prompt(text=prompt1,parameters=generation.PromptParameters(weight=weight1))
+ ,generation.Prompt(text=prompt2,parameters=generation.PromptParameters(weight=weight2))
+ ,generation.Prompt(text=prompt3,parameters=generation.PromptParameters(weight=weight3))
+ ],
+ seed=seed,
+ steps=steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ samples=samples,
+)
+
+# Process and upload the image
+import matplotlib.pyplot as plt
+import numpy as np
+
+image_files = []
+for i, image in enumerate(meta):
+ for artifact in image.artifacts:
+ img_array = np.frombuffer(artifact.binary, dtype=np.uint8)
+ img_filename = f"image{i}.jpg"
+ with open(img_filename, 'wb') as img_file:
+ img_file.write(img_array)
+ image_files.append(img_filename)
+
+for img_filename in image_files:
+ img = plt.imread(img_filename)
+ plt.imshow(img)
+ plt.axis('off') # Do not show axis to mimic the original display
+ plt.show()
diff --git a/test_scripts/test_anthropic.py b/test_scripts/t2t/test_anthropic.py
similarity index 91%
rename from test_scripts/test_anthropic.py
rename to test_scripts/t2t/test_anthropic.py
index 4f845de8..3b761774 100644
--- a/test_scripts/test_anthropic.py
+++ b/test_scripts/t2t/test_anthropic.py
@@ -10,7 +10,7 @@
question = """
Hey Claude! How can I recursively list all files in a directory in Python?
"""
-models = ["anthropic.claude-v2:1", "anthropic.claude-instant-v1", "anthropic.claude-v1", "anthropic.claude-v2"]
+models = ["anthropic.claude-v2:1", "anthropic.claude-instant-v1", "anthropic.claude-v1", "anthropic.claude-v2", "anthropic.claude-3-sonnet-20240229-v1:0"]
# Define an async function
async def run_async_code():
@@ -20,7 +20,7 @@ async def run_async_code():
max_tokens_to_sample=300,
temperature=0.01, # must be <= 1.0
top_p=1,
- model=models[0],
+ model=models[-1],
stream=True,
)
diff --git a/test_scripts/test_anthropicv2.py b/test_scripts/t2t/test_anthropicv2.py
similarity index 94%
rename from test_scripts/test_anthropicv2.py
rename to test_scripts/t2t/test_anthropicv2.py
index c8f4fd55..9aa575b0 100644
--- a/test_scripts/test_anthropicv2.py
+++ b/test_scripts/t2t/test_anthropicv2.py
@@ -50,8 +50,8 @@ async def call_anthropic(question, model, max_tokens):
print(completion.completion)
return completion.completion
-models = ["anthropic.claude-v2:1", "anthropic.claude-instant-v1", "anthropic.claude-v1", "anthropic.claude-v2"]
-model = models[1]
+models = ["anthropic.claude-v2:1", "anthropic.claude-instant-v1", "anthropic.claude-v1", "anthropic.claude-v2", "anthropic.claude-3-sonnet-20240229-v1:0"]
+model = models[-1]
question = "tell me a short story"
max_tokens = 2048
diff --git a/test_scripts/t2t/test_claude.py b/test_scripts/t2t/test_claude.py
new file mode 100644
index 00000000..36631caa
--- /dev/null
+++ b/test_scripts/t2t/test_claude.py
@@ -0,0 +1,79 @@
+from anthropic import AsyncAnthropic
+import os
+import asyncio
+# https://docs.anthropic.com/claude/reference/messages_post
+api_key = os.environ.get("ANTHROPIC_API_KEY")
+if not api_key:
+ raise ValueError("API key not found in environment variables")
+
+claude_client = AsyncAnthropic()
+claude_client.api_key = api_key
+
+messages = [
+ {
+ "role": "system",
+ "content": "respond in spanish"
+ },
+ {
+ "role": "user",
+ "content": "Hello!"
+ }
+]
+max_tokens = 100
+model = "claude-3-opus-20240229"
+
+
+# streaming
+async def call_claude(messages, max_tokens, model):
+ system_prompt = None
+ filtered_messages = []
+ for message in messages:
+ if message["role"] == "system":
+ system_prompt = message["content"]
+ else:
+ filtered_messages.append(message)
+
+ stream_kwargs = {
+ "max_tokens": max_tokens,
+ "messages": filtered_messages,
+ "model": model,
+ }
+
+ if system_prompt:
+ stream_kwargs["system"] = system_prompt
+
+ completion = claude_client.messages.stream(**stream_kwargs)
+ async with completion as stream:
+ async for text in stream.text_stream:
+ print(text, end="", flush=True)
+
+ # Send final message to close the stream
+ print("\n")
+
+# non streaming
+# async def call_claude(messages, max_tokens, model):
+# filtered_messages = []
+# for message in messages:
+# if message["role"] == "system":
+# system_prompt = message["content"]
+# else:
+# filtered_messages.append(message)
+
+# kwargs = {
+# "max_tokens": max_tokens,
+# "messages": filtered_messages,
+# "model": model,
+# }
+
+# if system_prompt:
+# kwargs["system"] = system_prompt
+
+# message = await claude_client.messages.create(**kwargs)
+# print(message.content[0].text)
+# return message.content[0].text
+
+async def main():
+ await call_claude(messages, max_tokens, model)
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/test_scripts/t2t/test_gemini.py b/test_scripts/t2t/test_gemini.py
new file mode 100644
index 00000000..e40ed590
--- /dev/null
+++ b/test_scripts/t2t/test_gemini.py
@@ -0,0 +1,103 @@
+import os
+import google.generativeai as genai
+import traceback
+import asyncio
+
+google_api = os.environ.get('GOOGLE_API_KEY')
+genai.configure(api_key=google_api)
+
+# https://ai.google.dev/tutorials/python_quickstart
+model = 'gemini-pro'
+messages = [
+ {
+ "role": "system",
+ "content": "respond in spanish with at least 10 words, write a paragraph"
+ },
+ {
+ "role": "user",
+ "content": "Tell me about miami"
+ }
+]
+# messages = ', '.join(message['content'] for message in messages)
+messages = [{'role': 'user', 'content': 'Compare and contrast the differences between inductive and deductive reasoning in the context of scientific research.'}]
+temperature = 0.0001
+max_tokens = 100000
+top_p = 0.01
+top_k = 1
+seed = 1234
+
+# for m in genai.list_models():
+# if 'generateContent' in m.supported_generation_methods:
+# print(m.name)
+# model = genai.GenerativeModel(model)
+
+# Streaming
+async def call_gemini(messages, temperature, model, max_tokens, top_p, top_k):
+ print(f"Calling Gemini. Temperature = {temperature}, Model = {model}, Messages = {messages}, max tokens = {max_tokens}, top_p = {top_p}, top_k = {top_k}")
+ try:
+ model = genai.GenerativeModel(model)
+ stream = model.generate_content(
+ str(messages),
+ stream=True,
+ generation_config=genai.types.GenerationConfig(
+ # candidate_count=1,
+ # stop_sequences=['x'],
+ temperature=temperature,
+ max_output_tokens=max_tokens,
+ top_p=top_p,
+ top_k=top_k,
+ # seed=seed,
+ )
+ )
+ for chunk in stream:
+ # print(chunk)
+ for part in chunk.candidates[0].content.parts:
+ print(chunk.text, end="", flush=True)
+ print(f"\n")
+ print(stream)
+ return stream.text
+ except:
+ print(f"error in call_gemini {traceback.format_exc()}")
+
+# Non streaming
+async def call_gemini(messages, temperature, model, max_tokens, top_p, top_k):
+ print(f"Calling Gemini. Temperature = {temperature}, Model = {model}, Messages = {messages}")
+ try:
+ model = genai.GenerativeModel(model)
+ response = model.generate_content(
+ str(messages),
+ stream=False,
+ generation_config=genai.types.GenerationConfig(
+ # candidate_count=1,
+ # stop_sequences=['x'],
+ temperature=temperature,
+ max_output_tokens=max_tokens,
+ top_p=top_p,
+ top_k=top_k,
+ # seed=seed,
+ )
+ )
+
+ print(f"validator response is {response.text}")
+ return response.text
+ except:
+ print(f"error in call_gemini {traceback.format_exc()}")
+
+async def main():
+ answer = await call_gemini(messages, temperature, model, max_tokens, top_p, top_k)
+ # print(f"\nAnswer = {answer}")
+
+if __name__ == "__main__":
+ asyncio.run(main())
+
+
+# from PIL import Image
+
+# img = Image.open('image.jpg')
+# # img.show()
+
+# # Initialize and use the model
+# model = genai.GenerativeModel('gemini-pro-vision')
+# response = model.generate_content(img)
+
+# print(response.text)
diff --git a/test_scripts/test_openai.py b/test_scripts/t2t/test_openai.py
similarity index 100%
rename from test_scripts/test_openai.py
rename to test_scripts/t2t/test_openai.py
diff --git a/test_scripts/t2v/dog.mp4 b/test_scripts/t2v/dog.mp4
new file mode 100644
index 00000000..4cf4ac3a
Binary files /dev/null and b/test_scripts/t2v/dog.mp4 differ
diff --git a/test_scripts/t2v/output.mp4 b/test_scripts/t2v/output.mp4
new file mode 100644
index 00000000..9e3c8974
Binary files /dev/null and b/test_scripts/t2v/output.mp4 differ
diff --git a/test_scripts/t2v/test_video.py b/test_scripts/t2v/test_video.py
new file mode 100644
index 00000000..47b050c5
--- /dev/null
+++ b/test_scripts/t2v/test_video.py
@@ -0,0 +1,202 @@
+import asyncio
+import requests
+import os
+import io
+import time
+from openai import AsyncOpenAI
+from playsound import playsound
+import base64
+import cv2
+from IPython.display import display, Image, Audio
+from PIL import Image
+from stability_sdk import client
+import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
+
+api_key = os.getenv("STABILITY_API_KEY")
+openai_key = os.getenv("OPENAI_API_KEY")
+gemini_key = os.getenv("GOOGLE_API_KEY")
+
+AsyncOpenAI.api_key = openai_key
+client = AsyncOpenAI(timeout=30)
+
+async def encode_image(image_path):
+ with open(image_path, "rb") as image_file:
+ return base64.b64encode(image_file.read()).decode('utf-8')
+
+# stability video parameters
+# ref: https://platform.stability.ai/docs/api-reference#tag/v2alphageneration/paths/~1v2alpha~1generation~1image-to-video/post
+
+# dimensions for video options:
+# 1024x576
+# 576x1024
+# 768x768
+
+# seed = [ 0 .. 2147483648 ]
+# default = 0
+# description = A specific value that is used to guide the 'randomness' of the generation. (Omit this parameter or pass 0 to use a random seed.
+
+# cfg_scale = [ 0 .. 10 ]
+# default = 2.5
+# description = How strongly the video sticks to the original image. Use lower values to allow the model more freedom to make changes and higher values to correct motion distortions.
+
+# motion_bucket_id = [ 1 .. 255 ]
+# default = 40
+# description = Lower values generally result in less motion in the output video, while higher values generally result in more motion. This parameter corresponds to the motion_bucket_id parameter from here:
+# https://static1.squarespace.com/static/6213c340453c3f502425776e/t/655ce779b9d47d342a93c890/1700587395994/stable_video_diffusion.pdf
+
+
+# Step 1: generate an image with the appropriate parameters
+
+
+
+# async def read_video_frames(video_path):
+# video = cv2.VideoCapture(video_path)
+# base64Frames = []
+# while video.isOpened():
+# success, frame = video.read()
+# if not success:
+# break
+# _, buffer = cv2.imencode(".jpg", frame)
+# base64Frames.append(base64.b64encode(buffer).decode("utf-8"))
+# video.release()
+# return base64Frames
+
+# async def display_frames(base64Frames):
+# display_handle = display(None, display_id=True)
+# for img in base64Frames:
+# display(Image(data=base64.b64decode(img.encode("utf-8"))))
+# await asyncio.sleep(0.025)
+
+# async def generate_descriptions(base64Frames, model, max_tokens):
+# PROMPT_MESSAGES = [
+# {
+# "role": "user",
+# "content": [
+# "These are frames from a video that I want to upload. Generate a compelling description that I can upload along with the video.",
+# *map(lambda x: {"image": x, "resize": 768}, base64Frames[0::50]),
+# ],
+# },
+# ]
+# params = {
+# "model": model,
+# "messages": PROMPT_MESSAGES,
+# "max_tokens": max_tokens,
+# }
+
+# result = await client.chat.completions.create(**params)
+# return result.choices[0].message.content
+
+# async def generate_audio_speech(text):
+# response = requests.post(
+# "https://api.openai.com/v1/audio/speech",
+# headers={
+# "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}",
+# },
+# json={
+# "model": "tts-1-1106",
+# "input": text,
+# "voice": "onyx",
+# },
+# )
+
+# audio_file_path = "output_audio.mp3"
+# with open(audio_file_path, "wb") as audio_file:
+# for chunk in response.iter_content(chunk_size=1024 * 1024):
+# audio_file.write(chunk)
+# return audio_file_path
+
+# async def main():
+# image_path = "../t2i/image.jpg"
+# base64_image = await encode_image(image_path)
+
+# base64Frames = await read_video_frames("output.mp4")
+# print(len(base64Frames), "frames read.")
+
+# await display_frames(base64Frames)
+
+# description = await generate_descriptions(base64Frames, "gpt-4-vision-preview", 200)
+# print(description)
+
+# voiceover_script = await generate_descriptions(base64Frames, "gpt-4-vision-preview", 500)
+# print(voiceover_script)
+
+# audio_file_path = await generate_audio_speech(voiceover_script)
+# playsound(audio_file_path)
+
+
+# asyncio.run(main())
+
+
+# response = client.chat.completions.create(
+# model="gpt-4-vision-preview",
+# messages=[
+# {
+# "role": "user",
+# "content": [
+# {"type": "text", "text": "What’s in this image?"},
+# {
+# "type": ,
+# "image_url": {
+# "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
+# },
+# },
+# ],
+# }
+# ],
+# max_tokens=300,
+# )
+
+
+
+image_path = "../images/john2.png"
+seed = 0
+cfg_scale = 1
+motion_bucket_id = 150
+FPS = 30 # 0 to 30
+
+
+
+# response = requests.post(
+# "https://api.stability.ai/v2alpha/generation/image-to-video",
+# headers={
+# "authorization": api_key,
+# },
+# data={
+# "seed": seed,
+# "cfg_scale": cfg_scale,
+# "motion_bucket_id": motion_bucket_id,
+# "FPS": FPS,
+# },
+# files={
+# "image": ("file", open(image_path, "rb"), "image/png")
+# },
+# )
+
+# if response.status_code != 200:
+# raise Exception("Non-200 response: " + str(response.text))
+
+# data = response.json()
+# generation_id = data["id"]
+# print(generation_id)
+# time.sleep(30)
+
+
+generation_id = "0ca93f4172448c1f0b5f40f3c5bb3afce3f60e8da4da07e9f635228be4e6acea"
+
+response = requests.request(
+ "GET",
+ f"https://api.stability.ai/v2alpha/generation/image-to-video/result/{generation_id}",
+ headers={
+ 'Accept': None, # Use 'application/json' to receive base64 encoded JSON
+ 'authorization': api_key,
+ },
+)
+
+if response.status_code == 202:
+ print("Generation in-progress, try again in 10 seconds.")
+elif response.status_code == 200:
+ print("Generation complete!")
+ with open('./john5.mp4', 'wb') as file:
+ file.write(response.content)
+else:
+ raise Exception("Non-200 response: " + str(response.json()))
\ No newline at end of file
diff --git a/test_scripts/test_claude.py b/test_scripts/test_claude.py
deleted file mode 100644
index 9fd9937b..00000000
--- a/test_scripts/test_claude.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import anthropic
-import os
-
-
-# Model options are claude-instant-1.2, claude-2.1
-# Retrieve API key from environment variable
-api_key = os.environ.get("ANTHROPIC_API_KEY")
-if not api_key:
- raise ValueError("API key not found in environment variables")
-
-client = anthropic.Anthropic()
-client.api_key = api_key
-
-question = "Tell me a short joke"
-with client.beta.messages.stream(
- max_tokens=1024,
- messages=[{"role": "user", "content": question}],
- model="claude-2.1",
-) as stream:
- for text in stream.text_stream:
- print(text, end="", flush=True)
-
-print("\n")
-
-
-# non streaming, async
-import asyncio
-import traceback
-from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
-import bittensor as bt
-import os
-
-
-try:
- anthropic = Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
-except KeyError as exc:
- raise ValueError("Please set the ANTROPIC_API_KEY environment variable.") from exc
-
-async def call_anthropic(messages, temperature, model, seed=1234) -> str:
-
- for _ in range(2):
- bt.logging.debug(f"Calling Anthropics. Model = {model}, Prompt = {prompt}")
- try:
- completion = anthropic.completions.create(
- model=model,
- max_tokens_to_sample=1000,
- prompt=f"{HUMAN_PROMPT} {messages[0]['content']}{AI_PROMPT}",
- temperature=temperature,
- )
- response = completion.completion
- bt.logging.debug(f"Validator response is {response}")
- return response
-
- except Exception as e:
- bt.logging.error(f"Error when calling Anthropics: {traceback.format_exc()}")
- await asyncio.sleep(0.5)
-
- return None
-
-# Example usage of the function
-prompt = "Tell me a short joke"
-messages = [{'role': 'user', 'content': prompt}]
-model = "claude-2"
-temperature = 0.0001
-
-# Run the async function
-response = asyncio.run(call_anthropic(messages, temperature, model))
-print(response)
\ No newline at end of file
diff --git a/test_scripts/test_stability.py b/test_scripts/test_stability.py
deleted file mode 100644
index b938205e..00000000
--- a/test_scripts/test_stability.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import os
-import io
-import requests
-from PIL import Image
-from stability_sdk import client
-import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
-
-# Set up the API connection
-stability_api = client.StabilityInference(
- key=os.environ['STABILITY_KEY'],
- verbose=True,
- engine="stable-diffusion-xl-1024-v1-0"
-)
-
-# # Generate the image
-# answers = stability_api.generate(
-# prompt="expansive landscape rolling greens with gargantuan yggdrasil, intricate world-spanning roots towering under a blue alien sky, masterful, ghibli",
-# seed=9283409,
-# steps=30,
-# cfg_scale=8.0,
-# width=1024,
-# height=1024,
-# samples=1,
-# sampler=generation.SAMPLER_K_DPMPP_2M
-# )
-
-# # Process and upload the image
-# for resp in answers:
-# for artifact in resp.artifacts:
-# if artifact.finish_reason == generation.FILTER:
-# print("Safety filters activated, prompt could not be processed.")
-# elif artifact.type == generation.ARTIFACT_IMAGE:
-# img = Image.open(io.BytesIO(artifact.binary))
-# img_buffer = io.BytesIO()
-# img.save(img_buffer, format="PNG")
-# img_buffer.seek(0)
-
-# # Upload to file.io
-# response = requests.post('https://file.io', files={'file': img_buffer})
-# if response.status_code == 200:
-# print(f"Image uploaded successfully. URL: {response.json()['link']}")
-# else:
-# print("Failed to upload the image.")
-
-prompt = 'halla halla, we dem boys'
-
-data = ('Stability', 'stable-diffusion-xl-1024-v1-0', prompt, '1024x1024', 1024, 1024, 'standard', 'vivid', 656827, 30, None, 8.0, 'SAMPLER_K_DPMPP_2M', 9, 1)
-
-meta = stability_api.generate(
- prompt=data[2],
- seed=data[8],
- steps=data[-6],
- cfg_scale=data[-4],
- width=data[-11],
- height=data[-10],
- samples=data[-1],
- sampler=data[-3],
-)
-
-# Process and upload the image
-for image in meta:
- for artifact in image.artifacts:
- img = Image.open(io.BytesIO(artifact.binary))
- img_buffer = io.BytesIO()
- img.save(img_buffer, format="PNG")
- img_buffer.seek(0)
- # Upload to file.io
- response = requests.post('https://file.io', files={'file': img_buffer})
- if response.status_code == 200:
- image_url = response.json()['link']
- print(image_url)
- else:
- bt.logging.error("Failed to upload the image.")
\ No newline at end of file
diff --git a/validators/embeddings_validator.py b/validators/embeddings_validator.py
index d35f8f61..d292f6d6 100644
--- a/validators/embeddings_validator.py
+++ b/validators/embeddings_validator.py
@@ -5,10 +5,10 @@
import random
import asyncio
import bittensor as bt
-import template.reward
-from template import client
+import cortext.reward
+from cortext import client
from datasets import load_dataset
-from template.protocol import Embeddings
+from cortext.protocol import Embeddings
from base_validator import BaseValidator
class EmbeddingsValidator(BaseValidator):
@@ -123,7 +123,7 @@ async def score_responses(self, query_responses, uid_to_question, metagraph):
response = next(res for u, res in query_responses if u == uid)
response = response[0]
if response.embeddings is not None:
- task = template.reward.embeddings_score_dot(openai_answer, response.embeddings, self.weight)
+ task = cortext.reward.embeddings_score_dot(openai_answer, response.embeddings, self.weight)
scoring_tasks.append((uid, task))
else:
scores[uid] = 0
diff --git a/validators/image_validator.py b/validators/image_validator.py
index 9894d0ee..b43b67a4 100644
--- a/validators/image_validator.py
+++ b/validators/image_validator.py
@@ -6,14 +6,14 @@
import aiohttp
import base64
import traceback
-import template.reward
+import cortext.reward
import bittensor as bt
from PIL import Image
from io import BytesIO
-from template.utils import get_question
+from cortext.utils import get_question
from base_validator import BaseValidator
-from template.protocol import ImageResponse
+from cortext.protocol import ImageResponse
class ImageValidator(BaseValidator):
@@ -46,7 +46,7 @@ async def start_query(self, available_uids, metagraph):
uid_to_question = {}
# Randomly choose the provider based on specified probabilities
- providers = ["OpenAI"] * 8 + ["Stability"] * 2
+ providers = ["OpenAI"] * 100 + ["Stability"] * 0
self.provider = random.choice(providers)
if self.provider == "Stability":
@@ -122,10 +122,10 @@ async def score_responses(self, query_responses, uid_to_question, metagraph):
if will_score_all:
if syn.provider == "OpenAI":
- score_task = template.reward.dalle_score(uid, image_url, self.size, syn.messages, self.weight)
+ score_task = cortext.reward.dalle_score(uid, image_url, self.size, syn.messages, self.weight)
else:
continue
- score_task = template.reward.deterministic_score(uid, syn, self.weight)
+ score_task = cortext.reward.deterministic_score(uid, syn, self.weight)
score_tasks.append(asyncio.create_task(score_task))
diff --git a/validators/text_validator.py b/validators/text_validator.py
index 59f13e80..ae5c7d86 100644
--- a/validators/text_validator.py
+++ b/validators/text_validator.py
@@ -1,14 +1,15 @@
import asyncio
import random
+import traceback
from typing import AsyncIterator, Tuple
import bittensor as bt
import torch
from base_validator import BaseValidator
-import template.reward
-from template.protocol import StreamPrompting
-from template.utils import call_openai, get_question, call_anthropic
+import cortext.reward
+from cortext.protocol import StreamPrompting
+from cortext.utils import call_openai, get_question, call_anthropic, call_gemini, call_claude
class TextValidator(BaseValidator):
@@ -16,8 +17,8 @@ def __init__(self, dendrite, config, subtensor, wallet: bt.wallet):
super().__init__(dendrite, config, subtensor, wallet, timeout=75)
self.streaming = True
self.query_type = "text"
- self.model = "gpt-4-0125-preview"
- self.max_tokens = 2048
+ self.model = "gpt-4-1106-preview"
+ self.max_tokens = 4096
self.temperature = 0.0001
self.weight = 1
self.seed = 1234
@@ -41,7 +42,7 @@ async def organic(self, metagraph, query: dict[str, list[dict[str, str]]]) -> As
f"timeout {self.timeout}: {syn.messages[0]['content']}"
)
- self.wandb_data["prompts"][uid] = messages
+ # self.wandb_data["prompts"][uid] = messages
responses = await self.dendrite(
metagraph.axons[uid],
syn,
@@ -72,34 +73,46 @@ async def get_question(self, qty):
return await get_question("text", qty)
async def start_query(self, available_uids, metagraph) -> tuple[list, dict]:
- query_tasks = []
- uid_to_question = {}
- # Randomly choose the provider based on specified probabilities
- providers = ["OpenAI"] * 95 + ["Anthropic"] * 5
- self.provider = random.choice(providers)
-
- if self.provider == "Anthropic":
- # bedrock models = ["anthropic.claude-v2:1", "anthropic.claude-instant-v1", "anthropic.claude-v1", "anthropic.claude-v2"]
- # claude models = ["claude-2.1", "claude-2.0", "claude-instant-1.2"]
- self.model = "anthropic.claude-v2:1"
- elif self.provider == "OpenAI":
- self.model = "gpt-4-0125-preview"
-
- for uid in available_uids:
- prompt = await self.get_question(len(available_uids))
- uid_to_question[uid] = prompt
- messages = [{'role': 'user', 'content': prompt}]
- syn = StreamPrompting(messages=messages, model=self.model, seed=self.seed, max_tokens=self.max_tokens, temperature=self.temperature, provider=self.provider, top_p=self.top_p, top_k=self.top_k)
- bt.logging.info(
- f"Sending {syn.model} {self.query_type} request to uid: {uid}, "
- f"timeout {self.timeout}: {syn.messages[0]['content']}"
- )
- task = self.query_miner(metagraph, uid, syn)
- query_tasks.append(task)
- self.wandb_data["prompts"][uid] = prompt
-
- query_responses = await asyncio.gather(*query_tasks)
- return query_responses, uid_to_question
+ try:
+ query_tasks = []
+ uid_to_question = {}
+ # Randomly choose the provider based on specified probabilities
+ providers = ["OpenAI"] * 88 + ["Anthropic"] * 2 + ["Gemini"] * 0 + ["Claude"] * 10
+ self.provider = random.choice(providers)
+
+ if self.provider == "Anthropic":
+ # bedrock models = ["anthropic.claude-v2:1", "anthropic.claude-instant-v1", "anthropic.claude-v1", "anthropic.claude-v2"]
+ # claude models = ["claude-2.1", "claude-2.0", "claude-instant-1.2"]
+ # gemini models = ["gemini-pro"]
+ self.model = "anthropic.claude-v2:1"
+ elif self.provider == "OpenAI":
+ self.model = "gpt-4-1106-preview"
+ # self.model = "gpt-3.5-turbo"
+
+ elif self.provider == "Gemini":
+ self.model = "gemini-pro"
+
+ elif self.provider == "Claude":
+ self.model = "claude-3-opus-20240229"
+ # self.model = "claude-3-sonnet-20240229"
+ bt.logging.info(f"provider = {self.provider}\nmodel = {self.model}")
+ for uid in available_uids:
+ prompt = await self.get_question(len(available_uids))
+ uid_to_question[uid] = prompt
+ messages = [{'role': 'user', 'content': prompt}]
+ syn = StreamPrompting(messages=messages, model=self.model, seed=self.seed, max_tokens=self.max_tokens, temperature=self.temperature, provider=self.provider, top_p=self.top_p, top_k=self.top_k)
+ bt.logging.info(
+ f"Sending {syn.model} {self.query_type} request to uid: {uid}, "
+ f"timeout {self.timeout}: {syn.messages[0]['content']}"
+ )
+ task = self.query_miner(metagraph, uid, syn)
+ query_tasks.append(task)
+ self.wandb_data["prompts"][uid] = prompt
+
+ query_responses = await asyncio.gather(*query_tasks)
+ return query_responses, uid_to_question
+ except:
+ bt.logging.error(f"error in start_query = {traceback.format_exc()}")
def should_i_score(self):
random_number = random.random()
@@ -112,6 +125,10 @@ async def call_api(self, prompt: str, provider: str) -> str:
return await call_openai([{'role': 'user', 'content': prompt}], self.temperature, self.model, self.seed, self.max_tokens)
elif provider == "Anthropic":
return await call_anthropic(prompt, self.temperature, self.model, self.max_tokens, self.top_p, self.top_k)
+ elif provider == "Gemini":
+ return await call_gemini(prompt, self.temperature, self.model, self.max_tokens, self.top_p, self.top_k)
+ elif provider == "Claude":
+ return await call_claude([{'role': 'user', 'content': prompt}], self.temperature, self.model, self.max_tokens, self.top_p, self.top_k)
else:
bt.logging.error(f"provider {provider} not found")
@@ -121,6 +138,7 @@ async def score_responses(
uid_to_question: dict[int, str], # uid -> prompt
metagraph: bt.metagraph,
) -> tuple[torch.Tensor, dict[int, float], dict]:
+
scores = torch.zeros(len(metagraph.hotkeys))
uid_scores_dict = {}
response_tasks = []
@@ -140,7 +158,7 @@ async def score_responses(
for (uid, _), api_answer in zip(response_tasks, api_responses):
if api_answer:
response = next(res for u, res in query_responses if u == uid) # Find the matching response
- task = template.reward.api_score(api_answer, response, self.weight)
+ task = cortext.reward.api_score(api_answer, response, self.weight, self.temperature, self.provider)
scoring_tasks.append((uid, task))
scored_responses = await asyncio.gather(*[task for _, task in scoring_tasks])
@@ -154,50 +172,9 @@ async def score_responses(
else:
scores[uid] = 0
uid_scores_dict[uid] = 0
-
if uid_scores_dict != {}:
bt.logging.info(f"text_scores is {uid_scores_dict}")
return scores, uid_scores_dict, self.wandb_data
-class TestTextValidator(TextValidator):
- def __init__(
- self,
- dendrite,
- config,
- subtensor,
- wallet: bt.wallet,
- ):
- super().__init__(dendrite, config, subtensor, wallet)
- self.openai_prompt_to_contents: dict[str, list[str]] = {}
- self.questions: list[str] = []
- self._questions_retrieved = -1
- self._openai_prompts_used: dict[str, int] = {}
-
- def feed_mock_data(self, openai_prompt_to_contents, questions):
- self.questions = questions
- self.openai_prompt_to_contents = openai_prompt_to_contents
- self._openai_prompts_used = dict.fromkeys(self.openai_prompt_to_contents, -1)
- self._questions_retrieved = -1
-
- def should_i_score(self):
- return True
-
- async def call_openai(self, prompt: str) -> str:
- self._openai_prompts_used[prompt] += 1
- used = self._openai_prompts_used[prompt]
- contents = self.openai_prompt_to_contents[prompt]
- return contents[used % len(contents)]
-
- async def get_question(self, qty):
- self._questions_retrieved += 1
- return self.questions[self._questions_retrieved % len(self.questions)]
-
- async def query_miner(self, metagraph, uid, syn: StreamPrompting):
- return uid, await self.call_openai(syn.messages[0]['content'])
-
- async def organic(self, metagraph, query: dict[str, list[dict[str, str]]]) -> AsyncIterator[tuple[int, str]]:
- for uid, messages in query.items():
- for msg in messages:
- yield uid, await self.call_openai(msg['content'])
diff --git a/validators/validator.py b/validators/validator.py
index 44b8dc4c..15cd2291 100644
--- a/validators/validator.py
+++ b/validators/validator.py
@@ -14,40 +14,35 @@
import bittensor as bt
import torch
import wandb
-from aiohttp import web
-from aiohttp.web_response import Response
from image_validator import ImageValidator
from embeddings_validator import EmbeddingsValidator
-from text_validator import TextValidator, TestTextValidator
+from text_validator import TextValidator
from base_validator import BaseValidator
from envparse import env
-import template
-from template import utils
+import cortext
+from cortext import utils
import sys
-from weight_setter import WeightSetter, TestWeightSetter
+from weight_setter import WeightSetter
text_vali = None
image_vali = None
embed_vali = None
metagraph = None
wandb_runs = {}
-# organic requests are scored, the tasks are stored in this queue
-# for later being consumed by `query_synapse` cycle:
-organic_scoring_tasks = set()
-EXPECTED_ACCESS_KEY = os.environ.get('EXPECTED_ACCESS_KEY', "hello")
def get_config() -> bt.config:
parser = argparse.ArgumentParser()
parser.add_argument("--netuid", type=int, default=18)
parser.add_argument('--wandb_off', action='store_false', dest='wandb_on')
- parser.add_argument('--http_port', type=int, default=8100)
+ parser.add_argument('--axon.port', type=int, default=8000)
parser.set_defaults(wandb_on=True)
bt.subtensor.add_args(parser)
bt.logging.add_args(parser)
bt.wallet.add_args(parser)
+ bt.axon.add_args(parser)
config = bt.config(parser)
_args = parser.parse_args()
full_path = Path(
@@ -62,17 +57,17 @@ def init_wandb(config, my_uid, wallet: bt.wallet):
if not config.wandb_on:
return
- run_name = f'validator-{my_uid}-{template.__version__}'
+ run_name = f'validator-{my_uid}-{cortext.__version__}'
config.uid = my_uid
config.hotkey = wallet.hotkey.ss58_address
config.run_name = run_name
- config.version = template.__version__
+ config.version = cortext.__version__
config.type = 'validator'
# Initialize the wandb run for the single project
run = wandb.init(
name=run_name,
- project=template.PROJECT_NAME,
+ project=cortext.PROJECT_NAME,
entity='cortex-t',
config=config,
dir=config.full_path,
@@ -84,7 +79,7 @@ def init_wandb(config, my_uid, wallet: bt.wallet):
config.signature = signature
wandb.config.update(config, allow_val_change=True)
- bt.logging.success(f"Started wandb run for project '{template.PROJECT_NAME}'")
+ bt.logging.success(f"Started wandb run for project '{cortext.PROJECT_NAME}'")
def initialize_components(config: bt.config):
@@ -102,63 +97,19 @@ def initialize_components(config: bt.config):
f"{subtensor}. Run btcli register --netuid 18 and try again."
)
sys.exit()
-
return wallet, subtensor, dendrite, my_uid
def initialize_validators(vali_config, test=False):
global text_vali, image_vali, embed_vali
- text_vali = (TextValidator if not test else TestTextValidator)(**vali_config)
+ text_vali = TextValidator(**vali_config)
image_vali = ImageValidator(**vali_config)
embed_vali = EmbeddingsValidator(**vali_config)
bt.logging.info("initialized_validators")
-async def process_text_validator(request: web.Request):
- # Basic request validation
- if request.method != "POST" or request.path != '/text-validator/':
- return web.Response(status=400, text="Invalid request")
-
- # Check access key
- access_key = request.headers.get("access-key")
- if access_key != EXPECTED_ACCESS_KEY:
- return web.Response(status=401, text="Invalid access key")
-
- try:
- messages_dict = {int(k): [{'role': 'user', 'content': v}] for k, v in (await request.json()).items()}
- except ValueError:
- return web.Response(status=400, text="Bad request format")
-
- response = web.StreamResponse()
- await response.prepare(request)
-
- uid_to_response = dict.fromkeys(messages_dict, "")
- try:
- async for uid, content in text_vali.organic(validator_app.weight_setter.metagraph, messages_dict):
- uid_to_response[uid] += content
- await response.write(content.encode())
- validator_app.weight_setter.register_text_validator_organic_query(
- uid_to_response, {k: v[0]['content'] for k, v in messages_dict.items()}
- )
- except Exception as e:
- bt.logging.error(f'Encountered in {process_text_validator.__name__}:\n{traceback.format_exc()}')
- await response.write(b'<>')
-
- return response
-
-
-class ValidatorApplication(web.Application):
- def __init__(self, *a, **kw):
- super().__init__(*a, **kw)
- self.weight_setter: WeightSetter | None = None
-
-
-validator_app = ValidatorApplication()
-# validator_app.add_routes([web.post('/text-validator/', process_text_validator)])
-
-
-def main(run_aio_app=True, test=False) -> None:
+def main(test=False) -> None:
config = get_config()
wallet, subtensor, dendrite, my_uid = initialize_components(config)
validator_config = {
@@ -170,21 +121,18 @@ def main(run_aio_app=True, test=False) -> None:
initialize_validators(validator_config, test)
init_wandb(config, my_uid, wallet)
loop = asyncio.get_event_loop()
-
- weight_setter = (WeightSetter if not test else TestWeightSetter)(
- loop, dendrite, subtensor, config, wallet, text_vali, image_vali, embed_vali)
- validator_app.weight_setter = weight_setter
-
- if run_aio_app:
- try:
- web.run_app(validator_app, port=config.http_port, loop=loop)
- except KeyboardInterrupt:
- bt.logging.info("Keyboard interrupt detected. Exiting validator.")
- finally:
- state = utils.get_state()
- utils.save_state_to_file(state)
- if config.wandb_on:
- wandb.finish()
+ weight_setter = WeightSetter(loop, dendrite, subtensor, config, wallet, text_vali, image_vali, embed_vali)
+ state_path = os.path.join(config.full_path, "state.json")
+ utils.get_state(state_path)
+ try:
+ loop.run_forever()
+ except KeyboardInterrupt:
+ bt.logging.info("Keyboard interrupt detected. Exiting validator.")
+ finally:
+ state = utils.get_state(state_path)
+ utils.save_state_to_file(state, state_path)
+ if config.wandb_on:
+ wandb.finish()
if __name__ == "__main__":
diff --git a/validators/weight_setter.py b/validators/weight_setter.py
index 6ed62713..7215a788 100644
--- a/validators/weight_setter.py
+++ b/validators/weight_setter.py
@@ -4,7 +4,7 @@
import traceback
import random
from typing import Tuple
-import template
+import cortext
import bittensor as bt
import torch
@@ -12,7 +12,40 @@
import os
import shutil
-from template.protocol import IsAlive
+import argparse
+import asyncio
+import base64
+import copy
+import json
+import os
+import pathlib
+import requests
+import threading
+import time
+import traceback
+from collections import deque
+from functools import partial
+from typing import Tuple
+import bittensor as bt
+import google.generativeai as genai
+import wandb
+from PIL import Image
+from stability_sdk import client
+from openai import AsyncOpenAI, OpenAI
+from stability_sdk import client as stability_client
+from PIL import Image
+import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
+import anthropic
+from anthropic_bedrock import AsyncAnthropicBedrock, HUMAN_PROMPT, AI_PROMPT, AnthropicBedrock
+
+import cortext
+from cortext.protocol import Embeddings, ImageResponse, IsAlive, StreamPrompting, TextPrompting
+from cortext.utils import get_version
+import sys
+
+from starlette.types import Send
+
+from cortext.protocol import IsAlive, StreamPrompting, ImageResponse, Embeddings
from text_validator import TextValidator
from image_validator import ImageValidator
from embeddings_validator import EmbeddingsValidator
@@ -21,39 +54,151 @@
scoring_organic_timeout = 60
-async def wait_for_coro_with_limit(coro, timeout: int) -> Tuple[bool, object]:
- try:
- result = await asyncio.wait_for(coro, timeout)
- except asyncio.TimeoutError:
- bt.logging.error('scoring task timed out')
- return False, None
- return True, result
-
-
class WeightSetter:
def __init__(self, loop: asyncio.AbstractEventLoop, dendrite, subtensor, config, wallet, text_vali, image_vali, embed_vali):
+ bt.logging.info("starting weight setter")
+ self.config = config
+ bt.logging.info(f"config:\n{self.config}")
+ self.prompt_cache: dict[str, Tuple[str, int]] = {}
+ self.request_timestamps = {}
self.loop = loop
self.dendrite = dendrite
self.subtensor = subtensor
- self.config = config
self.wallet = wallet
self.text_vali = text_vali
self.image_vali = image_vali
self.embed_vali = embed_vali
-
self.moving_average_scores = None
- self.metagraph = subtensor.metagraph(config.netuid)
+ self.axon = bt.axon(wallet=self.wallet, port=self.config.axon.port)
+ self.metagraph = self.subtensor.metagraph(config.netuid)
self.total_scores = torch.zeros(len(self.metagraph.hotkeys))
self.organic_scoring_tasks = set()
-
self.thread_executor = concurrent.futures.ThreadPoolExecutor(thread_name_prefix='asyncio')
# self.loop.create_task(self.consume_organic_scoring())
self.loop.create_task(self.perform_synthetic_scoring_and_update_weights())
+ def config(self) -> bt.config:
+ parser = argparse.ArgumentParser(description="Validator Configs")
+ return bt.config(parser)
+
async def run_sync_in_async(self, fn):
return await self.loop.run_in_executor(self.thread_executor, fn)
+ def blacklist_prompt( self, synapse: StreamPrompting ) -> Tuple[bool, str]:
+ blacklist = self.base_blacklist(synapse, cortext.PROMPT_BLACKLIST_STAKE)
+ bt.logging.info(blacklist[1])
+ return blacklist
+
+ def blacklist_is_alive( self, synapse: IsAlive ) -> Tuple[bool, str]:
+ blacklist = self.base_blacklist(synapse, cortext.ISALIVE_BLACKLIST_STAKE)
+ bt.logging.debug(blacklist[1])
+ return blacklist
+
+ def blacklist_images( self, synapse: ImageResponse ) -> Tuple[bool, str]:
+ blacklist = self.base_blacklist(synapse, cortext.IMAGE_BLACKLIST_STAKE)
+ bt.logging.info(blacklist[1])
+ return blacklist
+
+ def blacklist_embeddings( self, synapse: Embeddings ) -> Tuple[bool, str]:
+ blacklist = self.base_blacklist(synapse, cortext.EMBEDDING_BLACKLIST_STAKE)
+ bt.logging.info(blacklist[1])
+ return blacklist
+
+ def base_blacklist(self, synapse, blacklist_amt = 20000) -> Tuple[bool, str]:
+ try:
+ hotkey = synapse.dendrite.hotkey
+ synapse_type = type(synapse).__name__
+
+ if hotkey == self.wallet.hotkey.ss58_address:
+ return False, f"accepting {synapse_type} request from self"
+
+ elif hotkey in cortext.VALIDATOR_API_WHITELIST:
+ return False, f"accepting {synapse_type} request from whitelist: {hotkey}"
+
+ return True, f"rejecting {synapse_type} request from {hotkey}"
+
+ except Exception:
+ bt.logging.error(f"errror in blacklist {traceback.format_exc()}")
+
+ async def images(self, synapse: ImageResponse) -> ImageResponse:
+ bt.logging.info(f"received {synapse}")
+
+ synapse = self.dendrite.query(self.metagraph.axons[synapse.uid], synapse, deserialize=False, timeout=synapse.timeout)
+
+ bt.logging.info(f"new synapse = {synapse}")
+ return synapse
+
+ async def embeddings(self, synapse: Embeddings) -> Embeddings:
+ bt.logging.info(f"received {synapse}")
+
+ synapse = await self.dendrite(self.metagraph.axons[synapse.uid], synapse, deserialize=False, timeout=synapse.timeout)
+
+ bt.logging.info(f"new synapse = {synapse}")
+ return synapse
+
+ async def prompt(self, synapse: StreamPrompting) -> StreamPrompting:
+ bt.logging.info(f"received {synapse}")
+
+ async def _prompt(synapse, send: Send):
+ bt.logging.info(
+ f"Sending {synapse} request to uid: {synapse.uid}, "
+ )
+ async def handle_response(responses):
+ for resp in responses:
+ async for chunk in resp:
+ if isinstance(chunk, str):
+ await send(
+ {
+ "type": "http.response.body",
+ "body": chunk.encode("utf-8"),
+ "more_body": True,
+ }
+ )
+ bt.logging.info(f"Streamed text: {chunk}")
+ await send({"type": "http.response.body", "body": b'', "more_body": False})
+
+ axon = self.metagraph.axons[synapse.uid]
+ responses = self.dendrite.query(
+ axons=[axon],
+ synapse=synapse,
+ deserialize=False,
+ timeout=synapse.timeout,
+ streaming=True,
+ )
+ return await handle_response(responses)
+
+ token_streamer = partial(_prompt, synapse)
+ return synapse.create_streaming_response(token_streamer)
+
+ def text(self, synapse: TextPrompting) -> TextPrompting:
+ synapse.completion = "completed"
+ bt.logging.info("completed")
+
+ synapse = self.dendrite.query(self.metagraph.axons[synapse.uid], synapse, deserialize=False, timeout=synapse.timeout)
+
+ bt.logging.info(f"synapse = {synapse}")
+ return synapse
+
async def consume_organic_scoring(self):
+ bt.logging.info("Attaching forward function to axon.")
+ self.axon.attach(
+ forward_fn=self.prompt,
+ blacklist_fn=self.blacklist_prompt,
+ ).attach(
+ forward_fn=self.images,
+ blacklist_fn=self.blacklist_images,
+ ).attach(
+ forward_fn=self.embeddings,
+ blacklist_fn=self.blacklist_embeddings,
+ ).attach(
+ forward_fn=self.text,
+ )
+ # self.axon.serve(netuid = self.config.netuid, subtensor = self.subtensor)
+ self.axon.start()
+ self.my_subnet_uid = self.metagraph.hotkeys.index(
+ self.wallet.hotkey.ss58_address
+ )
+ bt.logging.info(f"Running validator on uid: {self.my_subnet_uid}")
while True:
try:
if self.organic_scoring_tasks:
@@ -72,12 +217,12 @@ async def consume_organic_scoring(self):
self.total_scores += data[0]
self.organic_scoring_tasks.difference_update(completed)
else:
- await asyncio.sleep(1)
+ await asyncio.sleep(60)
except Exception as e:
bt.logging.error(f'Encountered in {self.consume_organic_scoring.__name__} loop:\n{traceback.format_exc()}')
await asyncio.sleep(10)
-
+
async def perform_synthetic_scoring_and_update_weights(self):
while True:
for steps_passed in itertools.count():
@@ -97,10 +242,11 @@ async def perform_synthetic_scoring_and_update_weights(self):
f"Updating weights in {iterations_per_set_weights - steps_since_last_update - 1} iterations."
)
- await asyncio.sleep(10)
+ # if we want to slow down the speed of the validator steps
+ await asyncio.sleep(100)
def select_validator(self, steps_passed):
- return self.text_vali if steps_passed % 5 in (0, 1, 2, 3) else self.image_vali
+ return self.text_vali if steps_passed % 10 in (0, 1, 2, 3, 4, 5, 6, 7, 8) else self.image_vali
async def get_available_uids(self):
"""Get a dictionary of available UIDs and their axons asynchronously."""
@@ -176,34 +322,7 @@ async def set_weights(self, scores):
uids=self.metagraph.uids,
weights=self.moving_average_scores,
wait_for_inclusion=False,
- version_key=template.__weights_version__,
+ version_key=cortext.__weights_version__,
)
)
bt.logging.success("Successfully set weights.")
-
- def register_text_validator_organic_query(
- self,
- uid_to_response: dict[int, str], # [(uid, response)]
- messages_dict: dict[int, str],
- ):
- self.organic_scoring_tasks.add(asyncio.create_task(
- wait_for_coro_with_limit(
- self.text_vali.score_responses(
- query_responses=list(uid_to_response.items()),
- uid_to_question=messages_dict,
- metagraph=self.metagraph,
- ),
- scoring_organic_timeout
- )
- ))
-
-
-class TestWeightSetter(WeightSetter):
- def select_validator(self, steps_passed):
- return self.text_vali
-
- async def get_available_uids(self):
- return {i: None for i in range(len(self.metagraph.hotkeys))}
-
- def shuffled(self, list_: list) -> list:
- return list_