Skip to content

Commit

Permalink
Merge pull request #143 from ucl-bug/fix-utils-types
Browse files Browse the repository at this point in the history
Fix `get_implemented` bug
  • Loading branch information
astanziola authored Sep 17, 2024
2 parents 3285cf7 + 9c1629f commit c48ccf3
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 182 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Fixed
- Fixed `util.get_implemented` bug that was happening with the new version of `plum`

### Removed
- Removed the deprecated `util._get_implemented` function

## [0.2.7] - 2023-11-24
### Changed
Expand Down Expand Up @@ -66,4 +71,3 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
[0.2.7]: https://github.com/ucl-bug/jaxdf/compare/0.2.6...0.2.7
[0.2.6]: https://github.com/ucl-bug/jaxdf/compare/0.2.5...0.2.6
[0.2.5]: https://github.com/ucl-bug/jaxdf/tree/0.2.5

112 changes: 51 additions & 61 deletions jaxdf/util.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,51 @@
import warnings

from jax.numpy import expand_dims, ndarray


def append_dimension(x: ndarray):
return expand_dims(x, -1)


def update_dictionary(old: dict, new_entries: dict):
r"""Update a dictionary with new entries.
Args:
old (dict): The dictionary to update
new_entries (dict): The new entries to add to the dictionary
Returns:
dict: The updated dictionary
"""
for key, val in zip(new_entries.keys(), new_entries.values()):
old[key] = val
return old


def _get_implemented(f):
warnings.warn(
"jaxdf.util._get_implemented is deprecated. Use jaxdf.util.get_implemented instead.",
DeprecationWarning,
)
return get_implemented(f)


def get_implemented(f):
r"""Prints the implemented methods of an operator
Arguments:
f (Callable): The operator to get the implemented methods of.
Returns:
None
"""

# TODO: Why there are more instances for the same types?

print(f.__name__ + ":")
instances = []
a = f.methods
for f_instance in a:
# Get types
types = f_instance.types

# Change each type with its classname
types = tuple(map(lambda x: x.__name__, types))

# Append
instances.append(str(types))

instances = set(instances)
for instance in instances:
print(" ─ " + instance)
from jax.numpy import expand_dims, ndarray


def append_dimension(x: ndarray):
return expand_dims(x, -1)


def update_dictionary(old: dict, new_entries: dict):
r"""Update a dictionary with new entries.
Args:
old (dict): The dictionary to update
new_entries (dict): The new entries to add to the dictionary
Returns:
dict: The updated dictionary
"""
for key, val in zip(new_entries.keys(), new_entries.values()):
old[key] = val
return old


def get_implemented(f):
r"""Prints the implemented methods of an operator
Arguments:
f (Callable): The operator to get the implemented methods of.
Returns:
None
"""

# TODO: Why there are more instances for the same types?

print(f.__name__ + ":")
instances = []
a = f.methods
for f_instance in a:
# Get types
types = f_instance.signature.types

# Change each type with its classname
types = tuple(map(lambda x: x.__name__, types))

# Append
instances.append(str(types))

instances = set(instances)
for instance in instances:
print(" ─ " + instance)
240 changes: 120 additions & 120 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,120 +1,120 @@
[tool.poetry]
name = "jaxdf"
version = "0.2.7"
description = "A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations"
authors = [
"Antonio Stanziola <a.stanziola@ucl.ac.uk>",
"Simon Arridge",
"Ben T. Cox",
"Bradley E. Treeby",
]
readme = "README.md"
keywords = [
"jax",
"pde",
"discretization",
"differential equations",
"simulation",
"differentiable programming",
]
license = "LGPL-3.0-only"
classifiers=[
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Environment :: GPU",
"Environment :: GPU :: NVIDIA CUDA",
"Environment :: GPU :: NVIDIA CUDA :: 11.6",
"Environment :: GPU :: NVIDIA CUDA :: 11.7",
"Environment :: GPU :: NVIDIA CUDA :: 11.8",
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.0",
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.1",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Physics",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
]

packages = [
{ include="jaxdf", from="." }
]

[tool.poetry.urls]
"Homepage" = "https://ucl-bug.github.io/jaxdf"
"Repository" = "https://github.com/ucl-bug/jaxdf"
"Bug Tracker" = "https://github.com/ucl-bug/jaxdf/issues"
"Support" = "https://discord.gg/VtUb4fFznt"

[tool.poetry.dependencies]
python = "^3.9"
plum-dispatch = "^2.2.2"
jax = "^0.4.20"
equinox = "^0.11.2"

[tool.poetry.group.dev.dependencies]
coverage = "^7.3.2"
mypy = "^1.4.0"
pre-commit = "^3.3.3"
mkdocs-material-extensions = "^1.3.1"
mkdocs-material = "^9.4.12"
mkdocs-jupyter = "^0.24.6"
mkdocs-autorefs = "^0.5.0"
mkdocs-mermaid2-plugin = "^0.6.0"
mkdocstrings-python = "^1.7.5"
isort = "^5.12.0"
pycln = "^2.4.0"
python-kacl = "^0.4.6"
mkdocs-macros-plugin = "^1.0.5"
pymdown-extensions = "^10.4"
pytest = "^7.4.0"
plumkdocs = "^0.0.5"
jupyterlab = "^4.0.9"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tools.isort]
src_paths = ["jaxdf", "tests"]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true

[tool.pycln]
all = true

[tool.mypy]
disallow_any_unimported = true
disallow_untyped_defs = true
no_implicit_optional = true
strict_equality = true
warn_unused_ignores = true
warn_redundant_casts = true
warn_return_any = true
check_untyped_defs = true
show_error_codes = true
ignore_missing_imports = true
allow_redefinition = true
exclude = ['jaxdf/operators/']

[tool.yapf]
based_on_style = "pep8"
spaces_before_comment = 4
split_before_logical_operator = true
indent_width = 2

[tool.pytest.ini_options]
addopts = """\
--doctest-modules \
"""

[tool.coverage.report]
exclude_lines = [
'if TYPE_CHECKING:',
'pragma: no cover'
]
[tool.poetry]
name = "jaxdf"
version = "0.2.7"
description = "A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations"
authors = [
"Antonio Stanziola <a.stanziola@ucl.ac.uk>",
"Simon Arridge",
"Ben T. Cox",
"Bradley E. Treeby",
]
readme = "README.md"
keywords = [
"jax",
"pde",
"discretization",
"differential equations",
"simulation",
"differentiable programming",
]
license = "LGPL-3.0-only"
classifiers=[
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Environment :: GPU",
"Environment :: GPU :: NVIDIA CUDA",
"Environment :: GPU :: NVIDIA CUDA :: 11.6",
"Environment :: GPU :: NVIDIA CUDA :: 11.7",
"Environment :: GPU :: NVIDIA CUDA :: 11.8",
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.0",
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.1",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Physics",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
]

packages = [
{ include="jaxdf", from="." }
]

[tool.poetry.urls]
"Homepage" = "https://ucl-bug.github.io/jaxdf"
"Repository" = "https://github.com/ucl-bug/jaxdf"
"Bug Tracker" = "https://github.com/ucl-bug/jaxdf/issues"
"Support" = "https://discord.gg/VtUb4fFznt"

[tool.poetry.dependencies]
python = "^3.9"
plum-dispatch = "^2.5.2"
jax = "^0.4.20"
equinox = "^0.11.2"

[tool.poetry.group.dev.dependencies]
coverage = "^7.3.2"
mypy = "^1.4.0"
pre-commit = "^3.3.3"
mkdocs-material-extensions = "^1.3.1"
mkdocs-material = "^9.4.12"
mkdocs-jupyter = "^0.24.6"
mkdocs-autorefs = "^0.5.0"
mkdocs-mermaid2-plugin = "^0.6.0"
mkdocstrings-python = "^1.7.5"
isort = "^5.12.0"
pycln = "^2.4.0"
python-kacl = "^0.4.6"
mkdocs-macros-plugin = "^1.0.5"
pymdown-extensions = "^10.4"
pytest = "^7.4.0"
plumkdocs = "^0.0.5"
jupyterlab = "^4.0.9"

[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tools.isort]
src_paths = ["jaxdf", "tests"]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true

[tool.pycln]
all = true

[tool.mypy]
disallow_any_unimported = true
disallow_untyped_defs = true
no_implicit_optional = true
strict_equality = true
warn_unused_ignores = true
warn_redundant_casts = true
warn_return_any = true
check_untyped_defs = true
show_error_codes = true
ignore_missing_imports = true
allow_redefinition = true
exclude = ['jaxdf/operators/']

[tool.yapf]
based_on_style = "pep8"
spaces_before_comment = 4
split_before_logical_operator = true
indent_width = 2

[tool.pytest.ini_options]
addopts = """\
--doctest-modules \
"""

[tool.coverage.report]
exclude_lines = [
'if TYPE_CHECKING:',
'pragma: no cover'
]

0 comments on commit c48ccf3

Please sign in to comment.