Skip to content

Commit

Permalink
Recording environment params now peeks at environment to load dynamic…
Browse files Browse the repository at this point in the history
… params.
  • Loading branch information
mrucker committed Aug 17, 2023
1 parent 3e021ae commit 8fd09d3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
9 changes: 6 additions & 3 deletions coba/experiments/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self,
(self.lrn_id, self.lrn) = lrn or (None,None)
(self.val_id, self.val) = val or (None,None)
self.copy = copy

def __eq__(self, o: object) -> bool:
return isinstance(o,Task) \
and self.env_id == o.env_id \
Expand All @@ -38,10 +38,10 @@ def __eq__(self, o: object) -> bool:

class MakeTasks(Source[Iterable[Task]]):

def __init__(self,
def __init__(self,
triples: Sequence[Tuple[Environment,Learner,Evaluator]],
restored: Optional[Result] = None) -> None:

self._triples = triples
self._restored = restored or Result()

Expand Down Expand Up @@ -151,6 +151,9 @@ def filter(self, chunk: Iterable[Task]) -> Iterable[Any]:
if task.copy: lrn = deepcopy(lrn)

if env and not lrn and not val:
with CobaContext.logger.time(f"Peeking at Environment {env_id}..."):
peek_first(env.read())

with CobaContext.logger.time(f"Recording Environment {env_id} parameters..."):
yield ["T1", env_id, SafeEnvironment(env).params]

Expand Down
41 changes: 27 additions & 14 deletions coba/tests/test_experiments_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import cast, Iterable

from coba.context import CobaContext, BasicLogger
from coba.environments import LambdaSimulation, SimulatedInteraction, Environments, LinearSyntheticSimulation
from coba.environments import LambdaSimulation, SimulatedInteraction, Environments, LinearSyntheticSimulation, SupervisedSimulation
from coba.pipes import Pipes, ListSink, Cache
from coba.learners import Learner
from coba.evaluators import OnPolicyEvaluator
Expand Down Expand Up @@ -61,12 +61,19 @@ def __str__(self) -> str:
return "CountRead"

class ExceptionSimulation:
def __init__(self, params_ex = False, read_ex = False):
self._params_ex = params_ex
self._read_ex = read_ex
@property
def params(self):
raise Exception('ExceptionSimulation.params')
if self._params_ex:
raise Exception('ExceptionSimulation.params')
return {}

def read(self) -> Iterable[SimulatedInteraction]:
raise Exception('ExceptionSimulation.read')
if self._read_ex:
raise Exception('ExceptionSimulation.read')
return []

class ParamObj:
def __init__(self,**params):
Expand Down Expand Up @@ -332,14 +339,13 @@ def test_max_size_two(self):
self.assertCountEqual(groups[3], [tasks[7],tasks[6]])
self.assertCountEqual(groups[4], [tasks[2],tasks[4]])

class ProcessWorkItems_Tests(unittest.TestCase):
class ProcessTasks_Tests(unittest.TestCase):

def setUp(self) -> None:
CobaContext.logger = BasicLogger(ListSink())
ModuloLearner.n_finish = 0

def test_simple(self):

env1 = LambdaSimulation(5, lambda i: i, lambda i,c: [0,1,2], lambda i,c,a: cast(float,a))
lrn1 = ModuloLearner("1")
evl1 = ObserveEvaluator()
Expand All @@ -352,8 +358,17 @@ def test_simple(self):
self.assertIs(evl1.observed[1] , lrn1)
self.assertEqual(['T4', (1,1,1), []], transactions[0])

def test_environment_reused(self):
def test_env_task(self):
env1 = SupervisedSimulation([1,2],[1,2],label_type='c')
env2 = SupervisedSimulation([1,2],[1,2],label_type='c')
list(env2.read())

tasks = [Task((1,env1), None, None)]

transactions = list(ProcessTasks().filter(tasks))
self.assertEqual(['T1', 1, env2.params], transactions[0])

def test_environment_reused(self):
sim1 = Pipes.join(CountReadSimulation(), Cache())

lrn1 = ModuloLearner("1")
Expand All @@ -380,10 +395,9 @@ def test_environment_reused(self):
self.assertEqual(['T4', (0,0,0), [] ], transactions[1])
self.assertEqual(['T4', (0,1,1), [] ], transactions[2])

self.assertEqual(sim1[0].n_reads, 1)
self.assertEqual(sim1[0].n_reads, 2)

def test_task_copy_true(self):

lrn1 = ModuloLearner("1")

sim1 = CountReadSimulation()
Expand All @@ -402,7 +416,6 @@ def test_task_copy_true(self):
self.assertEqual(2,ModuloLearner.n_finish)

def test_task_copy_false(self):

lrn1 = ModuloLearner("1")

sim1 = CountReadSimulation()
Expand All @@ -419,7 +432,6 @@ def test_task_copy_false(self):
self.assertIs(task2.observed[1], lrn1)

def test_empty_env_skipped(self):

lrn1 = ModuloLearner("1")
src1 = LinearSyntheticSimulation(n_interactions=0)

Expand All @@ -433,7 +445,8 @@ def test_empty_env_skipped(self):
self.assertEqual(len(transactions), 0)

def test_exception_during_tasks(self):
env1 = ExceptionSimulation()
env0 = ExceptionSimulation(params_ex=True)
env1 = ExceptionSimulation(read_ex=True)
env2 = LambdaSimulation(5, lambda i: i, lambda i,c: [0,1,2], lambda i,c,a: cast(float,a))
lrn1 = ModuloLearner("1")

Expand All @@ -442,15 +455,15 @@ def test_exception_during_tasks(self):

CobaContext.logger.sink = ListSink()

tasks = [Task((0,env1), None, None), Task((0,env1),(0,lrn1), (1,val2)), Task((1,env2),(0,lrn1), (2,val3)) ]
tasks = [Task((0,env0), None, None), Task((0,env1),(0,lrn1), (1,val2)), Task((1,env2),(0,lrn1), (2,val3)) ]

expected = [ ["T4", (1,0,2), [] ] ]
actual = list(ProcessTasks().filter(tasks))

self.assertIs(val3.observed[1], lrn1)
self.assertEqual(expected, actual)
self.assertEqual('ExceptionSimulation.params', str(CobaContext.logger.sink.items[2]))
self.assertEqual('ExceptionSimulation.read', str(CobaContext.logger.sink.items[5]))
self.assertEqual('ExceptionSimulation.params', str(CobaContext.logger.sink.items[4]))
self.assertEqual('ExceptionSimulation.read', str(CobaContext.logger.sink.items[7]))

if __name__ == '__main__':
unittest.main()

0 comments on commit 8fd09d3

Please sign in to comment.