Skip to content

B-STAR: Monitoring and Balancing Exploration and Exploitation in Self-Taught Reasoners

Notifications You must be signed in to change notification settings

hkust-nlp/B-STaR

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 

Repository files navigation

B-STAR: Monitoring and Balancing Exploration and Exploitation in Self-Taught Reasoners

📄 Paper   

B-STAR (Balanced Self-Taught Reasoner) is a framework designed to improve the self-improvement process of reasoning models by dynamically balancing exploration and exploitation throughout training. This approach is particularly effective in enhancing performance in tasks requiring complex reasoning, such as mathematical problem-solving, coding, and commonsense reasoning.

截屏2024-12-22 17 35 44

Overview

Self-improvement in reasoning models involves iterative training where models generate their own training data from outputs. However, existing methods often stagnate after a few iterations due to imbalances between two critical factors:

  1. Exploration: The model's ability to generate diverse and high-quality responses.
  2. Exploitation: The effectiveness of external rewards in distinguishing and leveraging high-quality responses.

截屏2024-12-22 17 40 13

B-STAR introduces an adaptive mechanism to monitor and balance these factors dynamically, ensuring consistent performance improvements over multiple training iterations

Key Features

  • Dynamic Configuration Adjustments: Automatically tunes exploration and exploitation configurations (e.g., sampling temperature, reward thresholds) to optimize the self-improvement process.
  • Balance Score Metric: Quantifies the interplay between exploration and exploitation, guiding dynamic adjustments.
  • Generalization Across Tasks: Demonstrates effectiveness in mathematical reasoning, coding challenges, and commonsense reasoning tasks

Results

B-STAR achieves state-of-the-art performance across various benchmarks:

  • Significant improvements compared to previsous self-improvement methods. 截屏2024-12-22 17 39 06

  • Sustained performance growth across multiple iterations, outperforming existing methods that stagnate after a few iterations. 截屏2024-12-22 17 39 31

Reproduction

Our code builds upon easy-to-hard and gpt-accelerate. Please refer to gpt-accelerate for environment setup and model weight conversion instructions.

1. Prepare Model

We first need to prepare the model checkpoint in the gpt-fast format.

export DATA_DIR=/path/to/your/data/directory
export MODEL_REPO=mistralai/Mistral-7B-v0.1

python scripts/download.py \
    --repo_id $MODEL_REPO \
    --local_dir $DATA_DIR/checkpoints

python scripts/convert_hf_checkpoint.py \
    --checkpoint_dir $DATA_DIR/checkpoints/$MODEL_REPO \
    --target_precision bf16

2. Train SFT Model

export DATA_DIR=/path/to/your/data/directory
export MODEL_REPO= $DATA_DIR/checkpoints/Mistral-7B-v0.1

export OMP_NUM_THREADS=8


SFT_TRAIN_DATA=https://huggingface.co/datasets/AndrewZeng/math-trn-format/blob/main/math_format.json

# Please download this dataset to local folder
SFT_MODEL_SAVE_NAME=math_format_11k_mistral

torchrun --standalone --nproc_per_node=8 \
    train_sft.py \
    --do_train \
    --checkpoint_path $MODEL_REPO/model.pth \
    --source_max_len 768 \
    --target_max_len 768 \
    --total_max_len 1024 \
    --per_device_train_batch_size 16 \
    --micro_train_batch_size 4 \
    --learning_rate 5e-6 \
    --lr_eta_min 2e-7 \
    --num_train_epochs 3 \
    --dataset "$SFT_TRAIN_DATA" \
    --dataset_format "metamath" \
    --add_eos_to_marked_target \
    --save_strategy "steps" \
    --save_steps 25 \
    --optim_dtype bf16 \
    --save_total_limit 40 \
    --tensor_parallel_size 1 \
    --save_dir $DATA_DIR/checkpoints/$SFT_MODEL_SAVE_NAME \
    --resume_from_checkpoint

3. Train PRM Model

We constructed the PRM training data using the math-shepherd approach and trained the reward model using a pointwise objective.

export DATA_DIR=/path/to/your/data/directory

export MODEL_REPO= $DATA_DIR/checkpoints/Mistral-7B-v0.1
export OMP_NUM_THREADS=4


RM_DATA=train_prm_math_shepherd_mistral.json
RM_MODEL_SAVE_NAME=prm_model_mistral_sample_complete

torchrun --standalone --nproc_per_node=8 \
    train_rm_pointwise.py \
    --do_train \
    --checkpoint_path $MODEL_REPO/model.pth \
    --source_max_len 768 \
    --target_max_len 768 \
    --total_max_len 1024 \
    --per_device_train_batch_size 32 \
    --micro_train_batch_size 32 \
    --learning_rate 2e-6 \
    --lr_eta_min 2e-7 \
    --num_train_epochs 2 \
    --dataset "$RM_DATA" \
    --dataset_format "prm-v4" \
    --save_strategy epoch \
    --save_total_limit 5 \
    --train_on_every_token \
    --tensor_parallel_size 1 \
    --save_only_model True \
    --optim_dtype bf16 \
    --save_dir $DATA_DIR/checkpoints/$RM_MODEL_SAVE_NAME \
    --resume_from_checkpoint

4. Train B-STaR

## This is our initial release code. 
## We are working hard to clean it to make our code more clear and more readable
cd train_code
bash train_bstar.sh

5. Evaluation

Coming Soon !

Citation

If you find B-STaR useful, please cite our paper:

@article{zeng2024bstar,
  title={B-STAR: Monitoring and Balancing Exploration and Exploitation in Self-Taught Reasoners},
  author={Weihao Zeng, Yuzhen Huang, Lulu Zhao, Yijun Wang, Zifei Shan, Junxian He},
  journal={arXiv preprint arXiv:2412.17256},
  year={2024},
  url={https://arxiv.org/abs/2412.17256}
}

About

B-STAR: Monitoring and Balancing Exploration and Exploitation in Self-Taught Reasoners

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published