Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strongly connected component algorithm #112

Merged
merged 2 commits into from
Dec 19, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package org.apache.spark.graph.algorithms

import org.apache.spark.graph._

object StronglyConnectedComponents {

/**
* Compute the strongly connected component (SCC) of each vertex and return an RDD with the vertex
* value containing the lowest vertex id in the SCC containing that vertex.
*
* @tparam VD the vertex attribute type (discarded in the computation)
* @tparam ED the edge attribute type (preserved in the computation)
*
* @param graph the graph for which to compute the SCC
*
* @return a graph with vertex attributes containing the smallest vertex id in each SCC
*/
def run[VD: Manifest, ED: Manifest](graph: Graph[VD, ED], numIter: Int): Graph[Vid, ED] = {

// the graph we update with final SCC ids, and the graph we return at the end
var sccGraph = graph.mapVertices { case (vid, _) => vid }
// graph we are going to work with in our iterations
var sccWorkGraph = graph.mapVertices { case (vid, _) => (vid, false) }

var numVertices = sccWorkGraph.numVertices
var iter = 0
while (sccWorkGraph.numVertices > 0 && iter < numIter) {
iter += 1
do {
numVertices = sccWorkGraph.numVertices
sccWorkGraph = sccWorkGraph.outerJoinVertices(sccWorkGraph.outDegrees) {
(vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true)
}
sccWorkGraph = sccWorkGraph.outerJoinVertices(sccWorkGraph.inDegrees) {
(vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true)
}

// get all vertices to be removed
val finalVertices = sccWorkGraph.vertices
.filter { case (vid, (scc, isFinal)) => isFinal}
.mapValues { (vid, data) => data._1}

// write values to sccGraph
sccGraph = sccGraph.outerJoinVertices(finalVertices) {
(vid, scc, opt) => opt.getOrElse(scc)
}
// only keep vertices that are not final
sccWorkGraph = sccWorkGraph.subgraph(vpred = (vid, data) => !data._2)
} while (sccWorkGraph.numVertices < numVertices)

sccWorkGraph = sccWorkGraph.mapVertices{ case (vid, (color, isFinal)) => (vid, isFinal) }

// collect min of all my neighbor's scc values, update if it's smaller than mine
// then notify any neighbors with scc values larger than mine
sccWorkGraph = GraphLab[(Vid, Boolean), ED, Vid](sccWorkGraph, Integer.MAX_VALUE)(
(vid, e) => e.otherVertexAttr(vid)._1,
(vid1, vid2) => math.min(vid1, vid2),
(vid, scc, optScc) =>
(math.min(scc._1, optScc.getOrElse(scc._1)), scc._2),
(vid, e) => e.vertexAttr(vid)._1 < e.otherVertexAttr(vid)._1
)

// start at root of SCCs. Traverse values in reverse, notify all my neighbors
// do not propagate if colors do not match!
sccWorkGraph = GraphLab[(Vid, Boolean), ED, Boolean](
sccWorkGraph,
Integer.MAX_VALUE,
EdgeDirection.Out,
EdgeDirection.In
)(
// vertex is final if it is the root of a color
// or it has the same color as a neighbor that is final
(vid, e) => (vid == e.vertexAttr(vid)._1) || (e.vertexAttr(vid)._1 == e.otherVertexAttr(vid)._1),
(final1, final2) => final1 || final2,
(vid, scc, optFinal) =>
(scc._1, scc._2 || optFinal.getOrElse(false)),
// activate neighbor if they are not final, you are, and you have the same color
(vid, e) => e.vertexAttr(vid)._2 &&
!e.otherVertexAttr(vid)._2 && (e.vertexAttr(vid)._1 == e.otherVertexAttr(vid)._1),
// start at root of colors
(vid, data) => vid == data._1
)
}
sccGraph
}

}
43 changes: 43 additions & 0 deletions graph/src/test/scala/org/apache/spark/graph/AnalyticsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,49 @@ class AnalyticsSuite extends FunSuite with LocalSparkContext {
}
} // end of reverse chain connected components

test("Island Strongly Connected Components") {
withSpark(new SparkContext("local", "test")) { sc =>
val vertices = sc.parallelize((1L to 5L).map(x => (x, -1)))
val edges = sc.parallelize(Seq.empty[Edge[Int]])
val graph = Graph(vertices, edges)
val sccGraph = StronglyConnectedComponents.run(graph, 5)
for ((id, scc) <- sccGraph.vertices.collect) {
assert(id == scc)
}
}
}

test("Cycle Strongly Connected Components") {
withSpark(new SparkContext("local", "test")) { sc =>
val rawEdges = sc.parallelize((0L to 6L).map(x => (x, (x + 1) % 7)))
val graph = Graph.fromEdgeTuples(rawEdges, -1)
val sccGraph = StronglyConnectedComponents.run(graph, 20)
for ((id, scc) <- sccGraph.vertices.collect) {
assert(0L == scc)
}
}
}

test("2 Cycle Strongly Connected Components") {
withSpark(new SparkContext("local", "test")) { sc =>
val edges =
Array(0L -> 1L, 1L -> 2L, 2L -> 0L) ++
Array(3L -> 4L, 4L -> 5L, 5L -> 3L) ++
Array(6L -> 0L, 5L -> 7L)
val rawEdges = sc.parallelize(edges)
val graph = Graph.fromEdgeTuples(rawEdges, -1)
val sccGraph = StronglyConnectedComponents.run(graph, 20)
for ((id, scc) <- sccGraph.vertices.collect) {
if (id < 3)
assert(0L == scc)
else if (id < 6)
assert(3L == scc)
else
assert(id == scc)
}
}
}

test("Count a single triangle") {
withSpark(new SparkContext("local", "test")) { sc =>
val rawEdges = sc.parallelize(Array( 0L->1L, 1L->2L, 2L->0L ), 2)
Expand Down