-
-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a36736a
commit cd38813
Showing
50 changed files
with
1,550 additions
and
1,657 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import pickle | ||
from pathlib import Path | ||
|
||
from rdagent.components.coder.CoSTEER.config import CoSTEERSettings | ||
from rdagent.components.coder.CoSTEER.evolvable_subjects import EvolvingItem | ||
from rdagent.components.coder.CoSTEER.evolving_agent import FilterFailedRAGEvoAgent | ||
from rdagent.components.coder.CoSTEER.knowledge_management import ( | ||
CoSTEERKnowledgeBaseV1, | ||
CoSTEERKnowledgeBaseV2, | ||
CoSTEERRAGStrategyV1, | ||
CoSTEERRAGStrategyV2, | ||
) | ||
from rdagent.core.developer import Developer | ||
from rdagent.core.evaluation import Evaluator | ||
from rdagent.core.evolving_agent import EvolvingStrategy | ||
from rdagent.core.evolving_framework import KnowledgeBase | ||
from rdagent.core.experiment import Experiment | ||
from rdagent.log import rdagent_logger as logger | ||
|
||
|
||
class CoSTEER(Developer[Experiment]): | ||
def __init__( | ||
self, | ||
settings: CoSTEERSettings, | ||
eva: Evaluator, | ||
es: EvolvingStrategy, | ||
evolving_version: int, | ||
*args, | ||
with_knowledge: bool = True, | ||
with_feedback: bool = True, | ||
knowledge_self_gen: bool = True, | ||
filter_final_evo: bool = True, | ||
**kwargs, | ||
) -> None: | ||
super().__init__(*args, **kwargs) | ||
self.max_loop = settings.max_loop | ||
self.knowledge_base_path = ( | ||
Path(settings.knowledge_base_path) if settings.knowledge_base_path is not None else None | ||
) | ||
self.new_knowledge_base_path = ( | ||
Path(settings.new_knowledge_base_path) if settings.new_knowledge_base_path is not None else None | ||
) | ||
|
||
self.with_knowledge = with_knowledge | ||
self.with_feedback = with_feedback | ||
self.knowledge_self_gen = knowledge_self_gen | ||
self.filter_final_evo = filter_final_evo | ||
self.evolving_strategy = es | ||
self.evaluator = eva | ||
self.evolving_version = evolving_version | ||
|
||
# init knowledge base | ||
self.knowledge_base = self.load_or_init_knowledge_base( | ||
former_knowledge_base_path=self.knowledge_base_path, | ||
component_init_list=[], | ||
) | ||
# init rag method | ||
self.rag = ( | ||
CoSTEERRAGStrategyV2(self.knowledge_base, settings=settings) | ||
if self.evolving_version == 2 | ||
else CoSTEERRAGStrategyV1(self.knowledge_base, settings=settings) | ||
) | ||
|
||
def load_or_init_knowledge_base(self, former_knowledge_base_path: Path = None, component_init_list: list = []): | ||
if former_knowledge_base_path is not None and former_knowledge_base_path.exists(): | ||
knowledge_base = pickle.load(open(former_knowledge_base_path, "rb")) | ||
if self.evolving_version == 1 and not isinstance(knowledge_base, CoSTEERKnowledgeBaseV1): | ||
raise ValueError("The former knowledge base is not compatible with the current version") | ||
elif self.evolving_version == 2 and not isinstance( | ||
knowledge_base, | ||
CoSTEERKnowledgeBaseV2, | ||
): | ||
raise ValueError("The former knowledge base is not compatible with the current version") | ||
else: | ||
knowledge_base = ( | ||
CoSTEERKnowledgeBaseV2( | ||
init_component_list=component_init_list, | ||
) | ||
if self.evolving_version == 2 | ||
else CoSTEERKnowledgeBaseV1() | ||
) | ||
return knowledge_base | ||
|
||
def develop(self, exp: Experiment) -> Experiment: | ||
|
||
# init intermediate items | ||
experiment = EvolvingItem.from_experiment(exp) | ||
|
||
self.evolve_agent = FilterFailedRAGEvoAgent( | ||
max_loop=self.max_loop, | ||
evolving_strategy=self.evolving_strategy, | ||
rag=self.rag, | ||
with_knowledge=self.with_knowledge, | ||
with_feedback=self.with_feedback, | ||
knowledge_self_gen=self.knowledge_self_gen, | ||
) | ||
|
||
experiment = self.evolve_agent.multistep_evolve( | ||
experiment, | ||
self.evaluator, | ||
filter_final_evo=self.filter_final_evo, | ||
) | ||
|
||
# save new knowledge base | ||
if self.new_knowledge_base_path is not None: | ||
pickle.dump(self.knowledge_base, open(self.new_knowledge_base_path, "wb")) | ||
logger.info(f"New knowledge base saved to {self.new_knowledge_base_path}") | ||
exp.sub_workspace_list = experiment.sub_workspace_list | ||
return exp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import Union | ||
|
||
from rdagent.core.conf import ExtendedBaseSettings | ||
|
||
|
||
class CoSTEERSettings(ExtendedBaseSettings): | ||
"""CoSTEER settings, this setting is supposed not to be used directly!!!""" | ||
|
||
class Config: | ||
env_prefix = "CoSTEER_" | ||
|
||
coder_use_cache: bool = False | ||
"""Indicates whether to use cache for the coder""" | ||
|
||
max_loop: int = 10 | ||
"""Maximum number of task implementation loops""" | ||
|
||
fail_task_trial_limit: int = 20 | ||
|
||
v1_query_former_trace_limit: int = 5 | ||
v1_query_similar_success_limit: int = 5 | ||
|
||
v2_query_component_limit: int = 1 | ||
v2_query_error_limit: int = 1 | ||
v2_query_former_trace_limit: int = 1 | ||
v2_add_fail_attempt_to_latest_successful_execution: bool = False | ||
v2_error_summary: bool = False | ||
v2_knowledge_sampler: float = 1.0 | ||
|
||
knowledge_base_path: Union[str, None] = None | ||
"""Path to the knowledge base""" | ||
|
||
new_knowledge_base_path: Union[str, None] = None | ||
"""Path to the new knowledge base""" | ||
|
||
select_threshold: int = 10 | ||
|
||
|
||
CoSTEER_SETTINGS = CoSTEERSettings() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from abc import abstractmethod | ||
from typing import List, Tuple | ||
|
||
from rdagent.components.coder.CoSTEER.evolvable_subjects import EvolvingItem | ||
from rdagent.core.conf import RD_AGENT_SETTINGS | ||
from rdagent.core.evaluation import Evaluator, Feedback | ||
from rdagent.core.evolving_framework import QueriedKnowledge | ||
from rdagent.core.experiment import Workspace | ||
from rdagent.core.scenario import Task | ||
from rdagent.core.utils import multiprocessing_wrapper | ||
from rdagent.log import rdagent_logger as logger | ||
|
||
|
||
class CoSTEERSingleFeedback(Feedback): | ||
"""This class is a base class for all code generator feedback to single implementation""" | ||
|
||
def __init__( | ||
self, | ||
execution_feedback: str = None, | ||
shape_feedback: str = None, | ||
code_feedback: str = None, | ||
value_feedback: str = None, | ||
final_decision: bool = None, | ||
final_feedback: str = None, | ||
value_generated_flag: bool = None, | ||
final_decision_based_on_gt: bool = None, | ||
) -> None: | ||
self.execution_feedback = execution_feedback | ||
self.shape_feedback = shape_feedback | ||
self.code_feedback = code_feedback | ||
self.value_feedback = value_feedback | ||
self.final_decision = final_decision | ||
self.final_feedback = final_feedback | ||
self.value_generated_flag = value_generated_flag | ||
self.final_decision_based_on_gt = final_decision_based_on_gt | ||
|
||
def __str__(self) -> str: | ||
return f"""------------------Execution Feedback------------------ | ||
{self.execution_feedback if self.execution_feedback is not None else 'No execution feedback'} | ||
------------------Shape Feedback------------------ | ||
{self.shape_feedback if self.shape_feedback is not None else 'No shape feedback'} | ||
------------------Code Feedback------------------ | ||
{self.code_feedback if self.code_feedback is not None else 'No code feedback'} | ||
------------------Value Feedback------------------ | ||
{self.value_feedback if self.value_feedback is not None else 'No value feedback'} | ||
------------------Final Feedback------------------ | ||
{self.final_feedback if self.final_feedback is not None else 'No final feedback'} | ||
------------------Final Decision------------------ | ||
This implementation is {'SUCCESS' if self.final_decision else 'FAIL'}. | ||
""" | ||
|
||
|
||
class CoSTEERMultiFeedback( | ||
Feedback, | ||
List[CoSTEERSingleFeedback], | ||
): | ||
"""Feedback contains a list, each element is the corresponding feedback for each factor implementation.""" | ||
|
||
|
||
class CoSTEEREvaluator(Evaluator): | ||
# TODO: | ||
# I think we should have unified interface for all evaluates, for examples. | ||
# So we should adjust the interface of other factors | ||
@abstractmethod | ||
def evaluate( | ||
self, | ||
target_task: Task, | ||
implementation: Workspace, | ||
gt_implementation: Workspace, | ||
**kwargs, | ||
) -> CoSTEERSingleFeedback: | ||
raise NotImplementedError("Please implement the `evaluator` method") | ||
|
||
|
||
class CoSTEERMultiEvaluator(Evaluator): | ||
def __init__(self, single_evaluator: CoSTEEREvaluator, *args, **kwargs) -> None: | ||
super().__init__(*args, **kwargs) | ||
self.single_evaluator = single_evaluator | ||
|
||
def evaluate( | ||
self, | ||
evo: EvolvingItem, | ||
queried_knowledge: QueriedKnowledge = None, | ||
**kwargs, | ||
) -> CoSTEERMultiFeedback: | ||
multi_implementation_feedback = multiprocessing_wrapper( | ||
[ | ||
( | ||
self.single_evaluator.evaluate, | ||
( | ||
evo.sub_tasks[index], | ||
evo.sub_workspace_list[index], | ||
evo.sub_gt_implementations[index] if evo.sub_gt_implementations is not None else None, | ||
queried_knowledge, | ||
), | ||
) | ||
for index in range(len(evo.sub_tasks)) | ||
], | ||
n=RD_AGENT_SETTINGS.multi_proc_n, | ||
) | ||
|
||
final_decision = [ | ||
None if single_feedback is None else single_feedback.final_decision | ||
for single_feedback in multi_implementation_feedback | ||
] | ||
logger.info(f"Final decisions: {final_decision} True count: {final_decision.count(True)}") | ||
|
||
for index in range(len(evo.sub_tasks)): | ||
if final_decision[index]: | ||
evo.sub_tasks[index].factor_implementation = True | ||
|
||
return multi_implementation_feedback |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from rdagent.components.coder.CoSTEER.evaluators import CoSTEERSingleFeedback | ||
from rdagent.components.coder.CoSTEER.evolvable_subjects import EvolvingItem | ||
from rdagent.core.evolving_agent import RAGEvoAgent | ||
from rdagent.core.evolving_framework import EvolvableSubjects | ||
|
||
|
||
class FilterFailedRAGEvoAgent(RAGEvoAgent): | ||
def filter_evolvable_subjects_by_feedback( | ||
self, evo: EvolvableSubjects, feedback: CoSTEERSingleFeedback | ||
) -> EvolvableSubjects: | ||
assert isinstance(evo, EvolvingItem) | ||
assert isinstance(feedback, list) | ||
assert len(evo.sub_workspace_list) == len(feedback) | ||
|
||
for index in range(len(evo.sub_workspace_list)): | ||
if feedback[index] and not feedback[index].final_decision: | ||
evo.sub_workspace_list[index].clear() | ||
return evo |
Oops, something went wrong.