Skip to content

Commit

Permalink
Merge pull request #140 from ucl-bug/docs
Browse files Browse the repository at this point in the history
Updated docs
  • Loading branch information
astanziola authored Nov 24, 2023
2 parents a15d053 + 063249a commit 11daadf
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 79 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- JaxDF now uses standard Python logging. To set the logging level, use `jaxdf.logger.set_logging_level`, for example `jaxdf.logger.set_logging_level("DEBUG")`. The default level is `INFO`.
- Fields have now a handy property `` which is an alias for `.params`
- `Continuous` and `Linear` fields now have the `.is_complex` property
- `Field` and `Domain` are now `JaxDFModules`s, which are based on from `equinox.Module`. They are entirely equivalent to `equinox.Module`, but have the extra `.replace` method that is used to update a single field.
- `Field` and `Domain` are now `Modules`s, which are based on from `equinox.Module`. They are entirely equivalent to `equinox.Module`, but have the extra `.replace` method that is used to update a single field.

### Deprecated
- The property `.is_field_complex` is now deprecated in favor of `.is_complex`. Same goes for `.is_real`.
Expand Down
9 changes: 6 additions & 3 deletions docs/core.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@

# `jaxdf.core`

The `core` module contains the basic abstractions of the `jaxdf` framework.
## Module Overview

This module is the fundamental part of the `jaxdf` framework.

At its core is the `Field` class, a key element of `jaxdf`. This class is designed as a module derived from [`equinox.Module`](https://github.com/patrick-kidger/equinox), which means it's a JAX-compatible dataclass. All types of discretizations within `jaxdf` are derived from the `Field` class.

The key component is the `Field` class, which is a PyTree (see [`equinox`](https://github.com/patrick-kidger/equinox) and [`treeo`](https://github.com/cgarciae/treeo) for two great libraries that deal with defining JAX-compatible PyTrees from classes) from which all discretizations are defined, and the `operator` decorator which allows the use of multiple-dispatch (vial [`plum`](https://github.com/wesselb/plum)) for defining novel operators.
Another crucial feature of `jaxdf` is the `operator` decorator. This decorator enables the implementation of multiple-dispatch functionality through the [`plum`](https://github.com/wesselb/plum) library. This is particularly useful for creating new operators within the framework.

::: jaxdf.core
handler: python
Expand Down
9 changes: 9 additions & 0 deletions docs/mods.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
::: jaxdf.mods
handler: python
selection:
filters:
- "__init__$"
rendering:
show_root_heading: true
show_source: false
show_object_full_path: True
35 changes: 32 additions & 3 deletions jaxdf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,37 @@
# nopycln: file
from jaxdf.core import operator, debug_config, constants # isort:skip
from jaxdf import util, geometry # isort:skip
from jaxdf.discretization import * # isort:skip
from jaxdf import conv, logger
from jaxdf.core import constants, debug_config, discretization, operator
# Import geometry elements
from jaxdf.geometry import Domain
from jaxdf.mods import Module

from jaxdf.discretization import ( # isort:skip
Continuous, FiniteDifferences, FourierSeries, Linear, OnGrid)

from jaxdf.core import Field # isort:skip

from jaxdf import util, geometry, mods, operators # isort:skip

# Must be imported after discretization
from jaxdf.operators.magic import * # isort:skip
from jaxdf import operators # isort:skip

__all__ = [
'constants',
'conv',
'discretization',
'debug_config',
'geometry',
'logger',
'operator',
'operators',
'util',
'Continuous',
'Domain',
'FiniteDifferences',
'FourierSeries',
'Field',
'Linear',
'Module',
'OnGrid',
]
117 changes: 58 additions & 59 deletions jaxdf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from .geometry import Domain
from .logger import logger, set_logging_level
from .mods import JaxDFModule
from .mods import Module

# Initialize the dispatch table
_jaxdf_dispatch = Dispatcher()
Expand Down Expand Up @@ -129,80 +129,79 @@ def __call__(
init_params: Union[Callable, None] = None,
precedence: int = 0,
):
r"""Decorator for defining operators using multiple dispatch. The type annotation of the
`evaluate` function are used to determine the dispatch rules. The dispatch syntax is the
same as the Julia one, that is: operators are dispatched on the types of the positional arguments.
Keyword arguments are not considered for dispatching.
Keyword arguments are defined after the `*` in the function signature.
if evaluate is None:
# Returns the decorator
def decorator(evaluate):
return _operator(evaluate, precedence, init_params)

!!! example
```python
@operator
def my_operator(x: FourierSeries, *, dx: float, params=None):
...
```
return decorator
else:
return _operator(evaluate, precedence, init_params)

The argument `params` is mandatory and it must be a keyword argument. It is used to pass the
parameters of the operator, for example the stencil coefficients of a finite difference operator.
def abstract(self, evaluate: Callable):
"""Decorator for defining abstract operators. This is mainly used
to define generic docstrings."""
return _abstract_operator(evaluate)

The default value of the parameters is specified by the `init_params` function, as follows:

!!! example
```python
operator = Operator()
r"""Decorator for defining operators using multiple dispatch. The type annotation of the
`evaluate` function are used to determine the dispatch rules. The dispatch syntax is the
same as the Julia one, that is: operators are dispatched on the types of the positional arguments.
def params_initializer(x, *, dx):
return {"stencil": jnp.ones(x.shape) * dx}
Args:
evaluate (Callable): A function with the signature `evaluate(field, *args, **kwargs, params)`.
It must return a tuple, with the first element being a field and the second
element being the default parameters for the operator.
init_params (Callable): A function that overrides the default parameters initializer for the
operator. Useful when running the operator just to get the parameters is expensive.
precedence (int): The precedence of the operator if an ambiguous match is found.
@operator(init_params=params_initializer)
def my_operator(x, *, dx, params=None):
b = params["stencil"] / dx
y_params = jnp.convolve(x.params, b, mode="same")
return x.replace_params(y_params)
```
Returns:
Callable: The operator function with signature `evaluate(field, *args, **kwargs, params)`.
The default value of `params` is not considered during computation.
If the operator has no parameters, the `init_params` function can be omitted. In this case, the
`params` value is set to `None`.
Keyword arguments are not considered for dispatching.
Keyword arguments are defined after the `*` in the function signature.
For constant parameters, the `constants` function can be used:
!!! example
```python
@operator
def my_operator(x: FourierSeries, *, dx: float, params=None):
...
```
!!! example
```python
@operator(init_params=constants({"a": 1, "b": 2.0}))
def my_operator(x, *, params):
return x + params["a"] + params["b"]
```
The argument `params` is mandatory and it must be a keyword argument. It is used to pass the
parameters of the operator, for example the stencil coefficients of a finite difference operator.
The default value of the parameters is specified by the `init_params` function, as follows:
Args:
evaluate (Callable): A function with the signature `evaluate(field, *args, **kwargs, params)`.
It must return a tuple, with the first element being a field and the second
element being the default parameters for the operator.
init_params (Callable): A function that overrides the default parameters initializer for the
operator. Useful when running the operator just to get the parameters is expensive.
precedence (int): The precedence of the operator if an ambiguous match is found.
!!! example
```python
Returns:
Callable: The operator function with signature `evaluate(field, *args, **kwargs, params)`.
def params_initializer(x, *, dx):
return {"stencil": jnp.ones(x.shape) * dx}
"""
if evaluate is None:
# Returns the decorator
def decorator(evaluate):
return _operator(evaluate, precedence, init_params)
@operator(init_params=params_initializer)
def my_operator(x, *, dx, params=None):
b = params["stencil"] / dx
y_params = jnp.convolve(x.params, b, mode="same")
return x.replace_params(y_params)
```
return decorator
else:
return _operator(evaluate, precedence, init_params)
The default value of `params` is not considered during computation.
If the operator has no parameters, the `init_params` function can be omitted. In this case, the
`params` value is set to `None`.
def abstract(self, evaluate: Callable):
"""Decorator for defining abstract operators. This is mainly used
to define generic docstrings."""
return _abstract_operator(evaluate)
For constant parameters, the `constants` function can be used:
!!! example
```python
@operator(init_params=constants({"a": 1, "b": 2.0}))
def my_operator(x, *, params):
return x + params["a"] + params["b"]
```
operator = Operator()
"""


def discretization(cls):
Expand Down Expand Up @@ -237,7 +236,7 @@ def init_params(*args, **kwargs):
return init_params


class Field(JaxDFModule):
class Field(Module):
params: PyTree
domain: Domain

Expand Down
2 changes: 1 addition & 1 deletion jaxdf/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def f(N, dx):
return jnp.fft.fftfreq(N, dx) * 2 * jnp.pi

k_axis = [f(n, delta) for n, delta in zip(self.domain.N, self.domain.dx)]
if not self.is_field_complex:
if not self.is_complex:
k_axis[-1] = (jnp.fft.rfftfreq(self.domain.N[-1], self.domain.dx[-1]) *
2 * jnp.pi)
return k_axis
Expand Down
4 changes: 2 additions & 2 deletions jaxdf/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from jax import numpy as jnp
from jax import random

from .mods import JaxDFModule
from .mods import Module


class Domain(JaxDFModule):
class Domain(Module):
r"""Domain class describing a rectangular domain
Attributes:
Expand Down
16 changes: 8 additions & 8 deletions jaxdf/mods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
from jaxtyping import PyTree


class JaxDFModule(eqx.Module):
class Module(eqx.Module):
"""
A custom module inheriting from Equinox's Module class.
This module is designed to work with JAX and Equinox libraries, providing
functionalities that are specific to deep learning models and operations in JAX.
"""

def replace(self, name: str, value: PyTree):
Expand All @@ -23,12 +20,15 @@ def replace(self, name: str, value: PyTree):
compatible with JAX's PyTree structure.
Returns:
A new instance of JaxDFModule with the specified attribute updated.
A new instance of Module with the specified attribute updated.
The rest of the module's attributes remain unchanged.
Example:
>>> module = JaxDFModule(...)
>>> new_module = module.replace("weight", new_weight_value)
!!! example
```python
>>> module = jaxdf.Module(weight=1.0, bias=2.0)
>>> new_module = module.replace("weight", 3.0)
>>> new_module.weight == 3.0 # True
```
"""
f = lambda m: m.__getattribute__(name)
return eqx.tree_at(f, self, value)
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ nav:
- discretization: discretization.md
- exceptions: exceptions.md
- geometry: geometry.md
- mods: mods.md
- operators:
- differential: operators/differential.md
- functions: operators/functions.md
Expand Down
4 changes: 2 additions & 2 deletions tests/test_mods.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from jaxdf.mods import JaxDFModule
from jaxdf.mods import Module


def test_replace_params():

class TestModule(JaxDFModule):
class TestModule(Module):
a: float = 1.0
b: float = 2.0

Expand Down

0 comments on commit 11daadf

Please sign in to comment.