From b535db7d89abd59713ce83ae937d06193a04441e Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Tue, 13 Aug 2013 18:09:40 -0700 Subject: [PATCH 1/5] Added a fast and low-memory append-only map implementation for cogroup and parallel reduce operations --- .../scala/org/apache/spark/Aggregator.scala | 38 +-- .../org/apache/spark/rdd/CoGroupedRDD.scala | 16 +- .../main/scala/spark/util/AppendOnlyMap.scala | 241 ++++++++++++++++++ .../spark/scheduler/SparkListenerSuite.scala | 4 +- .../scala/spark/util/AppendOnlyMapSuite.scala | 141 ++++++++++ 5 files changed, 411 insertions(+), 29 deletions(-) create mode 100644 core/src/main/scala/spark/util/AppendOnlyMap.scala create mode 100644 core/src/test/scala/spark/util/AppendOnlyMapSuite.scala diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 3ef402926..fa1419df1 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -21,8 +21,10 @@ import java.util.{HashMap => JHashMap} import scala.collection.JavaConversions._ +import spark.util.AppendOnlyMap + /** A set of functions used to aggregate data. - * + * * @param createCombiner function to create the initial value of the aggregation. * @param mergeValue function to merge a new value into the aggregation result. * @param mergeCombiners function to merge outputs from multiple mergeValue function. @@ -33,27 +35,29 @@ case class Aggregator[K, V, C] ( mergeCombiners: (C, C) => C) { def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = { - val combiners = new JHashMap[K, C] - for (kv <- iter) { - val oldC = combiners.get(kv._1) - if (oldC == null) { - combiners.put(kv._1, createCombiner(kv._2)) - } else { - combiners.put(kv._1, mergeValue(oldC, kv._2)) - } + val combiners = new AppendOnlyMap[K, C] + for ((k, v) <- iter) { + combiners.changeValue(k, (hadValue, oldValue) => { + if (hadValue) { + mergeValue(oldValue, v) + } else { + createCombiner(v) + } + }) } combiners.iterator } def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = { - val combiners = new JHashMap[K, C] - iter.foreach { case(k, c) => - val oldC = combiners.get(k) - if (oldC == null) { - combiners.put(k, c) - } else { - combiners.put(k, mergeCombiners(oldC, c)) - } + val combiners = new AppendOnlyMap[K, C] + for ((k, c) <- iter) { + combiners.changeValue(k, (hadValue, oldValue) => { + if (hadValue) { + mergeCombiners(oldValue, c) + } else { + c + } + }) } combiners.iterator } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 0187256a8..f6dd8a65c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} +import org.apache.spark.util.AppendOnlyMap private[spark] sealed trait CoGroupSplitDep extends Serializable @@ -105,17 +106,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: val split = s.asInstanceOf[CoGroupPartition] val numRdds = split.deps.size // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs) - val map = new JHashMap[K, Seq[ArrayBuffer[Any]]] + val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]] def getSeq(k: K): Seq[ArrayBuffer[Any]] = { - val seq = map.get(k) - if (seq != null) { - seq - } else { - val seq = Array.fill(numRdds)(new ArrayBuffer[Any]) - map.put(k, seq) - seq - } + map.changeValue(k, (hadValue, oldValue) => { + if (hadValue) oldValue else Array.fill(numRdds)(new ArrayBuffer[Any]) + }) } val ser = SparkEnv.get.serializerManager.get(serializerClass) @@ -134,7 +130,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: } } } - JavaConversions.mapAsScalaMap(map).iterator + map.iterator } override def clearDependencies() { diff --git a/core/src/main/scala/spark/util/AppendOnlyMap.scala b/core/src/main/scala/spark/util/AppendOnlyMap.scala new file mode 100644 index 000000000..416b93ea4 --- /dev/null +++ b/core/src/main/scala/spark/util/AppendOnlyMap.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.util + +/** + * A simple open hash table optimized for the append-only use case, where keys + * are never removed, but the value for each key may be changed. + * + * This implementation uses quadratic probing with a power-of-2 hash table + * size, which is guaranteed to explore all spaces for each key (see + * http://en.wikipedia.org/wiki/Quadratic_probing). + * + * TODO: Cache the hash values of each key? java.util.HashMap does that. + */ +private[spark] +class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable { + if (!isPowerOf2(initialCapacity)) { + throw new IllegalArgumentException("Initial capacity must be power of 2") + } + if (initialCapacity >= (1 << 30)) { + throw new IllegalArgumentException("Can't make capacity bigger than 2^29 elements") + } + + private var capacity = initialCapacity + private var curSize = 0 + + // Holds keys and values in the same array for memory locality; specifically, the order of + // elements is key0, value0, key1, value1, key2, value2, etc. + private var data = new Array[AnyRef](2 * capacity) + + // Treat the null key differently so we can use nulls in "data" to represent empty items. + private var haveNullValue = false + private var nullValue: V = null.asInstanceOf[V] + + private val LOAD_FACTOR = 0.7 + + /** Get the value for a given key */ + def apply(key: K): V = { + val k = key.asInstanceOf[AnyRef] + if (k.eq(null)) { + return nullValue + } + val mask = capacity - 1 + var pos = rehash(k.hashCode) & mask + var i = 1 + while (true) { + val curKey = data(2 * pos) + if (curKey.eq(k) || curKey.eq(null) || curKey == k) { + return data(2 * pos + 1).asInstanceOf[V] + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } + } + return null.asInstanceOf[V] + } + + /** Set the value for a key */ + def update(key: K, value: V) { + val k = key.asInstanceOf[AnyRef] + if (k.eq(null)) { + if (!haveNullValue) { + incrementSize() + } + nullValue = value + haveNullValue = true + return + } + val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef]) + if (isNewEntry) { + incrementSize() + } + } + + /** + * Set the value for key to updateFunc(hadValue, oldValue), where oldValue will be the old value + * for key, if any, or null otherwise. Returns the newly updated value. + */ + def changeValue(key: K, updateFunc: (Boolean, V) => V): V = { + val k = key.asInstanceOf[AnyRef] + if (k.eq(null)) { + if (!haveNullValue) { + incrementSize() + } + nullValue = updateFunc(haveNullValue, nullValue) + haveNullValue = true + return nullValue + } + val mask = capacity - 1 + var pos = rehash(k.hashCode) & mask + var i = 1 + while (true) { + val curKey = data(2 * pos) + if (curKey.eq(null)) { + val newValue = updateFunc(false, null.asInstanceOf[V]) + data(2 * pos) = k + data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] + incrementSize() + return newValue + } else if (curKey.eq(k) || curKey == k) { + val newValue = updateFunc(true, data(2*pos + 1).asInstanceOf[V]) + data(2*pos + 1) = newValue.asInstanceOf[AnyRef] + return newValue + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } + } + null.asInstanceOf[V] // Never reached but needed to keep compiler happy + } + + /** Iterator method from Iterable */ + override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { + var pos = -1 + + /** Get the next value we should return from next(), or null if we're finished iterating */ + def nextValue(): (K, V) = { + if (pos == -1) { // Treat position -1 as looking at the null value + if (haveNullValue) { + return (null.asInstanceOf[K], nullValue) + } + pos += 1 + } + while (pos < capacity) { + if (!data(2 * pos).eq(null)) { + return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V]) + } + pos += 1 + } + null + } + + override def hasNext: Boolean = nextValue() != null + + override def next(): (K, V) = { + val value = nextValue() + if (value == null) { + throw new NoSuchElementException("End of iterator") + } + pos += 1 + value + } + } + + override def size: Int = curSize + + /** Increase table size by 1, rehashing if necessary */ + private def incrementSize() { + curSize += 1 + if (curSize > LOAD_FACTOR * capacity) { + growTable() + } + } + + /** + * Re-hash a value to deal better with hash functions that don't differ + * in the lower bits, similar to java.util.HashMap + */ + private def rehash(h: Int): Int = { + val r = h ^ (h >>> 20) ^ (h >>> 12) + r ^ (r >>> 7) ^ (r >>> 4) + } + + /** + * Put an entry into a table represented by data, returning true if + * this increases the size of the table or false otherwise. Assumes + * that "data" has at least one empty slot. + */ + private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = { + val mask = (data.length / 2) - 1 + var pos = rehash(key.hashCode) & mask + var i = 1 + while (true) { + val curKey = data(2 * pos) + if (curKey.eq(null)) { + data(2 * pos) = key + data(2 * pos + 1) = value.asInstanceOf[AnyRef] + return true + } else if (curKey.eq(key) || curKey == key) { + data(2 * pos + 1) = value.asInstanceOf[AnyRef] + return false + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } + } + return false // Never reached but needed to keep compiler happy + } + + /** Double the table's size and re-hash everything */ + private def growTable() { + val newCapacity = capacity * 2 + if (newCapacity >= (1 << 30)) { + // We can't make the table this big because we want an array of 2x + // that size for our data, but array sizes are at most Int.MaxValue + throw new Exception("Can't make capacity bigger than 2^29 elements") + } + val newData = new Array[AnyRef](2 * newCapacity) + var pos = 0 + while (pos < capacity) { + if (!data(2 * pos).eq(null)) { + putInto(newData, data(2 * pos), data(2 * pos + 1)) + } + pos += 1 + } + data = newData + capacity = newCapacity + } + + private def isPowerOf2(num: Int): Boolean = { + var n = num + while (n > 0) { + if (n == 1) { + return true + } else if (n % 2 == 1) { + return false + } else { + n /= 2 + } + } + return false + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 41a161e08..794c3e8f8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -44,7 +44,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)} - d.count + d.count() Thread.sleep(1000) listener.stageInfos.size should be (1) @@ -55,7 +55,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc val d4 = d2.cogroup(d3, 64).map{case(k,(v1,v2)) => w(k) -> (v1.size, v2.size)} d4.setName("A Cogroup") - d4.collectAsMap + d4.collectAsMap() Thread.sleep(1000) listener.stageInfos.size should be (4) diff --git a/core/src/test/scala/spark/util/AppendOnlyMapSuite.scala b/core/src/test/scala/spark/util/AppendOnlyMapSuite.scala new file mode 100644 index 000000000..d1e36781e --- /dev/null +++ b/core/src/test/scala/spark/util/AppendOnlyMapSuite.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.util + +import scala.collection.mutable.HashSet + +import org.scalatest.FunSuite + +class AppendOnlyMapSuite extends FunSuite { + test("initialization") { + val goodMap1 = new AppendOnlyMap[Int, Int](1) + assert(goodMap1.size === 0) + val goodMap2 = new AppendOnlyMap[Int, Int](256) + assert(goodMap2.size === 0) + intercept[IllegalArgumentException] { + new AppendOnlyMap[Int, Int](255) // Invalid map size: not power of 2 + } + intercept[IllegalArgumentException] { + new AppendOnlyMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29 + } + intercept[IllegalArgumentException] { + new AppendOnlyMap[Int, Int](-1) // Invalid map size: not power of 2 + } + } + + test("object keys and values") { + val map = new AppendOnlyMap[String, String]() + for (i <- 1 to 100) { + map("" + i) = "" + i + } + assert(map.size === 100) + for (i <- 1 to 100) { + assert(map("" + i) === "" + i) + } + assert(map("0") === null) + assert(map("101") === null) + assert(map(null) === null) + val set = new HashSet[(String, String)] + for ((k, v) <- map) { // Test the foreach method + set += ((k, v)) + } + assert(set === (1 to 100).map(_.toString).map(x => (x, x)).toSet) + } + + test("primitive keys and values") { + val map = new AppendOnlyMap[Int, Int]() + for (i <- 1 to 100) { + map(i) = i + } + assert(map.size === 100) + for (i <- 1 to 100) { + assert(map(i) === i) + } + assert(map(0) === null) + assert(map(101) === null) + val set = new HashSet[(Int, Int)] + for ((k, v) <- map) { // Test the foreach method + set += ((k, v)) + } + assert(set === (1 to 100).map(x => (x, x)).toSet) + } + + test("null keys") { + val map = new AppendOnlyMap[String, String]() + for (i <- 1 to 100) { + map("" + i) = "" + i + } + assert(map.size === 100) + assert(map(null) === null) + map(null) = "hello" + assert(map.size === 101) + assert(map(null) === "hello") + } + + test("null values") { + val map = new AppendOnlyMap[String, String]() + for (i <- 1 to 100) { + map("" + i) = null + } + assert(map.size === 100) + assert(map("1") === null) + assert(map(null) === null) + assert(map.size === 100) + map(null) = null + assert(map.size === 101) + assert(map(null) === null) + } + + test("changeValue") { + val map = new AppendOnlyMap[String, String]() + for (i <- 1 to 100) { + map("" + i) = "" + i + } + assert(map.size === 100) + for (i <- 1 to 100) { + val res = map.changeValue("" + i, (hadValue, oldValue) => { + assert(hadValue === true) + assert(oldValue === "" + i) + oldValue + "!" + }) + assert(res === i + "!") + } + // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a + // bug where changeValue would return the wrong result when the map grew on that insert + for (i <- 101 to 400) { + val res = map.changeValue("" + i, (hadValue, oldValue) => { + assert(hadValue === false) + i + "!" + }) + assert(res === i + "!") + } + assert(map.size === 400) + assert(map(null) === null) + map.changeValue(null, (hadValue, oldValue) => { + assert(hadValue === false) + "null!" + }) + assert(map.size === 401) + map.changeValue(null, (hadValue, oldValue) => { + assert(hadValue === true) + assert(oldValue === "null!") + "null!!" + }) + assert(map.size === 401) + } +} From 0e40cfabf867469f988979decd9981adc03c90b3 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Wed, 14 Aug 2013 11:45:21 -0700 Subject: [PATCH 2/5] Fix some review comments --- .../scala/org/apache/spark/Aggregator.scala | 19 ++++++------- .../org/apache/spark/rdd/CoGroupedRDD.scala | 2 -- .../main/scala/spark/util/AppendOnlyMap.scala | 27 +++++-------------- .../scala/spark/util/AppendOnlyMapSuite.scala | 23 ++++++++++++---- 4 files changed, 33 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index fa1419df1..84e15fc0c 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -17,18 +17,15 @@ package org.apache.spark -import java.util.{HashMap => JHashMap} +import org.apache.spark.util.AppendOnlyMap -import scala.collection.JavaConversions._ - -import spark.util.AppendOnlyMap - -/** A set of functions used to aggregate data. - * - * @param createCombiner function to create the initial value of the aggregation. - * @param mergeValue function to merge a new value into the aggregation result. - * @param mergeCombiners function to merge outputs from multiple mergeValue function. - */ +/** + * A set of functions used to aggregate data. + * + * @param createCombiner function to create the initial value of the aggregation. + * @param mergeValue function to merge a new value into the aggregation result. + * @param mergeCombiners function to merge outputs from multiple mergeValue function. + */ case class Aggregator[K, V, C] ( createCombiner: V => C, mergeValue: (C, V) => C, diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index f6dd8a65c..f41a023bc 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -18,9 +18,7 @@ package org.apache.spark.rdd import java.io.{ObjectOutputStream, IOException} -import java.util.{HashMap => JHashMap} -import scala.collection.JavaConversions import scala.collection.mutable.ArrayBuffer import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext} diff --git a/core/src/main/scala/spark/util/AppendOnlyMap.scala b/core/src/main/scala/spark/util/AppendOnlyMap.scala index 416b93ea4..a7a8625c9 100644 --- a/core/src/main/scala/spark/util/AppendOnlyMap.scala +++ b/core/src/main/scala/spark/util/AppendOnlyMap.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package spark.util +package org.apache.spark.util /** * A simple open hash table optimized for the append-only use case, where keys @@ -29,14 +29,10 @@ package spark.util */ private[spark] class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] with Serializable { - if (!isPowerOf2(initialCapacity)) { - throw new IllegalArgumentException("Initial capacity must be power of 2") - } - if (initialCapacity >= (1 << 30)) { - throw new IllegalArgumentException("Can't make capacity bigger than 2^29 elements") - } + require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") + require(initialCapacity >= 1, "Invalid initial capacity") - private var capacity = initialCapacity + private var capacity = nextPowerOf2(initialCapacity) private var curSize = 0 // Holds keys and values in the same array for memory locality; specifically, the order of @@ -225,17 +221,8 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi capacity = newCapacity } - private def isPowerOf2(num: Int): Boolean = { - var n = num - while (n > 0) { - if (n == 1) { - return true - } else if (n % 2 == 1) { - return false - } else { - n /= 2 - } - } - return false + private def nextPowerOf2(n: Int): Int = { + val highBit = Integer.highestOneBit(n) + if (highBit == n) n else highBit << 1 } } diff --git a/core/src/test/scala/spark/util/AppendOnlyMapSuite.scala b/core/src/test/scala/spark/util/AppendOnlyMapSuite.scala index d1e36781e..7177919a5 100644 --- a/core/src/test/scala/spark/util/AppendOnlyMapSuite.scala +++ b/core/src/test/scala/spark/util/AppendOnlyMapSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package spark.util +package org.apache.spark.util import scala.collection.mutable.HashSet @@ -25,16 +25,18 @@ class AppendOnlyMapSuite extends FunSuite { test("initialization") { val goodMap1 = new AppendOnlyMap[Int, Int](1) assert(goodMap1.size === 0) - val goodMap2 = new AppendOnlyMap[Int, Int](256) + val goodMap2 = new AppendOnlyMap[Int, Int](255) assert(goodMap2.size === 0) + val goodMap3 = new AppendOnlyMap[Int, Int](256) + assert(goodMap3.size === 0) intercept[IllegalArgumentException] { - new AppendOnlyMap[Int, Int](255) // Invalid map size: not power of 2 + new AppendOnlyMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29 } intercept[IllegalArgumentException] { - new AppendOnlyMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29 + new AppendOnlyMap[Int, Int](-1) } intercept[IllegalArgumentException] { - new AppendOnlyMap[Int, Int](-1) // Invalid map size: not power of 2 + new AppendOnlyMap[Int, Int](0) } } @@ -138,4 +140,15 @@ class AppendOnlyMapSuite extends FunSuite { }) assert(map.size === 401) } + + test("inserting in capacity-1 map") { + val map = new AppendOnlyMap[String, String](1) + for (i <- 1 to 100) { + map("" + i) = "" + i + } + assert(map.size === 100) + for (i <- 1 to 100) { + assert(map("" + i) === "" + i) + } + } } From 4acbc5afdd1eb26c936614ade2bf200c14a71d2b Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 7 Oct 2013 11:28:37 -0700 Subject: [PATCH 3/5] Moved files that were in the wrong directory after package rename --- .../main/scala/{ => org/apache}/spark/util/AppendOnlyMap.scala | 0 .../scala/{ => org/apache}/spark/util/AppendOnlyMapSuite.scala | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename core/src/main/scala/{ => org/apache}/spark/util/AppendOnlyMap.scala (100%) rename core/src/test/scala/{ => org/apache}/spark/util/AppendOnlyMapSuite.scala (100%) diff --git a/core/src/main/scala/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala similarity index 100% rename from core/src/main/scala/spark/util/AppendOnlyMap.scala rename to core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala diff --git a/core/src/test/scala/spark/util/AppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala similarity index 100% rename from core/src/test/scala/spark/util/AppendOnlyMapSuite.scala rename to core/src/test/scala/org/apache/spark/util/AppendOnlyMapSuite.scala From 0b35051f19bcd9c432574ad5c0a921d45cd902cb Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 7 Oct 2013 15:28:22 -0700 Subject: [PATCH 4/5] Address some comments on code clarity --- .../org/apache/spark/util/AppendOnlyMap.scala | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala index a7a8625c9..f60deafc6 100644 --- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala @@ -33,6 +33,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi require(initialCapacity >= 1, "Invalid initial capacity") private var capacity = nextPowerOf2(initialCapacity) + private var mask = capacity - 1 private var curSize = 0 // Holds keys and values in the same array for memory locality; specifically, the order of @@ -51,13 +52,14 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi if (k.eq(null)) { return nullValue } - val mask = capacity - 1 var pos = rehash(k.hashCode) & mask var i = 1 while (true) { val curKey = data(2 * pos) - if (curKey.eq(k) || curKey.eq(null) || curKey == k) { + if (k.eq(curKey) || k == curKey) { return data(2 * pos + 1).asInstanceOf[V] + } else if (curKey.eq(null)) { + return null.asInstanceOf[V] } else { val delta = i pos = (pos + delta) & mask @@ -68,7 +70,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi } /** Set the value for a key */ - def update(key: K, value: V) { + def update(key: K, value: V): Unit = { val k = key.asInstanceOf[AnyRef] if (k.eq(null)) { if (!haveNullValue) { @@ -98,21 +100,20 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi haveNullValue = true return nullValue } - val mask = capacity - 1 var pos = rehash(k.hashCode) & mask var i = 1 while (true) { val curKey = data(2 * pos) - if (curKey.eq(null)) { + if (k.eq(curKey) || k == curKey) { + val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V]) + data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] + return newValue + } else if (curKey.eq(null)) { val newValue = updateFunc(false, null.asInstanceOf[V]) data(2 * pos) = k data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] incrementSize() return newValue - } else if (curKey.eq(k) || curKey == k) { - val newValue = updateFunc(true, data(2*pos + 1).asInstanceOf[V]) - data(2*pos + 1) = newValue.asInstanceOf[AnyRef] - return newValue } else { val delta = i pos = (pos + delta) & mask @@ -219,6 +220,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi } data = newData capacity = newCapacity + mask = newCapacity - 1 } private def nextPowerOf2(n: Int): Int = { From 12d593129df8a434f66bd3d01812cab76f40e6b8 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Mon, 7 Oct 2013 21:22:08 -0700 Subject: [PATCH 5/5] Create fewer function objects in uses of AppendOnlyMap.changeValue --- .../scala/org/apache/spark/Aggregator.scala | 30 +++++++++---------- .../org/apache/spark/rdd/CoGroupedRDD.scala | 10 ++++--- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 84e15fc0c..1a2ec5587 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -33,28 +33,26 @@ case class Aggregator[K, V, C] ( def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = { val combiners = new AppendOnlyMap[K, C] - for ((k, v) <- iter) { - combiners.changeValue(k, (hadValue, oldValue) => { - if (hadValue) { - mergeValue(oldValue, v) - } else { - createCombiner(v) - } - }) + var kv: Product2[K, V] = null + val update = (hadValue: Boolean, oldValue: C) => { + if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) + } + while (iter.hasNext) { + kv = iter.next() + combiners.changeValue(kv._1, update) } combiners.iterator } def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = { val combiners = new AppendOnlyMap[K, C] - for ((k, c) <- iter) { - combiners.changeValue(k, (hadValue, oldValue) => { - if (hadValue) { - mergeCombiners(oldValue, c) - } else { - c - } - }) + var kc: (K, C) = null + val update = (hadValue: Boolean, oldValue: C) => { + if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2 + } + while (iter.hasNext) { + kc = iter.next() + combiners.changeValue(kc._1, update) } combiners.iterator } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index f41a023bc..d237797aa 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -106,10 +106,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: // e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs) val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]] - def getSeq(k: K): Seq[ArrayBuffer[Any]] = { - map.changeValue(k, (hadValue, oldValue) => { - if (hadValue) oldValue else Array.fill(numRdds)(new ArrayBuffer[Any]) - }) + val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => { + if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any]) + } + + val getSeq = (k: K) => { + map.changeValue(k, update) } val ser = SparkEnv.get.serializerManager.get(serializerClass)