Skip to content

Commit

Permalink
Add HuggingFace support for automated inference checkpoint conversion (
Browse files Browse the repository at this point in the history
…GoogleCloudPlatform#712)

* Add HuggingFace support for automated inference checkpoint conversion

* Add HuggingFace support for inference checkpoint conversion

* fix llama checkpoint names

* update containers to v0.2.3 / v0.2.2

* update containers to v0.2.3 / v0.2.2
  • Loading branch information
vivianrwu authored Jul 11, 2024
1 parent ac2c2ff commit 0fd14fd
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyri
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list
RUN apt -y update && apt install -y google-cloud-cli

RUN pip install kaggle
RUN pip install kaggle && \
pip install huggingface_hub[cli] && \
pip install google-jetstream

COPY checkpoint_converter.sh /usr/bin/
RUN chmod +x /usr/bin/checkpoint_converter.sh
Expand Down
33 changes: 25 additions & 8 deletions tutorials-and-examples/inference-servers/checkpoints/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,37 @@ docker push gcr.io/${PROJECT_ID}/inference-checkpoint:latest

Now you can use it in a [Kubernetes job](../jetstream/maxtext/single-host-inference/checkpoint-job.yaml) and pass the following arguments

Jetstream + MaxText
## Jetstream + MaxText
```
- -i=INFERENCE_SERVER
- -s=INFERENCE_SERVER
- -b=BUCKET_NAME
- -m=MODEL_PATH
- -v=VERSION (Optional)
```

Jetstream + Pytorch/XLA
## Jetstream + Pytorch/XLA
```
- -i=INFERENCE_SERVER
- -s=INFERENCE_SERVER
- -m=MODEL_PATH
- -q=QUANTIZE (Optional)
- -v=VERSION
- -1=EXTRA_PARAM_1
- -2=EXTRA_PARAM_2
- -n=MODEL_NAME
- -q=QUANTIZE_WEIGHTS (Optional) (default=False)
- -t=QUANTIZE_TYPE (Optional) (default=int8_per_channel)
- -v=VERSION (Optional) (default=jetstream-v0.2.3)
- -i=INPUT_DIRECTORY (Optional)
- -o=OUTPUT_DIRECTORY
- -h=HUGGINGFACE (Optional) (default=False)
```

## Argument descriptions:
```
b) BUCKET_NAME: (str) GSBucket, without gs://
s) INFERENCE_SERVER: (str) Inference server, ex. jetstream-maxtext, jetstream-pytorch
m) MODEL_PATH: (str) Model path, varies depending on inference server and location of base checkpoint
n) MODEL_NAME: (str) Model name, ex. llama-2, llama-3, gemma
h) HUGGINGFACE: (bool) Checkpoint is from HuggingFace.
q) QUANTIZE_WEIGHTS: (str) Whether to quantize weights
t) QUANTIZE_TYPE: (str) Quantization type, QUANTIZE_WEIGHTS must be set to true. Availabe quantize type: {"int8", "int4"} x {"per_channel", "blockwise"},
v) VERSION: (str) Version of inference server to override, ex. jetstream-v0.2.2, jetstream-v0.2.3
i) INPUT_DIRECTORY: (str) Input checkpoint directory, likely a GSBucket path
o) OUTPUT_DIRECTORY: (str) Output checkpoint directory, likely a GSBucket path
```
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
#!/bin/bash

export KAGGLE_CONFIG_DIR="/kaggle"
export HUGGINGFACE_TOKEN_DIR="/huggingface"
INFERENCE_SERVER="jetstream-maxtext"
BUCKET_NAME=""
MODEL_PATH=""

print_usage() {
printf "Usage: $0 [ -b BUCKET_NAME ] [ -i INFERENCE_SERVER ] [ -m MODEL_PATH ] [ -q QUANTIZE ] [ -v VERSION ] [ -1 EXTRA_PARAM_1 ] [ -2 EXTRA_PARAM_2 ]"
printf "Usage: $0 [ -b BUCKET_NAME ] [ -s INFERENCE_SERVER ] [ -m MODEL_PATH ] [ -n MODEL_NAME ] [ -h HUGGINGFACE ] [ -q QUANTIZE_WEIGHTS ] [ -t QUANTIZE_TYPE ] [ -v VERSION ] [ -i INPUT_DIRECTORY ] [ -o OUTPUT_DIRECTORY ]"
}

print_inference_server_unknown() {
printf "Enter a valid inference server [ -i INFERENCE_SERVER ]"
printf "Enter a valid inference server [ -s INFERENCE_SERVER ]"
printf "Valid options: jetstream-maxtext, jetstream-pytorch"
}

Expand Down Expand Up @@ -43,6 +44,31 @@ download_kaggle_checkpoint() {
echo -e "\nCompleted copy of data to gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME}"
}

download_huggingface_checkpoint() {
MODEL_PATH=$1
MODEL_NAME=$2

INPUT_CKPT_DIR_LOCAL=/base/
mkdir /base/
huggingface-cli login --token $(cat ${HUGGINGFACE_TOKEN_DIR}/HUGGINGFACE_TOKEN)
huggingface-cli download ${MODEL_PATH} --local-dir ${INPUT_CKPT_DIR_LOCAL}

if [[ $MODEL_NAME == *"llama"* ]]; then
if [[ $MODEL_NAME == "llama-2" ]]; then
TOKENIZER_PATH=/base/tokenizer.model
if [[ $MODEL_PATH != *"hf"* ]]; then
HUGGINGFACE="False"
fi
else
TOKENIZER_PATH=/base/original/tokenizer.model
fi
elif [[ $MODEL_NAME == *"gemma"* ]]; then
TOKENIZER_PATH=/base/tokenizer.model
else
echo -e "Unclear of tokenizer.model for ${MODEL_NAME}. May have to manually upload."
fi
}

convert_maxtext_checkpoint() {
BUCKET_NAME=$1
MODEL_NAME=$2
Expand All @@ -60,7 +86,7 @@ convert_maxtext_checkpoint() {
cd maxtext
git checkout ${MAXTEXT_VERSION}
python3 -m pip install -r requirements.txt
echo -e "\Cloned MaxText repository and completed installing requirements"
echo -e "\nCloned MaxText repository and completed installing requirements"

python3 MaxText/convert_gemma_chkpt.py --base_model_path gs://${BUCKET_NAME}/base/${MODEL_NAME}_${VARIATION_NAME}/${VARIATION_NAME} --maxtext_model_path gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME} --model_size ${MODEL_SIZE}
echo -e "\nCompleted conversion of checkpoint to gs://${BUCKET_NAME}/final/scanned/${MODEL_NAME}_${VARIATION_NAME}"
Expand All @@ -73,59 +99,92 @@ convert_maxtext_checkpoint() {

convert_pytorch_checkpoint() {
MODEL_PATH=$1
INPUT_CKPT_DIR=$2
OUTPUT_CKPT_DIR=$3
QUANTIZE=$4
PYTORCH_VERSION=$5
JETSTREAM_VERSION=v0.2.2
MODEL_NAME=$2
HUGGINGFACE=$3
INPUT_CKPT_DIR=$4
OUTPUT_CKPT_DIR=$5
QUANTIZE_TYPE=$6
QUANTIZE_WEIGHTS=$7
PYTORCH_VERSION=$8

if [ -z $PYTORCH_VERSION ]; then
PYTORCH_VERSION=jetstream-v0.2.2
PYTORCH_VERSION=jetstream-v0.2.3
fi

CKPT_PATH="$(echo ${INPUT_CKPT_DIR} | awk -F'gs://' '{print $2}')"
BUCKET_NAME="$(echo ${CKPT_PATH} | awk -F'/' '{print $1}')"

TO_REPLACE=gs://${BUCKET_NAME}
INPUT_CKPT_DIR_LOCAL=${INPUT_CKPT_DIR/${TO_REPLACE}/${MODEL_PATH}}
OUTPUT_CKPT_DIR_LOCAL=/pt-ckpt/

if [ -z $QUANTIZE ]; then
QUANTIZE="False"
fi
OUTPUT_CKPT_DIR_LOCAL=/pt-ckpt/

git clone https://github.com/google/JetStream.git
git clone https://github.com/google/jetstream-pytorch.git
cd JetStream
git checkout ${JETSTREAM_VERSION}
pip install -e

# checkout stable Pytorch commit
cd ../jetstream-pytorch
cd /jetstream-pytorch
git checkout ${PYTORCH_VERSION}
bash install_everything.sh
export PYTHONPATH=$PYTHONPATH:$(pwd)/deps/xla/experimental/torch_xla2:$(pwd)/JetStream:$(pwd)
echo -e "\nCloned JetStream PyTorch repository and completed installing requirements"

echo -e "\nRunning conversion script to convert model weights. This can take a couple minutes..."
python3 -m convert_checkpoints --input_checkpoint_dir=${INPUT_CKPT_DIR_LOCAL} --output_checkpoint_dir=${OUTPUT_CKPT_DIR_LOCAL} --quantize=${QUANTIZE}

if [ $HUGGINGFACE == "True" ]; then
echo "Checkpoint weights are from HuggingFace"
download_huggingface_checkpoint "$MODEL_PATH" "$MODEL_NAME"
else
HUGGINGFACE="False"

# Example:
# the input checkpoint directory is gs://jetstream-checkpoints/llama-2-7b/base-checkpoint/
# the local checkpoint directory will be /models/llama-2-7b/base-checkpoint/
# INPUT_CKPT_DIR_LOCAL=${INPUT_CKPT_DIR/${TO_REPLACE}/${MODEL_PATH}}
INPUT_CKPT_DIR_LOCAL=${INPUT_CKPT_DIR/${TO_REPLACE}/${MODEL_PATH}}
TOKENIZER_PATH=${INPUT_CKPT_DIR_LOCAL}/tokenizer.model
fi

if [ -z $QUANTIZE_WEIGHTS ]; then
QUANTIZE_WEIGHTS="False"
fi

# Possible quantizations:
# 1. quantize_weights = False, we run without specifying quantize_type
# 2. quantize_weights = True, we run without specifying quantize_type to use the default int8_per_channel
# 3. quantize_weights = True, we run and specify quantize_type
# We can use the same command for case #1 and #2, since both have quantize_weights set without needing to specify quantize_type

echo -e "\n quantize weights: ${QUANTIZE_WEIGHTS}"
if [ $QUANTIZE_WEIGHTS == "True" ]; then
# quantize_type is required, it will be set to the default value if not turned on
if [ -n $QUANTIZE_TYPE ]; then
python3 -m convert_checkpoints --model_name=${MODEL_NAME} --input_checkpoint_dir=${INPUT_CKPT_DIR_LOCAL} --output_checkpoint_dir=${OUTPUT_CKPT_DIR_LOCAL} --quantize_type=${QUANTIZE_TYPE} --quantize_weights=${QUANTIZE_WEIGHTS} --from_hf=${HUGGINGFACE}
fi
else
# quantize_weights should be false, but if not the convert_checkpoints script will catch it
python3 -m convert_checkpoints --model_name=${MODEL_NAME} --input_checkpoint_dir=${INPUT_CKPT_DIR_LOCAL} --output_checkpoint_dir=${OUTPUT_CKPT_DIR_LOCAL} --quantize_weights=${QUANTIZE_WEIGHTS} --from_hf=${HUGGINGFACE}
fi

echo -e "\nCompleted conversion of checkpoint to ${OUTPUT_CKPT_DIR_LOCAL}"
echo -e "\nUploading converted checkpoint from local path ${OUTPUT_CKPT_DIR_LOCAL} to GSBucket ${OUTPUT_CKPT_DIR}"


gcloud storage cp -r ${OUTPUT_CKPT_DIR_LOCAL}/* ${OUTPUT_CKPT_DIR}
gcloud storage cp ${TOKENIZER_PATH} ${OUTPUT_CKPT_DIR}
echo -e "\nCompleted uploading converted checkpoint from local path ${OUTPUT_CKPT_DIR_LOCAL} to GSBucket ${OUTPUT_CKPT_DIR}"
}


while getopts 'b:i:m:q:v:1:2:' flag; do
while getopts 'b:s:m:n:h:t:q:v:i:o:' flag; do
case "${flag}" in
b) BUCKET_NAME="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
i) INFERENCE_SERVER="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
s) INFERENCE_SERVER="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
m) MODEL_PATH="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
q) QUANTIZE="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
n) MODEL_NAME="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
h) HUGGINGFACE="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
t) QUANTIZE_TYPE="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
q) QUANTIZE_WEIGHTS="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
v) VERSION="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
1) EXTRA_PARAM_1="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
2) EXTRA_PARAM_2="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
i) INPUT_DIRECTORY="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
o) OUTPUT_DIRECTORY="$(echo ${OPTARG} | awk -F'=' '{print $2}')" ;;
*) print_usage
exit 1 ;;
esac
Expand All @@ -142,8 +201,8 @@ case ${INFERENCE_SERVER} in
convert_maxtext_checkpoint "$BUCKET_NAME" "$MODEL_NAME" "$VARIATION_NAME" "$MODEL_SIZE" "$VERSION"
;;
jetstream-pytorch)
check_model_path "$MODEL_PATH"
convert_pytorch_checkpoint "$MODEL_PATH" "$EXTRA_PARAM_1" "$EXTRA_PARAM_2" "$QUANTIZE" "$VERSION"
check_model_path "$MODEL_PATH"
convert_pytorch_checkpoint "$MODEL_PATH" "$MODEL_NAME" "$HUGGINGFACE" "$INPUT_DIRECTORY" "$OUTPUT_DIRECTORY" "$QUANTIZE_TYPE" "$QUANTIZE_WEIGHTS" "$VERSION"
;;
*) print_inference_server_unknown
exit 1 ;;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
FROM ubuntu:22.04

ENV DEBIAN_FRONTEND=noninteractive
ENV PYTORCH_JETSTREAM_VERSION=jetstream-v0.2.2
ENV PYTORCH_JETSTREAM_VERSION=jetstream-v0.2.3

RUN apt -y update && apt install -y --no-install-recommends \
ca-certificates \
Expand All @@ -20,8 +20,6 @@ cd /jetstream-pytorch && \
git checkout ${PYTORCH_JETSTREAM_VERSION} && \
bash install_everything.sh

ENV PYTHONPATH=$PYTHONPATH:$(pwd)/deps/xla/experimental/torch_xla2:$(pwd)/JetStream:$(pwd)

COPY jetstream_pytorch_server_entrypoint.sh /usr/bin/

RUN chmod +x /usr/bin/jetstream_pytorch_server_entrypoint.sh
Expand Down
Loading

0 comments on commit 0fd14fd

Please sign in to comment.