diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index febcf9c6e..ff45e7610 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -231,7 +231,7 @@ class SparkContext( } taskScheduler.start() - @volatile private var dagScheduler = new DAGScheduler(taskScheduler) + @volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler) dagScheduler.start() ui.start() diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala index 5a24042e1..c87b66f04 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationSource.scala @@ -34,7 +34,7 @@ class ApplicationSource(val application: ApplicationInfo) extends Source { override def getValue: Long = application.duration }) - metricRegistry.register(MetricRegistry.name("cores", "number"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("cores"), new Gauge[Int] { override def getValue: Int = application.coresGranted }) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala index 23d1cb77d..36c1b87b7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterSource.scala @@ -26,17 +26,17 @@ private[spark] class MasterSource(val master: Master) extends Source { val sourceName = "master" // Gauge for worker numbers in cluster - metricRegistry.register(MetricRegistry.name("workers","number"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("workers"), new Gauge[Int] { override def getValue: Int = master.workers.size }) // Gauge for application numbers in cluster - metricRegistry.register(MetricRegistry.name("apps", "number"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("apps"), new Gauge[Int] { override def getValue: Int = master.apps.size }) // Gauge for waiting application numbers in cluster - metricRegistry.register(MetricRegistry.name("waitingApps", "number"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("waitingApps"), new Gauge[Int] { override def getValue: Int = master.waitingApps.size }) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala index df269fd04..b7ddd8c81 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerSource.scala @@ -25,27 +25,27 @@ private[spark] class WorkerSource(val worker: Worker) extends Source { val sourceName = "worker" val metricRegistry = new MetricRegistry() - metricRegistry.register(MetricRegistry.name("executors", "number"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("executors"), new Gauge[Int] { override def getValue: Int = worker.executors.size }) // Gauge for cores used of this worker - metricRegistry.register(MetricRegistry.name("coresUsed", "number"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("coresUsed"), new Gauge[Int] { override def getValue: Int = worker.coresUsed }) // Gauge for memory used of this worker - metricRegistry.register(MetricRegistry.name("memUsed", "MBytes"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("memUsed_MB"), new Gauge[Int] { override def getValue: Int = worker.memoryUsed }) // Gauge for cores free of this worker - metricRegistry.register(MetricRegistry.name("coresFree", "number"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("coresFree"), new Gauge[Int] { override def getValue: Int = worker.coresFree }) // Gauge for memory free of this worker - metricRegistry.register(MetricRegistry.name("memFree", "MBytes"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("memFree_MB"), new Gauge[Int] { override def getValue: Int = worker.memoryFree }) } diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index 18c9dc1c0..34ed9c8f7 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -43,31 +43,31 @@ class ExecutorSource(val executor: Executor, executorId: String) extends Source val sourceName = "executor.%s".format(executorId) // Gauge for executor thread pool's actively executing task counts - metricRegistry.register(MetricRegistry.name("threadpool", "activeTask", "count"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("threadpool", "activeTasks"), new Gauge[Int] { override def getValue: Int = executor.threadPool.getActiveCount() }) // Gauge for executor thread pool's approximate total number of tasks that have been completed - metricRegistry.register(MetricRegistry.name("threadpool", "completeTask", "count"), new Gauge[Long] { + metricRegistry.register(MetricRegistry.name("threadpool", "completeTasks"), new Gauge[Long] { override def getValue: Long = executor.threadPool.getCompletedTaskCount() }) // Gauge for executor thread pool's current number of threads - metricRegistry.register(MetricRegistry.name("threadpool", "currentPool", "size"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("threadpool", "currentPool_size"), new Gauge[Int] { override def getValue: Int = executor.threadPool.getPoolSize() }) // Gauge got executor thread pool's largest number of threads that have ever simultaneously been in th pool - metricRegistry.register(MetricRegistry.name("threadpool", "maxPool", "size"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("threadpool", "maxPool_size"), new Gauge[Int] { override def getValue: Int = executor.threadPool.getMaximumPoolSize() }) // Gauge for file system stats of this executor for (scheme <- Array("hdfs", "file")) { - registerFileSystemStat(scheme, "bytesRead", _.getBytesRead(), 0L) - registerFileSystemStat(scheme, "bytesWritten", _.getBytesWritten(), 0L) - registerFileSystemStat(scheme, "readOps", _.getReadOps(), 0) - registerFileSystemStat(scheme, "largeReadOps", _.getLargeReadOps(), 0) - registerFileSystemStat(scheme, "writeOps", _.getWriteOps(), 0) + registerFileSystemStat(scheme, "read_bytes", _.getBytesRead(), 0L) + registerFileSystemStat(scheme, "write_bytes", _.getBytesWritten(), 0L) + registerFileSystemStat(scheme, "read_ops", _.getReadOps(), 0) + registerFileSystemStat(scheme, "largeRead_ops", _.getLargeReadOps(), 0) + registerFileSystemStat(scheme, "write_ops", _.getWriteOps(), 0) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4053b9113..5c40f5095 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -114,7 +114,7 @@ class DAGScheduler( private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] - private val listenerBus = new SparkListenerBus() + private[spark] val listenerBus = new SparkListenerBus() // Contains the locations that each RDD's partitions are cached on private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]] diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 446d490cc..151514896 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -27,23 +27,23 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: Spar val metricRegistry = new MetricRegistry() val sourceName = "%s.DAGScheduler".format(sc.appName) - metricRegistry.register(MetricRegistry.name("stage", "failedStages", "number"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("stage", "failedStages"), new Gauge[Int] { override def getValue: Int = dagScheduler.failed.size }) - metricRegistry.register(MetricRegistry.name("stage", "runningStages", "number"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("stage", "runningStages"), new Gauge[Int] { override def getValue: Int = dagScheduler.running.size }) - metricRegistry.register(MetricRegistry.name("stage", "waitingStages", "number"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("stage", "waitingStages"), new Gauge[Int] { override def getValue: Int = dagScheduler.waiting.size }) - metricRegistry.register(MetricRegistry.name("job", "allJobs", "number"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("job", "allJobs"), new Gauge[Int] { override def getValue: Int = dagScheduler.nextJobId.get() }) - metricRegistry.register(MetricRegistry.name("job", "activeJobs", "number"), new Gauge[Int] { + metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] { override def getValue: Int = dagScheduler.activeJobs.size }) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index a65e1ecd6..4d3e4a17b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -70,5 +70,23 @@ private[spark] class SparkListenerBus() extends Logging { queueFullErrorMessageLogged = true } } + + /** + * Waits until there are no more events in the queue, or until the specified time has elapsed. + * Used for testing only. Returns true if the queue has emptied and false is the specified time + * elapsed before the queue emptied. + */ + def waitUntilEmpty(timeoutMillis: Int): Boolean = { + val finishTime = System.currentTimeMillis + timeoutMillis + while (!eventQueue.isEmpty()) { + if (System.currentTimeMillis > finishTime) { + return false + } + /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify + * add overhead in the general case. */ + Thread.sleep(10) + } + return true + } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 6c500bad9..e936b1cfe 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -23,7 +23,7 @@ import java.io.{EOFException, InputStream, OutputStream} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} import com.esotericsoftware.kryo.{KryoException, Kryo} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} -import com.twitter.chill.ScalaKryoInstantiator +import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar} import org.apache.spark.{SerializableWritable, Logging} import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock, StorageLevel} @@ -39,7 +39,7 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging def newKryoOutput() = new KryoOutput(bufferSize) def newKryo(): Kryo = { - val instantiator = new ScalaKryoInstantiator + val instantiator = new EmptyScalaKryoInstantiator val kryo = instantiator.newKryo() val classLoader = Thread.currentThread.getContextClassLoader @@ -49,7 +49,11 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging StorageLevel.MEMORY_ONLY, PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY), GotBlock("1", ByteBuffer.allocate(1)), - GetBlock("1") + GetBlock("1"), + 1 to 10, + 1 until 10, + 1L to 10L, + 1L until 10L ) for (obj <- toRegister) kryo.register(obj.getClass) @@ -69,6 +73,10 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging case _: Exception => println("Failed to register spark.kryo.registrator") } + // Register Chill's classes; we do this after our ranges and the user's own classes to let + // our code override the generic serialziers in Chill for things like Seq + new AllScalaRegistrar().apply(kryo) + kryo.setClassLoader(classLoader) // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index acc395108..365866d1e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -28,7 +28,7 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar val metricRegistry = new MetricRegistry() val sourceName = "%s.BlockManager".format(sc.appName) - metricRegistry.register(MetricRegistry.name("memory", "maxMem", "MBytes"), new Gauge[Long] { + metricRegistry.register(MetricRegistry.name("memory", "maxMem_MB"), new Gauge[Long] { override def getValue: Long = { val storageStatusList = blockManager.master.getStorageStatus val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _) @@ -36,7 +36,7 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar } }) - metricRegistry.register(MetricRegistry.name("memory", "remainingMem", "MBytes"), new Gauge[Long] { + metricRegistry.register(MetricRegistry.name("memory", "remainingMem_MB"), new Gauge[Long] { override def getValue: Long = { val storageStatusList = blockManager.master.getStorageStatus val remainingMem = storageStatusList.map(_.memRemaining).reduce(_ + _) @@ -44,7 +44,7 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar } }) - metricRegistry.register(MetricRegistry.name("memory", "memUsed", "MBytes"), new Gauge[Long] { + metricRegistry.register(MetricRegistry.name("memory", "memUsed_MB"), new Gauge[Long] { override def getValue: Long = { val storageStatusList = blockManager.master.getStorageStatus val maxMem = storageStatusList.map(_.maxMem).reduce(_ + _) @@ -53,7 +53,7 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar } }) - metricRegistry.register(MetricRegistry.name("disk", "diskSpaceUsed", "MBytes"), new Gauge[Long] { + metricRegistry.register(MetricRegistry.name("disk", "diskSpaceUsed_MB"), new Gauge[Long] { override def getValue: Long = { val storageStatusList = blockManager.master.getStorageStatus val diskSpaceUsed = storageStatusList diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 794c3e8f8..a549417a4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -23,15 +23,9 @@ import scala.collection.mutable import org.scalatest.matchers.ShouldMatchers import org.apache.spark.SparkContext._ -/** - * - */ - class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { - // TODO: This test has a race condition since the DAGScheduler now reports results - // asynchronously. It needs to be updated for that patch. - ignore("local metrics") { + test("local metrics") { sc = new SparkContext("local[4]", "test") val listener = new SaveStageInfo sc.addSparkListener(listener) @@ -45,7 +39,8 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)} d.count() - Thread.sleep(1000) + val WAIT_TIMEOUT_MILLIS = 10000 + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) listener.stageInfos.size should be (1) val d2 = d.map{i => w(i) -> i * 2}.setName("shuffle input 1") @@ -57,18 +52,25 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc d4.collectAsMap() - Thread.sleep(1000) + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) listener.stageInfos.size should be (4) - listener.stageInfos.foreach {stageInfo => - //small test, so some tasks might take less than 1 millisecond, but average should be greater than 1 ms + listener.stageInfos.foreach { stageInfo => + /* small test, so some tasks might take less than 1 millisecond, but average should be greater + * than 0 ms. */ checkNonZeroAvg(stageInfo.taskInfos.map{_._1.duration}, stageInfo + " duration") - checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, stageInfo + " executorRunTime") - checkNonZeroAvg(stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, stageInfo + " executorDeserializeTime") + checkNonZeroAvg( + stageInfo.taskInfos.map{_._2.executorRunTime.toLong}, + stageInfo + " executorRunTime") + checkNonZeroAvg( + stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong}, + stageInfo + " executorDeserializeTime") if (stageInfo.stage.rdd.name == d4.name) { - checkNonZeroAvg(stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, stageInfo + " fetchWaitTime") + checkNonZeroAvg( + stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime}, + stageInfo + " fetchWaitTime") } - stageInfo.taskInfos.foreach{case (taskInfo, taskMetrics) => + stageInfo.taskInfos.foreach { case (taskInfo, taskMetrics) => taskMetrics.resultSize should be > (0l) if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) { taskMetrics.shuffleWriteMetrics should be ('defined) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 0164dda0b..c016c5117 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -103,6 +103,27 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { check(List(mutable.HashMap("one" -> 1, "two" -> 2),mutable.HashMap(1->"one",2->"two",3->"three"))) } + test("ranges") { + val ser = (new KryoSerializer).newInstance() + def check[T](t: T) { + assert(ser.deserialize[T](ser.serialize(t)) === t) + // Check that very long ranges don't get written one element at a time + assert(ser.serialize(t).limit < 100) + } + check(1 to 1000000) + check(1 to 1000000 by 2) + check(1 until 1000000) + check(1 until 1000000 by 2) + check(1L to 1000000L) + check(1L to 1000000L by 2L) + check(1L until 1000000L) + check(1L until 1000000L by 2L) + check(1.0 to 1000000.0 by 1.0) + check(1.0 to 1000000.0 by 2.0) + check(1.0 until 1000000.0 by 1.0) + check(1.0 until 1000000.0 by 2.0) + } + test("custom registrator") { System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index f991d86c8..c1ff9c417 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -144,10 +144,9 @@ Available algorithms for clustering: # Collaborative Filtering -[Collaborative -filtering](http://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) +[Collaborative filtering](http://en.wikipedia.org/wiki/Recommender_system#Collaborative_filtering) is commonly used for recommender systems. These techniques aim to fill in the -missing entries of a user-product association matrix. MLlib currently supports +missing entries of a user-item association matrix. MLlib currently supports model-based collaborative filtering, in which users and products are described by a small set of latent factors that can be used to predict missing entries. In particular, we implement the [alternating least squares @@ -158,7 +157,24 @@ following parameters: * *numBlocks* is the number of blacks used to parallelize computation (set to -1 to auto-configure). * *rank* is the number of latent factors in our model. * *iterations* is the number of iterations to run. -* *lambda* specifies the regularization parameter in ALS. +* *lambda* specifies the regularization parameter in ALS. +* *implicitPrefs* specifies whether to use the *explicit feedback* ALS variant or one adapted for *implicit feedback* data +* *alpha* is a parameter applicable to the implicit feedback variant of ALS that governs the *baseline* confidence in preference observations + +## Explicit vs Implicit Feedback + +The standard approach to matrix factorization based collaborative filtering treats +the entries in the user-item matrix as *explicit* preferences given by the user to the item. + +It is common in many real-world use cases to only have access to *implicit feedback* +(e.g. views, clicks, purchases, likes, shares etc.). The approach used in MLlib to deal with +such data is taken from +[Collaborative Filtering for Implicit Feedback Datasets](http://research.yahoo.com/pub/2433). +Essentially instead of trying to model the matrix of ratings directly, this approach treats the data as +a combination of binary preferences and *confidence values*. The ratings are then related +to the level of confidence in observed user preferences, rather than explicit ratings given to items. +The model then tries to find latent factors that can be used to predict the expected preference of a user +for an item. Available algorithms for collaborative filtering: diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index f67a1cc49..6c2336ad0 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -16,7 +16,7 @@ This guide will show how to use the Spark features described there in Python. There are a few key differences between the Python and Scala APIs: * Python is dynamically typed, so RDDs can hold objects of multiple types. -* PySpark does not yet support a few API calls, such as `lookup`, `sort`, and non-text input files, though these will be added in future releases. +* PySpark does not yet support a few API calls, such as `lookup` and non-text input files, though these will be added in future releases. In PySpark, RDDs support the same methods as their Scala counterparts but take Python functions and return Python collection types. Short functions can be passed to RDD methods using Python's [`lambda`](http://www.diveintopython.net/power_of_introspection/lambda_functions.html) syntax: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index be002d02b..36853acab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -21,7 +21,8 @@ import scala.collection.mutable.{ArrayBuffer, BitSet} import scala.util.Random import scala.util.Sorting -import org.apache.spark.{HashPartitioner, Partitioner, SparkContext} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.{Logging, HashPartitioner, Partitioner, SparkContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.rdd.RDD import org.apache.spark.serializer.KryoRegistrator @@ -61,6 +62,12 @@ case class Rating(val user: Int, val product: Int, val rating: Double) /** * Alternating Least Squares matrix factorization. * + * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices, + * `X` and `Y`, i.e. `Xt * Y = R`. Typically these approximations are called 'factor' matrices. + * The general approach is iterative. During each iteration, one of the factor matrices is held + * constant, while the other is solved for using least squares. The newly-solved factor matrix is + * then held constant while solving for the other factor matrix. + * * This is a blocked implementation of the ALS factorization algorithm that groups the two sets * of factors (referred to as "users" and "products") into blocks and reduces communication by only * sending one copy of each user vector to each product block on each iteration, and only for the @@ -70,11 +77,21 @@ case class Rating(val user: Int, val product: Int, val rating: Double) * vectors it receives from each user block it will depend on). This allows us to send only an * array of feature vectors between each user block and product block, and have the product block * find the users' ratings and update the products based on these messages. + * + * For implicit preference data, the algorithm used is based on + * "Collaborative Filtering for Implicit Feedback Datasets", available at + * [[http://research.yahoo.com/pub/2433]], adapted for the blocked approach used here. + * + * Essentially instead of finding the low-rank approximations to the rating matrix `R`, + * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if r > 0 + * and 0 if r = 0. The ratings then act as 'confidence' values related to strength of indicated user + * preferences rather than explicit ratings given to items. */ -class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double) - extends Serializable +class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var lambda: Double, + var implicitPrefs: Boolean, var alpha: Double) + extends Serializable with Logging { - def this() = this(-1, 10, 10, 0.01) + def this() = this(-1, 10, 10, 0.01, false, 1.0) /** * Set the number of blocks to parallelize the computation into; pass -1 for an auto-configured @@ -103,6 +120,16 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l this } + def setImplicitPrefs(implicitPrefs: Boolean): ALS = { + this.implicitPrefs = implicitPrefs + this + } + + def setAlpha(alpha: Double): ALS = { + this.alpha = alpha + this + } + /** * Run ALS with the configured parameters on an input RDD of (user, product, rating) triples. * Returns a MatrixFactorizationModel with feature vectors for each user and product. @@ -147,19 +174,24 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l } } - for (iter <- 0 until iterations) { + for (iter <- 1 to iterations) { // perform ALS update - products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda) - users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda) + logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations)) + // YtY / XtX is an Option[DoubleMatrix] and is only required for the implicit feedback model + val YtY = computeYtY(users) + val YtYb = ratings.context.broadcast(YtY) + products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda, + alpha, YtYb) + logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations)) + val XtX = computeYtY(products) + val XtXb = ratings.context.broadcast(XtX) + users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda, + alpha, XtXb) } // Flatten and cache the two final RDDs to un-block them - val usersOut = users.join(userOutLinks).flatMap { case (b, (factors, outLinkBlock)) => - for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) - } - val productsOut = products.join(productOutLinks).flatMap { case (b, (factors, outLinkBlock)) => - for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) - } + val usersOut = unblockFactors(users, userOutLinks) + val productsOut = unblockFactors(products, productOutLinks) usersOut.persist() productsOut.persist() @@ -167,6 +199,40 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l new MatrixFactorizationModel(rank, usersOut, productsOut) } + /** + * Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors + * for each user (or product), in a distributed fashion. Here `reduceByKeyLocally` is used as + * the driver program requires `YtY` to broadcast it to the slaves + * @param factors the (block-distributed) user or product factor vectors + * @return Option[YtY] - whose value is only used in the implicit preference model + */ + def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = { + if (implicitPrefs) { + Option( + factors.flatMapValues{ case factorArray => + factorArray.map{ vector => + val x = new DoubleMatrix(vector) + x.mmul(x.transpose()) + } + }.reduceByKeyLocally((a, b) => a.addi(b)) + .values + .reduce((a, b) => a.addi(b)) + ) + } else { + None + } + } + + /** + * Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs + */ + def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])], + outLinks: RDD[(Int, OutLinkBlock)]) = { + blockedFactors.join(outLinks).flatMap{ case (b, (factors, outLinkBlock)) => + for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i)) + } + } + /** * Make the out-links table for a block of the users (or products) dataset given the list of * (user, product, rating) values for the users in that block (or the opposite for products). @@ -251,7 +317,9 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l userInLinks: RDD[(Int, InLinkBlock)], partitioner: Partitioner, rank: Int, - lambda: Double) + lambda: Double, + alpha: Double, + YtY: Broadcast[Option[DoubleMatrix]]) : RDD[(Int, Array[Array[Double]])] = { val numBlocks = products.partitions.size @@ -265,7 +333,9 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) } }.groupByKey(partitioner) .join(userInLinks) - .mapValues{ case (messages, inLinkBlock) => updateBlock(messages, inLinkBlock, rank, lambda) } + .mapValues{ case (messages, inLinkBlock) => + updateBlock(messages, inLinkBlock, rank, lambda, alpha, YtY) + } } /** @@ -273,7 +343,7 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l * it received from each product and its InLinkBlock. */ def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock, - rank: Int, lambda: Double) + rank: Int, lambda: Double, alpha: Double, YtY: Broadcast[Option[DoubleMatrix]]) : Array[Array[Double]] = { // Sort the incoming block factor messages by block ID and make them an array @@ -298,8 +368,14 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l fillXtX(x, tempXtX) val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p) for (i <- 0 until us.length) { - userXtX(us(i)).addi(tempXtX) - SimpleBlas.axpy(rs(i), x, userXy(us(i))) + implicitPrefs match { + case false => + userXtX(us(i)).addi(tempXtX) + SimpleBlas.axpy(rs(i), x, userXy(us(i))) + case true => + userXtX(us(i)).addi(tempXtX.mul(alpha * rs(i))) + SimpleBlas.axpy(1 + alpha * rs(i), x, userXy(us(i))) + } } } } @@ -311,7 +387,10 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l // Add regularization (0 until rank).foreach(i => fullXtX.data(i*rank + i) += lambda) // Solve the resulting matrix, which is symmetric and positive-definite - Solve.solvePositive(fullXtX, userXy(index)).data + implicitPrefs match { + case false => Solve.solvePositive(fullXtX, userXy(index)).data + case true => Solve.solvePositive(fullXtX.add(YtY.value.get), userXy(index)).data + } } } @@ -381,7 +460,7 @@ object ALS { blocks: Int) : MatrixFactorizationModel = { - new ALS(blocks, rank, iterations, lambda).run(ratings) + new ALS(blocks, rank, iterations, lambda, false, 1.0).run(ratings) } /** @@ -419,6 +498,68 @@ object ALS { train(ratings, rank, iterations, 0.01, -1) } + /** + * Train a matrix factorization model given an RDD of 'implicit preferences' given by users + * to some products, in the form of (userID, productID, preference) pairs. We approximate the + * ratings matrix as the product of two lower-rank matrices of a given rank (number of features). + * To solve for these features, we run a given number of iterations of ALS. This is done using + * a level of parallelism given by `blocks`. + * + * @param ratings RDD of (userID, productID, rating) pairs + * @param rank number of features to use + * @param iterations number of iterations of ALS (recommended: 10-20) + * @param lambda regularization factor (recommended: 0.01) + * @param blocks level of parallelism to split computation into + * @param alpha confidence parameter (only applies when immplicitPrefs = true) + */ + def trainImplicit( + ratings: RDD[Rating], + rank: Int, + iterations: Int, + lambda: Double, + blocks: Int, + alpha: Double) + : MatrixFactorizationModel = + { + new ALS(blocks, rank, iterations, lambda, true, alpha).run(ratings) + } + + /** + * Train a matrix factorization model given an RDD of 'implicit preferences' given by users to + * some products, in the form of (userID, productID, preference) pairs. We approximate the + * ratings matrix as the product of two lower-rank matrices of a given rank (number of features). + * To solve for these features, we run a given number of iterations of ALS. The level of + * parallelism is determined automatically based on the number of partitions in `ratings`. + * + * @param ratings RDD of (userID, productID, rating) pairs + * @param rank number of features to use + * @param iterations number of iterations of ALS (recommended: 10-20) + * @param lambda regularization factor (recommended: 0.01) + */ + def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double) + : MatrixFactorizationModel = + { + trainImplicit(ratings, rank, iterations, lambda, -1, alpha) + } + + /** + * Train a matrix factorization model given an RDD of 'implicit preferences' ratings given by + * users to some products, in the form of (userID, productID, rating) pairs. We approximate the + * ratings matrix as the product of two lower-rank matrices of a given rank (number of features). + * To solve for these features, we run a given number of iterations of ALS. The level of + * parallelism is determined automatically based on the number of partitions in `ratings`. + * Model parameters `alpha` and `lambda` are set to reasonable default values + * + * @param ratings RDD of (userID, productID, rating) pairs + * @param rank number of features to use + * @param iterations number of iterations of ALS (recommended: 10-20) + */ + def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int) + : MatrixFactorizationModel = + { + trainImplicit(ratings, rank, iterations, 0.01, -1, 1.0) + } + private class ALSRegistrator extends KryoRegistrator { override def registerClasses(kryo: Kryo) { kryo.register(classOf[Rating]) @@ -426,29 +567,37 @@ object ALS { } def main(args: Array[String]) { - if (args.length != 5 && args.length != 6) { - println("Usage: ALS []") + if (args.length < 5 || args.length > 9) { + println("Usage: ALS " + + "[] [] [] []") System.exit(1) } val (master, ratingsFile, rank, iters, outputDir) = (args(0), args(1), args(2).toInt, args(3).toInt, args(4)) - val blocks = if (args.length == 6) args(5).toInt else -1 + val lambda = if (args.length >= 6) args(5).toDouble else 0.01 + val implicitPrefs = if (args.length >= 7) args(6).toBoolean else false + val alpha = if (args.length >= 8) args(7).toDouble else 1 + val blocks = if (args.length == 9) args(8).toInt else -1 + System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer") System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName) System.setProperty("spark.kryo.referenceTracking", "false") System.setProperty("spark.kryoserializer.buffer.mb", "8") System.setProperty("spark.locality.wait", "10000") + val sc = new SparkContext(master, "ALS") val ratings = sc.textFile(ratingsFile).map { line => val fields = line.split(',') Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) } - val model = ALS.train(ratings, rank, iters, 0.01, blocks) + val model = new ALS(rank = rank, iterations = iters, lambda = lambda, + numBlocks = blocks, implicitPrefs = implicitPrefs, alpha = alpha).run(ratings) + model.userFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } .saveAsTextFile(outputDir + "/userFeatures") model.productFeatures.map{ case (id, vec) => id + "," + vec.mkString(" ") } .saveAsTextFile(outputDir + "/productFeatures") println("Final user/product features written to " + outputDir) - System.exit(0) + sc.stop() } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index 3323f6cee..eafee060c 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -19,6 +19,7 @@ import java.io.Serializable; import java.util.List; +import java.lang.Math; import scala.Tuple2; @@ -48,7 +49,7 @@ public void tearDown() { } void validatePrediction(MatrixFactorizationModel model, int users, int products, int features, - DoubleMatrix trueRatings, double matchThreshold) { + DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { DoubleMatrix predictedU = new DoubleMatrix(users, features); List> userFeatures = model.userFeatures().toJavaRDD().collect(); for (int i = 0; i < features; ++i) { @@ -68,12 +69,32 @@ void validatePrediction(MatrixFactorizationModel model, int users, int products, DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose()); - for (int u = 0; u < users; ++u) { - for (int p = 0; p < products; ++p) { - double prediction = predictedRatings.get(u, p); - double correct = trueRatings.get(u, p); - Assert.assertTrue(Math.abs(prediction - correct) < matchThreshold); + if (!implicitPrefs) { + for (int u = 0; u < users; ++u) { + for (int p = 0; p < products; ++p) { + double prediction = predictedRatings.get(u, p); + double correct = trueRatings.get(u, p); + Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", + prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold); + } } + } else { + // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's implicit ALS tests) + double sqErr = 0.0; + double denom = 0.0; + for (int u = 0; u < users; ++u) { + for (int p = 0; p < products; ++p) { + double prediction = predictedRatings.get(u, p); + double truePref = truePrefs.get(u, p); + double confidence = 1.0 + /* alpha = */ 1.0 * trueRatings.get(u, p); + double err = confidence * (truePref - prediction) * (truePref - prediction); + sqErr += err; + denom += 1.0; + } + } + double rmse = Math.sqrt(sqErr / denom); + Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f", + rmse, matchThreshold), Math.abs(rmse) < matchThreshold); } } @@ -81,30 +102,62 @@ void validatePrediction(MatrixFactorizationModel model, int users, int products, public void runALSUsingStaticMethods() { int features = 1; int iterations = 15; - int users = 10; - int products = 10; - scala.Tuple2, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7); + int users = 50; + int products = 100; + scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7, false); JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); - validatePrediction(model, users, products, features, testData._2(), 0.3); + validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); } @Test public void runALSUsingConstructor() { int features = 2; int iterations = 15; - int users = 20; - int products = 30; - scala.Tuple2, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( - users, products, features, 0.7); + int users = 100; + int products = 200; + scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7, false); JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.3); + validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); + } + + @Test + public void runImplicitALSUsingStaticMethods() { + int features = 1; + int iterations = 15; + int users = 80; + int products = 160; + scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7, true); + + JavaRDD data = sc.parallelize(testData._1()); + MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); + validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + } + + @Test + public void runImplicitALSUsingConstructor() { + int features = 2; + int iterations = 15; + int users = 100; + int products = 200; + scala.Tuple3, DoubleMatrix, DoubleMatrix> testData = ALSSuite.generateRatingsAsJavaList( + users, products, features, 0.7, true); + + JavaRDD data = sc.parallelize(testData._1()); + + MatrixFactorizationModel model = new ALS().setRank(features) + .setIterations(iterations) + .setImplicitPrefs(true) + .run(data.rdd()); + validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 347ef238f..fafc5ec5f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -34,16 +34,19 @@ object ALSSuite { users: Int, products: Int, features: Int, - samplingRate: Double): (java.util.List[Rating], DoubleMatrix) = { - val (sampledRatings, trueRatings) = generateRatings(users, products, features, samplingRate) - (seqAsJavaList(sampledRatings), trueRatings) + samplingRate: Double, + implicitPrefs: Boolean): (java.util.List[Rating], DoubleMatrix, DoubleMatrix) = { + val (sampledRatings, trueRatings, truePrefs) = + generateRatings(users, products, features, samplingRate, implicitPrefs) + (seqAsJavaList(sampledRatings), trueRatings, truePrefs) } def generateRatings( users: Int, products: Int, features: Int, - samplingRate: Double): (Seq[Rating], DoubleMatrix) = { + samplingRate: Double, + implicitPrefs: Boolean = false): (Seq[Rating], DoubleMatrix, DoubleMatrix) = { val rand = new Random(42) // Create a random matrix with uniform values from -1 to 1 @@ -52,14 +55,20 @@ object ALSSuite { val userMatrix = randomMatrix(users, features) val productMatrix = randomMatrix(features, products) - val trueRatings = userMatrix.mmul(productMatrix) + val (trueRatings, truePrefs) = implicitPrefs match { + case true => + val raw = new DoubleMatrix(users, products, Array.fill(users * products)(rand.nextInt(10).toDouble): _*) + val prefs = new DoubleMatrix(users, products, raw.data.map(v => if (v > 0) 1.0 else 0.0): _*) + (raw, prefs) + case false => (userMatrix.mmul(productMatrix), null) + } val sampledRatings = { for (u <- 0 until users; p <- 0 until products if rand.nextDouble() < samplingRate) yield Rating(u, p, trueRatings.get(u, p)) } - (sampledRatings, trueRatings) + (sampledRatings, trueRatings, truePrefs) } } @@ -78,11 +87,19 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll { } test("rank-1 matrices") { - testALS(10, 20, 1, 15, 0.7, 0.3) + testALS(50, 100, 1, 15, 0.7, 0.3) } test("rank-2 matrices") { - testALS(20, 30, 2, 15, 0.7, 0.3) + testALS(100, 200, 2, 15, 0.7, 0.3) + } + + test("rank-1 matrices implicit") { + testALS(80, 160, 1, 15, 0.7, 0.4, true) + } + + test("rank-2 matrices implicit") { + testALS(100, 200, 2, 15, 0.7, 0.4, true) } /** @@ -96,11 +113,14 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll { * @param matchThreshold max difference allowed to consider a predicted rating correct */ def testALS(users: Int, products: Int, features: Int, iterations: Int, - samplingRate: Double, matchThreshold: Double) + samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false) { - val (sampledRatings, trueRatings) = ALSSuite.generateRatings(users, products, - features, samplingRate) - val model = ALS.train(sc.parallelize(sampledRatings), features, iterations) + val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products, + features, samplingRate, implicitPrefs) + val model = implicitPrefs match { + case false => ALS.train(sc.parallelize(sampledRatings), features, iterations) + case true => ALS.trainImplicit(sc.parallelize(sampledRatings), features, iterations) + } val predictedU = new DoubleMatrix(users, features) for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) { @@ -112,12 +132,31 @@ class ALSSuite extends FunSuite with BeforeAndAfterAll { } val predictedRatings = predictedU.mmul(predictedP.transpose) - for (u <- 0 until users; p <- 0 until products) { - val prediction = predictedRatings.get(u, p) - val correct = trueRatings.get(u, p) - if (math.abs(prediction - correct) > matchThreshold) { - fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format( - u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP)) + if (!implicitPrefs) { + for (u <- 0 until users; p <- 0 until products) { + val prediction = predictedRatings.get(u, p) + val correct = trueRatings.get(u, p) + if (math.abs(prediction - correct) > matchThreshold) { + fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format( + u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP)) + } + } + } else { + // For implicit prefs we use the confidence-weighted RMSE to test (ref Mahout's tests) + var sqErr = 0.0 + var denom = 0.0 + for (u <- 0 until users; p <- 0 until products) { + val prediction = predictedRatings.get(u, p) + val truePref = truePrefs.get(u, p) + val confidence = 1 + 1.0 * trueRatings.get(u, p) + val err = confidence * (truePref - prediction) * (truePref - prediction) + sqErr += err + denom += 1 + } + val rmse = math.sqrt(sqErr / denom) + if (math.abs(rmse) > matchThreshold) { + fail("Model failed to predict RMSE: %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format( + rmse, truePrefs, predictedRatings, predictedU, predictedP)) } } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 39c402b41..7019fb8be 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -117,8 +117,6 @@ def getCheckpointFile(self): else: return None - # TODO persist(self, storageLevel) - def map(self, f, preservesPartitioning=False): """ Return a new RDD containing the distinct elements in this RDD. @@ -227,7 +225,7 @@ def takeSample(self, withReplacement, num, seed): total = num samples = self.sample(withReplacement, fraction, seed).collect() - + # If the first sample didn't turn out large enough, keep trying to take samples; # this shouldn't happen often because we use a big multiplier for their initial size. # See: scala/spark/RDD.scala @@ -263,7 +261,55 @@ def __add__(self, other): raise TypeError return self.union(other) - # TODO: sort + def sortByKey(self, ascending=True, numPartitions=None, keyfunc = lambda x: x): + """ + Sorts this RDD, which is assumed to consist of (key, value) pairs. + + >>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)] + >>> sc.parallelize(tmp).sortByKey(True, 2).collect() + [('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)] + >>> tmp2 = [('Mary', 1), ('had', 2), ('a', 3), ('little', 4), ('lamb', 5)] + >>> tmp2.extend([('whose', 6), ('fleece', 7), ('was', 8), ('white', 9)]) + >>> sc.parallelize(tmp2).sortByKey(True, 3, keyfunc=lambda k: k.lower()).collect() + [('a', 3), ('fleece', 7), ('had', 2), ('lamb', 5), ('little', 4), ('Mary', 1), ('was', 8), ('white', 9), ('whose', 6)] + """ + if numPartitions is None: + numPartitions = self.ctx.defaultParallelism + + bounds = list() + + # first compute the boundary of each part via sampling: we want to partition + # the key-space into bins such that the bins have roughly the same + # number of (key, value) pairs falling into them + if numPartitions > 1: + rddSize = self.count() + maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner + fraction = min(maxSampleSize / max(rddSize, 1), 1.0) + + samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() + samples = sorted(samples, reverse=(not ascending), key=keyfunc) + + # we have numPartitions many parts but one of the them has + # an implicit boundary + for i in range(0, numPartitions - 1): + index = (len(samples) - 1) * (i + 1) / numPartitions + bounds.append(samples[index]) + + def rangePartitionFunc(k): + p = 0 + while p < len(bounds) and keyfunc(k) > bounds[p]: + p += 1 + if ascending: + return p + else: + return numPartitions-1-p + + def mapFunc(iterator): + yield sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k)) + + return (self.partitionBy(numPartitions, partitionFunc=rangePartitionFunc) + .mapPartitions(mapFunc,preservesPartitioning=True) + .flatMap(lambda x: x, preservesPartitioning=True)) def glom(self): """ @@ -425,7 +471,7 @@ def count(self): 3 """ return self.mapPartitions(lambda i: [sum(1 for _ in i)]).sum() - + def stats(self): """ Return a L{StatCounter} object that captures the mean, variance @@ -462,7 +508,7 @@ def stdev(self): 0.816... """ return self.stats().stdev() - + def sampleStdev(self): """ Compute the sample standard deviation of this RDD's elements (which corrects for bias in @@ -832,7 +878,7 @@ def subtractByKey(self, other, numPartitions=None): >>> y = sc.parallelize([("a", 3), ("c", None)]) >>> sorted(x.subtractByKey(y).collect()) [('b', 4), ('b', 5)] - """ + """ filter_func = lambda (key, vals): len(vals[0]) > 0 and len(vals[1]) == 0 map_func = lambda (key, vals): [(key, val) for val in vals[0]] return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func)