-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #143 from ucl-bug/fix-utils-types
Fix `get_implemented` bug
- Loading branch information
Showing
3 changed files
with
176 additions
and
182 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,61 +1,51 @@ | ||
import warnings | ||
|
||
from jax.numpy import expand_dims, ndarray | ||
|
||
|
||
def append_dimension(x: ndarray): | ||
return expand_dims(x, -1) | ||
|
||
|
||
def update_dictionary(old: dict, new_entries: dict): | ||
r"""Update a dictionary with new entries. | ||
Args: | ||
old (dict): The dictionary to update | ||
new_entries (dict): The new entries to add to the dictionary | ||
Returns: | ||
dict: The updated dictionary | ||
""" | ||
for key, val in zip(new_entries.keys(), new_entries.values()): | ||
old[key] = val | ||
return old | ||
|
||
|
||
def _get_implemented(f): | ||
warnings.warn( | ||
"jaxdf.util._get_implemented is deprecated. Use jaxdf.util.get_implemented instead.", | ||
DeprecationWarning, | ||
) | ||
return get_implemented(f) | ||
|
||
|
||
def get_implemented(f): | ||
r"""Prints the implemented methods of an operator | ||
Arguments: | ||
f (Callable): The operator to get the implemented methods of. | ||
Returns: | ||
None | ||
""" | ||
|
||
# TODO: Why there are more instances for the same types? | ||
|
||
print(f.__name__ + ":") | ||
instances = [] | ||
a = f.methods | ||
for f_instance in a: | ||
# Get types | ||
types = f_instance.types | ||
|
||
# Change each type with its classname | ||
types = tuple(map(lambda x: x.__name__, types)) | ||
|
||
# Append | ||
instances.append(str(types)) | ||
|
||
instances = set(instances) | ||
for instance in instances: | ||
print(" ─ " + instance) | ||
from jax.numpy import expand_dims, ndarray | ||
|
||
|
||
def append_dimension(x: ndarray): | ||
return expand_dims(x, -1) | ||
|
||
|
||
def update_dictionary(old: dict, new_entries: dict): | ||
r"""Update a dictionary with new entries. | ||
Args: | ||
old (dict): The dictionary to update | ||
new_entries (dict): The new entries to add to the dictionary | ||
Returns: | ||
dict: The updated dictionary | ||
""" | ||
for key, val in zip(new_entries.keys(), new_entries.values()): | ||
old[key] = val | ||
return old | ||
|
||
|
||
def get_implemented(f): | ||
r"""Prints the implemented methods of an operator | ||
Arguments: | ||
f (Callable): The operator to get the implemented methods of. | ||
Returns: | ||
None | ||
""" | ||
|
||
# TODO: Why there are more instances for the same types? | ||
|
||
print(f.__name__ + ":") | ||
instances = [] | ||
a = f.methods | ||
for f_instance in a: | ||
# Get types | ||
types = f_instance.signature.types | ||
|
||
# Change each type with its classname | ||
types = tuple(map(lambda x: x.__name__, types)) | ||
|
||
# Append | ||
instances.append(str(types)) | ||
|
||
instances = set(instances) | ||
for instance in instances: | ||
print(" ─ " + instance) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,120 +1,120 @@ | ||
[tool.poetry] | ||
name = "jaxdf" | ||
version = "0.2.7" | ||
description = "A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations" | ||
authors = [ | ||
"Antonio Stanziola <a.stanziola@ucl.ac.uk>", | ||
"Simon Arridge", | ||
"Ben T. Cox", | ||
"Bradley E. Treeby", | ||
] | ||
readme = "README.md" | ||
keywords = [ | ||
"jax", | ||
"pde", | ||
"discretization", | ||
"differential equations", | ||
"simulation", | ||
"differentiable programming", | ||
] | ||
license = "LGPL-3.0-only" | ||
classifiers=[ | ||
"Intended Audience :: Education", | ||
"Intended Audience :: Science/Research", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Environment :: GPU", | ||
"Environment :: GPU :: NVIDIA CUDA", | ||
"Environment :: GPU :: NVIDIA CUDA :: 11.6", | ||
"Environment :: GPU :: NVIDIA CUDA :: 11.7", | ||
"Environment :: GPU :: NVIDIA CUDA :: 11.8", | ||
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.0", | ||
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.1", | ||
"Topic :: Scientific/Engineering", | ||
"Topic :: Scientific/Engineering :: Artificial Intelligence", | ||
"Topic :: Scientific/Engineering :: Physics", | ||
"Topic :: Software Development", | ||
"Topic :: Software Development :: Libraries", | ||
"Topic :: Software Development :: Libraries :: Python Modules", | ||
] | ||
|
||
packages = [ | ||
{ include="jaxdf", from="." } | ||
] | ||
|
||
[tool.poetry.urls] | ||
"Homepage" = "https://ucl-bug.github.io/jaxdf" | ||
"Repository" = "https://github.com/ucl-bug/jaxdf" | ||
"Bug Tracker" = "https://github.com/ucl-bug/jaxdf/issues" | ||
"Support" = "https://discord.gg/VtUb4fFznt" | ||
|
||
[tool.poetry.dependencies] | ||
python = "^3.9" | ||
plum-dispatch = "^2.2.2" | ||
jax = "^0.4.20" | ||
equinox = "^0.11.2" | ||
|
||
[tool.poetry.group.dev.dependencies] | ||
coverage = "^7.3.2" | ||
mypy = "^1.4.0" | ||
pre-commit = "^3.3.3" | ||
mkdocs-material-extensions = "^1.3.1" | ||
mkdocs-material = "^9.4.12" | ||
mkdocs-jupyter = "^0.24.6" | ||
mkdocs-autorefs = "^0.5.0" | ||
mkdocs-mermaid2-plugin = "^0.6.0" | ||
mkdocstrings-python = "^1.7.5" | ||
isort = "^5.12.0" | ||
pycln = "^2.4.0" | ||
python-kacl = "^0.4.6" | ||
mkdocs-macros-plugin = "^1.0.5" | ||
pymdown-extensions = "^10.4" | ||
pytest = "^7.4.0" | ||
plumkdocs = "^0.0.5" | ||
jupyterlab = "^4.0.9" | ||
|
||
[build-system] | ||
requires = ["poetry-core>=1.0.0"] | ||
build-backend = "poetry.core.masonry.api" | ||
|
||
[tools.isort] | ||
src_paths = ["jaxdf", "tests"] | ||
multi_line_output = 3 | ||
include_trailing_comma = true | ||
force_grid_wrap = 0 | ||
use_parentheses = true | ||
|
||
[tool.pycln] | ||
all = true | ||
|
||
[tool.mypy] | ||
disallow_any_unimported = true | ||
disallow_untyped_defs = true | ||
no_implicit_optional = true | ||
strict_equality = true | ||
warn_unused_ignores = true | ||
warn_redundant_casts = true | ||
warn_return_any = true | ||
check_untyped_defs = true | ||
show_error_codes = true | ||
ignore_missing_imports = true | ||
allow_redefinition = true | ||
exclude = ['jaxdf/operators/'] | ||
|
||
[tool.yapf] | ||
based_on_style = "pep8" | ||
spaces_before_comment = 4 | ||
split_before_logical_operator = true | ||
indent_width = 2 | ||
|
||
[tool.pytest.ini_options] | ||
addopts = """\ | ||
--doctest-modules \ | ||
""" | ||
|
||
[tool.coverage.report] | ||
exclude_lines = [ | ||
'if TYPE_CHECKING:', | ||
'pragma: no cover' | ||
] | ||
[tool.poetry] | ||
name = "jaxdf" | ||
version = "0.2.7" | ||
description = "A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations" | ||
authors = [ | ||
"Antonio Stanziola <a.stanziola@ucl.ac.uk>", | ||
"Simon Arridge", | ||
"Ben T. Cox", | ||
"Bradley E. Treeby", | ||
] | ||
readme = "README.md" | ||
keywords = [ | ||
"jax", | ||
"pde", | ||
"discretization", | ||
"differential equations", | ||
"simulation", | ||
"differentiable programming", | ||
] | ||
license = "LGPL-3.0-only" | ||
classifiers=[ | ||
"Intended Audience :: Education", | ||
"Intended Audience :: Science/Research", | ||
"Programming Language :: Python :: 3.9", | ||
"Programming Language :: Python :: 3.10", | ||
"Programming Language :: Python :: 3.11", | ||
"Environment :: GPU", | ||
"Environment :: GPU :: NVIDIA CUDA", | ||
"Environment :: GPU :: NVIDIA CUDA :: 11.6", | ||
"Environment :: GPU :: NVIDIA CUDA :: 11.7", | ||
"Environment :: GPU :: NVIDIA CUDA :: 11.8", | ||
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.0", | ||
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.1", | ||
"Topic :: Scientific/Engineering", | ||
"Topic :: Scientific/Engineering :: Artificial Intelligence", | ||
"Topic :: Scientific/Engineering :: Physics", | ||
"Topic :: Software Development", | ||
"Topic :: Software Development :: Libraries", | ||
"Topic :: Software Development :: Libraries :: Python Modules", | ||
] | ||
|
||
packages = [ | ||
{ include="jaxdf", from="." } | ||
] | ||
|
||
[tool.poetry.urls] | ||
"Homepage" = "https://ucl-bug.github.io/jaxdf" | ||
"Repository" = "https://github.com/ucl-bug/jaxdf" | ||
"Bug Tracker" = "https://github.com/ucl-bug/jaxdf/issues" | ||
"Support" = "https://discord.gg/VtUb4fFznt" | ||
|
||
[tool.poetry.dependencies] | ||
python = "^3.9" | ||
plum-dispatch = "^2.5.2" | ||
jax = "^0.4.20" | ||
equinox = "^0.11.2" | ||
|
||
[tool.poetry.group.dev.dependencies] | ||
coverage = "^7.3.2" | ||
mypy = "^1.4.0" | ||
pre-commit = "^3.3.3" | ||
mkdocs-material-extensions = "^1.3.1" | ||
mkdocs-material = "^9.4.12" | ||
mkdocs-jupyter = "^0.24.6" | ||
mkdocs-autorefs = "^0.5.0" | ||
mkdocs-mermaid2-plugin = "^0.6.0" | ||
mkdocstrings-python = "^1.7.5" | ||
isort = "^5.12.0" | ||
pycln = "^2.4.0" | ||
python-kacl = "^0.4.6" | ||
mkdocs-macros-plugin = "^1.0.5" | ||
pymdown-extensions = "^10.4" | ||
pytest = "^7.4.0" | ||
plumkdocs = "^0.0.5" | ||
jupyterlab = "^4.0.9" | ||
|
||
[build-system] | ||
requires = ["poetry-core>=1.0.0"] | ||
build-backend = "poetry.core.masonry.api" | ||
|
||
[tools.isort] | ||
src_paths = ["jaxdf", "tests"] | ||
multi_line_output = 3 | ||
include_trailing_comma = true | ||
force_grid_wrap = 0 | ||
use_parentheses = true | ||
|
||
[tool.pycln] | ||
all = true | ||
|
||
[tool.mypy] | ||
disallow_any_unimported = true | ||
disallow_untyped_defs = true | ||
no_implicit_optional = true | ||
strict_equality = true | ||
warn_unused_ignores = true | ||
warn_redundant_casts = true | ||
warn_return_any = true | ||
check_untyped_defs = true | ||
show_error_codes = true | ||
ignore_missing_imports = true | ||
allow_redefinition = true | ||
exclude = ['jaxdf/operators/'] | ||
|
||
[tool.yapf] | ||
based_on_style = "pep8" | ||
spaces_before_comment = 4 | ||
split_before_logical_operator = true | ||
indent_width = 2 | ||
|
||
[tool.pytest.ini_options] | ||
addopts = """\ | ||
--doctest-modules \ | ||
""" | ||
|
||
[tool.coverage.report] | ||
exclude_lines = [ | ||
'if TYPE_CHECKING:', | ||
'pragma: no cover' | ||
] |