Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine-tuning support #248

Draft
wants to merge 41 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
d4c7d05
First pass at fine tuning support in SDK
pharmapsychotic May 22, 2023
c221c62
Fine-tuning updates:
pharmapsychotic May 23, 2023
d9a40b8
Finetune updates:
pharmapsychotic May 26, 2023
6c56a6d
Update to latest api-interfaces
pharmapsychotic Jun 8, 2023
ab4c022
Update to latest api-interfaces
pharmapsychotic Jun 15, 2023
2b3178e
Project.project_type -> Project.type
pharmapsychotic Jun 16, 2023
6665684
add pydantic dependency
brianfitzgerald Jun 20, 2023
039a5c6
Renamed object_name to object_prompt and addition of extras field for…
pharmapsychotic Jun 26, 2023
42457d2
Bf/platform 339 (#243)
brianfitzgerald Jun 30, 2023
9d1395a
Merge branch main into PLATFORM-339
pharmapsychotic Jun 30, 2023
45ed1d2
Update for multi-LoRA and using model names in prompt
pharmapsychotic Jun 30, 2023
6b8397e
Fix zip handling and fold install/cell cells
pharmapsychotic Jun 30, 2023
8326654
Extra error handling and logging for image uploads
pharmapsychotic Jun 30, 2023
8171e66
Add fine tune support to StabilityInference
pharmapsychotic Jul 5, 2023
75188fb
Improve training status display
pharmapsychotic Jul 6, 2023
33071cc
Fine-tuning updates:
pharmapsychotic Jul 11, 2023
99ab453
Better prompt model substitution
pharmapsychotic Jul 12, 2023
a2784e3
Clean up
pharmapsychotic Jul 14, 2023
f011706
Update notebook for SDXL 0.9 fine-tuning
pharmapsychotic Jul 19, 2023
0a04175
Fixes to <model:weight> syntax ensure the :weight is stripped and not…
pharmapsychotic Jul 21, 2023
fb1fa2f
Merge branch main into PLATFORM-339
pharmapsychotic Jul 21, 2023
c519f6a
Separate max image counts for face/object/style
pharmapsychotic Jul 24, 2023
e905c2a
Update client.StabilityInference to also use `parse_models_from_promp…
pharmapsychotic Jul 25, 2023
c35a25a
Increase max image size
pharmapsychotic Aug 2, 2023
4c6dcc3
Add `stable-diffusion-xl-1024-v1-0` to finetune notebook
pharmapsychotic Aug 2, 2023
25d92f9
Bugfix: regex for token replacement in prompt
brianfitzgerald Aug 10, 2023
5b01ef3
fix another edge case, and add unit tests
brianfitzgerald Aug 10, 2023
4855603
update notebook defaults
brianfitzgerald Aug 10, 2023
5d3b669
update image upload logic
brianfitzgerald Aug 14, 2023
00611e6
add external test link
brianfitzgerald Aug 28, 2023
374c3fa
feat: added fine-tuning REST API Colab
todd-elvers Sep 20, 2023
3bb15ae
feat: added fine-tuning REST API Colab
todd-elvers Sep 20, 2023
81418cd
chore: updated REST API colab to refrence prod preview url
todd-elvers Sep 22, 2023
08bf3ac
fix: corrected an indentation error
todd-elvers Sep 22, 2023
4fb309a
fix: corrected an indentation error
todd-elvers Sep 22, 2023
d5ef8f2
fix: corrected an indentation error
todd-elvers Sep 22, 2023
cde43e1
chore: added a Colab I made for Adil
todd-elvers Sep 28, 2023
e8c9145
chore: fixed a small bug in existing colab
todd-elvers Oct 2, 2023
f0fffd4
chore: more corrections to REST colabs
todd-elvers Oct 2, 2023
7208fa5
Use production endpoint for fine tuning notebook
brianfitzgerald Oct 5, 2023
61dd932
feat: updated REST API colab to allow for using more than one fine-tu…
todd-elvers Oct 5, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ pyenv/
.env
generation-*.pb.json
Pipfile*
.DS_Store
992 changes: 992 additions & 0 deletions nbs/Fine_Tuning_REST_API_External_Test.ipynb

Large diffs are not rendered by default.

357 changes: 357 additions & 0 deletions nbs/Fine_tuning_SDK_external_test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,357 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "GM-nO4oBegPm"
},
"source": [
"# Stability Fine-Tuning SDK - Dev Test\n",
"\n",
"Thank you for trying the first external beta of the Stability Fine Tuning SDK! Please reach out to us if you have any questions or run into issues using the service. Note that this is a **developer beta** - bugs and quality issues with the generated fine-tunes may occur. Please reach out to Stability if this is the case - and share what you've made as well!\n",
"\n",
"Feel free to implement the gRPC SDK below in your own code, though be warned that the API below is subject to change before public release. A REST API will also be available in the near future.\n",
"\n",
"Known issues:\n",
"\n",
"* Style fine-tunes may result in overfitting - if this is the case, lower the model strength in the prompt - i.e. the `0.7` in `<model_id:0.7>` within the prompt. You may need to go as low as 0.2 or 0.1.\n",
"* We will be exposing test parameters soon - please reach out with examples of datasets that produce overfitting or errors if you have them.\n",
"* Current input image limits are 3 minimum for all modes, 128 maximum for style fine-tuning, and 64 maximum for all other modes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "T9ma2X7bhH8y",
"outputId": "857cf355-b95c-43c6-c630-d2e680ee201b"
},
"outputs": [],
"source": [
"#@title Install Stability SDK with fine-tuning support\n",
"import getpass\n",
"import io\n",
"import logging\n",
"import os\n",
"import shutil\n",
"import sys\n",
"import time\n",
"from IPython.display import clear_output\n",
"from pathlib import Path\n",
"from zipfile import ZipFile\n",
"\n",
"if os.path.exists(\"../src/stability_sdk\"):\n",
" sys.path.append(\"../src\") # use local SDK src\n",
"else:\n",
" path = Path('stability-sdk')\n",
" if path.exists():\n",
" shutil.rmtree(path)\n",
" !pip uninstall -y stability-sdk\n",
" !git clone -b \"PLATFORM-339\" --recurse-submodules https://github.com/Stability-AI/stability-sdk\n",
" !pip install ./stability-sdk"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5LM5SUUOhH8z",
"outputId": "ace3ee80-7405-40f0-f18c-476101200900"
},
"outputs": [],
"source": [
"#@title Connect to the Stability API\n",
"from stability_sdk.api import Context, generation\n",
"from stability_sdk.finetune import (\n",
" create_model, delete_model, get_model, list_models, resubmit_model, update_model,\n",
" FineTuneMode, FineTuneParameters, FineTuneStatus\n",
")\n",
"\n",
"# @markdown To get your API key visit https://dreamstudio.ai/account. Ensure you are added to the whitelist during external test!\n",
"STABILITY_HOST = \"grpc.stability.ai:443\"\n",
"STABILITY_KEY = \"\"\n",
"\n",
"engine_id = \"stable-diffusion-xl-1024-v1-0\"\n",
"\n",
"# Create API context to query user info and generate images\n",
"context = Context(STABILITY_HOST, STABILITY_KEY, generate_engine_id=engine_id)\n",
"(balance, pfp) = context.get_user_info()\n",
"print(f\"Logged in org:{context._user_organization_id} with balance:{balance}\")\n",
"\n",
"# Redirect logs to print statements so we can see them in the notebook\n",
"class PrintHandler(logging.Handler):\n",
" def emit(self, record):\n",
" print(self.format(record))\n",
"logging.getLogger().addHandler(PrintHandler())\n",
"logging.getLogger().setLevel(logging.INFO)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "iQL5dFsfhH8z",
"outputId": "bfc71c3d-e235-4e10-e352-14add4a100c9"
},
"outputs": [],
"source": [
"#@title List fine-tuned models for this user / organization.\n",
"models = list_models(context, org_id=context._user_organization_id)\n",
"print(f\"Found {len(models)} models\")\n",
"for model in models:\n",
" print(f\" Model {model.id} {model.name} {model.status}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "t_gOf0i_gmCd"
},
"source": [
"For training, we need a dataset of images. Please only upload images that you have the permission to use. This can be a folder of images or a .zip file containing your images. Images can be of any aspect ratio, as long as they obey a minimum size of 384px on the shortest side, and a maximum size of 1024px on the longest side. Datasets can range from a minimum of 4 images to a maximum of 128 images.\n",
"\n",
"A larger dataset often tends to result in a more accurate model, but will also take longer to train.\n",
"\n",
"While each mode can accept up to 128 images, we have a few suggestions for a starter dataset based on the mode you are using:\n",
"\n",
"\n",
"\n",
"* Face: 6 or more images.\n",
"* Object: 6 - 10 images.\n",
"* Style: 20 - 30 images.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 395
},
"id": "9C1YOFxIhTJp",
"outputId": "f65c6762-6bd3-4cf6-f3e4-a736c206e32f"
},
"outputs": [],
"source": [
"#@title Upload ZIP file of images.\n",
"training_dir = \"./train\"\n",
"Path(training_dir).mkdir(exist_ok=True)\n",
"try:\n",
" from google.colab import files\n",
"\n",
" upload_res = files.upload()\n",
" extracted_dir = list(upload_res.keys())[0]\n",
" print(f\"Received {extracted_dir}\")\n",
" if not extracted_dir.endswith(\".zip\"):\n",
" raise ValueError(\"Uploaded file must be a zip file\")\n",
"\n",
" zf = ZipFile(io.BytesIO(upload_res[extracted_dir]), \"r\")\n",
" extracted_dir = Path(extracted_dir).stem\n",
" print(f\"Extracting to {extracted_dir}\")\n",
" zf.extractall(extracted_dir)\n",
"\n",
" for root, dirs, files in os.walk(extracted_dir):\n",
" for file in files:\n",
"\n",
" source_path = os.path.join(root, file)\n",
" target_path = os.path.join(training_dir, file)\n",
"\n",
" if 'MACOSX' in source_path or 'DS' in source_path:\n",
" continue\n",
" print('Adding input image: ', source_path, target_path)\n",
" # Move the file to the target directory\n",
" shutil.move(source_path, target_path)\n",
"\n",
"\n",
"except ImportError:\n",
" pass\n",
"\n",
"print(f\"Using training images from: {training_dir}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8xJLYZ4fgoU9"
},
"source": [
"Now we're ready to train our model. Specify parameters like the name of your model, the training mode, and the guiding prompt for object mode training.\n",
"\n",
"Please note that the training duration will vary based on the size of your dataset, the training mode or the engine that is being fine-tuned on.\n",
"\n",
"However, the following are some rough estimates for the training duration for each mode based on our recommended dataset sizes:\n",
"\n",
"* Face: 4 - 5 minutes.\n",
"* Object: 5 - 10 minutes.\n",
"* Style: 20 - 30 minutes.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "VLyYQVM3hH8z",
"outputId": "ab994fbd-71cc-49fe-9829-d1b456dc15c4"
},
"outputs": [],
"source": [
"#@title Perform fine-tuning\n",
"model_name = \"elliot-dev\" #@param {type:\"string\"}\n",
"#@markdown > Model names are unique, and may only contain numbers, letters, and hyphens.\n",
"training_mode = \"face\" #@param [\"face\", \"style\", \"object\"] {type:\"string\"}\n",
"#@markdown > The Face training_mode expects pictures containing a face, and automatically crops and centers on the face detected in the input photos. Object segments out the object specified with the prompt below; and Style simply crops the images and filters for image quality.\n",
"object_prompt = \"cat\" #@param {type:\"string\"}\n",
"#@markdown > The Object Prompt is used for segmenting out your subject in the Object fine tuning mode - i.e. if you want to fine tune on a cat, put `cat` - for a bottle of liquor, use `bottle`. In general, it's best to use the most general word you can to describe your object.\n",
"\n",
"print(training_dir)\n",
"print(len(os.listdir(training_dir)))\n",
"# Gather training images\n",
"images = []\n",
"for filename in os.listdir(training_dir):\n",
" if os.path.splitext(filename)[1].lower() in ['.png', '.jpg', '.jpeg']:\n",
" images.append(os.path.join(training_dir, filename))\n",
"\n",
"# Create the fine-tune model\n",
"params = FineTuneParameters(\n",
" name=model_name,\n",
" mode=FineTuneMode(training_mode),\n",
" object_prompt=object_prompt,\n",
" engine_id=engine_id,\n",
")\n",
"model = create_model(context, params, images)\n",
"print(f\"Model {model_name} created.\")\n",
"print(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "yEKyO3-bhH8z",
"outputId": "50b70059-833a-45ca-ae49-66fa6564ed2c"
},
"outputs": [],
"source": [
"#@title Check on training status\n",
"start_time = time.time()\n",
"while model.status != FineTuneStatus.COMPLETED and model.status != FineTuneStatus.FAILED:\n",
" model = get_model(context, model.id)\n",
" elapsed = time.time() - start_time\n",
" clear_output(wait=True)\n",
" print(f\"Model {model.name} ({model.id}) status: {model.status} for {elapsed:.0f} seconds\")\n",
" time.sleep(5)\n",
"\n",
"clear_output(wait=True)\n",
"status_message = \"completed\" if model.status == FineTuneStatus.COMPLETED else \"failed\"\n",
"print(f\"Model {model.name} ({model.id}) {status_message} after {elapsed:.0f} seconds\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Qr4jBHX7hH8z"
},
"outputs": [],
"source": [
"#@title If fine-tuning fails for some reason, you can resubmit the model\n",
"if model.status == FineTuneStatus.FAILED:\n",
" print(\"Training failed, resubmitting\")\n",
" model = resubmit_model(context, model.id)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "-Ugkjgy2hH8z",
"outputId": "93517776-9578-4d1c-ec84-68dba4b75085"
},
"outputs": [],
"source": [
"#@title <font color=\"#FFFFFF\">9. Generate images from your fine-tuned model\n",
"results = context.generate(\n",
" prompts=[f\"Illustration of <{model.id}:1> as a wizard\"],\n",
" weights=[1],\n",
" width=1024,\n",
" height=1024,\n",
" seed=42,\n",
" steps=40,\n",
" sampler=generation.SAMPLER_DDIM,\n",
" preset=\"photographic\",\n",
")\n",
"image = results[generation.ARTIFACT_IMAGE][0]\n",
"display(image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3BZLVniihH8z"
},
"outputs": [],
"source": [
"#@title Models can be updated to change settings before a resubmit or after training to rename\n",
"update_model(context, model.id, name=\"cat-ft-01-renamed\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eUFTMZOvhH80"
},
"outputs": [],
"source": [
"#@title Delete the model when it's no longer needed\n",
"delete_model(context, model.id)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 0
}
Loading