Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: Size Mismatch Issue When Loading Model Checkpoints Trained with Tensor Parallel if vocab_size % tp_size != 0 #6167

Open
1 task done
Lemon-412 opened this issue Dec 24, 2024 · 2 comments · May be fixed by #6168
Open
1 task done
Labels
bug Something isn't working

Comments

@Lemon-412
Copy link

Lemon-412 commented Dec 24, 2024

Is there an existing issue for this bug?

  • I have searched the existing issues

🐛 Describe the bug

Describe the bug
A size mismatch error occurs when loading model checkpoints trained with tensor parallel enabled, if the vocab_size is not divisible by tp_size.

To Reproduce
Let's modify the official Llama benchmark to reproduce with minimize work.

benchmark.py (modify llama model vocab_size):

MODEL_CONFIGS = {
    "100m": LlamaConfig(
        max_position_embeddings=4096,
        num_hidden_layers=4,
        num_attention_heads=32,
        intermediate_size=2048,
        hidden_size=1024,
        vocab_size=65535  # Note that vocab_size % tp_size != 0
    ),
}

benchmark.py (add to the end of main function):

    # save the checkpoint and load it again
    output_dir = './scripts/save'
    booster.save_model(model, output_dir, shard=True, size_per_shard=10240)

    print('wait 10 secs to ensure ckpts are saved.')
    from time import sleep; sleep(10)

    model = AutoModelForCausalLM.from_pretrained(  # Note that this will fail
            output_dir,
            trust_remote_code=True,
            **init_kwargs,
            torch_dtype=torch.bfloat16,
        )

entroypoint:

export OMP_NUM_THREADS=8
colossalai run --nproc_per_node 8 --hostfile $HOSTFILE benchmark.py \
    --plugin 3d --config 100m --xformers \
    --batch_size 1 --num_steps 5 \
    --grad_checkpoint --zero 1 \
    --tp 2 --pp 1 --mbs 1

the script will fail with RuntimeError after execuating model = AutoModelForCausalLM.from_pretrained():

[rank0]: RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM:
[rank0]:        size mismatch for model.embed_tokens.weight: copying a param with shape torch.Size([32768, 1024]) from checkpoint, the shape in current model is torch.Size([65535, 1024]).
[rank0]:        size mismatch for lm_head.weight: copying a param with shape torch.Size([32768, 1024]) from checkpoint, the shape in current model is torch.Size([65535, 1024]).

Others
No error reported if we set vocab_size=65536.
No error reported if we set --tp 1 --pp 2.
Similar error reported if we set --tp 2 --pp 2.

Environment

colossalai: latest(8b0ed61)
cluster: single node with H20 * 8.
feel free to ask for furher environment information (but i think it probably not crucial to this issue ^_^)

@Lemon-412 Lemon-412 added the bug Something isn't working label Dec 24, 2024
@Lemon-412
Copy link
Author

Lemon-412 commented Dec 24, 2024

Here's some insight:
Let's add some debug message here:

        # create a debug print function
        def print_emb_w(t, name, annotation):
            if name != 'model.embed_tokens.weight':
                return
            from colossalai.tensor.d_tensor import is_distributed_tensor
            is_p_tensor = str(is_padded_tensor(t))
            is_d_tensor = str(is_distributed_tensor(t))
            print(f'model.embed_tokens.weight {annotation:20}: ptensor={is_p_tensor:5} dtensor={is_d_tensor:5} shape={t.shape}')
        
        # Save parameters.
        for name, param in model.named_parameters():
            if param is None:
                continue
            
            print_emb_w(param, name, 'param(before unpad)')
            
            # Gather tensor pieces when using tensor parallel.
            if is_padded_tensor(param):
                param = to_unpadded_tensor(param)
                
                print_emb_w(param, name, 'param(after unpad)')
            
            print_emb_w(param, name, 'param(before gather)')
            
            param_ = gather_distributed_param(param, keep_vars=False)
            
            print_emb_w(param_, name, 'param_(after gather)')
            ...

we will see:

model.embed_tokens.weight param(before unpad) : ptensor=True  dtensor=True  shape=torch.Size([32768, 1024])
model.embed_tokens.weight param(after unpad)  : ptensor=False dtensor=True  shape=torch.Size([32768, 1024])
model.embed_tokens.weight param(before gather): ptensor=False dtensor=True  shape=torch.Size([32768, 1024])
model.embed_tokens.weight param_(after gather): ptensor=False dtensor=False shape=torch.Size([32768, 1024])

Howerver, param_ with shape torch.Size([65535, 1024] is expected.

Maybe a tensor should be gathered before unpadding.

Clipboard_Screenshot_1735039526

@Lemon-412
Copy link
Author

/cc @flybird11111 @ver217

embed_tokens.weight param(before unpad) : ptensor=True origin_len=65535     dtensor=True global_shape=torch.Size([65536, 1024])   shape=torch.Size([32768, 1024])
embed_tokens.weight param(after unpad)  : ptensor=False                     dtensor=True global_shape=torch.Size([65536, 1024])   shape=torch.Size([32768, 1024])
embed_tokens.weight param(before gather): ptensor=False                     dtensor=True global_shape=torch.Size([65536, 1024])   shape=torch.Size([32768, 1024])
embed_tokens.weight param_(after gather): ptensor=False                     dtensor=False                                         shape=torch.Size([32768, 1024])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
1 participant