Skip to content

Commit

Permalink
feat(compute_for): support ranges and generic iterables
Browse files Browse the repository at this point in the history
  • Loading branch information
wiwichips committed Sep 12, 2024
1 parent 50c5000 commit e52788d
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 4 deletions.
26 changes: 25 additions & 1 deletion dcp/api/compute_for.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .. import dry
from .. import js
from types import FunctionType
from collections.abc import Iterable

def compute_for_maker(Job):
def compute_for(*args, **kwargs):
Expand Down Expand Up @@ -44,7 +45,7 @@ def compute_for(*args, **kwargs):

# clean up job input for PythonMonkey
if job_input_idx != None:
for i, val in enumerate(args[job_input_idx]):
for i, val in enumerate(args[job_input_idx]): #TODO don't enumerate each time... perhaps wrap in iterator
if js.utils.throws_in_pm(val):
args[job_input_idx][i] = { '__pythonmonkey_guard': val }

Expand All @@ -56,6 +57,29 @@ def compute_for(*args, **kwargs):

####################################################

JSIterator = pm.eval("""
(class JSIterator {
constructor(pyit)
{
this.pyit = pyit;
}
next()
{
return this.pyit.next();
}
[Symbol.iterator]()
{
return this;
}
})
""")

if len(args) <= 3:
if isinstance(args[0], Iterable):
args[0] = pm.new(JSIterator)(iter(args[0]))#(IterableWrapper(args[0]))

compute_for_js = pm.eval("globalThis.dcp.compute.for")
job_js = dry.aio.blockify(compute_for_js)(*args, **kwargs)
return Job(job_js)
Expand Down
25 changes: 25 additions & 0 deletions dcp/api/job_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,28 @@ def convert_serializers_to_arguments(serializers):
serialized_serializers = bytearray(cloudpickle.dumps(stringified_serializers))
return serialized_serializers

#TODO: should this DEAD CODE be removed? or used?
class SerializeIterWrapper:
"""Serializes or deserializes itertaively over an iterator."""
def __init__(self, iterable, serializers, mode='serialize'):
self.iterable = iterable
self.serializers = serializers

if mode not in ('serialize', 'deserialize'):
raise ValueError(f"Mode must be 'serialize' or 'deserialize', not {mode}")

def __iter__(self):
return self

def __next__(self):
try:
val = next(self.iterable)

if self.mode == 'serialize':
return serialize(val, self.serializers)
elif self.mode == 'deserialize':
return deserialize(val, self.deserializers)
return val
except StopIteration:
raise StopIteration

19 changes: 19 additions & 0 deletions dcp/js/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,24 @@
PMDict = pm.eval('x = {}; x').__class__


python_to_js_iterator = pm.eval("""(class JSIterator {
constructor(pyit)
{
this.pyit = pyit;
}
next()
{
return this.pyit.next();
}
[Symbol.iterator]()
{
return this;
}
})""")


def isclass(ref):
# TODO: come up with better way to determine if class..
# if a js object prototype has more than one own property, it is a class
Expand All @@ -26,6 +44,7 @@ def obj_ctor(js_instance):
def equals(a, b):
return pm.eval('(a,b) => a === b')(a, b)


def throws_in_pm(value):
"""
Some values such as multi dimensional numpy arrays aren't supported in PM.
Expand Down
10 changes: 7 additions & 3 deletions tests/test_api/test_compute_for.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
import unittest
import pythonmonkey as pm
import dcp
dcp.init()

class TestComputeFor(unittest.TestCase):

def test_compute_for(self):
dcp.init()
job1 = dcp.compute_for('x=>{progress(); return x * 2}', [1,2,3])
job1 = dcp.compute_for([1,2,3], 'x=>{progress(); return x * 2}')
job2 = dcp.compute.do(5, 'x=>{progress(); return x * 2}')

# check compute_for returns the same type as compute.do
self.assertTrue(isinstance(job1, job2.__class__))

def test_smoke_bf2_attrs(self):
job = dcp.compute_for('x=>{progress(); return x * 2}', [1,2,3])
job = dcp.compute_for([1,2,3], 'x=>{progress(); return x * 2}')
self.assertTrue(hasattr(job, 'wait'))

def test_range_and_iterables(self):
dcp.compute_for(range(1,4), '')
pass

if __name__ == '__main__':
unittest.main()

0 comments on commit e52788d

Please sign in to comment.