Skip to content

Commit

Permalink
Merge pull request #21 from Distributive-Network/severn/serialize-pyt…
Browse files Browse the repository at this point in the history
…hon-types

Feature: Serialize Python Values,  Add Job API, and create BF2 Work Function Wrapper
  • Loading branch information
wiwichips authored Aug 8, 2024
2 parents b5f6eed + 22cd6b9 commit 8da3f92
Show file tree
Hide file tree
Showing 15 changed files with 725 additions and 194 deletions.
1 change: 1 addition & 0 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ jobs:
poetry env use python3
poetry install
echo "Install [complete]"
timeout-minutes: 1

- name: NPM ls
run: cd dcp/js && npm ls && cd -
Expand Down
3 changes: 2 additions & 1 deletion dcp/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .compute_for import compute_for_maker
from .compute_do import compute_do_maker
from .job import job_maker
from .result_handle import result_handle_maker
from .job_fs import JobFS
Expand All @@ -8,5 +9,5 @@
'ResultHandle': result_handle_maker,
}

__all__ = ['compute_for_maker', 'sub_classes', 'JobFS']
__all__ = ['compute_for_maker', 'compute_do_maker' 'sub_classes', 'JobFS']

22 changes: 22 additions & 0 deletions dcp/api/compute_do.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
compute_do API
Author: Severn Lortie <[email protected]>
Date: Aug 2024
"""

import pythonmonkey as pm
import dill
from types import FunctionType

def compute_do_maker(Job):
def compute_do(*args, **kwargs):
args = list(args)
for i in range(len(args)):
arg = args[i]
if isinstance(arg, FunctionType):
args[i] = dill.source.getsource(arg)
compute_do_js = pm.eval("globalThis.dcp.compute.do")
job_js = dry.aio.blockify(computedo_js)(*args, **kwargs)
return Job(job_js)
return compute_do
22 changes: 17 additions & 5 deletions dcp/api/compute_for.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
"""
compute_for API
Author: Will Pringle <[email protected]>, Severn Lortie <[email protected]>
Date: July 2024
"""

import pythonmonkey as pm
import dill
from .. import dry
from .. import js
from types import FunctionType

def compute_for_maker():
def compute_for_maker(Job):
def compute_for(*args, **kwargs):
args = list(args)
for i in range(len(args)):
arg = args[i]
if isinstance(arg, FunctionType):
args[i] = dill.source.getsource(arg)
compute_for_js = pm.eval("globalThis.dcp.compute.for")
ret_val = dry.aio.blockify(compute_for_js)(*args, **kwargs)
return dry.class_manager.wrap_obj(ret_val)
job_js = dry.aio.blockify(compute_for_js)(*args, **kwargs)
return Job(job_js)
return compute_for

114 changes: 109 additions & 5 deletions dcp/api/job.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,117 @@
import pythonmonkey as pm
import cloudpickle
import dill
import asyncio
from ..js import utils
from .. import dry
from .job_serializers import (
default_serializers,
serialize,
deserialize,
convert_serializers_to_arguments,
validate_serializers
)
from .job_env import convert_env_to_arguments
from .job_modules import convert_modules_to_requires
from .job_fs import JobFS
from collections.abc import Iterator
from types import FunctionType
import urllib
from .pyodide_work_function import get_work_function_string

def job_maker(super_class):
class Job(super_class):
#def exec(self, *args, **kwargs):
# # TODO change behaviour to match spec
# #print("overidden by Job hook")
def __init__(self, job_js):
super().__init__(job_js)
self.js_ref.worktime = 'pyodide'

def wait(self, *args, **kwargs):
pass
self._wrapper_set_attribute("serializers", default_serializers)
self._wrapper_set_attribute("env", {})
self._wrapper_set_attribute("modules", [])
self._wrapper_set_attribute("fs", JobFS())
self._wrapper_set_attribute("_exec_called", False)
self.aio.exec = self._exec;
self.aio.wait = self._wait;

def _before_exec(self, *args, **kwargs):
if not self.js_ref.worktime == "pyodide":
pass

work_function = urllib.parse.unquote(self.js_ref.workFunctionURI)
work_function = work_function.replace("data:,", "")

meta_arguments = [
work_function
]

serialized_arguments = []
serialized_input_data = []
if len(self.serializers):
validate_serializers(self.serializers)

super_range_object = pm.eval("globalThis.dcp['range-object'].SuperRangeObject")
if isinstance(self.js_ref.jobInputData, list):
for input_slice in self.js_ref.jobInputData:
serialized_slice = serialize(input_slice, self.serializers)
serialized_input_data.append(serialized_slice)
elif isinstance(self.js_ref.jobInputData, Iterator) and not utils.instanceof(self.js_ref.jobInputData, super_range_object):
serialized_input_data = serialize(self.js_ref.jobInputData, self.serializers)
else:
serialized_input_data = self.js_ref.jobInputData

for argument in self.js_ref.jobArguments:
serialized_argument = serialize(argument, self.serializers)
serialized_arguments.append(serialized_argument)

serialized_serializers = convert_serializers_to_arguments(self.serializers)
meta_arguments.append(serialized_serializers)
else:
serialized_arguments = self.js_ref.jobArguments
serialized_input_data = self.js_ref.jobInputData

job_fs = bytearray(self.fs.to_gzip_tar())
env_args = convert_env_to_arguments(self.env)
modules = convert_modules_to_requires(self.modules)
if len(modules) > 0:
self.js_ref.requires(modules)

offset_to_argument_vector = 3 + len(env_args)
self.js_ref.jobInputData = serialized_input_data
self.js_ref.jobArguments = [offset_to_argument_vector] + ["gzImage", job_fs] + env_args + serialized_arguments + [meta_arguments]
self.js_ref.workFunctionURI = "data:," + urllib.parse.quote(get_work_function_string(), safe="=:,#+;")

#TODO Make sure this runs on our event loop
def _exec(self, *args):
self._before_exec()
self._wrapper_set_attribute("_exec_called", True)
accepted_future = asyncio.Future()
def handle_accepted():
accepted_future.set_result(self.js_ref.id)
self.js_ref.on('accepted', handle_accepted)
self.js_ref.exec(*args)
return accepted_future

#TODO Make sure this runs on our event loop
def _wait(self):
if not self._exec_called:
raise Exception("Wait called before exec()")
complete_future = asyncio.Future()
def handle_complete(resultHandle):
serialized_results = resultHandle["values"]()
results = []
for serialized_result in serialized_results:
result = deserialize(serialized_result, self.serializers)
results.append(result)
complete_future.set_result(results)
self.js_ref.on("complete", handle_complete)
return complete_future

def exec(self, *args):
results = dry.aio.blockify(self._exec)(*args)
return results

def wait(self):
return dry.aio.blockify(self._wait)()

def on(self, *args):
if len(args) > 1 and callable(args[1]):
Expand Down
14 changes: 14 additions & 0 deletions dcp/api/job_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
Converts Job Env arguments for Pyodide worktime.
Author: Severn Lortie <[email protected]>
Date: Aug 2024
"""

def convert_env_to_arguments(env):
if not len(env):
return []
args = ["env"]
for env_key in env:
args.append(f"{env_key}={env[env_key]}")
return args
13 changes: 13 additions & 0 deletions dcp/api/job_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""
Converts Job modules (Pyodide core packages) for Pyodide worktime.
Author: Severn Lortie <[email protected]>
Date: Aug 2024
"""

def convert_modules_to_requires(modules):
requires = []
for module in modules:
requires.append(f"pyodide-{module}/pyodide-{module}.js")
return requires

124 changes: 124 additions & 0 deletions dcp/api/job_serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
Responsible for serializing Job arguments and input data.
Author: Severn Lortie <[email protected]>
Date: July 2024
"""

import cloudpickle
import dill
from collections.abc import Iterator

def numpy_save_interrogate(value):
import numpy as np
return isinstance(value, np.ndarray)
def numpy_save_serialize(value):
import numpy as np
from io import BytesIO
byte_buffer = BytesIO()
np.save(byte_buffer, value)
buffer.seek(0)
return buffer.read()
def numpy_save_deserialize(value):
import numpy as np
from io import BytesIO
byte_buffer = BytesIO()
buffer.seek(0)
return np.load(value)

def pickle_interrogate(value):
return True
def pickle_serialize(value):
import cloudpickle
return cloudpickle.dumps(value)
def pickle_deserialize(value):
import cloudpickle
return cloudpickle.loads(value)

default_serializers = [
{
"name": "numpy-save",
"interrogator": numpy_save_interrogate,
"serializer": numpy_save_serialize,
"deserializer": numpy_save_deserialize,
},
{
"name": "pickle",
"interrogator": pickle_interrogate,
"serializer": pickle_serialize,
"deserializer": pickle_deserialize,
},
]

def validate_serializers(serializers):
required_keys = ['name', 'interrogator', 'serializer', 'deserializer']
for i in range(len(serializers)):
serializer = serializers[i]
missing_keys = [key for key in required_keys if key not in serializer]
if len(missing_keys) > 0:
raise TypeError(f"Serializer at index {i} is missing keys: {missing_keys}")
if len(serializer["name"]) > 256:
raise TypeError(f"Serializer at index {i} has name '{serializer.name}' which exceeds 256 characters")

def serialize(value, serializers):
class IteratorWrapper:
def __init__(self, iterator):
self.iterator = iterator

def __iter__(self):
return self

def __next__(self):
value = next(self.iterator)
return serialize(value)

primitive_types = (int, float, bool, str, bytes)
if isinstance(value, primitive_types):
return value
if isinstance(value, Iterator):
return IteratorWrapper(value)

for serializer in serializers:
if serializer["interrogator"](value):
serialized_value_bytes = serializer["serializer"](value)
serialized_serializer_name_bytes = serializer["name"].encode('utf-8')
serializer_name_length = len(serializer["name"])
serializer_name_length_byte = bytearray(serializer_name_length.to_bytes(1, byteorder='big'))
serialized_serializer_name_byte_array = bytearray(serialized_serializer_name_bytes)
serialized_value_byte_array = bytearray(serialized_value_bytes)
return serializer_name_length_byte + serialized_serializer_name_byte_array + serialized_value_byte_array

def deserialize(serialized_value, serializers):
if isinstance(serialized_value, memoryview):
value = bytearray(serialized_value.tobytes())
elif not isinstance(serialized_value, bytearray):
return serialized_value
else:
value = serialized_value.copy()
serializer_name_length = value[0]
if serializer_name_length > len(value):
return serialized_value
name_start_idx = 1
name_end_idx = serializer_name_length + 1
serializer_name_bytes = value[name_start_idx:name_end_idx]
del value[0:name_end_idx]
serializer_name = serializer_name_bytes.decode('utf-8')
allowed_serializer_names = [serializer["name"] for serializer in serializers]
if serializer_name not in allowed_serializer_names:
return serialized_value

serializer = next((serializer for serializer in serializers if serializer["name"] == serializer_name), None)
return serializer["deserializer"](value)

def convert_serializers_to_arguments(serializers):
stringified_serializers = []
for serializer in serializers:
stringified_serializers.append({
"name": serializer["name"],
"interrogator": dill.source.getsource(serializer["interrogator"], lstrip=True),
"serializer": dill.source.getsource(serializer["serializer"], lstrip=True),
"deserializer": dill.source.getsource(serializer["deserializer"], lstrip=True)
})
serialized_serializers = bytearray(cloudpickle.dumps(stringified_serializers))
return serialized_serializers

Loading

0 comments on commit 8da3f92

Please sign in to comment.