Skip to content

Commit

Permalink
initial update for-loop added to notebook. Mean of new parent node ad…
Browse files Browse the repository at this point in the history
…ded to add_parent function
  • Loading branch information
LouieMH authored and LegrandNico committed Dec 17, 2024
1 parent 6f970fe commit 1f7b8c2
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 57 deletions.
219 changes: 169 additions & 50 deletions docs/source/notebooks/Latent_var_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,52 +31,41 @@
},
{
"cell_type": "code",
"execution_count": 45,
"execution_count": 70,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'pyhgf.updates.structure'",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[45], line 12\u001b[0m\n\u001b[0;32m 10\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpyhgf\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mresponse\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m first_level_gaussian_surprise\n\u001b[0;32m 11\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpyhgf\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m beliefs_propagation\n\u001b[1;32m---> 12\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpyhgf\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mupdates\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mstructure\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m add_parent\n\u001b[0;32m 15\u001b[0m plt\u001b[38;5;241m.\u001b[39mrcParams[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfigure.constrained_layout.use\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
"\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'pyhgf.updates.structure'"
]
}
],
"outputs": [],
"source": [
"import arviz as az\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"import pymc as pm\n",
"import numpy as np\n",
"import jax\n",
"\n",
"from pyhgf import load_data\n",
"from pyhgf.distribution import HGFDistribution\n",
"from pyhgf.model import HGF, Network\n",
"from pyhgf.response import first_level_gaussian_surprise\n",
"from pyhgf.utils import beliefs_propagation\n",
"from pyhgf.updates.structure import add_parent\n",
"# from pyhgf.updates.structure import add_parent\n",
"\n",
"\n",
"plt.rcParams[\"figure.constrained_layout.use\"] = True"
]
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"# Disable JIT compilation globally\n",
"# jax.config.update(\"jax_disable_jit\", False) # True - If I want the compiler disabled."
"jax.config.update(\"jax_disable_jit\", False) # True - If I want the compiler disabled."
]
},
{
"cell_type": "code",
"execution_count": 47,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -97,34 +86,19 @@
"source": [
"timeserie = load_data(\"continuous\")\n",
"\n",
"# latent_hgf = (\n",
"# Network()\n",
"# .add_nodes(precision=1e4)\n",
"# .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, value_children=0)\n",
"# .add_nodes(precision=1e1, tonic_volatility=-2.0, volatility_children=0)\n",
"# .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, value_children=1)\n",
"# # .add_nodes(precision=1e1, tonic_volatility=-2.0, volatility_children=1)\n",
"# # .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, value_children=2)\n",
"# # .add_nodes(precision=1e1, tonic_volatility=-2.0, volatility_children=2)\n",
"# ).create_belief_propagation_fn()\n",
"\n",
"latent_hgf = (\n",
" Network()\n",
" .add_nodes(precision=1e4)\n",
" .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, value_children=0)\n",
" .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, \n",
" value_children=0)\n",
" .add_nodes(precision=1e1, tonic_volatility=-2.0, value_children=1)\n",
" .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, value_children=0)\n",
" .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, \n",
" value_children=0)\n",
" # .add_nodes(precision=1e1, tonic_volatility=-2.0, volatility_children=1)\n",
" # .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, value_children=2)\n",
" # .add_nodes(precision=1e4, mean=timeserie[0], tonic_volatility=-13.0, \n",
" # value_children=2)\n",
" # .add_nodes(precision=1e1, tonic_volatility=-2.0, volatility_children=2)\n",
").create_belief_propagation_fn()\n",
"\n",
"attributes, edges, update_sequence = (\n",
" latent_hgf.get_network()\n",
")\n",
"\n",
"print(len(attributes))\n",
"print(len(edges))"
").create_belief_propagation_fn()"
]
},
{
Expand Down Expand Up @@ -218,20 +192,20 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"attributes, edges, update_sequence = (\n",
" latent_hgf.get_network()\n",
")\n",
"\n",
"latent_hgf_alt_attributes, latent_hgf_alt_edges = add_parent(attributes, edges, 3, 'volatility')"
"new_hgf_attributes, new_hgf_edges = add_parent(attributes, edges, 3, 'volatility')"
]
},
{
"cell_type": "code",
"execution_count": 50,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -244,23 +218,23 @@
}
],
"source": [
"print(len(latent_hgf_alt_attributes))\n",
"print(len(latent_hgf_alt_edges))"
"print(len(new_hgf_attributes))\n",
"print(len(new_hgf_edges))"
]
},
{
"cell_type": "code",
"execution_count": 51,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"latent_hgf.attributes = latent_hgf_alt_attributes\n",
"latent_hgf.edges = latent_hgf_alt_edges"
"latent_hgf.attributes = new_hgf_attributes\n",
"latent_hgf.edges = new_hgf_edges"
]
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -344,7 +318,152 @@
}
],
"source": [
"latent_hgf.plot_network() # Not sure why the plot function doesn't function with altered Attributes and Edges..."
"latent_hgf.plot_network()"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(123)\n",
"dist_mean, dist_std = 5, 1\n",
"input_data = np.random.normal(loc=dist_mean, scale=dist_std, size=1000)"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"# for each observation\n",
"for value in input_data:\n",
"\n",
" # interleave observations and masks\n",
" data = (value, 1.0, 1.0)\n",
"\n",
" # update the probabilistic network\n",
" attributes, _ = beliefs_propagation(\n",
" attributes=attributes,\n",
" inputs=data,\n",
" update_sequence=update_sequence,\n",
" edges=edges,\n",
" input_idxs=latent_hgf.input_idxs\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{-1: {'time_step': Array(1., dtype=float32, weak_type=True)},\n",
" 0: {'autoconnection_strength': Array(0., dtype=float32, weak_type=True),\n",
" 'expected_mean': Array(5.1438637, dtype=float32),\n",
" 'expected_precision': Array(10000., dtype=float32, weak_type=True),\n",
" 'mean': Array(3.8885696, dtype=float32),\n",
" 'observed': Array(1., dtype=float32, weak_type=True),\n",
" 'precision': Array(10000., dtype=float32, weak_type=True),\n",
" 'temp': {'effective_precision': Array(0.9999, dtype=float32, weak_type=True),\n",
" 'value_prediction_error': Array(-0.62764704, dtype=float32),\n",
" 'volatility_prediction_error': Array(3939.408, dtype=float32)},\n",
" 'tonic_drift': Array(0., dtype=float32, weak_type=True),\n",
" 'tonic_volatility': Array(0., dtype=float32, weak_type=True),\n",
" 'value_coupling_children': None,\n",
" 'value_coupling_parents': (Array(1., dtype=float32, weak_type=True),\n",
" Array(1., dtype=float32, weak_type=True)),\n",
" 'volatility_coupling_children': None,\n",
" 'volatility_coupling_parents': None},\n",
" 1: {'autoconnection_strength': Array(1., dtype=float32, weak_type=True),\n",
" 'expected_mean': Array(4.1078835, dtype=float32),\n",
" 'expected_precision': Array(61701.812, dtype=float32, weak_type=True),\n",
" 'mean': Array(4.0203476, dtype=float32),\n",
" 'observed': Array(1, dtype=int32, weak_type=True),\n",
" 'precision': Array(71701.81, dtype=float32, weak_type=True),\n",
" 'temp': {'effective_precision': Array(0.13946642, dtype=float32, weak_type=True),\n",
" 'value_prediction_error': Array(-0.08753586, dtype=float32),\n",
" 'volatility_prediction_error': Array(472.65228, dtype=float32)},\n",
" 'tonic_drift': Array(0., dtype=float32, weak_type=True),\n",
" 'tonic_volatility': Array(-13., dtype=float32, weak_type=True),\n",
" 'value_coupling_children': (Array(1., dtype=float32, weak_type=True),),\n",
" 'value_coupling_parents': (Array(1., dtype=float32, weak_type=True),),\n",
" 'volatility_coupling_children': None,\n",
" 'volatility_coupling_parents': None},\n",
" 2: {'autoconnection_strength': Array(1., dtype=float32, weak_type=True),\n",
" 'expected_mean': Array(0.34866714, dtype=float32),\n",
" 'expected_precision': Array(7.388171, dtype=float32, weak_type=True),\n",
" 'mean': Array(0.26114178, dtype=float32),\n",
" 'observed': Array(1, dtype=int32, weak_type=True),\n",
" 'precision': Array(61709.2, dtype=float32, weak_type=True),\n",
" 'temp': {'effective_precision': Array(0.99988025, dtype=float32, weak_type=True),\n",
" 'value_prediction_error': Array(0., dtype=float32, weak_type=True),\n",
" 'volatility_prediction_error': Array(0., dtype=float32, weak_type=True)},\n",
" 'tonic_drift': Array(0., dtype=float32, weak_type=True),\n",
" 'tonic_volatility': Array(-2., dtype=float32, weak_type=True),\n",
" 'value_coupling_children': (Array(1., dtype=float32, weak_type=True),),\n",
" 'value_coupling_parents': None,\n",
" 'volatility_coupling_children': None,\n",
" 'volatility_coupling_parents': None},\n",
" 3: {'autoconnection_strength': Array(1., dtype=float32, weak_type=True),\n",
" 'expected_mean': Array(1.3846478, dtype=float32),\n",
" 'expected_precision': Array(61701.812, dtype=float32, weak_type=True),\n",
" 'mean': Array(1.2971121, dtype=float32),\n",
" 'observed': Array(1, dtype=int32, weak_type=True),\n",
" 'precision': Array(71701.81, dtype=float32, weak_type=True),\n",
" 'temp': {'effective_precision': Array(0.13946642, dtype=float32, weak_type=True),\n",
" 'value_prediction_error': Array(0., dtype=float32, weak_type=True),\n",
" 'volatility_prediction_error': Array(0., dtype=float32, weak_type=True)},\n",
" 'tonic_drift': Array(0., dtype=float32, weak_type=True),\n",
" 'tonic_volatility': Array(-13., dtype=float32, weak_type=True),\n",
" 'value_coupling_children': (Array(1., dtype=float32, weak_type=True),),\n",
" 'value_coupling_parents': None,\n",
" 'volatility_coupling_children': None,\n",
" 'volatility_coupling_parents': (Array(1., dtype=float32, weak_type=True),\n",
" Array(1., dtype=float32, weak_type=True))},\n",
" 4: {'autoconnection_strength': Array(1., dtype=float32, weak_type=True),\n",
" 'expected_mean': Array(0., dtype=float32, weak_type=True),\n",
" 'expected_precision': Array(1., dtype=float32, weak_type=True),\n",
" 'mean': Array(0., dtype=float32, weak_type=True),\n",
" 'observed': Array(1, dtype=int32, weak_type=True),\n",
" 'precision': Array(1., dtype=float32, weak_type=True),\n",
" 'temp': {'effective_precision': Array(0., dtype=float32, weak_type=True),\n",
" 'value_prediction_error': Array(0., dtype=float32, weak_type=True),\n",
" 'volatility_prediction_error': Array(0., dtype=float32, weak_type=True)},\n",
" 'tonic_drift': Array(0., dtype=float32, weak_type=True),\n",
" 'tonic_volatility': Array(-4., dtype=float32, weak_type=True),\n",
" 'value_coupling_children': None,\n",
" 'value_coupling_parents': None,\n",
" 'volatility_coupling_children': (Array(1., dtype=float32, weak_type=True),),\n",
" 'volatility_coupling_parents': None},\n",
" 5: {'autoconnection_strength': Array(1., dtype=float32, weak_type=True),\n",
" 'expected_mean': Array(0., dtype=float32, weak_type=True),\n",
" 'expected_precision': Array(1., dtype=float32, weak_type=True),\n",
" 'mean': Array(0., dtype=float32, weak_type=True),\n",
" 'observed': Array(1, dtype=int32, weak_type=True),\n",
" 'precision': Array(1., dtype=float32, weak_type=True),\n",
" 'temp': {'effective_precision': Array(0., dtype=float32, weak_type=True),\n",
" 'value_prediction_error': Array(0., dtype=float32, weak_type=True),\n",
" 'volatility_prediction_error': Array(0., dtype=float32, weak_type=True)},\n",
" 'tonic_drift': Array(0., dtype=float32, weak_type=True),\n",
" 'tonic_volatility': Array(-4., dtype=float32, weak_type=True),\n",
" 'value_coupling_children': None,\n",
" 'value_coupling_parents': None,\n",
" 'volatility_coupling_children': (Array(1., dtype=float32, weak_type=True),),\n",
" 'volatility_coupling_parents': None}}"
]
},
"execution_count": 77,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"attributes"
]
}
],
Expand Down
16 changes: 9 additions & 7 deletions pyhgf/updates/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@


def add_parent(
attributes: Dict, edges: Edges, index: int, coupling_type: str
attributes: Dict, edges: Edges, index: int, coupling_type: str, mean: float
) -> Tuple[Dict, Edges]:
r"""Add new continuous-state parent node to the attributes and edges of an existing
network.
r"""Add a new continuous-state parent node to the attributes and edges of an
existing network.
Parameters
----------
Expand All @@ -21,8 +21,10 @@ def add_parent(
index :
The index of the node you want to connect a new parent node to.
coupling_type :
The type of coupling you want between the existing node and it's new parent. Can
be either "value" or "volatility".
The type of coupling you want between the existing node and it's new parent.
Can be either "value" or "volatility".
mean :
The mean value of the new parent node.
Returns
-------
Expand All @@ -37,8 +39,8 @@ def add_parent(

# Add new node to attributes
attributes[new_node_idx] = {
"mean": 0.0,
"expected_mean": 0.0,
"mean": mean,
"expected_mean": mean,
"precision": 1.0,
"expected_precision": 1.0,
"volatility_coupling_children": None,
Expand Down

0 comments on commit 1f7b8c2

Please sign in to comment.