Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NLP support #56

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ README.org
.cache/
.history/
.lib/
.coverage
dist/*
target/
lib_managed/
Expand Down
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,16 @@ Spark DataFrames are a natural construct for applying deep learning models to a

```python
from sparkdl import readImages, TFImageTransformer
import sparkdl.graph.utils as tfx
from sparkdl.transformers import utils
import tensorflow as tf

g = tf.Graph()
with g.as_default():
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
image_arr = utils.imageInputPlaceholder()
resized_images = tf.image.resize_images(image_arr, (299, 299))
# the following step is not necessary for this graph, but can be for graphs with variables, etc
frozen_graph = utils.stripAndFreezeGraph(g.as_graph_def(add_shapes=True), tf.Session(graph=g),
[resized_images])
frozen_graph = tfx.strip_and_freeze_until([resized_images], graph, sess,
return_graph=True)

transformer = TFImageTransformer(inputCol="image", outputCol="predictions", graph=frozen_graph,
inputTensor=image_arr, outputTensor=resized_images,
Expand Down Expand Up @@ -241,7 +241,7 @@ registerKerasImageUDF("my_keras_inception_udf", InceptionV3(weights="imagenet"),

```

### Estimator

## Releases:
* 0.1.0 initial release

1 change: 1 addition & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ pygments>=2.2.0
tensorflow==1.3.0
pandas>=0.19.1
six>=1.10.0
kafka-python>=1.3.5
317 changes: 317 additions & 0 deletions python/sparkdl/estimators/tf_text_file_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
#
# Copyright 2017 Databricks, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# pylint: disable=protected-access
from __future__ import absolute_import, division, print_function

import logging
import threading
import time
import os
import sys

from kafka import KafkaConsumer
from kafka import KafkaProducer
from pyspark.ml import Estimator

from sparkdl.param import (
keyword_only, HasLabelCol, HasInputCol, HasOutputCol)
from sparkdl.param.shared_params import KafkaParam, FitParam, MapFnParam
import sparkdl.utils.jvmapi as JVMAPI

if sys.version_info[:2] <= (2, 7):
import cPickle as pickle
else:
import pickle

__all__ = ['TFTextFileEstimator']

logger = logging.getLogger('sparkdl')


class TFTextFileEstimator(Estimator, HasInputCol, HasOutputCol, HasLabelCol, KafkaParam, FitParam, MapFnParam):
"""
Build a Estimator from tensorflow or keras when backend is tensorflow.

First,assume we have data in dataframe like following.

.. code-block:: python
documentDF = self.session.createDataFrame([
("Hi I heard about Spark", 1),
("I wish Java could use case classes", 0),
("Logistic regression models are neat", 2)
], ["text", "preds"])

transformer = TFTextTransformer(
inputCol=input_col,
outputCol=output_col)

df = transformer.transform(documentDF)

TFTextTransformer will transform text column to `output_col`, which is 2-D array.

Then we create a tensorflow function.

.. code-block:: python
def map_fun(_read_data, **args):
import tensorflow as tf
EMBEDDING_SIZE = args["embedding_size"]
feature = args['feature']
label = args['label']
params = args['params']['fitParam']
SEQUENCE_LENGTH = 64

def feed_dict(batch):
# Convert from dict of named arrays to two numpy arrays of the proper type
features = []
for i in batch:
features.append(i['sentence_matrix'])

# print("{} {}".format(feature, features))
return features

encoder_variables_dict = {
"encoder_w1": tf.Variable(
tf.random_normal([SEQUENCE_LENGTH * EMBEDDING_SIZE, 256]), name="encoder_w1"),
"encoder_b1": tf.Variable(tf.random_normal([256]), name="encoder_b1"),
"encoder_w2": tf.Variable(tf.random_normal([256, 128]), name="encoder_w2"),
"encoder_b2": tf.Variable(tf.random_normal([128]), name="encoder_b2")
}

_read_data is a data generator. args provide hyper parameteres configured in this estimator.

here is how to use _read_data:

.. code-block:: python
for data in _read_data(max_records=params.batch_size):
batch_data = feed_dict(data)
sess.run(train_step, feed_dict={input_x: batch_data})

finally we can create TFTextFileEstimator to train our model:

.. code-block:: python
estimator = TFTextFileEstimator(inputCol="sentence_matrix",
outputCol="sentence_matrix", labelCol="preds",
kafkaParam={"bootstrap_servers": ["127.0.0.1"], "topic": "test",
"group_id": "sdl_1"},
fitParam=[{"epochs": 5, "batch_size": 64}, {"epochs": 5, "batch_size": 1}],
mapFnParam=map_fun)
estimator.fit(df)

"""

@keyword_only
def __init__(self, inputCol=None, outputCol=None, labelCol=None, kafkaParam=None, fitParam=None, mapFnParam=None):
super(TFTextFileEstimator, self).__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, inputCol=None, outputCol=None, labelCol=None, kafkaParam=None, fitParam=None, mapFnParam=None):
kwargs = self._input_kwargs
return self._set(**kwargs)

def fit(self, dataset, params=None):
self._validateParams()
if params is None:
paramMaps = self.getFitParam()
elif isinstance(params, (list, tuple)):
if len(params) == 0:
paramMaps = [dict()]
else:
self._validateFitParams(params)
paramMaps = params
elif isinstance(params, dict):
paramMaps = [params]
else:
raise ValueError("Params must be either a param map or a list/tuple of param maps, "
"but got %s." % type(params))
return self._fitInParallel(dataset, paramMaps)

def _validateParams(self):
"""
Check Param values so we can throw errors on the driver, rather than workers.
:return: True if parameters are valid
"""
if not self.isDefined(self.inputCol):
raise ValueError("Input column must be defined")
if not self.isDefined(self.outputCol):
raise ValueError("Output column must be defined")
return True

def _fitInParallel(self, dataset, paramMaps):

inputCol = self.getInputCol()
labelCol = self.getLabelCol()

from time import gmtime, strftime
kafaParams = self.getKafkaParam()
topic = kafaParams["topic"] + "_" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
group_id = kafaParams["group_id"]
bootstrap_servers = kafaParams["bootstrap_servers"]
kafka_test_mode = kafaParams["test_mode"] if "test_mode" in kafaParams else False
mock_kafka_file = kafaParams["mock_kafka_file"] if kafka_test_mode else None

def _write_data():
def _write_partition(index, d_iter):
producer = KafkaMockServer(index, mock_kafka_file) if kafka_test_mode else KafkaProducer(
bootstrap_servers=bootstrap_servers)
try:
for d in d_iter:
producer.send(topic, pickle.dumps(d))
producer.send(topic, pickle.dumps("_stop_"))
producer.flush()
finally:
producer.close()
return []

dataset.rdd.mapPartitionsWithIndex(_write_partition).count()

if kafka_test_mode:
_write_data()
else:
t = threading.Thread(target=_write_data)
t.start()

stop_flag_num = dataset.rdd.getNumPartitions()
temp_item = dataset.take(1)[0]
vocab_s = temp_item["vocab_size"]
embedding_size = temp_item["embedding_size"]

sc = JVMAPI._curr_sc()

paramMapsRDD = sc.parallelize(paramMaps, numSlices=len(paramMaps))

# Obtain params for this estimator instance
baseParamMap = self.extractParamMap()
baseParamDict = dict([(param.name, val) for param, val in baseParamMap.items()])
baseParamDictBc = sc.broadcast(baseParamDict)

def _local_fit(override_param_map):
# Update params
params = baseParamDictBc.value
params["fitParam"] = override_param_map

def _read_data(max_records=64):
consumer = KafkaMockServer(0, mock_kafka_file) if kafka_test_mode else KafkaConsumer(topic,
group_id=group_id,
bootstrap_servers=bootstrap_servers,
auto_offset_reset="earliest",
enable_auto_commit=False
)
try:
stop_count = 0
fail_msg_count = 0
while True:
if kafka_test_mode:
time.sleep(1)
messages = consumer.poll(timeout_ms=1000, max_records=max_records)
group_msgs = []
for tp, records in messages.items():
for record in records:
try:
msg_value = pickle.loads(record.value)
if msg_value == "_stop_":
stop_count += 1
else:
group_msgs.append(msg_value)
except:
fail_msg_count += 0
pass
if len(group_msgs) > 0:
yield group_msgs

if kafka_test_mode:
print(
"stop_count = {} "
"group_msgs = {} "
"stop_flag_num = {} "
"fail_msg_count = {}".format(stop_count,
len(group_msgs),
stop_flag_num,
fail_msg_count))

if stop_count >= stop_flag_num and len(group_msgs) == 0:
break
finally:
consumer.close()

self.getMapFnParam()(args={"feature": inputCol,
"label": labelCol,
"vacab_size": vocab_s,
"embedding_size": embedding_size,
"params": params}, ctx=None, _read_data=_read_data,
)

return paramMapsRDD.map(lambda paramMap: (paramMap, _local_fit(paramMap)))

def _fit(self, dataset): # pylint: disable=unused-argument
err_msgs = ["This function should not have been called",
"Please contact library maintainers to file a bug"]
raise NotImplementedError('\n'.join(err_msgs))


class KafkaMockServer(object):
"""
Restrictions of KafkaMockServer:
* Make sure all data have been writen before consume.
* Poll function will just ignore max_records and just return all data in queue.
"""

_kafka_mock_server_tmp_file_ = None
sended = False

def __init__(self, index=0, tmp_file=None):
super(KafkaMockServer, self).__init__()
self.index = index
self.queue = []
self._kafka_mock_server_tmp_file_ = tmp_file
if not os.path.exists(self._kafka_mock_server_tmp_file_):
os.mkdir(self._kafka_mock_server_tmp_file_)

def send(self, topic, msg):
self.queue.append(pickle.loads(msg))

def flush(self):
with open(self._kafka_mock_server_tmp_file_ + "/" + str(self.index), "wb") as f:
pickle.dump(self.queue, f)
self.queue = []

def close(self):
pass

def poll(self, timeout_ms, max_records):
if self.sended:
return {}

records = []
for file in os.listdir(self._kafka_mock_server_tmp_file_):
with open(self._kafka_mock_server_tmp_file_ + "/" + file, "rb") as f:
tmp = pickle.load(f)
records += tmp
result = {}
couter = 0
for i in records:
obj = MockRecord()
obj.value = pickle.dumps(i)
couter += 1
result[str(couter) + "_"] = [obj]
self.sended = True
return result


class MockRecord(list):
pass
Loading