Skip to content

Commit

Permalink
support local client
Browse files Browse the repository at this point in the history
  • Loading branch information
haonan-li committed Apr 15, 2024
1 parent 6c3b49e commit d042738
Show file tree
Hide file tree
Showing 19 changed files with 52 additions and 39 deletions.
16 changes: 10 additions & 6 deletions factcheck/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import time
import tiktoken

from factcheck.utils.llmclient import client_mapper
from factcheck.utils.llmclient import CLIENTS, model2client
from factcheck.utils.prompt import prompt_mapper
from factcheck.utils.CustomLogger import CustomLogger
from factcheck.utils.config.api_config import load_api_config
from factcheck.utils.logger import CustomLogger
from factcheck.utils.api_config import load_api_config
from factcheck.core import (
Decompose,
Checkworthy,
Expand All @@ -20,6 +20,7 @@ class FactCheck:
def __init__(
self,
default_model: str = "gpt-4-0125-preview",
client: str = None,
prompt: str = "chatgpt_prompt",
retriever: str = "serper",
decompose_model: str = None,
Expand Down Expand Up @@ -48,16 +49,19 @@ def __init__(
for key, _model_name in step_models.items():
_model_name = default_model if _model_name is None else _model_name
print(f"== Init {key} with model: {_model_name}")
LLMClient = client_mapper(_model_name)
if client is not None:
logger.info(f"== Use specified client: {client}")
LLMClient = CLIENTS[client]
else:
logger.info("== Client is not specified, use model2client() to get the default llm client.")
LLMClient = model2client(_model_name)
setattr(self, key, LLMClient(model=_model_name, api_config=self.api_config))

# sub-modules
self.decomposer = Decompose(llm_client=self.decompose_model, prompt=self.prompt)
self.checkworthy = Checkworthy(llm_client=self.checkworthy_model, prompt=self.prompt)
self.query_generator = QueryGenerator(llm_client=self.query_generator_model, prompt=self.prompt)

self.evidence_crawler = retriever_mapper(retriever_name=retriever)(api_config=self.api_config)

self.claimverify = ClaimVerify(llm_client=self.claim_verify_model, prompt=self.prompt)

logger.info("===Sub-modules Init Finished===")
Expand Down
6 changes: 5 additions & 1 deletion factcheck/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import argparse

from factcheck.utils.llmclient import CLIENTS
from factcheck.utils.multimodal import modal_normalization
from factcheck.utils.utils import load_yaml
from factcheck import FactCheck
Expand All @@ -21,7 +22,9 @@ def check(args):
print(f"Error loading api config: {e}")
api_config = {}

factcheck = FactCheck(default_model=args.model, api_config=api_config, prompt=args.prompt, retriever=args.retriever)
factcheck = FactCheck(
default_model=args.model, client=args.client, api_config=api_config, prompt=args.prompt, retriever=args.retriever
)

content = modal_normalization(args.modal, args.input)
res = factcheck.check_response(content)
Expand All @@ -31,6 +34,7 @@ def check(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="gpt-4-0125-preview")
parser.add_argument("--client", type=str, default=None, choices=CLIENTS.keys())
parser.add_argument("--prompt", type=str, default="chatgpt_prompt")
parser.add_argument("--retriever", type=str, default="serper")
parser.add_argument("--modal", type=str, default="text")
Expand Down
10 changes: 5 additions & 5 deletions factcheck/config/api_config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
SERPER_API_KEY: None
SERPER_API_KEY: null

OPENAI_API_KEY: None
OPENAI_API_KEY: null

ANTHROPIC_API_KEY: None
ANTHROPIC_API_KEY: null

LOCAL_API_KEY: None
LOCAL_API_URL: http://localhost:8000/v1/
LOCAL_API_KEY: null
LOCAL_API_URL: null
1 change: 0 additions & 1 deletion factcheck/config/sample_prompt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,3 @@ verify_prompt: |
[evidences]: {evidence}
Output:
2 changes: 1 addition & 1 deletion factcheck/core/CheckWorthy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from factcheck.utils.CustomLogger import CustomLogger
from factcheck.utils.logger import CustomLogger

logger = CustomLogger(__name__).getlog()

Expand Down
2 changes: 1 addition & 1 deletion factcheck/core/ClaimVerify.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from factcheck.utils.CustomLogger import CustomLogger
from factcheck.utils.logger import CustomLogger

logger = CustomLogger(__name__).getlog()

Expand Down
2 changes: 1 addition & 1 deletion factcheck/core/Decompose.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from factcheck.utils.CustomLogger import CustomLogger
from factcheck.utils.logger import CustomLogger
import nltk

logger = CustomLogger(__name__).getlog()
Expand Down
2 changes: 1 addition & 1 deletion factcheck/core/QueryGenerator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from factcheck.utils.CustomLogger import CustomLogger
from factcheck.utils.logger import CustomLogger

logger = CustomLogger(__name__).getlog()

Expand Down
2 changes: 1 addition & 1 deletion factcheck/core/Retriever/EvidenceRetrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from copy import deepcopy
from factcheck.utils.web_util import parse_response, crawl_web
from factcheck.utils.CustomLogger import CustomLogger
from factcheck.utils.logger import CustomLogger

logger = CustomLogger(__name__).getlog()

Expand Down
4 changes: 2 additions & 2 deletions factcheck/core/Retriever/GoogleEvidenceRetrieve.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from concurrent.futures import ThreadPoolExecutor
from factcheck.utils.web_util import common_web_request, crawl_google_web
from factcheck.core.Retriever.EvidenceRetrieve import EvidenceRetrieve
from factcheck.utils.CustomLogger import CustomLogger
from .EvidenceRetrieve import EvidenceRetrieve
from factcheck.utils.logger import CustomLogger

logger = CustomLogger(__name__).getlog()

Expand Down
2 changes: 1 addition & 1 deletion factcheck/core/Retriever/SerperEvidenceRetrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import re
import bs4
from factcheck.utils.CustomLogger import CustomLogger
from factcheck.utils.logger import CustomLogger
from factcheck.utils.web_util import crawl_web

logger = CustomLogger(__name__).getlog()
Expand Down
4 changes: 2 additions & 2 deletions factcheck/core/Retriever/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from factcheck.core.Retriever.GoogleEvidenceRetrieve import GoogleEvidenceRetrieve
from factcheck.core.Retriever.SerperEvidenceRetrieve import SerperEvidenceRetrieve
from .GoogleEvidenceRetrieve import GoogleEvidenceRetrieve
from .SerperEvidenceRetrieve import SerperEvidenceRetrieve

retriever_map = {
"google": GoogleEvidenceRetrieve,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,4 @@ def load_api_config(api_config: dict = None):
merged_config[key] = api_config.get(key, None)
if merged_config[key] is None:
merged_config[key] = os.environ.get(key, None)

return merged_config
14 changes: 10 additions & 4 deletions factcheck/utils/llmclient/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from .gpt_client import GPTClient
from .claude_client import ClaudeClient
from .local_client import LocalClient
from .local_openai_client import LocalOpenAIClient

CLIENTS = {
"gpt": GPTClient,
"claude": ClaudeClient,
"local_openai": LocalOpenAIClient
}

def client_mapper(model_name: str):
# router for model to client

def model2client(model_name: str):
"""If the client is not specified, use this function to map the model name to the corresponding client."""
if model_name.startswith("gpt"):
return GPTClient
elif model_name.startswith("claude"):
return ClaudeClient
elif model_name.startswith("vicuna"):
return LocalClient
return LocalOpenAIClient
else:
raise ValueError(f"Model {model_name} not supported.")
3 changes: 1 addition & 2 deletions factcheck/utils/llmclient/claude_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import time
from anthropic import Anthropic

from factcheck.utils.llmclient.base import BaseClient
from .base import BaseClient


class ClaudeClient(BaseClient):
Expand Down
3 changes: 1 addition & 2 deletions factcheck/utils/llmclient/gpt_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import time

from openai import OpenAI
from factcheck.utils.llmclient.base import BaseClient
from .base import BaseClient


class GPTClient(BaseClient):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import time

import openai
from openai import OpenAI
from factcheck.utils.llmclient.base import BaseClient
from .base import BaseClient


class LocalOpenAIClient(BaseClient):
"""Support Local host LLM chatbot with OpenAI API.
see https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md for example usage.
"""

class LocalClient(BaseClient):
def __init__(
self,
model: str = "gpt-4-turbo",
model: str = "",
api_config: dict = None,
max_requests_per_minute=200,
request_window=60,
):
super().__init__(model, api_config, max_requests_per_minute, request_window)

openai.api_key = "EMPTY"
openai.base_url = "http://localhost:8000/v1/"
openai.api_key = api_config["LOCAL_API_KEY"]
openai.base_url = api_config["LOCAL_API_URL"]

def _call(self, messages: str, **kwargs):
seed = kwargs.get("seed", 42) # default seed is 42
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion factcheck/utils/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import cv2
import base64
import requests
from .CustomLogger import CustomLogger
from .logger import CustomLogger

logger = CustomLogger(__name__).getlog()

Expand Down

0 comments on commit d042738

Please sign in to comment.