Skip to content

Commit

Permalink
Release current version of ml-platform (GoogleCloudPlatform#819)
Browse files Browse the repository at this point in the history
* Added product category preprocessing  (GoogleCloudPlatform#733)

* Added product category cleanup

* Added architecture diagrams

* Data processing observability (GoogleCloudPlatform#752)

* adding dataprep for fine tuning

* taking in the new code with logging

* finetune use cases

* Updating the Google Terraform provider version

* Switched to Google managed DCGM

* variables for batch size and gradient

* Moved the git configuration to a secret manager secret

* Converted use case scripts to use structured JSON logging

* Standardized cloud build, logging, and signal handling

* Standardized bucket name variables and cloud build commands

* Simplied variable configuration and environment setup

* Added l4 deployment option to model-eval

* Add steps to install MLflow

* Create end-2-end-narrative.md

* Updated test scripts

* Enhancements to the mlflow logging and metrics for fine-tuning (GoogleCloudPlatform#813)

---------

Co-authored-by: xiangshen-dk <shenxiang@google.com>
Co-authored-by: Kavitha Rajendran <103603287+karajendran@users.noreply.github.com>
Co-authored-by: Kavitha Rajendran <karajendran@google.com>
Co-authored-by: Kent Hua <kenthua@google.com>
Co-authored-by: Shobhit Gupta <43795024+gushob21@users.noreply.github.com>
Co-authored-by: gushob21 <gushob@google.com>
Co-authored-by: Xiang Shen <shen.xiang@gmail.com>
  • Loading branch information
8 people authored Sep 17, 2024
1 parent 6114d02 commit 0f771a1
Show file tree
Hide file tree
Showing 138 changed files with 160,530 additions and 1,041 deletions.
5 changes: 5 additions & 0 deletions best-practices/ml-platform/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
.venv/
mlruns/
venv/

.python-version
test/log/*.log
test/scripts/locks/*.lock
4 changes: 3 additions & 1 deletion best-practices/ml-platform/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ For an outline of products and features used in the platform, see the [Platform

## Use cases

- [Distributed Data Processing with Ray](examples/use-case/ray/dataprocessing/README.md): Run a distributed data processing job using Ray.
- [Distributed Data Processing with Ray](examples/use-case/data-processing/ray/README.md): Run a distributed data processing job using Ray.
- [Dataset Prepartion for Fine Tuning Gemma IT With Gemini Flash](examples/use-case/data-preparation/gemma-it/README.md): Generate prompts for fine tuning Gemma Instruction Tuned model with Vertex AI Gemini Flash
- [Fine Tuning Gemma2 9B IT model With FSDP](examples/use-case/fine-tuning/pytorch/README.md): Fine tune Gemma2 9B IT model with PyTorch FSDP

## Resources

Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Playground Machine learning platform (MLP) on GKE: Architecture

![Playground Architecture](/best-practices/ml-platform/docs/images/platform/playground/mlp_playground_architecture.svg)

## Platform

- [Google Cloud Project](https://console.cloud.google.com/cloud-resource-manager)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

This document outlines the products and features that are used in the platform.

## Cloud Logging
![Playground Architecture](/best-practices/ml-platform/docs/images/platform/playground/mlp_playground_architecture_full.svg)

## Cloud Logging

Cloud Logging is a real-time log-management system with storage, search, analysis, and monitoring support. Cloud Logging automatically collects logs from Google Cloud resources. You can also collect logs from your applications, on-premise resources, and resources from other cloud providers. You can also configure alerting policies so that Cloud Monitoring notifies you if certain kinds of events are reported in your logs. For regulatory or security reasons, you can determine where your log data is stored.

Expand Down Expand Up @@ -89,11 +91,11 @@ GKE Dataplane V2 observability offers the following troubleshooting tools:

For more information see the [Dataplane V2 observability documentation](https://cloud.google.com/kubernetes-engine/docs/concepts/about-dpv2-observability)

### Filestore CSI driver
### Filestore CSI driver

The Filestore CSI driver is the primary way to use Filestore instances with GKE. The CSI driver provides a fully-managed experience powered by the open source Google Cloud Filestore CSI driver.

The CSI driver version is tied to Kubernetes minor version numbers and is typically the latest driver available at the time that the Kubernetes minor version is released. The drivers update automatically when the cluster is upgraded to the latest GKE patch.
The CSI driver version is tied to Kubernetes minor version numbers and is typically the latest driver available at the time that the Kubernetes minor version is released. The drivers update automatically when the cluster is upgraded to the latest GKE patch.

For more information see the [Filestore CSI driver](https://cloud.google.com/kubernetes-engine/docs/how-to/persistent-volumes/filestore-csi-driver)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
## Introduction

This document presents a comprehensive guide to implementing the ML use case from start to finish. It explores the key components involved, outlines the step-by-step process flow, and highlights crucial considerations for successful implementation.

## Use case

The goal is to build a Virtual Retail Assistant \- a chat bot that is contextually sensitive to a retailer's product catalog, capable of responding to customer inquiries and suggesting alternatives and complimentary outfits.

An instruction tuned base model can help with the conversation, but further fine-tuning is necessary to include retailer’s product catalog data and to respond in an instructed way. Throughout this process, we will explore:

* Data Set Preparation
* Fine-Tuning of the Instruction Tuned Model
* Hyperparameter Tuning to Generate Multiple Model Iterations
* Model Validation, Evaluation, and Identification of the Optimal Model

Within each topic, we will share methodologies, frameworks, tools, and lessons learned to support your journey.


## Process Flow for Implementing the ML Usecase End to End


![MLOps workflow](/best-practices/ml-platform/docs/images/use-case/MLOps_e2e.png)


## Data Preprocessing

The data preprocessing phase in MLOps is foundational. It directly impacts the quality and performance of machine learning models. Preprocessing includes tasks such as data cleaning, feature engineering, scaling, and encoding, all of which are essential for ensuring that models learn effectively from the data.

### Dataset

We are leveraging a [pre-crawled public dataset](https://www.kaggle.com/datasets/PromptCloudHQ/flipkart-products), taken as a subset (20000) of a bigger dataset (more than 5.8 million products) that was created by extracting data from [Flipkart](https://www.flipkart.com/), a leading Indian eCommerce store.

![Dataset Snapshot](/best-practices/ml-platform/docs/images/use-case/dataset_info.png)

### Data Preprocessing Steps

The dataset has 15 different columns. The columns of our interest are: 'uniq\_id', 'product\_name', 'description', 'brand' ,'product\_category\_tree', 'image' ,'product\_specifications'.

Other than dropping null values and duplicates, we will do following steps:

1. description: Cleaning up Product Description by removing stop words & punctuations
2. product\_category\_tree: Split into different columns
3. product\_specifications: Parse the Product Specifications into Key:Value pairs
4. image: Parse the list of image urls. Validate the URL and download the image.

Now, consider the scenario where a preprocessing task involves extracting multiple image URLs from each row of a large dataset and uploading the images to a Google Cloud Storage (GCS) bucket. This might sound straightforward, but with a dataset containing 20,000+ rows, each with potentially up to seven URLs, the process can become incredibly time-consuming when executed serially in Python. In our experience, such a task can take upwards of eight hours to complete.

### Parallelism: The Solution to Scalability

To tackle this scalability issue, we turn to parallelism. By breaking the dataset into smaller chunks and distributing the processing across multiple threads or processes, we can drastically reduce the overall execution time.

For implementaion steps, please check this document [Distributed Data Preprocessing with Ray](/best-practices/ml-platform/examples/use-case/data-processing/ray/README.md)


## Data Preparation

For our use case, we are fine tuning the Gemma2 9B Instruction Tuned model to align it with our specific requirements. The instruction-tuned nature of this model mandates that input data adheres to a natural language conversational format. However, our preprocessed product catalog is currently structured in a tabular format, necessitating a data preparation phase to construct suitable prompts.

These prompts will consist of a range of potential questions that an online shopper might pose, alongside corresponding answers derived from the product catalog. The answers will be presented in a contextually relevant manner, simulating the interaction between a shopper and a knowledgeable assistant within an e-commerce environment.



**Methodology:**

The dataset underwent a filtering process to concentrate exclusively on clothing items, catering specifically to Women, Men, and Kids. Products with less than ten units in a specific category were excluded.

To optimize instruction-tuned models, it is imperative that each data point consists of three distinct components: Question, Answer and Context

The Gemini Flash model (through Vertex AI) was used for the generation of natural shopping-style conversational questions/Inputs and answers/Outputs. The formulation of these Input and Output was based on the product specifications detailed within the catalog.

The prompts were created utilizing the previously generated questions and answers. The product category is used to set the context of the conversation.

**Prompt Sample:**

```
<start_of_turn>user
Context:Online shopping for Men's Clothing
What kind of material is the SKOOKIE Sleeveless Solid Men's Jacket made of?<end_of_turn> <start_of_turn>model
The SKOOKIE Sleeveless Solid Men's Jacket is made of a blinded fabric.
Product Name: SKOOKIE Sleeveless Solid Men's Jacket
Product Category: Men's Clothing
Product Details:
- Sleeve: Sleeveless
- Reversible: No
- Fabric: Blinded
- Pattern: Solid
- Ideal For: Men's
- Style Code: SKMJ-3005-C-Blue
<end_of_turn>
```

The ‘End Of Sequence’ token was appended to each prompt.

EOS\_TOKEN \= '\<eos\>'

For implementaion steps, please check this document [Data preparation for fine tuning Gemma IT model](/best-practices/ml-platform/examples/use-case/data-preparation/gemma-it/README.md)


## Fine-tuning

With our prepared dataset, we proceed to fine-tune the Gemma model using PyTorch's Low-Rank Adaptation (LoRA) within the Parameter-Efficient Fine-Tuning (PEFT) framework. This approach incorporates Transformer Reinforcement Learning's (TRL) Supervised Fine-Tuning (SFT) methodology. The SFT process is initiated by providing the *`SFTrainer`* with both the dataset and the specific tuning parameter configurations.

While single-GPU setups without quantization might be sufficient, scaling to multiple GPUs across multiple hosts necessitates sharding and distributed workload management. We employ Fully Sharded Data Parallelism ([FSDP](https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/)) for model sharding and leverage [Accelerate](https://huggingface.co/docs/accelerate/en/index) for efficient distributed training.

### **Initial Fine-Tuning and Optimization**

The initial fine-tuning run utilized a baseline set of parameters. Accelerate aggregates data across all shards and saves the model to a mounted Google Cloud Storage (GCS) bucket for convenient access and persistence.

Having established a working baseline, our focus shifts to optimization. Our goal is to minimize loss while ensuring the model achieves a satisfactory level of accuracy. We'll systematically adjust hyperparameters such as learning rate, the number of training epochs to identify the optimal configuration for our specific task and dataset.

## Hyperparameter Tuning to Generate Multiple Model Iterations

Fine-tuning large language models like Gemma involves carefully adjusting various hyperparameters to optimize performance and tailor the model's behavior for specific tasks. These hyperparameters control aspects of the model's architecture, training process, and optimization algorithm.

In our fine-tuning process, we focused on the following key parameters:

### **Core LoRA Parameters**

* **lora\_r (Rank):** This parameter controls the rank of the low-rank matrices used in LoRA. A higher rank increases the expressiveness of the model but also increases the number of trainable parameters.
* **lora\_alpha (Alpha):** This scaling factor balances the contribution of the LoRA updates to the original model weights. A larger alpha can lead to faster adaptation but may also introduce instability.
* **lora\_dropout (Dropout):** Dropout is a regularization technique that helps prevent overfitting. Applying dropout to LoRA layers can further improve generalization.

### **Learning Rate, Optimization, and Training Duration**

* **learning\_rate (Learning Rate):** The learning rate determines the step size the optimizer takes during training. Finding the optimal learning rate is crucial for efficient convergence and avoiding overshooting.
* **epochs (Number of Epochs):** The number of epochs dictates how many times the model sees the entire training dataset. More epochs generally lead to better performance, but overfitting can become a concern.
* **max\_grad\_norm (Maximum Gradient Norm):** Gradient clipping helps prevent exploding gradients, which can destabilize training. This parameter limits the maximum norm of the gradients.
* **weight\_decay (Weight Decay):** Weight decay is a regularization technique that adds a penalty to the loss function based on the magnitude of the model weights. This helps prevent overfitting.
* **warmup\_ratio (Warmup Ratio):** This parameter determines the proportion of the total training steps during which the learning rate is gradually increased from zero to its initial value. Warmup can improve stability and convergence.

### **Sequence Length Considerations**

* **max\_seq\_length (Maximum Sequence Length):** This parameter controls the maximum length of the input sequences the model can process. Longer sequences can provide more context but also require more computational resources.

### Experimental Values

The following table displays the experimental values across multiple jobs we used for the aforementioned parameters:

| parameter | base | job-0 | job-1 | job-2 |
| :---- | ----- | ----- | ----- | ----- |
| epochs | 1 | 2 | 3 | 4 |
| lora\_r | 8 | 8 | 16 | 32 |
| lora\_alpha | 16 | 16 | 32 | 64 |
| lora\_dropout | 0.1 | 0.1 | 0.2 | 0.3 |
| max\_grad\_norm | 0.3 | 1 | 1 | 1 |
| learning\_rate | 2.00E+04 | 2.00E+05 | 2.00E+04 | 3.00E+04 |
| weight\_decay | 0.001 | 0.01 | 0.005 | 0.001 |
| warmup\_ratio | 0.03 | 0.1 | 0.2 | 0.3 |
| max\_seq\_length | 512 | 1024 | 2048 | 8192 |

### Monitoring Hyperparamer tuning jobs

Tracking multiple hyperparameter tuning jobs with varying parameters and metrics requires a mechanism to track input parameters, monitor the training progress and share the results. Custom tools can be developed, but there are also open-source options like [MLflow](https://www.mlflow.org/).

MLflow is an open-source platform for tracking experiments, including hyperparameter tuning. It provides a central repository for storing and organizing experimental data, making it easy to compare and analyze different configurations and results. MLflow was deployed onto the platform cluster to track our experiments.

**Incorporating MLflow**

Incorporating MLflow into your code is straightforward. Import the MLflow library and set the tracking server URI. Enable system metrics logging by setting the MLFLOW\_ENABLE\_SYSTEM\_METRICS\_LOGGING environment variable and installing the psutil and pynvml libraries.

```py
import mlflow

mlflow.set_tracking_uri(remote_server_uri)
mlflow.autolog()
```

This will allow MLflow to track fine-tuning parameters, results, and system-level metrics like CPU, memory, and GPU utilization. This added layer of monitoring provides a comprehensive view of batch fine-tuning jobs, making it easier to compare different configurations and results.

![epoch_vs_loss](/best-practices/ml-platform/docs/images/use-case/mlflow_epoch_loss.png)

![ml_flow_](/best-practices/ml-platform/docs/images/use-case/MLFlow_experiment_tracking.png)



Alternative solutions, such as MLflow and Weights & Biases, offer additional capabilities. While MLflow provides comprehensive pipeline features, our immediate requirements are satisfied by its core tracking functionality.

Integrating MLflow was relatively simple, involving the deployment of a tracking server and a local SQLite database. While the file-based approach is swift and straightforward, it becomes inefficient at scale. In a production-grade deployment, organizations may opt for a managed Postgres database (e.g., Cloud SQL, AlloyDB) to host MLflow data.

## Model evaluation and validation

After each fine-tuning cycle, the model must be validated for precision and accuracy. This assessment involves a two-stage process:

### Prediction Stage

* The fine-tuned model, loaded into vLLM for efficient serving, is presented with prompts derived from the test dataset.
* Model outputs (predictions) are captured and stored for subsequent analysis.

### Accuracy Calculation

Accuracy calculation will be performed on the predictions by comparing the responses from the model with the product data from the training dataset.

Key metrics that this task will output the following fields:

* **Training Data Set Size:** The total number of samples in the training dataset.
* **Data Match:** The count of predictions that exactly align with the ground truth.
* **False Positives:** The count of predictions where the model incorrectly recommended a product.
* **No Product Details:** The count of instances where the model failed to extract product details when they were present in the training dataset
* **Calculated Accuracy:** The overall proportion of correct predictions made by the model.
* **Calculated Precision:** The proportion of positive predictions made by the model that were actually correct.

### Our Results

| Model | Job | Training Data Size | Test Data Size | True Positives | False Positives | No Product Details | Accuracy (match/total) | Precision (match/(total-false)) |
| :---- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |
| Gemma 1.1 7b-it | 1 | 3902 | 976 | 815 | 160 | 1 | 83.50% | 83.59% |
| Gemma 1.1 7b-it | 2 | 3902 | 976 | 737 | 225 | 14 | 75.51% | 76.61% |
| Gemma 2 9b-it | 1 | 3902 | 976 | 836 | 137 | 4 | 85.66% | 86.01% |
| Gemma 2 9b-it | 2 | 3902 | 976 | 812 | 164 | 0 | 83.20% | 83.20% |
| Gemma 2 9b-it | 1 | 12838 | 3210 | 2838 | 372 | 0 | 88.41% | 88.41% |
| Gemma 2 9b-it | 2 | 12838 | 3210 | 2482 | 719 | 9 | 77.32% | 77.54% |

### Interpreting Results and Decision-Making

Above evaluation provides a granular understanding of the model's performance. By examining the accuracy, precision, and the distribution of errors (false positives, no product details), the team can make informed decisions about the model's suitability for deployment:

* **Acceptable Accuracy Threshold:** Determine a minimum accuracy level required for the model to be considered effective.
* **Error Analysis:** Investigate the types of errors the model makes to identify potential areas for improvement in future fine-tuning iterations.
* **Deployment Decision:** If the model meets or exceeds the predetermined criteria, it can be confidently deployed for real-world use. If not, further refinement or adjustments may be necessary.

**Additional Considerations**

* Ensure the test dataset accurately reflects the real-world data distribution the model will encounter in production.
* Consider the size of the test dataset to assess the statistical significance of the results.

For implementaion steps, please check this document [Fine tuning Gemma IT model](/best-practices/ml-platform/examples/use-case/fine-tuning/pytorch/README.md)
Loading

0 comments on commit 0f771a1

Please sign in to comment.