Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Sep 13, 2024
1 parent adc7422 commit 9d61ef6
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 241 deletions.
4 changes: 2 additions & 2 deletions cortext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from openai import AsyncOpenAI

from cortext.protocol import StreamPrompting, TextPrompting, Embeddings, ImageResponse, IsAlive
from cortext.protocol import StreamPrompting, Embeddings, ImageResponse, IsAlive

load_dotenv()
try:
Expand Down Expand Up @@ -3768,4 +3768,4 @@
]


ALL_SYNAPSE_TYPE = Union[StreamPrompting, TextPrompting, Embeddings, ImageResponse, IsAlive]
ALL_SYNAPSE_TYPE = Union[StreamPrompting, Embeddings, ImageResponse, IsAlive]
103 changes: 11 additions & 92 deletions cortext/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,95 +322,14 @@ def extract_info(prefix: str) -> dict[str, str]:
"axon": extract_info("bt_header_axon"),
"messages": self.messages,
"completion": self.completion,
}


class TextPrompting(bt.Synapse):
messages: List[Dict[str, Union[str, List[Dict[str, Union[str, Dict[str, str]]]]]]] = pydantic.Field(
...,
title="Messages",
description="A list of messages in the StreamPrompting scenario, "
"each containing a role and content. Immutable.",
allow_mutation=False,
)

required_hash_fields: List[str] = pydantic.Field(
["messages"],
title="Required Hash Fields",
description="A list of required fields for the hash.",
allow_mutation=False,
)

seed: int = pydantic.Field(
default="1234",
title="Seed",
description="Seed for text generation. This attribute is immutable and cannot be updated.",
)

temperature: float = pydantic.Field(
default=0.0001,
title="Temperature",
description="Temperature for text generation. "
"This attribute is immutable and cannot be updated.",
)

max_tokens: int = pydantic.Field(
default=2048,
title="Max Tokens",
description="Max tokens for text generation. "
"This attribute is immutable and cannot be updated.",
)

top_p: float = pydantic.Field(
default=0.001,
title="Top_p",
description="Top_p for text generation. The sampler will pick one of "
"the top p percent tokens in the logit distirbution. "
"This attribute is immutable and cannot be updated.",
)

top_k: int = pydantic.Field(
default=1,
title="Top_k",
description="Top_k for text generation. Sampler will pick one of "
"the k most probablistic tokens in the logit distribtion. "
"This attribute is immutable and cannot be updated.",
)

completion: str = pydantic.Field(
None,
title="Completion",
description="Completion status of the current StreamPrompting object. "
"This attribute is mutable and can be updated.",
)

provider: str = pydantic.Field(
default="OpenAI",
title="Provider",
description="The provider to use when calling for your response. "
"Options: OpenAI, Anthropic, Gemini, Groq, Bedrock",
)

model: str = pydantic.Field(
default="gpt-3.5-turbo",
title="model",
description="The model to use when calling provider for your response.",
)

uid: int = pydantic.Field(
default=3,
title="uid",
description="The UID to send the streaming synapse to",
)

timeout: int = pydantic.Field(
default=60,
title="timeout",
description="The timeout for the dendrite of the streaming synapse",
)

streaming: bool = pydantic.Field(
default=True,
title="streaming",
description="whether to stream the output",
)
"provider": self.provider,
"model": self.model,
"seed": self.seed,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"timeout": self.timeout,
"streaming": self.streaming,
"uid": self.uid,
}
11 changes: 3 additions & 8 deletions cortext/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,8 @@
import re
import io
import torch
import openai
import typing
import difflib
import asyncio
import logging
import aiohttp
import requests
import traceback
import numpy as np
from numpy.linalg import norm
Expand Down Expand Up @@ -71,8 +66,8 @@ async def api_score(api_answer: str, response: str, weight: float, temperature:
words_in_response = len(response.split())
words_in_api = len(api_answer.split())

word_count_over_threshold = words_in_api * 1.20
word_count_under_threshold = words_in_api * 0.60
word_count_over_threshold = words_in_api * 1.4
word_count_under_threshold = words_in_api * 0.50

# Check if the word count of the response is within the thresholds
if words_in_response <= word_count_over_threshold and words_in_response >= word_count_under_threshold:
Expand Down Expand Up @@ -158,7 +153,7 @@ def calculate_image_similarity(image, description, max_length: int = 77):
# Calculate cosine similarity
return torch.cosine_similarity(image_embedding, text_embedding, dim=1).item()

async def dalle_score(uid, url, desired_size, description, weight, similarity_threshold=0.23) -> float:
async def dalle_score(uid, url, desired_size, description, weight, similarity_threshold=0.21) -> float:
"""Calculate the image score based on similarity and size asynchronously."""

if not re.match(url_regex, url):
Expand Down
8 changes: 4 additions & 4 deletions env.example
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
ENV=test
# for validators
WANDB_API_KEY=

# used both by validator and miner:
OPENAI_API_KEY=
PIXABAY_API_KEY=

# For validators and miners
GOOGLE_API_KEY=
ANTHROPIC_API_KEY=
GROQ_API_KEY=test
AWS_ACCESS_KEY=
AWS_SECRET_KEY=
PIXABAY_API_KEY=

2 changes: 1 addition & 1 deletion miner/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from cortext import ImageResponse, TextPrompting, StreamPrompting
from cortext import ImageResponse, StreamPrompting
from miner.providers import OpenAI, Anthropic, AnthropicBedrock, Groq, Gemini, Bedrock

task_image = ImageResponse.__name__
Expand Down
4 changes: 2 additions & 2 deletions miner/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from starlette.types import Send
from abc import abstractmethod

from cortext.protocol import StreamPrompting, TextPrompting, Embeddings, ImageResponse, IsAlive
from cortext.protocol import StreamPrompting, Embeddings, ImageResponse, IsAlive
from cortext import ALL_SYNAPSE_TYPE
from cortext.metaclasses import ProviderRegistryMeta
from miner.error_handler import error_handler
Expand All @@ -15,7 +15,7 @@ def __init__(self, synapse: ALL_SYNAPSE_TYPE):
self.model = synapse.model
self.uid = synapse.uid
self.timeout = synapse.timeout
if type(synapse) in [StreamPrompting, TextPrompting]:
if type(synapse) in [StreamPrompting]:
self.messages = synapse.messages
self.required_hash_fields = synapse.required_hash_fields
self.seed = synapse.seed
Expand Down
19 changes: 0 additions & 19 deletions miner/services/text.py

This file was deleted.

21 changes: 9 additions & 12 deletions start_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import cortext
from cortext.utils import get_version, send_discord_alert

default_address = "wss://bittensor-finney.api.onfinality.io/public-ws"
webhook_url = ""
current_version = cortext.__version__

Expand All @@ -25,7 +24,7 @@ def update_and_restart(pm2_name, netuid, wallet_name, wallet_hotkey, address, au
if current_version != latest_version and latest_version != None:
if not autoupdate:
send_discord_alert(
f"Your validator not running the latest code ({current_version}). You will quickly lose vturst if you don't update to version {latest_version}",
f"Your validator not running the latest code ({current_version}). You will quickly lose vtrust if you don't update to version {latest_version}",
webhook_url)
print("Updating to the latest version...")
subprocess.run(["pm2", "delete", pm2_name])
Expand All @@ -49,16 +48,14 @@ def update_and_restart(pm2_name, netuid, wallet_name, wallet_hotkey, address, au
description="Automatically update and restart the validator process when a new version is released."
)

parser.add_argument("--pm2_name", required=True, help="Name of the PM2 process.")
parser.add_argument("--wallet_name", required=True, help="Name of the wallet.")
parser.add_argument("--wallet_hotkey", required=True, help="Hotkey for the wallet.")
parser.add_argument("--netuid", required=True, help="netuid for validator")
parser.add_argument("--subtensor.chain_endpoint", default=default_address, dest='address',
help="Subtensor chain_endpoint, defaults to 'wss://bittensor-finney.api.onfinality.io/public-ws' if not provided.")
parser.add_argument("--autoupdate", action='store_true', dest='autoupdate',
help="Disable automatic update. Only send a Discord alert. Add your webhook at the top of the script.")
parser.add_argument("--logging", required=False, default="debug")
parser.add_argument("--wandb_on", action='store_true', required=False, dest='wandb_on')
parser.add_argument("--pm2_name", required=False, default="autoupdater", help="Name of the PM2 process.")
parser.add_argument("--wallet_name", required=False, default="default", help="Name of the wallet.")
parser.add_argument("--wallet_hotkey", required=False, default="default", help="Hotkey for the wallet.")
parser.add_argument("--netuid", required=False, default=18, help="netuid for validator")
parser.add_argument("--subtensor.chain_endpoint", required=False, default="wss://entrypoint-finney.opentensor.ai:443", dest="address")
parser.add_argument("--autoupdate", action='store_true', dest="autoupdate")
parser.add_argument("--logging", required=False, default="info")
parser.add_argument("--wandb_on", action='store_true', required=False, dest="wandb_on")
parser.add_argument("--max_miners_cnt", type=int, default=30)

args = parser.parse_args()
Expand Down
2 changes: 1 addition & 1 deletion validators/services/validators/base_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, config, metagraph):
self.metagraph = metagraph
self.dendrite = config.dendrite
self.wallet = config.wallet
self.timeout = config.ASYNC_TIME_OUT
self.timeout = config.async_time_out
self.streaming = False
self.provider = None
self.model = None
Expand Down
6 changes: 2 additions & 4 deletions validators/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@ def get(self, key, default=None):

class Config:
def __init__(self, args):
self.ENV = os.getenv('ENV')
self.ASYNC_TIME_OUT = int(os.getenv('ASYNC_TIME_OUT', 60))
self.SLEEP_PER_ITERATION = 1
self.IMAGE_VALIDATOR_CHOOSE_PROBABILITY = 0.01

# Add command-line arguments to the Config object
for key, value in vars(args).items():
Expand Down Expand Up @@ -66,6 +62,8 @@ def parse_arguments():
parser.add_argument("--axon.port", type=int, default=8000)
parser.add_argument("--logging.level", choices=['info', 'debug', 'trace'], default='info')
parser.add_argument("--autoupdate", action="store_true", help="Enable auto-updates")
parser.add_argument("--image_validator_probability", type=float, default=0.001)
parser.add_argument("--async_time_out", type=int, default=60)
return parser.parse_args(namespace=NestedNamespace())


Expand Down
Loading

0 comments on commit 9d61ef6

Please sign in to comment.