Skip to content

Commit

Permalink
Move max_message to correct level of abstraction. Updated version and…
Browse files Browse the repository at this point in the history
… requirements.txt.
  • Loading branch information
tomsch420 committed Nov 3, 2023
1 parent 5c86a52 commit 943b259
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 39 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
networkx>=3.2.1
networkx>=3.0
numpy>=1.24.4
random_events>=1.1.1
tabulate>=0.9.0
2 changes: 1 addition & 1 deletion src/fglib2/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.1.0'
__version__ = '1.1.1'
21 changes: 1 addition & 20 deletions src/fglib2/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from random_events.variables import Discrete, Symbolic
from random_events.variables import Discrete
from random_events.events import Event, EncodedEvent

import tabulate
Expand Down Expand Up @@ -205,25 +205,6 @@ def likelihood(self, event: List) -> float:
"""
return self._likelihood(self.encode(event))

def max_message(self, variable) -> 'Multinomial':
"""
Construct a message that contains the maximum likelihood for each value of the variable.
.. Note::
The message is not normalized. The reason is the purpose of a max message. In every entry of the
`probabilities` array is the maximum possible likelihood for the corresponding event. Therefore,
this message should not be normalized.
:param variable: The variable to construct it over.
:return: A not normalized distribution over the variable with the maximum likelihood for each value.
"""
if variable not in self.variables:
raise ValueError("The variable {} is not in the distribution."
"The distributions variables are {}".format(variable, self.variables))
axis = tuple(index for index, var in enumerate(self.variables) if var != variable)
probabilities = np.max(self.probabilities, axis=axis)
return Multinomial([variable], probabilities)

def _conditional(self, event: EncodedEvent) -> 'Multinomial':
"""
Calculate the conditional distribution given an event encoded.
Expand Down
23 changes: 21 additions & 2 deletions src/fglib2/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,25 @@ def __mul__(self, other: 'FactorNode') -> 'FactorGraph':
"""
return FactorGraph() * self * other

def max_message(self, variable) -> 'Multinomial':
"""
Construct a message that contains the maximum likelihood for each value of the variable.
.. Note::
The message is not normalized. The reason is the purpose of a max message. In every entry of the
`probabilities` array is the maximum possible likelihood for the corresponding event. Therefore,
this message should not be normalized.
:param variable: The variable to construct it over.
:return: A not normalized distribution over the variable with the maximum likelihood for each value.
"""
if variable not in self.variables:
raise ValueError("The variable {} is not in the distribution."
"The distributions variables are {}".format(variable, self.variables))
axis = tuple(index for index, var in enumerate(self.variables) if var != variable)
probabilities = np.max(self.distribution.probabilities, axis=axis)
return Multinomial([variable], probabilities)


class Edge:
"""
Expand Down Expand Up @@ -318,7 +337,7 @@ def max_product(self) -> Event:
self.neighbors(source) if neighbour != target]
msg = source.sum_product(incoming_messages)

msg = msg.max_message(target.variable)
msg = FactorNode(msg).max_message(target.variable)
self.edges[source, target]['edge'].factor_to_variable = msg

for source, target in backtracking_path:
Expand All @@ -345,7 +364,7 @@ def max_product(self) -> Event:
self.neighbors(source)]
msg = source.sum_product(incoming_messages)

msg = msg.max_message(target.variable)
msg = FactorNode(msg).max_message(target.variable)
self.edges[source, target]['edge'].factor_to_variable = msg

result = Event()
Expand Down
13 changes: 1 addition & 12 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,18 +160,6 @@ def test_multiple_modes(self):
self.assertEqual(mode[1]["X"], (1,))
self.assertEqual(mode[1]["Y"], (0,))

def test_random_max_message(self):
max_message = self.random_distribution.max_message(self.z)
self.assertEqual(max_message.probabilities.shape, (5, ))

def test_crafted_max_message(self):
max_message = self.crafted_distribution.max_message(self.x)
self.assertTrue(np.allclose(max_message.probabilities, np.array([0.3, 0.7])))

def test_max_message_wrong_variable(self):
with self.assertRaises(ValueError):
self.crafted_distribution.max_message(self.z)

def test_crafted_probability(self):
distribution = self.crafted_distribution.normalize()
event = Event()
Expand Down Expand Up @@ -270,5 +258,6 @@ def test_disjoint_variables(self):
with self.assertRaises(AssertionError):
self.distribution_x * self.distribution_y


if __name__ == '__main__':
unittest.main()
39 changes: 36 additions & 3 deletions test/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,39 @@
from random_events.variables import Symbolic


class FactorNodeTestCase(unittest.TestCase):
x: Symbolic
y: Symbolic
z: Symbolic
random_factor_node: FactorNode
crafted_factor_node: FactorNode

@classmethod
def setUpClass(cls):
np.random.seed(69)
cls.x = Symbolic("X", range(2))
cls.y = Symbolic("Y", range(3))
cls.z = Symbolic("Z", range(5))
cls.random_factor_node = FactorNode(Multinomial([cls.x, cls.y, cls.z], np.random.rand(len(cls.x.domain),
len(cls.y.domain),
len(cls.z.domain))))

cls.crafted_factor_node = (
FactorNode(Multinomial([cls.x, cls.y], np.array([[0.1, 0.2, 0.3], [0.7, 0.4, 0.1]]))))

def test_random_max_message(self):
max_message = self.random_factor_node.max_message(self.z)
self.assertEqual(max_message.probabilities.shape, (5, ))

def test_crafted_max_message(self):
max_message = self.crafted_factor_node.max_message(self.x)
self.assertTrue(np.allclose(max_message.probabilities, np.array([0.3, 0.7])))

def test_max_message_wrong_variable(self):
with self.assertRaises(ValueError):
self.crafted_factor_node.max_message(self.z)


class FactorGraphTestCase(unittest.TestCase):
x: Symbolic
y: Symbolic
Expand Down Expand Up @@ -187,9 +220,9 @@ def test_latex_equation(self):
def test_max_product(self):
mode = self.graph.max_product()
self.assertEqual(mode[self.x1], (0, 1))
self.assertEqual(mode[self.x2], (0, ))
self.assertEqual(mode[self.x3], (1, ))
self.assertEqual(mode[self.x4], (1, ))
self.assertEqual(mode[self.x2], (0,))
self.assertEqual(mode[self.x3], (1,))
self.assertEqual(mode[self.x4], (1,))


if __name__ == "__main__":
Expand Down

0 comments on commit 943b259

Please sign in to comment.