diff --git a/docs/source/api.rst b/docs/source/api.rst index bb689ff68..6f2c70604 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -89,7 +89,7 @@ Prediction error steps Compute the value and volatility prediction errors of a given node. The prediction error can only be computed after the posterior update (or observation) of a given node. Binary state nodes -^^^^^^^^^^^^^^^^^^ +------------------ .. currentmodule:: pyhgf.updates.prediction_error.binary @@ -100,7 +100,7 @@ Binary state nodes binary_finite_state_node_prediction_error Categorical state nodes -^^^^^^^^^^^^^^^^^^^^^^^ +----------------------- .. currentmodule:: pyhgf.updates.prediction_error.categorical @@ -110,7 +110,7 @@ Categorical state nodes categorical_state_prediction_error Continuous state nodes -^^^^^^^^^^^^^^^^^^^^^^ +---------------------- .. currentmodule:: pyhgf.updates.prediction_error.continuous @@ -122,7 +122,7 @@ Continuous state nodes continuous_node_prediction_error Dirichlet state nodes -^^^^^^^^^^^^^^^^^^^^^ +--------------------- .. currentmodule:: pyhgf.updates.prediction_error.dirichlet @@ -137,7 +137,7 @@ Dirichlet state nodes clusters_likelihood Exponential family -^^^^^^^^^^^^^^^^^^ +------------------ .. currentmodule:: pyhgf.updates.prediction_error.exponential diff --git a/pyhgf/updates/posterior/continuous/__init__.py b/pyhgf/updates/posterior/continuous/__init__.py index fd8740754..a87285b60 100644 --- a/pyhgf/updates/posterior/continuous/__init__.py +++ b/pyhgf/updates/posterior/continuous/__init__.py @@ -3,6 +3,5 @@ __all__ = [ "continuous_node_posterior_update_ehgf", - "continuous_node_posterior_update_unbounded", "continuous_node_posterior_update", ] diff --git a/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node.py b/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node.py index 9ede2d017..5bfc7f066 100644 --- a/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node.py +++ b/pyhgf/updates/posterior/continuous/posterior_update_precision_continuous_node.py @@ -5,6 +5,8 @@ import jax.numpy as jnp from jax import grad, jit +from jax.lax import cond +from jax.tree_util import Partial from pyhgf.typing import Edges @@ -31,7 +33,7 @@ def posterior_update_precision_continuous_node( Where :math:`\kappa_j` is the volatility coupling strength between the child node and the state node and :math:`\delta_j^{(k)}` is the value prediction error that - was computed before hand by + was computed beforehand by :py:func:`pyhgf.updates.prediction_errors.continuous.continuous_node_value_prediction_error`. For non-linear value coupling: @@ -80,8 +82,9 @@ def posterior_update_precision_continuous_node( The attributes of the probabilistic nodes. edges : The edges of the probabilistic nodes as a tuple of - :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as node number. - For each node, the index list value and volatility parents and children. + :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number + of nodes. For each node, the index lists the value and volatility parents and + children. node_idx : Pointer to the value parent node that will be updated. time_step : @@ -108,6 +111,60 @@ def posterior_update_precision_continuous_node( Mathys, C. (2023). The generalized Hierarchical Gaussian Filter (Version 1). arXiv. https://doi.org/10.48550/ARXIV.2305.10937 + """ + # ---------------------------------------------------------------------------------- + # Decide which update to use depending on the presence of observed value in the + # children nodes. If no values were observed, the precision should increase + # as a function of time using the function precision_missing_values(). Otherwise, + # we use regular HGF updates for value and volatility couplings. + # ---------------------------------------------------------------------------------- + + # For all children, get the `observed` flag - if all these values are 0.0, the node + # has not received any observations and we should call precision_missing_values() + observations = [] + if edges[node_idx].value_children is not None: + for children_idx in edges[node_idx].value_children: # type: ignore + observations.append(attributes[children_idx]["observed"]) + if edges[node_idx].volatility_children is not None: + for children_idx in edges[node_idx].volatility_children: # type: ignore + observations.append(attributes[children_idx]["observed"]) + observations = jnp.any(jnp.array(observations)) + + posterior_precision = cond( + observations, + Partial(precision_update, edges=edges, node_idx=node_idx), + Partial(precision_update_missing_values, edges=edges, node_idx=node_idx), + attributes, + ) + + return posterior_precision + + +@partial(jit, static_argnames=("edges", "node_idx")) +def precision_update(attributes: Dict, edges: Edges, node_idx: int) -> float: + """Compute new precision in the case of observed values. + + Parameters + ---------- + attributes : + The attributes of the probabilistic nodes. + edges : + The edges of the probabilistic nodes as a tuple of + :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number + of nodes. For each node, the index lists the value and volatility parents and + children. + node_idx : + Pointer to the value parent node that will be updated. + time_step : + The time elapsed between this observation and the previous one. + + Returns + ------- + posterior_precision : + The new posterior precision when at least one of the children has + observed a new value. We then use the regular HGF update for volatility + coupling. + """ # sum the prediction errors from both value and volatility coupling precision_weigthed_prediction_error = 0.0 @@ -177,13 +234,41 @@ def posterior_update_precision_continuous_node( ) # ensure the new precision is greater than 0 - observed_posterior_precision = jnp.where( + posterior_precision = jnp.where( posterior_precision > 1e-128, posterior_precision, jnp.nan ) - # additionnal steps for unobserved values - # --------------------------------------- + return posterior_precision + +@partial(jit, static_argnames=("edges", "node_idx")) +def precision_update_missing_values( + attributes: Dict, edges: Edges, node_idx: int +) -> float: + """Compute new precision in the case of missing observations. + + Parameters + ---------- + attributes : + The attributes of the probabilistic nodes. + edges : + The edges of the probabilistic nodes as a tuple of + :py:class:`pyhgf.typing.Indexes`. The tuple has the same length as the number + of nodes. For each node, the index lists the value and volatility parents and + children. + node_idx : + Pointer to the value parent node that will be updated. + time_step : + The time elapsed between this observation and the previous one. + + Returns + ------- + posterior_precision_missing_values : + The new posterior precision in the case of missing values in all child nodes. + The new precision decreases proportionally to the time elapsed, accounting for + the influence of volatility parents. + + """ # List the node's volatility parents volatility_parents_idxs = edges[node_idx].volatility_parents @@ -201,29 +286,13 @@ def posterior_update_precision_continuous_node( volatility_coupling * attributes[volatility_parents_idx]["mean"] ) - # compute the predicted_volatility from the total volatility + # compute the new predicted_volatility from the total volatility time_step = attributes[-1]["time_step"] predicted_volatility = time_step * jnp.exp(total_volatility) # Estimate the new precision for the continuous state node - unobserved_posterior_precision = 1 / ( + posterior_precision_missing_values = 1 / ( (1 / attributes[node_idx]["precision"]) + predicted_volatility ) - # for all children, look at the values of VAPE - # if all these values are NaNs, the node has not received observations - observations = [] - if edges[node_idx].value_children is not None: - for children_idx in edges[node_idx].value_children: # type: ignore - observations.append(attributes[children_idx]["observed"]) - if edges[node_idx].volatility_children is not None: - for children_idx in edges[node_idx].volatility_children: # type: ignore - observations.append(attributes[children_idx]["observed"]) - observations = jnp.any(jnp.array(observations)) - - posterior_precision = ( - unobserved_posterior_precision * (1 - observations) # type: ignore - + observed_posterior_precision * observations - ) - - return posterior_precision + return posterior_precision_missing_values diff --git a/tests/test_updates/posterior/continuous.py b/tests/test_updates/posterior/continuous.py index 6bc59f396..6982a63cb 100644 --- a/tests/test_updates/posterior/continuous.py +++ b/tests/test_updates/posterior/continuous.py @@ -1,10 +1,11 @@ # Author: Nicolas Legrand +import jax.numpy as jnp + from pyhgf.model import Network from pyhgf.updates.posterior.continuous import ( continuous_node_posterior_update, continuous_node_posterior_update_ehgf, - continuous_node_posterior_update_unbounded, ) @@ -20,17 +21,56 @@ def test_continuous_posterior_updates(): # Standard HGF updates ------------------------------------------------------------- # ---------------------------------------------------------------------------------- + + # value update + attributes, edges, _ = network.get_network() + attributes[0]["temp"]["value_prediction_error"] = 1.0357 + attributes[0]["mean"] = 1.0357 + + new_attributes = continuous_node_posterior_update( + attributes=attributes, node_idx=1, edges=edges + ) + assert jnp.isclose(new_attributes[1]["mean"], 0.51785) + + # volatility update attributes, edges, _ = network.get_network() - _ = continuous_node_posterior_update(attributes=attributes, node_idx=2, edges=edges) + attributes[1]["temp"]["effective_precision"] = 0.01798621006309986 + attributes[1]["temp"]["value_prediction_error"] = 0.5225493907928467 + attributes[1]["temp"]["volatility_prediction_error"] = -0.23639076948165894 + attributes[1]["expected_precision"] = 0.9820137619972229 + attributes[1]["mean"] = 0.5225493907928467 + attributes[1]["precision"] = 1.9820137023925781 + + new_attributes = continuous_node_posterior_update( + attributes=attributes, node_idx=2, edges=edges + ) + assert jnp.isclose(new_attributes[1]["mean"], -0.0021212) + assert jnp.isclose(new_attributes[1]["precision"], 1.0022112) # eHGF updates --------------------------------------------------------------------- # ---------------------------------------------------------------------------------- - _ = continuous_node_posterior_update_ehgf( + + # value update + attributes, edges, _ = network.get_network() + attributes[0]["temp"]["value_prediction_error"] = 1.0357 + attributes[0]["mean"] = 1.0357 + + new_attributes = continuous_node_posterior_update_ehgf( attributes=attributes, node_idx=2, edges=edges ) + assert jnp.isclose(new_attributes[1]["mean"], 0.51785) - # unbounded updates ---------------------------------------------------------------- - # ---------------------------------------------------------------------------------- - _ = continuous_node_posterior_update_unbounded( + # volatility update + attributes, edges, _ = network.get_network() + attributes[1]["temp"]["effective_precision"] = 0.01798621006309986 + attributes[1]["temp"]["value_prediction_error"] = 0.5225493907928467 + attributes[1]["temp"]["volatility_prediction_error"] = -0.23639076948165894 + attributes[1]["expected_precision"] = 0.9820137619972229 + attributes[1]["mean"] = 0.5225493907928467 + attributes[1]["precision"] = 1.9820137023925781 + + new_attributes = continuous_node_posterior_update_ehgf( attributes=attributes, node_idx=2, edges=edges ) + assert jnp.isclose(new_attributes[1]["mean"], -0.00212589) + assert jnp.isclose(new_attributes[1]["precision"], 1.0022112)