Skip to content

Commit

Permalink
✨ Add option to github action
Browse files Browse the repository at this point in the history
  • Loading branch information
Freed-Wu committed Jul 16, 2023
1 parent 0c91cfc commit 531a1ab
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 12 deletions.
3 changes: 3 additions & 0 deletions action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ inputs:
translator:
description: translator
default: google
option:
description: >
the option passed to translator, such as 'temperature=0 max_tokens=256'
wrapwidth:
description: wrap the width by polib
default: "76"
Expand Down
14 changes: 11 additions & 3 deletions src/translate_shell/tools/po/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,21 @@ def get_parser() -> ArgumentParser:
parser.add_argument("--version", version=VERSION, action="version")
shtab.add_argument_to(parser)
for input, info in action["inputs"].items():
default = os.getenv(
"INPUT_" + input.upper().replace("-", "_"),
info.get("default", ""),
)
if input == "option":
default = default.split()
action = "append"
else:
action = "store"
parser.add_argument(
"--" + input,
# https://docs.github.com/en/actions/creating-actions/metadata-syntax-for-github-actions#example-specifying-inputs
default=os.getenv(
"INPUT_" + input.upper().replace("-", "_"), info["default"]
),
default=default,
help=info["description"] + ". default: %(default)s",
action=action,
)
parser.add_argument(
"workspace",
Expand Down
16 changes: 14 additions & 2 deletions src/translate_shell/tools/po/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
r"""Refer ``action.yml``."""
import logging
import os
from argparse import Namespace
from difflib import Differ
Expand All @@ -9,6 +10,8 @@

from translate_shell.translate import translate

logger = logging.getLogger(__name__)


def run(args: Namespace) -> None:
"""Run.
Expand All @@ -21,6 +24,10 @@ def run(args: Namespace) -> None:
default_target_lang = args.target_lang
source_lang = args.source_lang
translator = args.translator
option = {
option.partition("=")[0]: option.partition("=")[2]
for option in args.option
}
wrapwidth = int(args.wrapwidth)
progress = args.progress.lower() == "true"
verbose = args.verbose.lower() == "true"
Expand Down Expand Up @@ -53,9 +60,14 @@ def run(args: Namespace) -> None:
old = str(entry).splitlines()
try:
entry.msgstr = translate(
entry.msgid, target_lang, source_lang, [translator]
entry.msgid,
target_lang,
source_lang,
[translator],
{translator: option},
).results[0]["paraphrase"]
except Exception: # skipcq: PYL-W0703
except Exception as e: # skipcq: PYL-W0703
logger.warning(e)
po.save()
continue
entry.fuzzy = False # type: ignore
Expand Down
34 changes: 33 additions & 1 deletion src/translate_shell/translators/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def init_messages(
}
for template in templates
]
if prompt := option.get("prompt"):
messages[0] = {"role": "system", "content": prompt}
return messages # type: ignore

@staticmethod
Expand All @@ -118,7 +120,37 @@ def init_kwargs(option: dict) -> dict:
:type option: dict
:rtype: dict
"""
return {}
kwargs = {}
if temperature := option.get("temperature"):
kwargs["temperature"] = float(temperature)
if top_p := option.get("top_p"):
kwargs["top_p"] = float(top_p)
if top_k := option.get("top_k"):
kwargs["top_k"] = float(top_k)
if stream := option.get("stream"):
kwargs["stream"] = bool(stream)
if stop := option.get("stop"):
if isinstance(stop, list):
kwargs["stop"] = stop
else:
kwargs["stop"] = ":".split(stop)
if max_tokens := option.get("max_tokens"):
kwargs["max_tokens"] = int(max_tokens)
if presence_penalty := option.get("presence_penalty"):
kwargs["presence_penalty"] = float(presence_penalty)
if frequency_penalty := option.get("frequency_penalty"):
kwargs["frequency_penalty"] = float(frequency_penalty)
if repeat_penalty := option.get("repeat_penalty"):
kwargs["repeat_penalty"] = float(repeat_penalty)
if tfs_z := option.get("tfs_z"):
kwargs["tfs_z"] = float(tfs_z)
if mirostat_mode := option.get("mirostat_mode"):
kwargs["mirostat_mode"] = float(mirostat_mode)
if mirostat_tau := option.get("mirostat_tau"):
kwargs["mirostat_tau"] = float(mirostat_tau)
if mirostat_eta := option.get("mirostat_eta"):
kwargs["mirostat_eta"] = float(mirostat_eta)
return kwargs

@staticmethod
def init_result(completion: Mapping, result: TRANSLATION) -> TRANSLATION:
Expand Down
54 changes: 48 additions & 6 deletions src/translate_shell/translators/llm/_llama_cpp.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
r"""LLaMa cpp
=============
"""
import os
from typing import Any

from llama_cpp import Llama

from ...external.platformdirs import AppDirs
from . import LLMTranslator

# every time initing Llama costs about 1s
# cache to fasten
MODEL = Llama(
str(AppDirs("translate-shell").user_data_path / "model.bin"), verbose=False
)
MODEL = str(AppDirs("translate-shell").user_data_path / "model.bin")


class LlamaTranslator(LLMTranslator):
Expand All @@ -33,4 +30,49 @@ def init_model(option: dict) -> Any:
:type option: dict
:rtype: Any
"""
return option.get("model", MODEL)
model = option.get("model", MODEL)
if isinstance(model, str):
model = os.path.expanduser(model)
kwargs = {}
if n_ctx := option.get("n_ctx"):
kwargs["n_ctx"] = int(n_ctx)
if n_parts := option.get("n_parts"):
kwargs["n_parts"] = int(n_parts)
if n_gpu_layers := option.get("n_gpu_layers"):
kwargs["n_gpu_layers"] = int(n_gpu_layers)
if seed := option.get("seed"):
kwargs["seed"] = int(seed)
if f16_kv := option.get("f16_kv"):
kwargs["f16_kv"] = bool(f16_kv)
if logits_all := option.get("logits_all"):
kwargs["logits_all"] = bool(logits_all)
if vocab_only := option.get("vocab_only"):
kwargs["vocab_only"] = bool(vocab_only)
if use_mmap := option.get("use_mmap"):
kwargs["use_mmap"] = bool(use_mmap)
if use_mlock := option.get("use_mlock"):
kwargs["use_mlock"] = bool(use_mlock)
if embedding := option.get("embedding"):
kwargs["embedding"] = bool(embedding)
if n_threads := option.get("n_threads"):
kwargs["n_threads"] = n_threads
if n_batch := option.get("n_batch"):
kwargs["n_batch"] = int(n_batch)
if last_n_tokens_size := option.get("last_n_tokens_size"):
kwargs["last_n_tokens_size"] = int(last_n_tokens_size)
if lora_base := option.get("lora_base"):
kwargs["lora_base"] = lora_base
if lora_path := option.get("lora_path"):
kwargs["lora_path"] = lora_path
if low_vram := option.get("low_vram"):
kwargs["low_vram"] = bool(low_vram)
if tensor_split := option.get("tensor_split"):
kwargs["tensor_split"] = tensor_split
if rope_freq_base := option.get("rope_freq_base"):
kwargs["rope_freq_base"] = float(rope_freq_base)
if rope_freq_scale := option.get("rope_freq_scale"):
kwargs["rope_freq_scale"] = float(rope_freq_scale)
if verbose := option.get("verbose"):
kwargs["verbose"] = bool(verbose)
model = Llama(model, **kwargs)
return model

0 comments on commit 531a1ab

Please sign in to comment.