Skip to content

Commit

Permalink
fix: #1, #2
Browse files Browse the repository at this point in the history
  • Loading branch information
JamzumSum committed Nov 5, 2022
1 parent 1ebaea6 commit cfd96dd
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 22 deletions.
181 changes: 180 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "yacs-stubgen"
version = "0.2.2"
version = "0.3.0"
description = "Generate stub file for yacs config."
authors = ["JamzumSum <zzzzss990315@gmail.com>"]
license = "MIT"
Expand All @@ -15,6 +15,12 @@ repository = "https://github.com/JamzumSum/yacs-stubgen"
python = "^3.6.2"
yacs = "~0.1.4"

[tool.poetry.group.test.dependencies]
pytest = [
{ version = "^7.2.0", python = "^3.7" },
{ version = "<7.1.0", python = "~3.6.2" },
]

[tool.poetry.group.dev]
optional = true

Expand Down
65 changes: 45 additions & 20 deletions src/yacs_stubgen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,50 @@
from pathlib import Path
from typing import Union
from types import NoneType
from typing import Any, Dict, Optional, Union

import yaml
from yacs.config import CfgNode

typo_map = {
list: "T.Sequence",
tuple: "T.Sequence",
NoneType: "T.Any",
}

def _to_py_obj(cfg: CfgNode, cls_name: str):
classes = {}
d = {}
for k, v in cfg.items():
if isinstance(v, CfgNode):
_clsname = str.capitalize(k)
clss, _ = _to_py_obj(v, _clsname)
classes.update(clss)
d[k] = _clsname
else:
assert str.isidentifier(type(v).__name__)
d[k] = type(v).__name__
classes[f"class {cls_name}(CN)"] = d
return classes, d

def _cls_def(name: str):
return f"class {name}(CN)"


class _CfgTyper:
def __init__(self, var_name: str, cls_name: str) -> None:
self.classes: Dict[str, Any] = {var_name: cls_name}
self.dup_fmt = "{name}_{id}"
self.cls_name = cls_name

def __select_name(self, name: str) -> str:
if _cls_def(name) not in self.classes:
return name

idx = 1
while True:
name_id = self.dup_fmt.format(name=name, id=idx)
if _cls_def(name_id) not in self.classes:
return name_id
idx += 1

def add_cfg(self, cfg: CfgNode, cls_name: Optional[str] = None):
d = {}
for k, v in cfg.items():
if isinstance(v, CfgNode):
_clsname = self.__select_name(str.capitalize(k))
self.add_cfg(v, _clsname)
d[k] = _clsname
else:
d[k] = typo_map.get(type(v), type(v).__name__)

self.classes[_cls_def(cls_name or self.cls_name)] = d
return self.classes


class _BlackDumper(yaml.SafeDumper):
Expand All @@ -40,11 +66,10 @@ def build_pyi(
:param var_name: name of the `cfg` object. You should passin this param correctly.
"""
assert cls_name != var_name, "class name should not be the same with var name"
d, _ = _to_py_obj(cfg, cls_name)
d[var_name] = cls_name
path = Path(path)
with open(path.with_suffix(".pyi"), "w") as f:
# f.write("from typing import *\n")
d = _CfgTyper(var_name=var_name, cls_name=cls_name).add_cfg(cfg, cls_name)

with open(Path(path).with_suffix(".pyi"), "w") as f:
f.write("import typing as T\n\n")
f.write("from yacs.config import CfgNode as CN\n\n")
yaml.dump(
d,
Expand Down

0 comments on commit cfd96dd

Please sign in to comment.