Skip to content

Commit

Permalink
Merge pull request #20 from will7200/feat/ab_graphing
Browse files Browse the repository at this point in the history
feat: add stateful graph traversal
  • Loading branch information
cmutel authored Sep 9, 2024
2 parents 38c01ed + 24caeaa commit bda02d2
Show file tree
Hide file tree
Showing 14 changed files with 1,558 additions and 376 deletions.
25 changes: 25 additions & 0 deletions bw_graph_tools/graph_traversal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
__all__ = (
"AssumedDiagonalGraphTraversal",
"Edge",
"Flow",
"NewNodeEachVisitGraphTraversal",
"NewNodeEachVisitTaggedGraphTraversal",
"Node",
"SameNodeEachVisitGraphTraversal",
"SameNodeEachVisitTaggedGraphTraversal",
"SupplyChainTraversalSettings",
"TaggedSupplyChainTraversalSettings",
)

from .assumed_diagonal import AssumedDiagonalGraphTraversal
from .graph_objects import Edge, Flow, Node
from .new_node_each_visit import (
NewNodeEachVisitGraphTraversal,
SupplyChainTraversalSettings,
)
from .same_node_each_visit import SameNodeEachVisitGraphTraversal
from .tagged_nodes import (
NewNodeEachVisitTaggedGraphTraversal,
SameNodeEachVisitTaggedGraphTraversal,
TaggedSupplyChainTraversalSettings,
)
29 changes: 29 additions & 0 deletions bw_graph_tools/graph_traversal/assumed_diagonal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import matrix_utils as mu
import numpy as np

from .new_node_each_visit import NewNodeEachVisitGraphTraversal


class AssumedDiagonalGraphTraversal(NewNodeEachVisitGraphTraversal):
@classmethod
def get_production_exchanges(cls, mapped_matrix: mu.MappedMatrix) -> (np.ndarray, np.ndarray):
"""
Assume production exchanges are always on the diagonal instead of
examining matrix structure and input data.
Parameters
----------
mapped_matrix : matrix_utils.MappedMatrix
A matrix and mapping data (from database ids to matrix indices)
from the ``matrix_utils`` library. Normally built automatically by
an ``LCA`` class. Should be the ``technosphere_matrix`` or
equivalent.
Returns
-------
(numpy.array, numpy.array)
The matrix row and column indices of the production exchanges.
"""
length = mapped_matrix.matrix.shape[0]
return np.arange(length), np.arange(length)
109 changes: 109 additions & 0 deletions bw_graph_tools/graph_traversal/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import typing
from typing import Dict, Generic, List, TypeVar

from .graph_objects import Edge, Flow, Node
from .utils import CachingSolver

if typing.TYPE_CHECKING:
import bw2calc

Settings = TypeVar("Settings")


class GraphTraversalException(Exception): ...


class BaseGraphTraversal(Generic[Settings]):
def __init__(
self,
lca: "bw2calc.LCA",
settings: Settings,
functional_unit_unique_id: int = -1,
static_activity_indices=None,
):
"""
Base class for common graph traversal methods. Should be inherited from, not used directly.
Parameters
----------
lca : bw2calc.LCA
Already instantiated `LCA` object with inventory and impact
assessment calculated.
settings: object
Settings for the graph traversal
functional_unit_unique_id : int
An integer id we can use for the functional unit virtual activity.
Shouldn't overlap any other activity ids. Don't change unless you
really know what you are doing.
static_activity_indices : set
A set of activity matrix indices which we don't want the graph to
traverse - i.e. we stop traversal when we hit these nodes, but
still add them to the returned `nodes` dictionary, and calculate
their direct and cumulative scores.
"""
if static_activity_indices is None:
static_activity_indices = set()
self.lca = lca
self.settings = settings
self.static_activity_indices = static_activity_indices
# allows the user to store metadata from the traversal
self.metadata = dict()

# internal properties
self._root_node = Node(
unique_id=functional_unit_unique_id,
activity_datapackage_id=functional_unit_unique_id,
activity_index=functional_unit_unique_id,
reference_product_datapackage_id=functional_unit_unique_id,
reference_product_index=functional_unit_unique_id,
reference_product_production_amount=1.0,
depth=0,
# Not one of any particular product in the functional unit, but one functional
# unit itself.
supply_amount=1.0,
cumulative_score=self.lca.score,
direct_emissions_score=0.0,
)
self._nodes: Dict[int, Node] = {functional_unit_unique_id: self._root_node}
self._edges: List[Edge] = []
self._flows: List[Flow] = []
self._heap: List[Node] = []
self._caching_solver = CachingSolver(lca)

@property
def nodes(self):
"""
List of `Node` dataclass instances.
Each `Node` instance has a `unique_id`, regardless of graph traversal class. In some
classes, each node in the database will only appear once in this list of graph traversal
node instances, but in `NewNodeEachVisitGraphTraversal`, we create a new `Node` every time
we reach a database node, even if we have seen it before.
See the `Node` documentation for its other attributes.
"""
return self._nodes

@property
def edges(self):
"""
List of `Edge` instances. Edges link two `Node` instances.
Note that there are no `Edge` instances which link `Flow` instances - these are handled
separately.
See the `Edge` documentation for its other attributes.
"""
return self._edges

@property
def flows(self):
"""
List of `Flow` instances.
A `Flow` instance is a *characterized biosphere flow* associated with a specific `Node`
instance.
See the `Flow` documentation for its other attributes.
"""
return self._flows
152 changes: 152 additions & 0 deletions bw_graph_tools/graph_traversal/graph_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from dataclasses import dataclass
from typing import List


@dataclass
class Node:
"""
A visited activity in a supply chain graph. Although our graph is cyclic, we treat each
activity as a separate node every time we visit it.
Parameters
----------
unique_id : int
A unique integer id for this visit to this activity node
activity_datapackage_id : int
The id that identifies this activity in the datapackage, and hence in the database
activity_index : int
The technosphere matrix column index of this activity
reference_product_datapackage_id : int
The id that identifies the reference product of this activity in the datapackage
reference_product_index : int
The technosphere matrix row index of this activity's reference product
reference_product_production_amount : float
The *net* production amount of this activity's reference product
depth : int
Depth in the supply chain graph, starting from 0 as the functional unit
supply_amount : float
The amount of the *activity* (not reference product!) needed to supply the demand from the
requesting supply chain edge.
cumulative_score : float
Total LCIA score attributed to `supply_amount` of this activity, including impacts from
direct emissions.
direct_emissions_score : float
Total LCIA score attributed only to the direct characterized biosphere flows of
`supply_amount` of this activity.
direct_emissions_score_outside_specific_flows : float
The score attributable to *direct emissions* of this node which isn't broken out into
separate `Flow` objects.
remaining_cumulative_score_outside_specific_flows : float
The *cumulative* score of this node, including direct emissions, which isn't broken out
into separate `Flow` objects.
terminal : bool
Boolean flag indicating whether graph traversal was cutoff at this node
"""

unique_id: int
activity_datapackage_id: int
activity_index: int
reference_product_datapackage_id: int
reference_product_index: int
reference_product_production_amount: float
depth: int
supply_amount: float
cumulative_score: float
direct_emissions_score: float
direct_emissions_score_outside_specific_flows: float = 0.0
remaining_cumulative_score_outside_specific_flows: float = 0.0
terminal: bool = False

def __lt__(self, other):
# Needed for sorting
return self.cumulative_score < other.cumulative_score


@dataclass
class GroupedNodes:
"""
A group of nodes
"""

nodes: List[Node]
label: str
unique_id: int
depth: int
supply_amount: float
cumulative_score: float
direct_emissions_score: float
direct_emissions_score_outside_specific_flows: float = 0.0
terminal: bool = False
activity_index: int = None

def __lt__(self, other):
# Needed for sorting
return self.cumulative_score < other.cumulative_score


@dataclass
class Edge:
"""
An edge between two `Node` instances. The `amount` is the amount of the product demanded by the
`consumer`.
Parameters
----------
consumer_index : int
The matrix column index of the consuming activity
consumer_unique_id : int
The traversal-specific unique id of the consuming activity
producer_index : int
The matrix column index of the producing activity
producer_unique_id : int
The traversal-specific unique id of the producing activity
product_index : int
The matrix row index of the consumed product
amount : float
The amount of the product demanded by the consumer. Not scaled to producer production
amount.
"""

consumer_index: int
consumer_unique_id: int
producer_index: int
producer_unique_id: int
product_index: int
amount: float


@dataclass
class Flow:
"""
A characterized biosphere flow associated with a given `Node` instance.
Parameters
----------
flow_datapackage_id : int
The id that identifies the biosphere flow in the datapackage
flow_index : int
The matrix row index of the biosphere flow
activity_unique_id : int
The `Node.unique_id` of this instance of the emitting activity
activity_id : int
The id that identifies the emitting activity in the datapackage
activity_index : int
The matrix column index of the emitting activity
amount : float
The amount of the biosphere flow being emitting by this activity instance
score : float
The LCIA score for `amount` of this biosphere flow
"""

flow_datapackage_id: int
flow_index: int
activity_unique_id: int
activity_id: int
activity_index: int
amount: float
score: float

def __lt__(self, other):
# Needed for sorting
return self.score < other.score
Loading

0 comments on commit bda02d2

Please sign in to comment.