Skip to content

Commit

Permalink
add test_shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
weilycoder committed Oct 8, 2024
1 parent d5f7980 commit 3ca37c8
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
8 changes: 7 additions & 1 deletion cyaron/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,13 @@ def __init__(self, point_count, directed=False):
"""
self.directed = directed
self.edges = [[] for i in range(point_count + 1)]


def vertex_count(self):
"""edge_count(self) -> int
Return the vertex of the edges in the graph.
"""
return len(self.edges) - 1

def edge_count(self):
"""edge_count(self) -> int
Return the count of the edges in the graph.
Expand Down
55 changes: 55 additions & 0 deletions cyaron/tests/graph_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools
import random
import unittest
from cyaron import Graph

Expand Down Expand Up @@ -172,3 +174,56 @@ def test_GraphMatrix(self):
self.assertEqual(str(g.to_matrix(default=9, merge=merge3)), "9 9 3\n9 9 3\n9 1 1")
self.assertEqual(str(g.to_matrix(default=0, merge=merge4)), "0 0 3\n0 0 1\n0 1 1")
self.assertEqual(str(g.to_matrix(default=0, merge=merge5)), "0 0 3\n0 0 84\n0 1 1")

def test_shuffle(self):
def read_graph(n, data, directed = False):
g = Graph(n, directed)
for l in data.split('\n'):
u, v, w = map(int, l.split())
g.add_edge(u, v, weight=w)
return g

def isomorphic(graph1, graph2, mapping = None, directed = False):
n = graph1.vertex_count()
if n != graph2.vertex_count():
return False
if graph1.edge_count() != graph2.edge_count():
return False
if mapping is None:
for per in itertools.permutations(range(1, n + 1)):
if isomorphic(graph1, graph2, (0, ) + per):
return True
return False
edges = {}
for e in graph2.iterate_edges():
key, val = (e.start, e.end), e.weight
if key in edges:
edges[key].append(val)
else:
edges[key] = [val]
for e in graph1.iterate_edges():
key, val = (mapping[e.start], mapping[e.end]), e.weight
if not directed and key[0] > key[1]:
key = key[1], key[0]
if key not in edges:
return False
if val not in edges[key]:
return False
edges[key].remove(val)
return True

def unit_test(n, m, shuffle_kwargs = {}, check_kwargs = {}):
g = Graph.graph(n, m)
data = g.to_str(**shuffle_kwargs)
h = read_graph(n, data)
self.assertTrue(isomorphic(g, h, **check_kwargs))

unit_test(8, 20)
unit_test(8, 20, {"shuffle": True})
mapping = [0] + random.sample(range(1, 8), k = 7)
shuffer = lambda seq: list(map(lambda i: mapping[i], seq))
unit_test(7, 10, {"shuffle": True, "node_shuffler": shuffer})
unit_test(7, 14, {"shuffle": True, "node_shuffler": shuffer}, {"mapping": mapping})
shuffer_without_swap = lambda table: random.sample(table, k=len(table))
unit_test(7, 12, {"shuffle": True, "edge_shuffler": shuffer_without_swap}, {"directed": True})

0 comments on commit 3ca37c8

Please sign in to comment.