-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
log on action and prob for off-policy evaluation #43
Changes from 17 commits
a70ec66
ccc5331
a0353c4
cb659b4
3a0372a
b1db33e
50cc8ca
688cc8f
d0ec9e5
4ca7b00
af39097
23490d5
aad6af5
9c7efac
3c0686b
3d161d8
0efdfaa
dd3b34b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ | |
from coba.learners import Learner, SafeLearner | ||
from coba.primitives import Batch, argmax | ||
from coba.statistics import percentile | ||
from coba.utilities import PackageChecker, peek_first | ||
from coba.utilities import PackageChecker, peek_first, sample_actions | ||
|
||
from coba.evaluators.primitives import Evaluator, get_ope_loss | ||
|
||
|
@@ -235,15 +235,15 @@ def evaluate(self, environment: Optional[Environment], learner: Optional[Learner | |
predict_time = time.time()-start_time | ||
if not batched: | ||
ope_reward = sum(p*float(log_rewards.eval(a)) for p,a in zip(on_probs,log_actions)) | ||
on_action, on_prob = sample_actions(log_actions, on_probs) | ||
else: | ||
ope_reward = [ sum(p*float(R.eval(a)) for p,a in zip(P,A)) for P,A,R in zip(on_probs,log_actions,log_rewards) ] | ||
on_action, on_prob = zip(*[sample_actions(actions, probs) for actions, probs in zip(log_actions, on_probs)]) | ||
else: | ||
start_time = time.time() | ||
if not batched: | ||
on_prob = request(log_context,log_actions,[log_action]) | ||
else: | ||
on_prob = request(log_context,log_actions,log_action) | ||
on_action, on_prob = predict(log_context, log_actions)[:2] | ||
predict_time = time.time()-start_time | ||
|
||
if not batched: | ||
ope_reward = on_prob*float(log_rewards.eval(log_action)) | ||
else: | ||
|
@@ -263,15 +263,15 @@ def evaluate(self, environment: Optional[Environment], learner: Optional[Learner | |
if record_time : out['predict_time'] = predict_time | ||
if record_time : out['learn_time'] = learn_time | ||
if record_reward: out['reward'] = ope_reward | ||
if record_action: out['action'] = log_action | ||
if record_prob: out['probability'] = log_prob | ||
if record_action: out['action'] = on_action | ||
if record_prob: out['probability'] = on_prob | ||
if record_context: out['context'] = log_context | ||
if record_actions: out['actions'] = log_actions | ||
if record_rewards: out['rewards'] = log_rewards | ||
|
||
out.update({k: interaction[k] for k in interaction.keys()-OffPolicyEvaluator.IMPLICIT_EXCLUDE}) | ||
|
||
if record_ope_loss: out['ope_loss'] = get_ope_loss(learner) | ||
if record_ope_loss: out['ope_loss'] = get_ope_loss(learner) if not batched else [get_ope_loss(learner)] * len(log_context) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make OPE loss work for batched evaluation |
||
|
||
if info: | ||
out.update(info) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
from math import isclose | ||
from typing import Any, Sequence, Tuple, Mapping, Literal | ||
|
||
from coba.utilities import sample_actions | ||
from coba.exceptions import CobaException | ||
from coba.random import CobaRandom | ||
from coba.primitives import Batch, Context, Action, Actions | ||
|
@@ -266,8 +267,7 @@ def predict(self, context: Context, actions: Actions) -> Tuple[Action,Prob,kwarg | |
pred = list(pred.values())[0] | ||
|
||
if self._pred_format[:2] == 'PM': | ||
i = self._get_pmf_index(pred) | ||
a,p = actions[i], pred[i] | ||
a,p = sample_actions(actions, pred, self._rng) | ||
|
||
if self._pred_format[:2] == 'AP': | ||
a,p = pred[:2] | ||
|
@@ -287,9 +287,7 @@ def predict(self, context: Context, actions: Actions) -> Tuple[Action,Prob,kwarg | |
|
||
A,P = [],[] | ||
if self._pred_format[:2] == 'PM': | ||
I = [self._get_pmf_index(p) for p in pred] | ||
A = [ a[i] for a,i in zip(actions,I) ] | ||
P = [ p[i] for p,i in zip(pred,I) ] | ||
A, P = list(map(list, zip(*[sample_actions(a, p, self._rng) for a, p in zip(actions, pred)]))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could remove |
||
|
||
if self._pred_format[:2] == 'AX': | ||
A = pred | ||
|
@@ -308,9 +306,7 @@ def predict(self, context: Context, actions: Actions) -> Tuple[Action,Prob,kwarg | |
pred = list(pred.values())[0] | ||
|
||
if self._pred_format[:2] == 'PM': | ||
I = [self._get_pmf_index(p) for p in zip(*pred)] | ||
A = [ a[i] for a,i in zip(actions,I) ] | ||
P = [ p[i] for p,i in zip(pred,I) ] | ||
A, P = list(map(list, zip(*[sample_actions(a, p, self._rng) for a, p in zip(actions, pred)]))) | ||
|
||
if self._pred_format[:2] == 'AX': | ||
A = pred | ||
|
@@ -335,8 +331,5 @@ def learn(self, context, action, reward, probability, **kwargs) -> None: | |
raise CobaException("It appears that learner.learn expected kwargs but learner.predict did not provide any.") from ex | ||
raise | ||
|
||
def _get_pmf_index(self,pmf): | ||
return self._rng.choice(range(len(pmf)), pmf) | ||
|
||
def __str__(self) -> str: | ||
return self.full_name |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -658,8 +658,17 @@ def request(self, context, actions, request): | |
def test_batched_request_continuous(self): | ||
class TestLearner: | ||
def request(self,context,actions,request): | ||
if isinstance(context,BatchType): raise Exception() | ||
return .5 | ||
if isinstance(context,BatchType): | ||
raise Exception() | ||
return 0.5 | ||
|
||
def predict(self, context, actions): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Struggling to make this test pass. |
||
# if isinstance(context,BatchType): | ||
# raise Exception() | ||
return [(2, 0.5, None), (3, 0.5, None)] | ||
# return (2, 0.5) | ||
# return 2, 0.5, None | ||
# return 2 | ||
|
||
task = OffPolicyEvaluator(learn=False) | ||
learner = TestLearner() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How to do this for continuous actions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For continuous actions we just need to call
on_action,on_prob = predict(log_context, log_actions)[:2]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe on line 246? I don't think we need to have separate processing for batched and non-batched. Man I hate all this batched logic. It's all here for neural network stuff we do where backpropagation with mini-batches can give huge gains in computation time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried to add support for continuous actions but struggle to make some tests pass, see below.