Skip to content

Commit

Permalink
Split posterior precision update into two branches (#263)
Browse files Browse the repository at this point in the history
* split posterior update of precision into two branches

* add tests

* fix api docs
  • Loading branch information
LegrandNico authored Dec 18, 2024
1 parent 80ab968 commit 043beed
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 37 deletions.
10 changes: 5 additions & 5 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ Prediction error steps
Compute the value and volatility prediction errors of a given node. The prediction error can only be computed after the posterior update (or observation) of a given node.

Binary state nodes
^^^^^^^^^^^^^^^^^^
------------------

.. currentmodule:: pyhgf.updates.prediction_error.binary

Expand All @@ -100,7 +100,7 @@ Binary state nodes
binary_finite_state_node_prediction_error

Categorical state nodes
^^^^^^^^^^^^^^^^^^^^^^^
-----------------------

.. currentmodule:: pyhgf.updates.prediction_error.categorical

Expand All @@ -110,7 +110,7 @@ Categorical state nodes
categorical_state_prediction_error

Continuous state nodes
^^^^^^^^^^^^^^^^^^^^^^
----------------------

.. currentmodule:: pyhgf.updates.prediction_error.continuous

Expand All @@ -122,7 +122,7 @@ Continuous state nodes
continuous_node_prediction_error

Dirichlet state nodes
^^^^^^^^^^^^^^^^^^^^^
---------------------

.. currentmodule:: pyhgf.updates.prediction_error.dirichlet

Expand All @@ -137,7 +137,7 @@ Dirichlet state nodes
clusters_likelihood

Exponential family
^^^^^^^^^^^^^^^^^^
------------------

.. currentmodule:: pyhgf.updates.prediction_error.exponential

Expand Down
1 change: 0 additions & 1 deletion pyhgf/updates/posterior/continuous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@

__all__ = [
"continuous_node_posterior_update_ehgf",
"continuous_node_posterior_update_unbounded",
"continuous_node_posterior_update",
]
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import jax.numpy as jnp
from jax import grad, jit
from jax.lax import cond
from jax.tree_util import Partial

from pyhgf.typing import Edges

Expand All @@ -31,7 +33,7 @@ def posterior_update_precision_continuous_node(
Where :math:`\kappa_j` is the volatility coupling strength between the child node
and the state node and :math:`\delta_j^{(k)}` is the value prediction error that
was computed before hand by
was computed beforehand by
:py:func:`pyhgf.updates.prediction_errors.continuous.continuous_node_value_prediction_error`.
For non-linear value coupling:
Expand Down Expand Up @@ -80,8 +82,9 @@ def posterior_update_precision_continuous_node(
The attributes of the probabilistic nodes.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number.
For each node, the index list value and volatility parents and children.
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number
of nodes. For each node, the index lists the value and volatility parents and
children.
node_idx :
Pointer to the value parent node that will be updated.
time_step :
Expand All @@ -108,6 +111,60 @@ def posterior_update_precision_continuous_node(
Mathys, C. (2023). The generalized Hierarchical Gaussian Filter (Version 1).
arXiv. https://doi.org/10.48550/ARXIV.2305.10937
"""
# ----------------------------------------------------------------------------------
# Decide which update to use depending on the presence of observed value in the
# children nodes. If no values were observed, the precision should increase
# as a function of time using the function precision_missing_values(). Otherwise,
# we use regular HGF updates for value and volatility couplings.
# ----------------------------------------------------------------------------------

# For all children, get the `observed` flag - if all these values are 0.0, the node
# has not received any observations and we should call precision_missing_values()
observations = []
if edges[node_idx].value_children is not None:
for children_idx in edges[node_idx].value_children: # type: ignore
observations.append(attributes[children_idx]["observed"])
if edges[node_idx].volatility_children is not None:
for children_idx in edges[node_idx].volatility_children: # type: ignore
observations.append(attributes[children_idx]["observed"])
observations = jnp.any(jnp.array(observations))

posterior_precision = cond(
observations,
Partial(precision_update, edges=edges, node_idx=node_idx),
Partial(precision_update_missing_values, edges=edges, node_idx=node_idx),
attributes,
)

return posterior_precision


@partial(jit, static_argnames=("edges", "node_idx"))
def precision_update(attributes: Dict, edges: Edges, node_idx: int) -> float:
"""Compute new precision in the case of observed values.
Parameters
----------
attributes :
The attributes of the probabilistic nodes.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number
of nodes. For each node, the index lists the value and volatility parents and
children.
node_idx :
Pointer to the value parent node that will be updated.
time_step :
The time elapsed between this observation and the previous one.
Returns
-------
posterior_precision :
The new posterior precision when at least one of the children has
observed a new value. We then use the regular HGF update for volatility
coupling.
"""
# sum the prediction errors from both value and volatility coupling
precision_weigthed_prediction_error = 0.0
Expand Down Expand Up @@ -177,13 +234,41 @@ def posterior_update_precision_continuous_node(
)

# ensure the new precision is greater than 0
observed_posterior_precision = jnp.where(
posterior_precision = jnp.where(
posterior_precision > 1e-128, posterior_precision, jnp.nan
)

# additionnal steps for unobserved values
# ---------------------------------------
return posterior_precision


@partial(jit, static_argnames=("edges", "node_idx"))
def precision_update_missing_values(
attributes: Dict, edges: Edges, node_idx: int
) -> float:
"""Compute new precision in the case of missing observations.
Parameters
----------
attributes :
The attributes of the probabilistic nodes.
edges :
The edges of the probabilistic nodes as a tuple of
:py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number
of nodes. For each node, the index lists the value and volatility parents and
children.
node_idx :
Pointer to the value parent node that will be updated.
time_step :
The time elapsed between this observation and the previous one.
Returns
-------
posterior_precision_missing_values :
The new posterior precision in the case of missing values in all child nodes.
The new precision decreases proportionally to the time elapsed, accounting for
the influence of volatility parents.
"""
# List the node's volatility parents
volatility_parents_idxs = edges[node_idx].volatility_parents

Expand All @@ -201,29 +286,13 @@ def posterior_update_precision_continuous_node(
volatility_coupling * attributes[volatility_parents_idx]["mean"]
)

# compute the predicted_volatility from the total volatility
# compute the new predicted_volatility from the total volatility
time_step = attributes[-1]["time_step"]
predicted_volatility = time_step * jnp.exp(total_volatility)

# Estimate the new precision for the continuous state node
unobserved_posterior_precision = 1 / (
posterior_precision_missing_values = 1 / (
(1 / attributes[node_idx]["precision"]) + predicted_volatility
)

# for all children, look at the values of VAPE
# if all these values are NaNs, the node has not received observations
observations = []
if edges[node_idx].value_children is not None:
for children_idx in edges[node_idx].value_children: # type: ignore
observations.append(attributes[children_idx]["observed"])
if edges[node_idx].volatility_children is not None:
for children_idx in edges[node_idx].volatility_children: # type: ignore
observations.append(attributes[children_idx]["observed"])
observations = jnp.any(jnp.array(observations))

posterior_precision = (
unobserved_posterior_precision * (1 - observations) # type: ignore
+ observed_posterior_precision * observations
)

return posterior_precision
return posterior_precision_missing_values
52 changes: 46 additions & 6 deletions tests/test_updates/posterior/continuous.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Author: Nicolas Legrand <nicolas.legrand@cas.au.dk>

import jax.numpy as jnp

from pyhgf.model import Network
from pyhgf.updates.posterior.continuous import (
continuous_node_posterior_update,
continuous_node_posterior_update_ehgf,
continuous_node_posterior_update_unbounded,
)


Expand All @@ -20,17 +21,56 @@ def test_continuous_posterior_updates():

# Standard HGF updates -------------------------------------------------------------
# ----------------------------------------------------------------------------------

# value update
attributes, edges, _ = network.get_network()
attributes[0]["temp"]["value_prediction_error"] = 1.0357
attributes[0]["mean"] = 1.0357

new_attributes = continuous_node_posterior_update(
attributes=attributes, node_idx=1, edges=edges
)
assert jnp.isclose(new_attributes[1]["mean"], 0.51785)

# volatility update
attributes, edges, _ = network.get_network()
_ = continuous_node_posterior_update(attributes=attributes, node_idx=2, edges=edges)
attributes[1]["temp"]["effective_precision"] = 0.01798621006309986
attributes[1]["temp"]["value_prediction_error"] = 0.5225493907928467
attributes[1]["temp"]["volatility_prediction_error"] = -0.23639076948165894
attributes[1]["expected_precision"] = 0.9820137619972229
attributes[1]["mean"] = 0.5225493907928467
attributes[1]["precision"] = 1.9820137023925781

new_attributes = continuous_node_posterior_update(
attributes=attributes, node_idx=2, edges=edges
)
assert jnp.isclose(new_attributes[1]["mean"], -0.0021212)
assert jnp.isclose(new_attributes[1]["precision"], 1.0022112)

# eHGF updates ---------------------------------------------------------------------
# ----------------------------------------------------------------------------------
_ = continuous_node_posterior_update_ehgf(

# value update
attributes, edges, _ = network.get_network()
attributes[0]["temp"]["value_prediction_error"] = 1.0357
attributes[0]["mean"] = 1.0357

new_attributes = continuous_node_posterior_update_ehgf(
attributes=attributes, node_idx=2, edges=edges
)
assert jnp.isclose(new_attributes[1]["mean"], 0.51785)

# unbounded updates ----------------------------------------------------------------
# ----------------------------------------------------------------------------------
_ = continuous_node_posterior_update_unbounded(
# volatility update
attributes, edges, _ = network.get_network()
attributes[1]["temp"]["effective_precision"] = 0.01798621006309986
attributes[1]["temp"]["value_prediction_error"] = 0.5225493907928467
attributes[1]["temp"]["volatility_prediction_error"] = -0.23639076948165894
attributes[1]["expected_precision"] = 0.9820137619972229
attributes[1]["mean"] = 0.5225493907928467
attributes[1]["precision"] = 1.9820137023925781

new_attributes = continuous_node_posterior_update_ehgf(
attributes=attributes, node_idx=2, edges=edges
)
assert jnp.isclose(new_attributes[1]["mean"], -0.00212589)
assert jnp.isclose(new_attributes[1]["precision"], 1.0022112)

0 comments on commit 043beed

Please sign in to comment.