Skip to content

Commit

Permalink
export edited model, gradient checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Jul 17, 2023
1 parent f23c7b8 commit 7e333db
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
24 changes: 16 additions & 8 deletions fastedit/editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from .utils.generate import generate_fast, generate_interactive


def test_rome(data: str, model: str, config: str, template: Optional[str] = "default") -> None:
def test_rome(
data: str, model: str, config: str, template: Optional[str] = "default",
output: Optional[str] = None, checkpointing: Optional[bool] = False
) -> None:
r"""
Edits a pre-trained model using model-editing algorithms.
Expand All @@ -23,10 +26,10 @@ def test_rome(data: str, model: str, config: str, template: Optional[str] = "def
The name of the hyper-parameters to use for editing the model.
template (`str`, *optional*, defaults to `default`):
The name of the template to use in generation.
Returns:
diff_weights (`Dict[str, Tensor]`):
A dict of diff weights that have been changed.
output (`str`, *optional*, defaults to `None`):
The path to save the edited model.
checkpointing (`bool`, *optional*, defaults to `False`):
Whether to enable gradient checkpointing or not.
"""

assert os.path.exists(data), "data not found"
Expand All @@ -36,7 +39,7 @@ def test_rome(data: str, model: str, config: str, template: Optional[str] = "def

queries = [query for request in requests for query in request["queries"]]

model, tokenizer, batch_first = load_model_and_tokenizer(model)
model_old, tokenizer, batch_first = load_model_and_tokenizer(model, checkpointing)
template = Template(name=template)

print_loud("Retrieving hyperparameters")
Expand All @@ -45,12 +48,12 @@ def test_rome(data: str, model: str, config: str, template: Optional[str] = "def

if len(queries) > 0:
print_loud("Generating pre-update text")
pre_update_text = generate_fast(model, tokenizer, queries, template, max_length=100)
pre_update_text = generate_fast(model_old, tokenizer, queries, template, max_length=100)
print("\n\n".join([queries[i] + " " + pre_update_text[i] for i in range(len(queries))]))

print_loud(f"Applying rome to model")
model_new, _ = apply_rome_to_model(
model,
model_old,
tokenizer,
requests,
hparams,
Expand All @@ -66,6 +69,11 @@ def test_rome(data: str, model: str, config: str, template: Optional[str] = "def
print_loud("Starting interactively generation interface")
generate_interactive(model_new, tokenizer, template)

if output is not None:
model_new.config.use_cache = True
model_new.save_pretrained(output, max_shard_size="10GB")
tokenizer.save_pretrained(output)


if __name__ == "__main__":
fire.Fire(test_rome)
27 changes: 23 additions & 4 deletions fastedit/utils/mtloader.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import torch
from typing import Tuple
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizer
PreTrainedTokenizerBase
)

def load_model_and_tokenizer(model: str) -> Tuple[PreTrainedModel, PreTrainedTokenizer, bool]:
def load_model_and_tokenizer(
model: str, checkpointing: bool
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase, bool]:

batch_first = True

Expand All @@ -18,14 +22,29 @@ def load_model_and_tokenizer(model: str) -> Tuple[PreTrainedModel, PreTrainedTok
trust_remote_code=True
)

if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = 0

config = AutoConfig.from_pretrained(model)

model = AutoModelForCausalLM.from_pretrained(
model,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
trust_remote_code=True
).cuda()

if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = 0
# Register auto class to save the custom code files.
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()
if isinstance(model, PreTrainedModel) and "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
model.__class__.register_for_auto_class()
if isinstance(tokenizer, PreTrainedTokenizerBase) and "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
tokenizer.__class__.register_for_auto_class()

if checkpointing:
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.config.use_cache = False

return model, tokenizer, batch_first

0 comments on commit 7e333db

Please sign in to comment.