diff --git a/CHANGELOG.md b/CHANGELOG.md index 80c2d01..2d7577a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - JaxDF now uses standard Python logging. To set the logging level, use `jaxdf.logger.set_logging_level`, for example `jaxdf.logger.set_logging_level("DEBUG")`. The default level is `INFO`. - Fields have now a handy property `.θ` which is an alias for `.params` - `Continuous` and `Linear` fields now have the `.is_complex` property -- `Field` and `Domain` are now `JaxDFModules`s, which are based on from `equinox.Module`. They are entirely equivalent to `equinox.Module`, but have the extra `.replace` method that is used to update a single field. +- `Field` and `Domain` are now `Modules`s, which are based on from `equinox.Module`. They are entirely equivalent to `equinox.Module`, but have the extra `.replace` method that is used to update a single field. ### Deprecated - The property `.is_field_complex` is now deprecated in favor of `.is_complex`. Same goes for `.is_real`. diff --git a/docs/core.md b/docs/core.md index a3e9c01..9e5db59 100644 --- a/docs/core.md +++ b/docs/core.md @@ -1,9 +1,12 @@ - # `jaxdf.core` -The `core` module contains the basic abstractions of the `jaxdf` framework. +## Module Overview + +This module is the fundamental part of the `jaxdf` framework. + +At its core is the `Field` class, a key element of `jaxdf`. This class is designed as a module derived from [`equinox.Module`](https://github.com/patrick-kidger/equinox), which means it's a JAX-compatible dataclass. All types of discretizations within `jaxdf` are derived from the `Field` class. -The key component is the `Field` class, which is a PyTree (see [`equinox`](https://github.com/patrick-kidger/equinox) and [`treeo`](https://github.com/cgarciae/treeo) for two great libraries that deal with defining JAX-compatible PyTrees from classes) from which all discretizations are defined, and the `operator` decorator which allows the use of multiple-dispatch (vial [`plum`](https://github.com/wesselb/plum)) for defining novel operators. +Another crucial feature of `jaxdf` is the `operator` decorator. This decorator enables the implementation of multiple-dispatch functionality through the [`plum`](https://github.com/wesselb/plum) library. This is particularly useful for creating new operators within the framework. ::: jaxdf.core handler: python diff --git a/docs/mods.md b/docs/mods.md new file mode 100644 index 0000000..f2417ee --- /dev/null +++ b/docs/mods.md @@ -0,0 +1,9 @@ +::: jaxdf.mods + handler: python + selection: + filters: + - "__init__$" + rendering: + show_root_heading: true + show_source: false + show_object_full_path: True diff --git a/jaxdf/__init__.py b/jaxdf/__init__.py index 349bb13..7e49dba 100644 --- a/jaxdf/__init__.py +++ b/jaxdf/__init__.py @@ -1,8 +1,37 @@ # nopycln: file -from jaxdf.core import operator, debug_config, constants # isort:skip -from jaxdf import util, geometry # isort:skip -from jaxdf.discretization import * # isort:skip +from jaxdf import conv, logger +from jaxdf.core import constants, debug_config, discretization, operator +# Import geometry elements +from jaxdf.geometry import Domain +from jaxdf.mods import Module + +from jaxdf.discretization import ( # isort:skip + Continuous, FiniteDifferences, FourierSeries, Linear, OnGrid) + +from jaxdf.core import Field # isort:skip + +from jaxdf import util, geometry, mods, operators # isort:skip # Must be imported after discretization from jaxdf.operators.magic import * # isort:skip from jaxdf import operators # isort:skip + +__all__ = [ + 'constants', + 'conv', + 'discretization', + 'debug_config', + 'geometry', + 'logger', + 'operator', + 'operators', + 'util', + 'Continuous', + 'Domain', + 'FiniteDifferences', + 'FourierSeries', + 'Field', + 'Linear', + 'Module', + 'OnGrid', +] diff --git a/jaxdf/core.py b/jaxdf/core.py index 8339b35..e91ce47 100644 --- a/jaxdf/core.py +++ b/jaxdf/core.py @@ -14,7 +14,7 @@ from .geometry import Domain from .logger import logger, set_logging_level -from .mods import JaxDFModule +from .mods import Module # Initialize the dispatch table _jaxdf_dispatch = Dispatcher() @@ -129,80 +129,79 @@ def __call__( init_params: Union[Callable, None] = None, precedence: int = 0, ): - r"""Decorator for defining operators using multiple dispatch. The type annotation of the - `evaluate` function are used to determine the dispatch rules. The dispatch syntax is the - same as the Julia one, that is: operators are dispatched on the types of the positional arguments. - Keyword arguments are not considered for dispatching. - - Keyword arguments are defined after the `*` in the function signature. + if evaluate is None: + # Returns the decorator + def decorator(evaluate): + return _operator(evaluate, precedence, init_params) - !!! example - ```python - @operator - def my_operator(x: FourierSeries, *, dx: float, params=None): - ... - ``` + return decorator + else: + return _operator(evaluate, precedence, init_params) - The argument `params` is mandatory and it must be a keyword argument. It is used to pass the - parameters of the operator, for example the stencil coefficients of a finite difference operator. + def abstract(self, evaluate: Callable): + """Decorator for defining abstract operators. This is mainly used + to define generic docstrings.""" + return _abstract_operator(evaluate) - The default value of the parameters is specified by the `init_params` function, as follows: - !!! example - ```python +operator = Operator() +r"""Decorator for defining operators using multiple dispatch. The type annotation of the + `evaluate` function are used to determine the dispatch rules. The dispatch syntax is the + same as the Julia one, that is: operators are dispatched on the types of the positional arguments. - def params_initializer(x, *, dx): - return {"stencil": jnp.ones(x.shape) * dx} + Args: + evaluate (Callable): A function with the signature `evaluate(field, *args, **kwargs, params)`. + It must return a tuple, with the first element being a field and the second + element being the default parameters for the operator. + init_params (Callable): A function that overrides the default parameters initializer for the + operator. Useful when running the operator just to get the parameters is expensive. + precedence (int): The precedence of the operator if an ambiguous match is found. - @operator(init_params=params_initializer) - def my_operator(x, *, dx, params=None): - b = params["stencil"] / dx - y_params = jnp.convolve(x.params, b, mode="same") - return x.replace_params(y_params) - ``` + Returns: + Callable: The operator function with signature `evaluate(field, *args, **kwargs, params)`. - The default value of `params` is not considered during computation. - If the operator has no parameters, the `init_params` function can be omitted. In this case, the - `params` value is set to `None`. + Keyword arguments are not considered for dispatching. + Keyword arguments are defined after the `*` in the function signature. - For constant parameters, the `constants` function can be used: + !!! example + ```python + @operator + def my_operator(x: FourierSeries, *, dx: float, params=None): + ... + ``` - !!! example - ```python - @operator(init_params=constants({"a": 1, "b": 2.0})) - def my_operator(x, *, params): - return x + params["a"] + params["b"] - ``` + The argument `params` is mandatory and it must be a keyword argument. It is used to pass the + parameters of the operator, for example the stencil coefficients of a finite difference operator. + The default value of the parameters is specified by the `init_params` function, as follows: - Args: - evaluate (Callable): A function with the signature `evaluate(field, *args, **kwargs, params)`. - It must return a tuple, with the first element being a field and the second - element being the default parameters for the operator. - init_params (Callable): A function that overrides the default parameters initializer for the - operator. Useful when running the operator just to get the parameters is expensive. - precedence (int): The precedence of the operator if an ambiguous match is found. + !!! example + ```python - Returns: - Callable: The operator function with signature `evaluate(field, *args, **kwargs, params)`. + def params_initializer(x, *, dx): + return {"stencil": jnp.ones(x.shape) * dx} - """ - if evaluate is None: - # Returns the decorator - def decorator(evaluate): - return _operator(evaluate, precedence, init_params) + @operator(init_params=params_initializer) + def my_operator(x, *, dx, params=None): + b = params["stencil"] / dx + y_params = jnp.convolve(x.params, b, mode="same") + return x.replace_params(y_params) + ``` - return decorator - else: - return _operator(evaluate, precedence, init_params) + The default value of `params` is not considered during computation. + If the operator has no parameters, the `init_params` function can be omitted. In this case, the + `params` value is set to `None`. - def abstract(self, evaluate: Callable): - """Decorator for defining abstract operators. This is mainly used - to define generic docstrings.""" - return _abstract_operator(evaluate) + For constant parameters, the `constants` function can be used: + !!! example + ```python + @operator(init_params=constants({"a": 1, "b": 2.0})) + def my_operator(x, *, params): + return x + params["a"] + params["b"] + ``` -operator = Operator() + """ def discretization(cls): @@ -237,7 +236,7 @@ def init_params(*args, **kwargs): return init_params -class Field(JaxDFModule): +class Field(Module): params: PyTree domain: Domain diff --git a/jaxdf/discretization.py b/jaxdf/discretization.py index c557bd7..aa5949e 100644 --- a/jaxdf/discretization.py +++ b/jaxdf/discretization.py @@ -313,7 +313,7 @@ def f(N, dx): return jnp.fft.fftfreq(N, dx) * 2 * jnp.pi k_axis = [f(n, delta) for n, delta in zip(self.domain.N, self.domain.dx)] - if not self.is_field_complex: + if not self.is_complex: k_axis[-1] = (jnp.fft.rfftfreq(self.domain.N[-1], self.domain.dx[-1]) * 2 * jnp.pi) return k_axis diff --git a/jaxdf/geometry.py b/jaxdf/geometry.py index f7f5a83..9650164 100644 --- a/jaxdf/geometry.py +++ b/jaxdf/geometry.py @@ -6,10 +6,10 @@ from jax import numpy as jnp from jax import random -from .mods import JaxDFModule +from .mods import Module -class Domain(JaxDFModule): +class Domain(Module): r"""Domain class describing a rectangular domain Attributes: diff --git a/jaxdf/mods.py b/jaxdf/mods.py index 56961e5..cbfb93b 100644 --- a/jaxdf/mods.py +++ b/jaxdf/mods.py @@ -2,12 +2,9 @@ from jaxtyping import PyTree -class JaxDFModule(eqx.Module): +class Module(eqx.Module): """ A custom module inheriting from Equinox's Module class. - - This module is designed to work with JAX and Equinox libraries, providing - functionalities that are specific to deep learning models and operations in JAX. """ def replace(self, name: str, value: PyTree): @@ -23,12 +20,15 @@ def replace(self, name: str, value: PyTree): compatible with JAX's PyTree structure. Returns: - A new instance of JaxDFModule with the specified attribute updated. + A new instance of Module with the specified attribute updated. The rest of the module's attributes remain unchanged. - Example: - >>> module = JaxDFModule(...) - >>> new_module = module.replace("weight", new_weight_value) + !!! example + ```python + >>> module = jaxdf.Module(weight=1.0, bias=2.0) + >>> new_module = module.replace("weight", 3.0) + >>> new_module.weight == 3.0 # True + ``` """ f = lambda m: m.__getattribute__(name) return eqx.tree_at(f, self, value) diff --git a/mkdocs.yml b/mkdocs.yml index fc35739..031083e 100755 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -64,6 +64,7 @@ nav: - discretization: discretization.md - exceptions: exceptions.md - geometry: geometry.md + - mods: mods.md - operators: - differential: operators/differential.md - functions: operators/functions.md diff --git a/tests/test_mods.py b/tests/test_mods.py index 98a4dc6..bad5cc2 100644 --- a/tests/test_mods.py +++ b/tests/test_mods.py @@ -1,9 +1,9 @@ -from jaxdf.mods import JaxDFModule +from jaxdf.mods import Module def test_replace_params(): - class TestModule(JaxDFModule): + class TestModule(Module): a: float = 1.0 b: float = 2.0