From 3ca37c8368336d4d7c225b2bdeac1294f72af7c9 Mon Sep 17 00:00:00 2001 From: weilycoder Date: Tue, 8 Oct 2024 17:52:25 +0800 Subject: [PATCH] add test_shuffle --- cyaron/graph.py | 8 +++++- cyaron/tests/graph_test.py | 55 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/cyaron/graph.py b/cyaron/graph.py index 5748acc..66ae82d 100644 --- a/cyaron/graph.py +++ b/cyaron/graph.py @@ -46,7 +46,13 @@ def __init__(self, point_count, directed=False): """ self.directed = directed self.edges = [[] for i in range(point_count + 1)] - + + def vertex_count(self): + """edge_count(self) -> int + Return the vertex of the edges in the graph. + """ + return len(self.edges) - 1 + def edge_count(self): """edge_count(self) -> int Return the count of the edges in the graph. diff --git a/cyaron/tests/graph_test.py b/cyaron/tests/graph_test.py index 1e3f1db..f42959e 100644 --- a/cyaron/tests/graph_test.py +++ b/cyaron/tests/graph_test.py @@ -1,3 +1,5 @@ +import itertools +import random import unittest from cyaron import Graph @@ -172,3 +174,56 @@ def test_GraphMatrix(self): self.assertEqual(str(g.to_matrix(default=9, merge=merge3)), "9 9 3\n9 9 3\n9 1 1") self.assertEqual(str(g.to_matrix(default=0, merge=merge4)), "0 0 3\n0 0 1\n0 1 1") self.assertEqual(str(g.to_matrix(default=0, merge=merge5)), "0 0 3\n0 0 84\n0 1 1") + + def test_shuffle(self): + def read_graph(n, data, directed = False): + g = Graph(n, directed) + for l in data.split('\n'): + u, v, w = map(int, l.split()) + g.add_edge(u, v, weight=w) + return g + + def isomorphic(graph1, graph2, mapping = None, directed = False): + n = graph1.vertex_count() + if n != graph2.vertex_count(): + return False + if graph1.edge_count() != graph2.edge_count(): + return False + if mapping is None: + for per in itertools.permutations(range(1, n + 1)): + if isomorphic(graph1, graph2, (0, ) + per): + return True + return False + edges = {} + for e in graph2.iterate_edges(): + key, val = (e.start, e.end), e.weight + if key in edges: + edges[key].append(val) + else: + edges[key] = [val] + for e in graph1.iterate_edges(): + key, val = (mapping[e.start], mapping[e.end]), e.weight + if not directed and key[0] > key[1]: + key = key[1], key[0] + if key not in edges: + return False + if val not in edges[key]: + return False + edges[key].remove(val) + return True + + def unit_test(n, m, shuffle_kwargs = {}, check_kwargs = {}): + g = Graph.graph(n, m) + data = g.to_str(**shuffle_kwargs) + h = read_graph(n, data) + self.assertTrue(isomorphic(g, h, **check_kwargs)) + + unit_test(8, 20) + unit_test(8, 20, {"shuffle": True}) + mapping = [0] + random.sample(range(1, 8), k = 7) + shuffer = lambda seq: list(map(lambda i: mapping[i], seq)) + unit_test(7, 10, {"shuffle": True, "node_shuffler": shuffer}) + unit_test(7, 14, {"shuffle": True, "node_shuffler": shuffer}, {"mapping": mapping}) + shuffer_without_swap = lambda table: random.sample(table, k=len(table)) + unit_test(7, 12, {"shuffle": True, "edge_shuffler": shuffer_without_swap}, {"directed": True}) +