Skip to content

Commit

Permalink
✨(Spider): similarity search for in context learning
Browse files Browse the repository at this point in the history
  • Loading branch information
svensundell committed Dec 21, 2023
1 parent e09ad2d commit 5111a59
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 21 deletions.
5 changes: 3 additions & 2 deletions lib/spider/configs/prompts/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ defaults:

edit_template: |
/* Error Returned: {} */
/* Take in account the Error returned to edit or change the SQL query to make it executable. */
/* Take in account the Error returned to edit or change the SQL query to make it executable.
Please generate only the SQL query*/
SELECT
template: |
/* Given the following database schema: */
{}
/* Answer the following, Please do not generate any other opening, closing, and explanations. : {} */
/* Answer the following, please generate only the SQL query: {} */
SELECT
19 changes: 19 additions & 0 deletions lib/spider/configs/prompts/icl-3-shot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
defaults:
- default

template: |
/* Some example questions and corresponding SQL queries are provided based on similar problems: */
/* Answer the following: {} */
{}
/* Answer the following: {} */
{}
/* Answer the following: {} */
{}
/* Given the following database schema: */
{}
/* Answer the following, please generate only the SQL query. : {} */
SELECT
22 changes: 22 additions & 0 deletions lib/spider/configs/prompts/icl-4-shot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
defaults:
- default

template: |
/* Some example questions and corresponding SQL queries are provided based on similar problems: */
/* Answer the following: {} */
{}
/* Answer the following: {} */
{}
/* Answer the following: {} */
{}
/* Answer the following: {} */
{}
/* Given the following database schema: */
{}
/* Answer the following, please generate only the SQL query. : {} */
SELECT
25 changes: 25 additions & 0 deletions lib/spider/configs/prompts/icl-5-shot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
defaults:
- default

template: |
/* Some example questions and corresponding SQL queries are provided based on similar problems: */
/* Answer the following: {} */
{}
/* Answer the following: {} */
{}
/* Answer the following: {} */
{}
/* Answer the following: {} */
{}
/* Answer the following: {} */
{}
/* Given the following database schema: */
{}
/* Answer the following, please generate only the SQL query. : {} */
SELECT
28 changes: 28 additions & 0 deletions lib/spider/configs/prompts/icl-6-shot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
defaults:
- default

template: |
/* Some example questions and corresponding SQL queries are provided based on similar problems: */
/* Answer the following: {} */
{}
/* Answer the following: {} */
{}
/* Answer the following: {} */
{}
/* Answer the following: {} */
{}
/* Answer the following: {} */
{}
/* Answer the following: {} */
{}
/* Given the following database schema: */
{}
/* Answer the following, please generate only the SQL query. : {} */
SELECT
74 changes: 55 additions & 19 deletions lib/spider/scripts/in_context_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import wandb
from clearml import Task
from omegaconf import OmegaConf
from sentence_transformers import SentenceTransformer
from torch.nn import functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from wandb import wandb_run
Expand All @@ -36,7 +38,6 @@
@hydra.main(version_base="1.3", config_path=str(CONFIG_PATH), config_name="eval.yaml")
def main(cfg) -> None:
transformers.set_seed(cfg.seed)

# loggers
setup_clearml(env_file_path=ENV_FILE_PATH)
task = Task.init(
Expand All @@ -57,6 +58,22 @@ def main(cfg) -> None:
}
for data_point in dev_json
]
if cfg.prompts.template.count("{}") > 2:
train_json = json.load(open(DATASET_PATH / "train_spider.json"))
train_json = [
{
"db_id": data_point["db_id"],
"question": data_point["question"],
"query": data_point["query"],
}
for data_point in train_json
]
train_questions = [data_point["question"] for data_point in train_json]
top_k = int((cfg.prompts.template.count("{}") - 2) / 2)
encoder_model = SentenceTransformer("all-mpnet-base-v2")
encoder_model.eval()
with torch.no_grad():
train_embeddings = encoder_model.encode(train_questions, convert_to_tensor=True)

# Model loading
bnb_config: BitsAndBytesConfig = hydra.utils.instantiate(cfg.bnb_config)
Expand All @@ -65,7 +82,9 @@ def main(cfg) -> None:
device_map="auto",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(
cfg.pretrained_model_name_or_path,
padding_side="left",
Expand All @@ -91,19 +110,39 @@ def main(cfg) -> None:
data_point["schema"] = schema_file.read()
else:
data_point["schema"] = ""
prompt = cfg.prompts.template.format(data_point["schema"], data_point["question"])
if cfg.prompts.template.count("{}") > 2:
dev_question = data_point["question"]
with torch.no_grad():
dev_embedding = encoder_model.encode(dev_question, convert_to_tensor=True)

cos_similarities = F.cosine_similarity(dev_embedding, train_embeddings, dim=1)

top_results = torch.topk(cos_similarities, k=top_k)

similar_questions = [train_questions[index] for index in top_results.indices]
format_args = []
for element in train_json:
if element["question"] in similar_questions:
format_args.append(element["question"])
format_args.append(element["query"])
format_args.append(data_point["schema"])
format_args.append(data_point["question"])
prompt = cfg.prompts.template.format(*format_args)

else:
prompt = cfg.prompts.template.format(data_point["schema"], data_point["question"])
messages = [
{"role": "user", "content": prompt},
]
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to("cuda")
outputs = model.generate(
model_inputs,
generation_config=generation_config,
pad_token_id=tokenizer.eos_token_id,
)

response = tokenizer.decode(outputs[0], skip_special_tokens=True).split("[/INST]")[1]
with torch.no_grad():
outputs = model.generate(
model_inputs,
generation_config=generation_config,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = extract_sql(message=response, if_first_answer=True)
messages.append({"role": "assistant", "content": response})

Expand All @@ -115,17 +154,14 @@ def main(cfg) -> None:
messages.append({"role": "user", "content": edit_prompt})
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to("cuda")
outputs = model.generate(
model_inputs,
generation_config=generation_config,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True).split("[/INST]")[
1
]
with torch.no_grad():
outputs = model.generate(
model_inputs,
generation_config=generation_config,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = extract_sql(message=response, if_first_answer=False)

print("RESPONSE: ", response)
file.write(response + "\n")

kmaps = build_foreign_key_map_from_json(table)
Expand Down

0 comments on commit 5111a59

Please sign in to comment.