Skip to content

Commit

Permalink
merged with master
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasatdatabricks committed Dec 29, 2017
1 parent aeff9c9 commit f616462
Show file tree
Hide file tree
Showing 16 changed files with 175 additions and 177 deletions.
4 changes: 2 additions & 2 deletions python/sparkdl/estimators/keras_image_file_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
import pyspark.ml.linalg as spla
from pyspark.ml.param import Param, Params, TypeConverters

from sparkdl.image.imageIO import imageStructToArray
from sparkdl.param import (
keyword_only, CanLoadImage, HasKerasModel, HasKerasOptimizer, HasKerasLoss, HasOutputMode,
HasInputCol, HasInputImageNodeName, HasLabelCol, HasOutputNodeName, HasOutputCol)
from sparkdl.transformers.keras_image import KerasImageFileTransformer
from sparkdl.image.image import ImageSchema
import sparkdl.utils.jvmapi as JVMAPI
import sparkdl.utils.keras_model as kmutil

Expand Down Expand Up @@ -202,7 +202,7 @@ def _getNumpyFeaturesAndLabels(self, dataset):
rows = image_df.collect()
for row in rows:
spimg = row[tmp_image_col]
features = imageStructToArray(spimg)
features = ImageSchema.toNDArray(spimg)
localFeatures.append(features)

if not localFeatures: # NOTE(phi-dbq): pep-8 recommended against testing 0 == len(array)
Expand Down
84 changes: 61 additions & 23 deletions python/sparkdl/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
"""

import numpy as np
from collections import namedtuple

from pyspark import SparkContext
from pyspark.sql.types import Row, _create_row, _parse_datatype_json_string
from pyspark.sql import DataFrame, SparkSession
Expand All @@ -42,9 +44,23 @@ class _ImageSchema(object):
def __init__(self):
self._imageSchema = None
self._ocvTypes = None
self._ocvTypesByName = None
self._ocvTypesByMode = None
self._imageFields = None
self._undefinedImageType = None

_OcvType = namedtuple("OcvType", ["name", "mode", "nChannels", "dataType", "nptype"])

_ocvToNumpyMap = {
"8U": np.dtype("uint8"),
"8S": np.dtype("int8"),
"16U": np.dtype('uint16'),
"16S": np.dtype('int16'),
"32S": np.dtype('int32'),
"32F": np.dtype('float32'),
"64F": np.dtype('float64')}
_numpyToOcvMap = {x[1]: x[0] for x in _ocvToNumpyMap.items()}

@property
def imageSchema(self):
"""
Expand All @@ -57,7 +73,7 @@ def imageSchema(self):
"""

if self._imageSchema is None:
ctx = SparkContext._active_spark_context
ctx = SparkContext.getOrCreate()
jschema = ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageSchema()
self._imageSchema = _parse_datatype_json_string(jschema.json())
return self._imageSchema
Expand All @@ -71,11 +87,30 @@ def ocvTypes(self):
.. versionadded:: 2.3.0
"""

if self._ocvTypes is None:
ctx = SparkContext._active_spark_context
self._ocvTypes = dict(ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes())
return self._ocvTypes
ctx = SparkContext.getOrCreate()
self._ocvTypes = [self._OcvType(name=x.name(),
mode=x.mode(),
nChannels=x.nChannels(),
dataType=x.dataType(),
nptype=self._ocvToNumpyMap[x.dataType()])
for x in ctx._jvm.org.apache.spark.ml.image.ImageSchema.javaOcvTypes()]
return [x for x in self._ocvTypes]

def ocvTypeByName(self, name):
if self._ocvTypesByName is None:
self._ocvTypesByName = {x.name: x for x in self.ocvTypes}
if not name in self._ocvTypesByName:
raise ValueError(
"Unsupported image format, can not find matching OpenCvFormat for type = '%s'; currently supported formats = %s" %
(name, str(
self._ocvTypesByName.keys())))
return self._ocvTypesByName[name]

def ocvTypeByMode(self, mode):
if self._ocvTypesByMode is None:
self._ocvTypesByMode = {x.mode: x for x in self.ocvTypes}
return self._ocvTypesByMode[mode]

@property
def imageFields(self):
Expand All @@ -88,7 +123,7 @@ def imageFields(self):
"""

if self._imageFields is None:
ctx = SparkContext._active_spark_context
ctx = SparkContext.getOrCreate()
self._imageFields = list(ctx._jvm.org.apache.spark.ml.image.ImageSchema.imageFields())
return self._imageFields

Expand All @@ -101,7 +136,7 @@ def undefinedImageType(self):
"""

if self._undefinedImageType is None:
ctx = SparkContext._active_spark_context
ctx = SparkContext.getOrCreate()
self._undefinedImageType = \
ctx._jvm.org.apache.spark.ml.image.ImageSchema.undefinedImageType()
return self._undefinedImageType
Expand All @@ -126,15 +161,20 @@ def toNDArray(self, image):
raise ValueError(
"image argument should have attributes specified in "
"ImageSchema.imageSchema [%s]." % ", ".join(self.imageFields))

height = image.height
width = image.width
nChannels = image.nChannels
ocvType = self.ocvTypeByMode(image.mode)
if nChannels != ocvType.nChannels:
raise ValueError(
"Unexpected number of channels, image has %d channels but OcvType '%s' expects %d channels." %
(nChannels, ocvType.name, ocvType.nChannels))
itemSz = ocvType.nptype.itemsize
return np.ndarray(
shape=(height, width, nChannels),
dtype=np.uint8,
dtype=ocvType.nptype,
buffer=image.data,
strides=(width * nChannels, nChannels, 1))
strides=(width * nChannels * itemSz, nChannels * itemSz, itemSz))

def toImage(self, array, origin=""):
"""
Expand All @@ -152,29 +192,27 @@ def toImage(self, array, origin=""):
"array argument should be numpy.ndarray; however, it got [%s]." % type(array))

if array.ndim != 3:
raise ValueError("Invalid array shape")
raise ValueError("Invalid array shape %s" % str(array.shape))

height, width, nChannels = array.shape
ocvTypes = ImageSchema.ocvTypes
if nChannels == 1:
mode = ocvTypes["CV_8UC1"]
elif nChannels == 3:
mode = ocvTypes["CV_8UC3"]
elif nChannels == 4:
mode = ocvTypes["CV_8UC4"]
else:
raise ValueError("Invalid number of channels")
dtype = array.dtype
if not dtype in self._numpyToOcvMap:
raise ValueError(
"Unexpected/unsupported array data type '%s', currently only supported formats are %s" %
(str(array.dtype), str(self._numpyToOcvMap.keys())))
ocvName = "CV_%sC%d" % (self._numpyToOcvMap[dtype], nChannels)
ocvType = self.ocvTypeByName(ocvName)

# Running `bytearray(numpy.array([1]))` fails in specific Python versions
# with a specific Numpy version, for example in Python 3.6.0 and NumPy 1.13.3.
# Here, it avoids it by converting it to bytes.
data = bytearray(array.astype(dtype=np.uint8).ravel().tobytes())
data = bytearray(array.tobytes())

# Creating new Row with _create_row(), because Row(name = value, ... )
# orders fields by name, which conflicts with expected schema order
# when the new DataFrame is created by UDF
return _create_row(self.imageFields,
[origin, height, width, nChannels, mode, data])
[origin, height, width, nChannels, ocvType.mode, data])

def readImages(self, path, recursive=False, numPartitions=-1,
dropImageFailures=False, sampleRatio=1.0, seed=0):
Expand Down Expand Up @@ -203,7 +241,7 @@ def readImages(self, path, recursive=False, numPartitions=-1,
.. versionadded:: 2.3.0
"""

ctx = SparkContext._active_spark_context
ctx = SparkContext.getOrCreate()
spark = SparkSession(ctx)
image_schema = ctx._jvm.org.apache.spark.ml.image.ImageSchema
jsession = spark._jsparkSession
Expand Down
126 changes: 30 additions & 96 deletions python/sparkdl/image/imageIO.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,83 +21,12 @@
from PIL import Image

# pyspark
from pyspark import Row
from pyspark import SparkContext
from sparkdl.image.image import ImageSchema
from pyspark.sql.functions import udf
from pyspark.sql.types import (
BinaryType, IntegerType, StringType, StructField, StructType)


# ImageType represents supported OpenCV types
# fields:
# name - OpenCvMode
# ord - Ordinal of the corresponding OpenCV mode (stored in mode field of ImageSchema).
# nChannels - number of channels in the image
# dtype - data type of the image's array, sorted as a numpy compatible string.
#
# NOTE: likely to be migrated to Spark ImageSchema code in the near future.
_OcvType = namedtuple("OcvType", ["name", "ord", "nChannels", "dtype"])


_supportedOcvTypes = (
_OcvType(name="CV_8UC1", ord=0, nChannels=1, dtype="uint8"),
_OcvType(name="CV_32FC1", ord=5, nChannels=1, dtype="float32"),
_OcvType(name="CV_8UC3", ord=16, nChannels=3, dtype="uint8"),
_OcvType(name="CV_32FC3", ord=21, nChannels=3, dtype="float32"),
_OcvType(name="CV_8UC4", ord=24, nChannels=4, dtype="uint8"),
_OcvType(name="CV_32FC4", ord=29, nChannels=4, dtype="float32"),
)

# NOTE: likely to be migrated to Spark ImageSchema code in the near future.
_ocvTypesByName = {m.name: m for m in _supportedOcvTypes}
_ocvTypesByOrdinal = {m.ord: m for m in _supportedOcvTypes}


def imageTypeByOrdinal(ord):
if not ord in _ocvTypesByOrdinal:
raise KeyError("unsupported image type with ordinal %d, supported OpenCV types = %s" % (
ord, str(_supportedOcvTypes)))
return _ocvTypesByOrdinal[ord]


def imageTypeByName(name):
if not name in _ocvTypesByName:
raise KeyError("unsupported image type with name '%s', supported supported OpenCV types = %s" % (
name, str(_supportedOcvTypes)))
return _ocvTypesByName[name]


def imageArrayToStruct(imgArray, origin=""):
"""
Create a row representation of an image from an image array.
:param imgArray: ndarray, image data.
:return: Row, image as a DataFrame Row with schema==ImageSchema.
"""
# Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists.
if len(imgArray.shape) == 4:
if imgArray.shape[0] != 1:
raise ValueError(
"The first dimension of a 4-d image array is expected to be 1.")
imgArray = imgArray.reshape(imgArray.shape[1:])
imageType = _arrayToOcvMode(imgArray)
height, width, nChannels = imgArray.shape
data = bytearray(imgArray.tobytes())
return Row(origin=origin, mode=imageType.ord, height=height,
width=width, nChannels=nChannels, data=data)


def imageStructToArray(imageRow):
"""
Convert an image to a numpy array.
:param imageRow: Row, must use imageSchema.
:return: ndarray, image data.
"""
imType = imageTypeByOrdinal(imageRow.mode)
shape = (imageRow.height, imageRow.width, imageRow.nChannels)
return np.ndarray(shape, imType.dtype, imageRow.data)
from sparkdl.image.image import ImageSchema


def imageStructToPIL(imageRow):
Expand All @@ -107,20 +36,20 @@ def imageStructToPIL(imageRow):
:param imageRow: Row, must have ImageSchema
:return PIL image
"""
imgType = imageTypeByOrdinal(imageRow.mode)
if imgType.dtype != 'uint8':
ary = ImageSchema.toNDArray(imageRow)
if ary.dtype != np.uint8:
raise ValueError("Can not convert image of type " +
imgType.dtype + " to PIL, can only deal with 8U format")
ary = imageStructToArray(imageRow)
ary.dtype + " to PIL, can only deal with 8U format")

# PIL expects RGB order, image schema is BGR
# => we need to flip the order unless there is only one channel
if imgType.nChannels != 1:
if imageRow.nChannels != 1:
ary = _reverseChannels(ary)
if imgType.nChannels == 1:
if imageRow.nChannels == 1:
return Image.fromarray(obj=ary, mode='L')
elif imgType.nChannels == 3:
elif imageRow.nChannels == 3:
return Image.fromarray(obj=ary, mode='RGB')
elif imgType.nChannels == 4:
elif imageRow.nChannels == 4:
return Image.fromarray(obj=ary, mode='RGBA')
else:
raise ValueError("don't know how to convert " +
Expand All @@ -132,19 +61,6 @@ def PIL_to_imageStruct(img):
return _reverseChannels(np.asarray(img))


def _arrayToOcvMode(arr):
assert len(arr.shape) == 3, "Array should have 3 dimensions but has shape {}".format(
arr.shape)
num_channels = arr.shape[2]
if arr.dtype == "uint8":
name = "CV_8UC%d" % num_channels
elif arr.dtype == "float32":
name = "CV_32FC%d" % num_channels
else:
raise ValueError("Unsupported type '%s'" % arr.dtype)
return imageTypeByName(name)


def fixColorChannelOrdering(currentOrder, imgAry):
if currentOrder == 'RGB':
return _reverseChannels(imgAry)
Expand All @@ -160,6 +76,24 @@ def fixColorChannelOrdering(currentOrder, imgAry):
"Unexpected channel order, expected one of L,RGB,BGR but got " + currentChannelOrder)


def _stripBatchSize(imgArray):
"""
Strip batch size (if it's there) from a multi dimensional array.
Assumes batch size is the first coordinate and is equal to 1.
Batch size != 1 will cause an error.
:param imgArray: ndarray, image data.
:return: imgArray without the leading batch size
"""
# Sometimes tensors have a leading "batch-size" dimension. Assume to be 1 if it exists.
if len(imgArray.shape) == 4:
if imgArray.shape[0] != 1:
raise ValueError(
"The first dimension of a 4-d image array is expected to be 1.")
imgArray = imgArray.reshape(imgArray.shape[1:])
return imgArray


def _reverseChannels(ary):
return ary[..., ::-1]

Expand All @@ -183,8 +117,8 @@ def _resizeImageAsRow(imgAsRow):
return imgAsRow
imgAsPil = imageStructToPIL(imgAsRow).resize(sz)
# PIL is RGB based while image schema is BGR based => we need to flip the channels
imgAsArray = _reverseChannels(np.asarray(imgAsPil))
return imageArrayToStruct(imgAsArray, origin=imgAsRow.origin)
imgAsArray = PIL_to_imageStruct(imgAsPil)
return ImageSchema.toImage(imgAsArray, origin=imgAsRow.origin)
return udf(_resizeImageAsRow, ImageSchema.imageSchema['image'].dataType)


Expand Down Expand Up @@ -242,7 +176,7 @@ def readImagesWithCustomFn(path, decode_f, numPartition=None):
def _readImagesWithCustomFn(path, decode_f, numPartition, sc):
def _decode(path, raw_bytes):
try:
return imageArrayToStruct(decode_f(raw_bytes), origin=path)
return ImageSchema.toImage(decode_f(raw_bytes), origin=path)
except BaseException:
return None
decodeImage = udf(_decode, ImageSchema.imageSchema['image'].dataType)
Expand Down
11 changes: 5 additions & 6 deletions python/sparkdl/param/image_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from sparkdl.image.image import ImageSchema
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.sql.functions import udf
from sparkdl.image.imageIO import imageArrayToStruct
from sparkdl.image.imageIO import _reverseChannels
from sparkdl.image.imageIO import _reverseChannels, _stripBatchSize
from sparkdl.param import SparkDLTypeConverters

OUTPUT_MODES = ["vector", "image"]
Expand Down Expand Up @@ -95,10 +94,10 @@ def loadImagesInternal(self, dataframe, inputCol):
# Load from external resources can fail, so we should allow None to be returned

def load_image_uri_impl(uri):
try:
return imageArrayToStruct(_reverseChannels(loader(uri)))
except BaseException: # pylint: disable=bare-except
return None
# try:
return ImageSchema.toImage(_reverseChannels(_stripBatchSize(loader(uri))))
# except BaseException: # pylint: disable=bare-except
# return None
load_udf = udf(load_image_uri_impl, ImageSchema.imageSchema['image'].dataType)
return dataframe.withColumn(self._loadedImageCol(), load_udf(dataframe[inputCol]))

Expand Down
Loading

0 comments on commit f616462

Please sign in to comment.