Skip to content

Commit

Permalink
Merge pull request #44 from mateiz/fast-map
Browse files Browse the repository at this point in the history
A fast and low-memory append-only map for shuffle operations

This is a continuation of the old repo's pull request mesos/spark#823 to add a more efficient hashmap class for shuffles. I've optimized and tested this more thoroughly now so I think it's good to go. I've also addressed some of the comments that were outstanding there.

The idea is to reduce the cost of shuffles by taking advantage of the properties their hashmaps need. In particular, the hashmaps there are append-only, and a common operation is updating a key's value based on the old value. The included AppendOnlyMap class uses open hashing to use less space than Java's (by not having a linked list per bucket), does not support deletes, and has a changeValue operation to update a key in place without following the hash chain twice. In micro-benchmarks against java.util.HashMap and scala.collection.mutable.HashMap, this is 20-30% smaller and 10-40% faster depending on the number and type of keys. It's also noticeably faster than fastutil's Object2ObjectOpenHashMap.

I've also tested this in Spark apps now. While the speed gain is modest (partly due to other overheads, like serialization), there is some, and I think the lower memory usage is worth it. Here's one example where the speedup is most noticeable, in spark-shell on local mode:
```
scala> val nums = sc.parallelize(1 to 8).flatMap(x => (1 to 5e6.toInt)).cache

scala> nums.count

scala> def time(x: => Unit) = { val now = System.currentTimeMillis; x; System.currentTimeMillis - now }

scala> (1 to 8).map(_ => time(nums.map(x => (x % 100000, x)).reduceByKey(_ + _).count) / 1000.0)
```

This prints the following times before and after this change:
```
Before: Vector(4.368, 2.635, 2.549, 2.522, 2.233, 2.222, 2.214, 2.195)

After: Vector(3.588, 1.741, 1.706, 1.648, 1.777, 1.81, 1.776, 1.731)
```

I've also run the spark-perf suite, enhanced with some tests that use Ints (https://github.com/amplab/spark-perf/pull/9), and it shows some speedup on those, but less on the string ones (presumably due to existing overhead): https://gist.github.com/mateiz/6897121.
  • Loading branch information
mateiz committed Oct 10, 2013
2 parents 320418f + 001d13f commit cd08f73
Show file tree
Hide file tree
Showing 5 changed files with 419 additions and 40 deletions.
49 changes: 24 additions & 25 deletions core/src/main/scala/org/apache/spark/Aggregator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,42 @@

package org.apache.spark

import java.util.{HashMap => JHashMap}
import org.apache.spark.util.AppendOnlyMap

import scala.collection.JavaConversions._

/** A set of functions used to aggregate data.
*
* @param createCombiner function to create the initial value of the aggregation.
* @param mergeValue function to merge a new value into the aggregation result.
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
*/
/**
* A set of functions used to aggregate data.
*
* @param createCombiner function to create the initial value of the aggregation.
* @param mergeValue function to merge a new value into the aggregation result.
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
*/
case class Aggregator[K, V, C] (
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {

def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
for (kv <- iter) {
val oldC = combiners.get(kv._1)
if (oldC == null) {
combiners.put(kv._1, createCombiner(kv._2))
} else {
combiners.put(kv._1, mergeValue(oldC, kv._2))
}
val combiners = new AppendOnlyMap[K, C]
var kv: Product2[K, V] = null
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
}
while (iter.hasNext) {
kv = iter.next()
combiners.changeValue(kv._1, update)
}
combiners.iterator
}

def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
iter.foreach { case(k, c) =>
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, c)
} else {
combiners.put(k, mergeCombiners(oldC, c))
}
val combiners = new AppendOnlyMap[K, C]
var kc: (K, C) = null
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2
}
while (iter.hasNext) {
kc = iter.next()
combiners.changeValue(kc._1, update)
}
combiners.iterator
}
Expand Down
22 changes: 9 additions & 13 deletions core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
package org.apache.spark.rdd

import java.io.{ObjectOutputStream, IOException}
import java.util.{HashMap => JHashMap}

import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.util.AppendOnlyMap


private[spark] sealed trait CoGroupSplitDep extends Serializable
Expand Down Expand Up @@ -105,17 +104,14 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
val split = s.asInstanceOf[CoGroupPartition]
val numRdds = split.deps.size
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]

def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
val seq = map.get(k)
if (seq != null) {
seq
} else {
val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
map.put(k, seq)
seq
}
val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => {
if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any])
}

val getSeq = (k: K) => {
map.changeValue(k, update)
}

val ser = SparkEnv.get.serializerManager.get(serializerClass)
Expand All @@ -134,7 +130,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
}
}
}
JavaConversions.mapAsScalaMap(map).iterator
map.iterator
}

override def clearDependencies() {
Expand Down
230 changes: 230 additions & 0 deletions core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util

/**
* A simple open hash table optimized for the append-only use case, where keys
* are never removed, but the value for each key may be changed.
*
* This implementation uses quadratic probing with a power-of-2 hash table
* size, which is guaranteed to explore all spaces for each key (see
* http://en.wikipedia.org/wiki/Quadratic_probing).
*
* TODO: Cache the hash values of each key? java.util.HashMap does that.
*/
private[spark]
class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable {
require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
require(initialCapacity >= 1, "Invalid initial capacity")

private var capacity = nextPowerOf2(initialCapacity)
private var mask = capacity - 1
private var curSize = 0

// Holds keys and values in the same array for memory locality; specifically, the order of
// elements is key0, value0, key1, value1, key2, value2, etc.
private var data = new Array[AnyRef](2 * capacity)

// Treat the null key differently so we can use nulls in "data" to represent empty items.
private var haveNullValue = false
private var nullValue: V = null.asInstanceOf[V]

private val LOAD_FACTOR = 0.7

/** Get the value for a given key */
def apply(key: K): V = {
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
return nullValue
}
var pos = rehash(k.hashCode) & mask
var i = 1
while (true) {
val curKey = data(2 * pos)
if (k.eq(curKey) || k == curKey) {
return data(2 * pos + 1).asInstanceOf[V]
} else if (curKey.eq(null)) {
return null.asInstanceOf[V]
} else {
val delta = i
pos = (pos + delta) & mask
i += 1
}
}
return null.asInstanceOf[V]
}

/** Set the value for a key */
def update(key: K, value: V): Unit = {
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
if (!haveNullValue) {
incrementSize()
}
nullValue = value
haveNullValue = true
return
}
val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef])
if (isNewEntry) {
incrementSize()
}
}

/**
* Set the value for key to updateFunc(hadValue, oldValue), where oldValue will be the old value
* for key, if any, or null otherwise. Returns the newly updated value.
*/
def changeValue(key: K, updateFunc: (Boolean, V) => V): V = {
val k = key.asInstanceOf[AnyRef]
if (k.eq(null)) {
if (!haveNullValue) {
incrementSize()
}
nullValue = updateFunc(haveNullValue, nullValue)
haveNullValue = true
return nullValue
}
var pos = rehash(k.hashCode) & mask
var i = 1
while (true) {
val curKey = data(2 * pos)
if (k.eq(curKey) || k == curKey) {
val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
return newValue
} else if (curKey.eq(null)) {
val newValue = updateFunc(false, null.asInstanceOf[V])
data(2 * pos) = k
data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
incrementSize()
return newValue
} else {
val delta = i
pos = (pos + delta) & mask
i += 1
}
}
null.asInstanceOf[V] // Never reached but needed to keep compiler happy
}

/** Iterator method from Iterable */
override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
var pos = -1

/** Get the next value we should return from next(), or null if we're finished iterating */
def nextValue(): (K, V) = {
if (pos == -1) { // Treat position -1 as looking at the null value
if (haveNullValue) {
return (null.asInstanceOf[K], nullValue)
}
pos += 1
}
while (pos < capacity) {
if (!data(2 * pos).eq(null)) {
return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
}
pos += 1
}
null
}

override def hasNext: Boolean = nextValue() != null

override def next(): (K, V) = {
val value = nextValue()
if (value == null) {
throw new NoSuchElementException("End of iterator")
}
pos += 1
value
}
}

override def size: Int = curSize

/** Increase table size by 1, rehashing if necessary */
private def incrementSize() {
curSize += 1
if (curSize > LOAD_FACTOR * capacity) {
growTable()
}
}

/**
* Re-hash a value to deal better with hash functions that don't differ
* in the lower bits, similar to java.util.HashMap
*/
private def rehash(h: Int): Int = {
val r = h ^ (h >>> 20) ^ (h >>> 12)
r ^ (r >>> 7) ^ (r >>> 4)
}

/**
* Put an entry into a table represented by data, returning true if
* this increases the size of the table or false otherwise. Assumes
* that "data" has at least one empty slot.
*/
private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = {
val mask = (data.length / 2) - 1
var pos = rehash(key.hashCode) & mask
var i = 1
while (true) {
val curKey = data(2 * pos)
if (curKey.eq(null)) {
data(2 * pos) = key
data(2 * pos + 1) = value.asInstanceOf[AnyRef]
return true
} else if (curKey.eq(key) || curKey == key) {
data(2 * pos + 1) = value.asInstanceOf[AnyRef]
return false
} else {
val delta = i
pos = (pos + delta) & mask
i += 1
}
}
return false // Never reached but needed to keep compiler happy
}

/** Double the table's size and re-hash everything */
private def growTable() {
val newCapacity = capacity * 2
if (newCapacity >= (1 << 30)) {
// We can't make the table this big because we want an array of 2x
// that size for our data, but array sizes are at most Int.MaxValue
throw new Exception("Can't make capacity bigger than 2^29 elements")
}
val newData = new Array[AnyRef](2 * newCapacity)
var pos = 0
while (pos < capacity) {
if (!data(2 * pos).eq(null)) {
putInto(newData, data(2 * pos), data(2 * pos + 1))
}
pos += 1
}
data = newData
capacity = newCapacity
mask = newCapacity - 1
}

private def nextPowerOf2(n: Int): Int = {
val highBit = Integer.highestOneBit(n)
if (highBit == n) n else highBit << 1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}

val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
d.count
d.count()
val WAIT_TIMEOUT_MILLIS = 10000
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be (1)
Expand All @@ -50,7 +50,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => w(k) -> (v1.size, v2.size)}
d4.setName("A Cogroup")

d4.collectAsMap
d4.collectAsMap()

assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be (4)
Expand Down
Loading

0 comments on commit cd08f73

Please sign in to comment.