Skip to content

Commit

Permalink
local model test
Browse files Browse the repository at this point in the history
  • Loading branch information
Haonan Li committed Apr 15, 2024
1 parent c19d8ac commit 6c3b49e
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 4 deletions.
2 changes: 1 addition & 1 deletion factcheck/config/api_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ OPENAI_API_KEY: None
ANTHROPIC_API_KEY: None

LOCAL_API_KEY: None
LOCAL_API_URL: None
LOCAL_API_URL: http://localhost:8000/v1/
7 changes: 5 additions & 2 deletions factcheck/utils/llmclient/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from factcheck.utils.llmclient.gpt_client import GPTClient
from factcheck.utils.llmclient.claude_client import ClaudeClient
from .gpt_client import GPTClient
from .claude_client import ClaudeClient
from .local_client import LocalClient


def client_mapper(model_name: str):
Expand All @@ -8,5 +9,7 @@ def client_mapper(model_name: str):
return GPTClient
elif model_name.startswith("claude"):
return ClaudeClient
elif model_name.startswith("vicuna"):
return LocalClient
else:
raise ValueError(f"Model {model_name} not supported.")
50 changes: 50 additions & 0 deletions factcheck/utils/llmclient/local_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import time

import openai
from openai import OpenAI
from factcheck.utils.llmclient.base import BaseClient


class LocalClient(BaseClient):
def __init__(
self,
model: str = "gpt-4-turbo",
api_config: dict = None,
max_requests_per_minute=200,
request_window=60,
):
super().__init__(model, api_config, max_requests_per_minute, request_window)

openai.api_key = "EMPTY"
openai.base_url = "http://localhost:8000/v1/"

def _call(self, messages: str, **kwargs):
seed = kwargs.get("seed", 42) # default seed is 42
assert type(seed) is int, "Seed must be an integer."

response = openai.chat.completions.create(
response_format={"type": "json_object"},
seed=seed,
model=self.model,
messages=messages,
)
r = response.choices[0].message.content
return r

def get_request_length(self, messages):
# TODO: check if we should return the len(menages) instead
return 1

def construct_message_list(
self,
prompt_list: list[str],
system_role: str = "You are a helpful assistant designed to output JSON.",
):
messages_list = list()
for prompt in prompt_list:
messages = [
{"role": "system", "content": system_role},
{"role": "user", "content": prompt},
]
messages_list.append(messages)
return messages_list
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ bs4
flask
httpx
nltk
openai
openai>=1.0.0
opencv-python
pandas
playwright
Expand Down

0 comments on commit 6c3b49e

Please sign in to comment.