Skip to content

Commit

Permalink
Fix flaky python tests
Browse files Browse the repository at this point in the history
  • Loading branch information
knutwalker committed Dec 18, 2024
1 parent bbcea74 commit b49925b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
10 changes: 5 additions & 5 deletions crates/mate/tests/ds_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import numpy as np
import pandas as pd

from graph_mate import DiGraph, Graph
from graph_mate import DiGraph, Graph, Layout


def test_numpy_graph():
el = np.array([[0, 1], [2, 3], [4, 1]], dtype=np.uint32)
g = Graph.from_numpy(el)
g = Graph.from_numpy(el, layout=Layout.Sorted)

assert g.node_count() == 5
assert g.edge_count() == 3
Expand All @@ -20,7 +20,7 @@ def test_numpy_graph():

def test_pandas_graph():
df = pd.DataFrame({"source": [0, 2, 4], "target": [1, 3, 1]})
g = Graph.from_pandas(df)
g = Graph.from_pandas(df, layout=Layout.Sorted)

assert g.node_count() == 5
assert g.edge_count() == 3
Expand All @@ -34,7 +34,7 @@ def test_pandas_graph():

def test_numpy_digraph():
el = np.array([[0, 1], [2, 3], [4, 1]], dtype=np.uint32)
g = DiGraph.from_numpy(el)
g = DiGraph.from_numpy(el, layout=Layout.Sorted)

assert g.node_count() == 5
assert g.edge_count() == 3
Expand All @@ -49,7 +49,7 @@ def test_numpy_digraph():

def test_pandas_digraph():
df = pd.DataFrame({"source": [0, 2, 4], "target": [1, 3, 1]})
g = DiGraph.from_pandas(df)
g = DiGraph.from_pandas(df, layout=Layout.Sorted)

assert g.node_count() == 5
assert g.edge_count() == 3
Expand Down
27 changes: 16 additions & 11 deletions crates/mate/tests/graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,33 @@ def test_load_graph(g: DiGraph):


def test_to_undirected(g: DiGraph, ug: Graph):
g = g.to_undirected()
undirected = g.to_undirected()

for n in range(g.node_count()):
assert set(g.copy_neighbors(n)) == set(ug.copy_neighbors(n))
for n in range(undirected.node_count()):
assert set(undirected.copy_neighbors(n)) == set(ug.copy_neighbors(n))


def test_to_undirected_with_layout():
g = DiGraph.from_numpy(
np.array([[0, 1], [0, 1], [0, 2], [1, 2], [2, 1], [0, 3]], dtype=np.uint32)
)

def compare_unsorted(expect, actual):
sorted = expect.copy()
sorted.sort()
return np.array_equal(sorted, actual)

ug = g.to_undirected()
assert np.array_equal(ug.neighbors(0), [1, 1, 2, 3])
assert np.array_equal(ug.neighbors(1), [2, 0, 0, 2])
assert np.array_equal(ug.neighbors(2), [1, 0, 1])
assert np.array_equal(ug.neighbors(3), [0])
assert compare_unsorted(ug.neighbors(0), [1, 1, 2, 3])
assert compare_unsorted(ug.neighbors(1), [0, 0, 2, 2])
assert compare_unsorted(ug.neighbors(2), [0, 1, 1])
assert compare_unsorted(ug.neighbors(3), [0])

ug = g.to_undirected(Layout.Unsorted)
assert np.array_equal(ug.neighbors(0), [1, 1, 2, 3])
assert np.array_equal(ug.neighbors(1), [2, 0, 0, 2])
assert np.array_equal(ug.neighbors(2), [1, 0, 1])
assert np.array_equal(ug.neighbors(3), [0])
assert compare_unsorted(ug.neighbors(0), [1, 1, 2, 3])
assert compare_unsorted(ug.neighbors(1), [0, 0, 2, 2])
assert compare_unsorted(ug.neighbors(2), [0, 1, 1])
assert compare_unsorted(ug.neighbors(3), [0])

ug = g.to_undirected(Layout.Sorted)
assert np.array_equal(ug.neighbors(0), [1, 1, 2, 3])
Expand Down

0 comments on commit b49925b

Please sign in to comment.