diff --git a/cyaron/graph.py b/cyaron/graph.py index fcde9f2..565ac41 100644 --- a/cyaron/graph.py +++ b/cyaron/graph.py @@ -537,18 +537,19 @@ def _calc_max_edge(point_count, directed, self_loop): @staticmethod def forest(point_count, tree_count, **kwargs): + """ + Return a forest with point_count vertexes and tree_count trees. + Args: + point_count: the count of vertexes + tree_count: the count of trees + """ if tree_count <= 0 or tree_count > point_count: raise ValueError("tree_count must be between 1 and point_count") tree = list(Graph.tree(point_count, **kwargs).iterate_edges()) - need_delete = set( - i[0] for i in (Vector.random_unique_vector(tree_count - 1, [( - 0, point_count - 2)]) if tree_count > 1 else [])) result = Graph(point_count, 0) - for i in range(point_count - 1): - if i not in need_delete: - result.add_edge(tree[i].start, - tree[i].end, - weight=tree[i].weight) + need_add = random.sample(tree, len(tree) - tree_count + 1) + for edge in need_add: + result.add_edge(edge.start, edge.end, weight=edge.weight) return result