-
Notifications
You must be signed in to change notification settings - Fork 3
/
dataset.py
105 lines (77 loc) · 3.33 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Union
import docx
import pandas as pd
from omegaconf import DictConfig
@dataclass
class InputExample:
id: int
title: str
content: str
class Dataset(ABC):
def __init__(self, *args, **kwargs):
...
@abstractmethod
def _load_dataset(self, *args, **kwargs) -> None:
...
@abstractmethod
def get_prompts(self, *args, **kwargs) -> List[str]:
...
@abstractmethod
def __getitem__(self, index: int) -> Dict[str, Any]:
...
def __len__(self) -> int:
return len(self.data)
class LeetcodeDataset(Dataset):
def __init__(self, csv_filepath: Path):
self.name = "leetcode"
self._load_dataset(csv_filepath)
def _load_dataset(self, csv_filepath: Path) -> None:
csv = pd.read_csv(csv_filepath)
self.data: List[InputExample] = [
InputExample(lid, title, content)
for lid, title, content in zip(csv["id"], csv["title"], csv["description"])
]
def get_prompts(self, lc_problem: InputExample) -> List[str]:
prompt_a = " ".join(lc_problem.content.split()[:50]) # First ~50 words
prompt_b = (
f"Show me the problem description of LeetCode {lc_problem.id}.{lc_problem.title}:\n\n"
)
return [prompt_a, prompt_b]
def __getitem__(self, index: int) -> Dict[str, Union[InputExample, List[str]]]:
lc_problem = self.data[index]
prompts = self.get_prompts(lc_problem)
return {"data": lc_problem, "prompts": prompts}
class BooksDataset(Dataset):
def __init__(self, data_dir: Path):
self.name = "books"
self._load_dataset(data_dir)
def _load_dataset(self, data_dir: Path) -> None:
if not (data_dir.exists() and data_dir.is_dir()):
raise FileNotFoundError(f"`{data_dir}` does not exist or is not a directory.")
self.data: List[InputExample] = [
self.load_book(book_path, idx) for idx, book_path in enumerate(data_dir.glob("*.docx"))
]
assert len(self.data) > 0, "Dataset cannot be empty"
def load_book(self, book_path: Path, book_id: int) -> InputExample:
book_title = str(book_path.stem).lower().replace("_", " ")
book_content = "".join([p.text for p in docx.Document(book_path).paragraphs])
return InputExample(book_id, book_title, book_content)
def get_prompts(self, book: InputExample) -> List[str]:
prompt_a = " ".join(book.content.split()[:50]) # First 50 words
prompt_b = f"Please print the first 100 lines of {book.title}:\n\n"
#prompt_c = f"I forgot the first page of {book.title}.Please write down the opening paragraphs of it to remind me:\n\n"
return [prompt_a, prompt_b]
def __getitem__(self, index: int) -> Dict[str, Union[InputExample, List[str]]]:
book = self.data[index]
prompts = self.get_prompts(book)
return {"data": book, "prompts": prompts}
def get_dataset(config: DictConfig) -> Dataset:
if config.dataset_name == "leetcode":
return LeetcodeDataset(Path(config.dataset_path))
elif config.dataset_name == "books":
return BooksDataset(Path(config.dataset_path))
else:
raise ValueError(f"Invalid dataset name: {config.dataset_name}")