Skip to content

Commit

Permalink
add test for tagged traversal
Browse files Browse the repository at this point in the history
  • Loading branch information
will7200 committed Sep 11, 2024
1 parent 54a491e commit 86ae0c5
Showing 1 changed file with 37 additions and 7 deletions.
44 changes: 37 additions & 7 deletions tests/traversal/test_tagging_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,26 @@
import bw2data as bd
import pytest

from bw_graph_tools import NewNodeEachVisitGraphTraversal
from bw_graph_tools.graph_traversal import (
GraphTraversalSettings,
NewNodeEachVisitGraphTraversal,
NewNodeEachVisitTaggedGraphTraversal,
SameNodeEachVisitTaggedGraphTraversal,
TaggedGraphTraversalSettings,
)
from bw_graph_tools.graph_traversal.graph_objects import Edge, GroupedNodes, Node
from bw_graph_tools.graph_traversal.utils import Counter


def get_default_graph(lca, tags):
return NewNodeEachVisitTaggedGraphTraversal(
def get_default_graph(lca, tags, variant=NewNodeEachVisitTaggedGraphTraversal):
return variant(
lca=lca,
settings=TaggedGraphTraversalSettings(cutoff=0.001, max_calc=10, tags=tags),
)


def get_untagged_new_graph(lca):
return NewNodeEachVisitGraphTraversal(
lca=lca, settings=GraphTraversalSettings(cutoff=0.001, max_calc=10)
)
def get_untagged_new_graph(lca, variant=NewNodeEachVisitGraphTraversal):
return variant(lca=lca, settings=GraphTraversalSettings(cutoff=0.001, max_calc=10))


@pytest.fixture
Expand All @@ -32,6 +31,16 @@ def graph(sample_database_with_tagged_products):
yield g


@pytest.fixture
def same_node_graph(sample_database_with_tagged_products):
g: SameNodeEachVisitTaggedGraphTraversal = get_default_graph(
sample_database_with_tagged_products,
["test"],
variant=SameNodeEachVisitTaggedGraphTraversal,
)
yield g


class Faker:
@staticmethod
def generate_list_from_bwdata_names(items):
Expand Down Expand Up @@ -200,3 +209,24 @@ def test_tagged_traversal(self, graph, sample_database_with_tagged_products):

assert len(groups["test: group-a"][0].nodes) == 1
assert len(groups["test: group-b"][0].nodes) == 3


class TestSameNodeTaggingTraversal:
@staticmethod
def count_grouped_nodes(some_graph):
return sum(
[1 if isinstance(node, GroupedNodes) else 0 for node in some_graph.nodes.values()]
)

def test_tagged_traversal(self, same_node_graph):
same_node_graph.traverse(depth=2)
assert len(same_node_graph.nodes) == 5, "Expecting 5 nodes in the graph"
assert self.count_grouped_nodes(same_node_graph) == 2, "Expecting to find two grouped nodes"

gn = same_node_graph.nodes[7]
assert isinstance(gn, GroupedNodes), "Expecting a grouped node"
assert len(gn.nodes) == 1, "Expecting only one node in the group"
same_node_graph.traverse_from_node(gn, depth=1)
count_of_grouped_nodes = self.count_grouped_nodes(same_node_graph)
assert count_of_grouped_nodes == 1, "Expecting only one grouped node"
assert len(same_node_graph.nodes) == 5, "Expecting 5 nodes only"

0 comments on commit 86ae0c5

Please sign in to comment.