From d0427388732194a499f7592d70f97fba5136be87 Mon Sep 17 00:00:00 2001 From: haonan-li Date: Mon, 15 Apr 2024 18:35:39 +0400 Subject: [PATCH] support local client --- factcheck/__init__.py | 16 ++++++++++------ factcheck/__main__.py | 6 +++++- factcheck/config/api_config.yaml | 10 +++++----- factcheck/config/sample_prompt.yaml | 1 - factcheck/core/CheckWorthy.py | 2 +- factcheck/core/ClaimVerify.py | 2 +- factcheck/core/Decompose.py | 2 +- factcheck/core/QueryGenerator.py | 2 +- factcheck/core/Retriever/EvidenceRetrieve.py | 2 +- .../core/Retriever/GoogleEvidenceRetrieve.py | 4 ++-- .../core/Retriever/SerperEvidenceRetrieve.py | 2 +- factcheck/core/Retriever/__init__.py | 4 ++-- factcheck/utils/{config => }/api_config.py | 1 - factcheck/utils/llmclient/__init__.py | 14 ++++++++++---- factcheck/utils/llmclient/claude_client.py | 3 +-- factcheck/utils/llmclient/gpt_client.py | 3 +-- .../{local_client.py => local_openai_client.py} | 15 +++++++++------ factcheck/utils/{CustomLogger.py => logger.py} | 0 factcheck/utils/multimodal.py | 2 +- 19 files changed, 52 insertions(+), 39 deletions(-) rename factcheck/utils/{config => }/api_config.py (99%) rename factcheck/utils/llmclient/{local_client.py => local_openai_client.py} (78%) rename factcheck/utils/{CustomLogger.py => logger.py} (100%) diff --git a/factcheck/__init__.py b/factcheck/__init__.py index 9718afd..981d16c 100644 --- a/factcheck/__init__.py +++ b/factcheck/__init__.py @@ -1,10 +1,10 @@ import time import tiktoken -from factcheck.utils.llmclient import client_mapper +from factcheck.utils.llmclient import CLIENTS, model2client from factcheck.utils.prompt import prompt_mapper -from factcheck.utils.CustomLogger import CustomLogger -from factcheck.utils.config.api_config import load_api_config +from factcheck.utils.logger import CustomLogger +from factcheck.utils.api_config import load_api_config from factcheck.core import ( Decompose, Checkworthy, @@ -20,6 +20,7 @@ class FactCheck: def __init__( self, default_model: str = "gpt-4-0125-preview", + client: str = None, prompt: str = "chatgpt_prompt", retriever: str = "serper", decompose_model: str = None, @@ -48,16 +49,19 @@ def __init__( for key, _model_name in step_models.items(): _model_name = default_model if _model_name is None else _model_name print(f"== Init {key} with model: {_model_name}") - LLMClient = client_mapper(_model_name) + if client is not None: + logger.info(f"== Use specified client: {client}") + LLMClient = CLIENTS[client] + else: + logger.info("== Client is not specified, use model2client() to get the default llm client.") + LLMClient = model2client(_model_name) setattr(self, key, LLMClient(model=_model_name, api_config=self.api_config)) # sub-modules self.decomposer = Decompose(llm_client=self.decompose_model, prompt=self.prompt) self.checkworthy = Checkworthy(llm_client=self.checkworthy_model, prompt=self.prompt) self.query_generator = QueryGenerator(llm_client=self.query_generator_model, prompt=self.prompt) - self.evidence_crawler = retriever_mapper(retriever_name=retriever)(api_config=self.api_config) - self.claimverify = ClaimVerify(llm_client=self.claim_verify_model, prompt=self.prompt) logger.info("===Sub-modules Init Finished===") diff --git a/factcheck/__main__.py b/factcheck/__main__.py index 7130b32..4239281 100644 --- a/factcheck/__main__.py +++ b/factcheck/__main__.py @@ -1,6 +1,7 @@ import json import argparse +from factcheck.utils.llmclient import CLIENTS from factcheck.utils.multimodal import modal_normalization from factcheck.utils.utils import load_yaml from factcheck import FactCheck @@ -21,7 +22,9 @@ def check(args): print(f"Error loading api config: {e}") api_config = {} - factcheck = FactCheck(default_model=args.model, api_config=api_config, prompt=args.prompt, retriever=args.retriever) + factcheck = FactCheck( + default_model=args.model, client=args.client, api_config=api_config, prompt=args.prompt, retriever=args.retriever + ) content = modal_normalization(args.modal, args.input) res = factcheck.check_response(content) @@ -31,6 +34,7 @@ def check(args): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default="gpt-4-0125-preview") + parser.add_argument("--client", type=str, default=None, choices=CLIENTS.keys()) parser.add_argument("--prompt", type=str, default="chatgpt_prompt") parser.add_argument("--retriever", type=str, default="serper") parser.add_argument("--modal", type=str, default="text") diff --git a/factcheck/config/api_config.yaml b/factcheck/config/api_config.yaml index c053c49..148b7ac 100644 --- a/factcheck/config/api_config.yaml +++ b/factcheck/config/api_config.yaml @@ -1,8 +1,8 @@ -SERPER_API_KEY: None +SERPER_API_KEY: null -OPENAI_API_KEY: None +OPENAI_API_KEY: null -ANTHROPIC_API_KEY: None +ANTHROPIC_API_KEY: null -LOCAL_API_KEY: None -LOCAL_API_URL: http://localhost:8000/v1/ +LOCAL_API_KEY: null +LOCAL_API_URL: null diff --git a/factcheck/config/sample_prompt.yaml b/factcheck/config/sample_prompt.yaml index 554aaaa..3db790e 100644 --- a/factcheck/config/sample_prompt.yaml +++ b/factcheck/config/sample_prompt.yaml @@ -104,4 +104,3 @@ verify_prompt: | [evidences]: {evidence} Output: - diff --git a/factcheck/core/CheckWorthy.py b/factcheck/core/CheckWorthy.py index d64c973..7e408b4 100644 --- a/factcheck/core/CheckWorthy.py +++ b/factcheck/core/CheckWorthy.py @@ -1,4 +1,4 @@ -from factcheck.utils.CustomLogger import CustomLogger +from factcheck.utils.logger import CustomLogger logger = CustomLogger(__name__).getlog() diff --git a/factcheck/core/ClaimVerify.py b/factcheck/core/ClaimVerify.py index db23554..b92e138 100644 --- a/factcheck/core/ClaimVerify.py +++ b/factcheck/core/ClaimVerify.py @@ -1,7 +1,7 @@ from __future__ import annotations import json -from factcheck.utils.CustomLogger import CustomLogger +from factcheck.utils.logger import CustomLogger logger = CustomLogger(__name__).getlog() diff --git a/factcheck/core/Decompose.py b/factcheck/core/Decompose.py index debde17..b0dca2b 100644 --- a/factcheck/core/Decompose.py +++ b/factcheck/core/Decompose.py @@ -1,4 +1,4 @@ -from factcheck.utils.CustomLogger import CustomLogger +from factcheck.utils.logger import CustomLogger import nltk logger = CustomLogger(__name__).getlog() diff --git a/factcheck/core/QueryGenerator.py b/factcheck/core/QueryGenerator.py index 984c033..30507e8 100644 --- a/factcheck/core/QueryGenerator.py +++ b/factcheck/core/QueryGenerator.py @@ -1,4 +1,4 @@ -from factcheck.utils.CustomLogger import CustomLogger +from factcheck.utils.logger import CustomLogger logger = CustomLogger(__name__).getlog() diff --git a/factcheck/core/Retriever/EvidenceRetrieve.py b/factcheck/core/Retriever/EvidenceRetrieve.py index f364f13..b5dd548 100644 --- a/factcheck/core/Retriever/EvidenceRetrieve.py +++ b/factcheck/core/Retriever/EvidenceRetrieve.py @@ -2,7 +2,7 @@ import os from copy import deepcopy from factcheck.utils.web_util import parse_response, crawl_web -from factcheck.utils.CustomLogger import CustomLogger +from factcheck.utils.logger import CustomLogger logger = CustomLogger(__name__).getlog() diff --git a/factcheck/core/Retriever/GoogleEvidenceRetrieve.py b/factcheck/core/Retriever/GoogleEvidenceRetrieve.py index 40beb37..58c15f3 100644 --- a/factcheck/core/Retriever/GoogleEvidenceRetrieve.py +++ b/factcheck/core/Retriever/GoogleEvidenceRetrieve.py @@ -1,7 +1,7 @@ from concurrent.futures import ThreadPoolExecutor from factcheck.utils.web_util import common_web_request, crawl_google_web -from factcheck.core.Retriever.EvidenceRetrieve import EvidenceRetrieve -from factcheck.utils.CustomLogger import CustomLogger +from .EvidenceRetrieve import EvidenceRetrieve +from factcheck.utils.logger import CustomLogger logger = CustomLogger(__name__).getlog() diff --git a/factcheck/core/Retriever/SerperEvidenceRetrieve.py b/factcheck/core/Retriever/SerperEvidenceRetrieve.py index 99d6908..97c354d 100644 --- a/factcheck/core/Retriever/SerperEvidenceRetrieve.py +++ b/factcheck/core/Retriever/SerperEvidenceRetrieve.py @@ -4,7 +4,7 @@ import os import re import bs4 -from factcheck.utils.CustomLogger import CustomLogger +from factcheck.utils.logger import CustomLogger from factcheck.utils.web_util import crawl_web logger = CustomLogger(__name__).getlog() diff --git a/factcheck/core/Retriever/__init__.py b/factcheck/core/Retriever/__init__.py index dfd6cb7..fd443af 100644 --- a/factcheck/core/Retriever/__init__.py +++ b/factcheck/core/Retriever/__init__.py @@ -1,5 +1,5 @@ -from factcheck.core.Retriever.GoogleEvidenceRetrieve import GoogleEvidenceRetrieve -from factcheck.core.Retriever.SerperEvidenceRetrieve import SerperEvidenceRetrieve +from .GoogleEvidenceRetrieve import GoogleEvidenceRetrieve +from .SerperEvidenceRetrieve import SerperEvidenceRetrieve retriever_map = { "google": GoogleEvidenceRetrieve, diff --git a/factcheck/utils/config/api_config.py b/factcheck/utils/api_config.py similarity index 99% rename from factcheck/utils/config/api_config.py rename to factcheck/utils/api_config.py index 84d675d..ba9de9d 100644 --- a/factcheck/utils/config/api_config.py +++ b/factcheck/utils/api_config.py @@ -26,5 +26,4 @@ def load_api_config(api_config: dict = None): merged_config[key] = api_config.get(key, None) if merged_config[key] is None: merged_config[key] = os.environ.get(key, None) - return merged_config diff --git a/factcheck/utils/llmclient/__init__.py b/factcheck/utils/llmclient/__init__.py index 41ce230..e744462 100644 --- a/factcheck/utils/llmclient/__init__.py +++ b/factcheck/utils/llmclient/__init__.py @@ -1,15 +1,21 @@ from .gpt_client import GPTClient from .claude_client import ClaudeClient -from .local_client import LocalClient +from .local_openai_client import LocalOpenAIClient +CLIENTS = { + "gpt": GPTClient, + "claude": ClaudeClient, + "local_openai": LocalOpenAIClient +} -def client_mapper(model_name: str): - # router for model to client + +def model2client(model_name: str): + """If the client is not specified, use this function to map the model name to the corresponding client.""" if model_name.startswith("gpt"): return GPTClient elif model_name.startswith("claude"): return ClaudeClient elif model_name.startswith("vicuna"): - return LocalClient + return LocalOpenAIClient else: raise ValueError(f"Model {model_name} not supported.") diff --git a/factcheck/utils/llmclient/claude_client.py b/factcheck/utils/llmclient/claude_client.py index c1c855e..4423650 100644 --- a/factcheck/utils/llmclient/claude_client.py +++ b/factcheck/utils/llmclient/claude_client.py @@ -1,7 +1,6 @@ import time from anthropic import Anthropic - -from factcheck.utils.llmclient.base import BaseClient +from .base import BaseClient class ClaudeClient(BaseClient): diff --git a/factcheck/utils/llmclient/gpt_client.py b/factcheck/utils/llmclient/gpt_client.py index efd40ce..58122e4 100644 --- a/factcheck/utils/llmclient/gpt_client.py +++ b/factcheck/utils/llmclient/gpt_client.py @@ -1,7 +1,6 @@ import time - from openai import OpenAI -from factcheck.utils.llmclient.base import BaseClient +from .base import BaseClient class GPTClient(BaseClient): diff --git a/factcheck/utils/llmclient/local_client.py b/factcheck/utils/llmclient/local_openai_client.py similarity index 78% rename from factcheck/utils/llmclient/local_client.py rename to factcheck/utils/llmclient/local_openai_client.py index 086c788..2c179fd 100644 --- a/factcheck/utils/llmclient/local_client.py +++ b/factcheck/utils/llmclient/local_openai_client.py @@ -1,22 +1,25 @@ import time - import openai from openai import OpenAI -from factcheck.utils.llmclient.base import BaseClient +from .base import BaseClient + +class LocalOpenAIClient(BaseClient): + """Support Local host LLM chatbot with OpenAI API. + see https://github.com/lm-sys/FastChat/blob/main/docs/openai_api.md for example usage. + """ -class LocalClient(BaseClient): def __init__( self, - model: str = "gpt-4-turbo", + model: str = "", api_config: dict = None, max_requests_per_minute=200, request_window=60, ): super().__init__(model, api_config, max_requests_per_minute, request_window) - openai.api_key = "EMPTY" - openai.base_url = "http://localhost:8000/v1/" + openai.api_key = api_config["LOCAL_API_KEY"] + openai.base_url = api_config["LOCAL_API_URL"] def _call(self, messages: str, **kwargs): seed = kwargs.get("seed", 42) # default seed is 42 diff --git a/factcheck/utils/CustomLogger.py b/factcheck/utils/logger.py similarity index 100% rename from factcheck/utils/CustomLogger.py rename to factcheck/utils/logger.py diff --git a/factcheck/utils/multimodal.py b/factcheck/utils/multimodal.py index 8f05eaf..0f78049 100644 --- a/factcheck/utils/multimodal.py +++ b/factcheck/utils/multimodal.py @@ -2,7 +2,7 @@ import cv2 import base64 import requests -from .CustomLogger import CustomLogger +from .logger import CustomLogger logger = CustomLogger(__name__).getlog()