Skip to content

Commit

Permalink
Refactor, add --scripts-path arg
Browse files Browse the repository at this point in the history
  • Loading branch information
x-mass authored and nkaskov committed Dec 5, 2023
1 parent 27365b6 commit 455a45e
Showing 1 changed file with 121 additions and 90 deletions.
211 changes: 121 additions & 90 deletions scripts/aggregated_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datetime import datetime, timedelta


MERGE_TASK_KEY = 1234
MERGE_TASK_KEY = 1234 # TODO: peplace with actual merge task key, adjust task input as well


def progress_bar(iterable, prefix="", suffix="", fill="█"):
Expand All @@ -30,105 +30,130 @@ def print_bar(iteration):
print_bar(i)


class TimeoutError(Exception):
pass


def dummy_num_parts(key):
# TODO: extract parallel factor from .ll file
return 2


def extract_json(text):
try:
json_str = text[text.index("{"): text.rindex("}") + 1]
return json.loads(json_str)
except (ValueError, json.JSONDecodeError):
return None


def push_task(key, file, cost, subkey=None):
cmd = [
"python3", "request_tools.py", "push",
"--cost", str(cost),
"--file", file,
"--key", str(key),
]
if subkey is not None:
cmd += ["--subkey", str(subkey)]
result = subprocess.run(cmd, stdout=subprocess.PIPE)
response = extract_json(result.stdout.decode())
return response["_key"] if response and "_key" in response else None


def get_proof(request_key):
cmd = ["python3", "proof_tools.py", "get",
"--request_key", str(request_key)]
result = subprocess.run(cmd, stdout=subprocess.PIPE)
return result.stdout.decode().strip()


def get_status(task_key):
cmd = ["python3", "request_tools.py", "get", "--key", str(task_key)]
result = subprocess.run(cmd, stdout=subprocess.PIPE)
return extract_json(result.stdout.decode())


def wait_for_completion(task_key, timeout, poll_interval):
end_time = datetime.now() + timedelta(seconds=timeout)
while datetime.now() < end_time:
status = get_status(task_key)
if status and status["status"] == "completed":
return status
time.sleep(poll_interval)
raise TimeoutError(f"Task {task_key} timed out.")


def run_tasks(key, file, cost, task_timeout, poll_interval):
num_parts = dummy_num_parts(key)
tasks = [push_task(key, file, cost, i) for i in range(num_parts)]
for task_key in progress_bar(tasks, prefix="Proofs awaited:"):
wait_for_completion(task_key, task_timeout, poll_interval)

return tasks
class TaskDistributor:
class TimeoutError(Exception):
pass

def __init__(self, scripts_path, subtasks_number, task_timeout, poll_interval):
self.scripts_path = scripts_path
self.subtasks_number = subtasks_number
self.task_timeout = task_timeout
self.poll_interval = poll_interval

@staticmethod
def _extract_json(text):
# proof_tools output contains multiple JSON's, we need the last one
try:
json_str = text[text.rindex("{"): text.rindex("}") + 1]
return json.loads(json_str)
except (ValueError, json.JSONDecodeError):
return None

@staticmethod
def _run_command(command):
result = subprocess.run(command, stderr=subprocess.PIPE, text=True)
return result.stderr

def distribute_and_merge_tasks(self, key, file, cost):
completed_tasks = self._run_tasks(key, file, cost)
return self._merge_proofs(completed_tasks, cost)

def _push_task(self, key, file, cost, subkey=None):
cmd = [
"python3",
f"{self.scripts_path}/request_tools.py",
"push",
"--cost", str(cost),
"--file", file,
"--key", str(key),
]
if subkey is not None:
cmd += ["--subkey", str(subkey)]
result = self._run_command(cmd)
response = self._extract_json(result)
return response["_key"]

def _get_proof(self, request_key):
cmd = [
"python3",
f"{self.scripts_path}/proof_tools.py",
"get",
"--request_key", str(request_key),
]
result = self._run_command(cmd)
response = self._extract_json(result)
return response["proof"]

def _get_status(self, task_key):
cmd = [
"python3",
f"{self.scripts_path}/request_tools.py",
"get",
"--key", str(task_key),
]
result = self._run_command(cmd)
response = self._extract_json(result)
return response["status"]

def _wait_for_completion(self, task_key):
end_time = datetime.now() + timedelta(seconds=self.task_timeout)
while datetime.now() < end_time:
status = self._get_status(task_key)
if status and status == "completed":
return
time.sleep(self.poll_interval)
raise TimeoutError(f"Task {task_key} timed out.")

def _process_level(self, tasks, cost):
new_tasks = []
i = 0
while i < len(tasks):
proofs = [self._get_proof(tasks[i])]
i += 1

if i < len(tasks):
proofs.append(self._get_proof(tasks[i]))
i += 1

def process_level(tasks, cost, task_timeout, poll_interval):
new_tasks = []
i = 0
while i < len(tasks):
proofs = [get_proof(tasks[i])]
i += 1
with tempfile.NamedTemporaryFile(mode="w") as tmp_file:
json.dump(proofs, tmp_file)
tmp_file.flush()
combine_task_key = self._push_task(
MERGE_TASK_KEY, tmp_file.name, cost)

if i < len(tasks):
proofs.append(get_proof(tasks[i]))
i += 1
new_tasks.append(combine_task_key)

with tempfile.NamedTemporaryFile(mode="w") as tmp_file:
json.dump(proofs, tmp_file)
tmp_file.flush()
combine_task_key = push_task(MERGE_TASK_KEY, tmp_file.name, cost)
for task_key in progress_bar(new_tasks, prefix="Merges awaited:"):
self._wait_for_completion(task_key)

new_tasks.append(combine_task_key)
return new_tasks

for task_key in progress_bar(new_tasks, prefix="Merges awaited:"):
wait_for_completion(task_key, task_timeout, poll_interval)
def _run_tasks(self, key, file, cost):
tasks = [
self._push_task(key, file, cost, i) for i in range(self.subtasks_number)
]
for task_key in progress_bar(tasks, prefix="Proofs awaited:"):
self._wait_for_completion(task_key)

return new_tasks
return tasks

def _merge_proofs(self, tasks, cost):
# Process results in a Merkle tree fashion
while len(tasks) > 1:
tasks = self._process_level(tasks, cost)

def merge_proofs(tasks, cost, task_timeout, poll_interval):
# Process results in a Merkle tree fashion
while len(tasks) > 1:
tasks = process_level(tasks, cost, task_timeout, poll_interval)

return get_proof(tasks[0])
return self._get_proof(tasks[0])


def main():
parser = argparse.ArgumentParser(
description="Distribute tasks and assemble results in a Merkle tree fashion."
)
parser.add_argument(
"--scripts-path",
required=True,
help="Path to the directory containing the scripts",
)
parser.add_argument(
"--key", required=True, help="Key to be forwarded to request_tool.py"
)
Expand All @@ -142,9 +167,15 @@ def main():
help="Cost parameter to be forwarded to request_tool.py",
)
parser.add_argument(
"--task-timeout",
"--subtasks-number",
type=int,
required=True,
help="How many subtasks to split into",
)
parser.add_argument(
"--task-timeout",
type=int,
default=120,
help="Timeout in seconds for while waiting for task to complete",
)
parser.add_argument(
Expand All @@ -153,11 +184,11 @@ def main():

args = parser.parse_args()

completed_tasks = run_tasks(
args.key, args.file, args.cost, args.task_timeout, args.poll_interval
distributor = TaskDistributor(
args.scripts_path, args.subtasks_number, args.task_timeout, args.poll_interval
)
merged_proof = merge_proofs(
completed_tasks, args.cost, args.task_timeout, args.poll_interval
merged_proof = distributor.distribute_and_merge_tasks(
args.key, args.file, args.cost
)

print(f"Final proof is: {merged_proof}")
Expand Down

0 comments on commit 455a45e

Please sign in to comment.