diff --git a/llm/benchmark/analyse.py b/llm/benchmark/analyse.py new file mode 100644 index 0000000000..9cc5355e03 --- /dev/null +++ b/llm/benchmark/analyse.py @@ -0,0 +1,280 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import sys +from dataclasses import dataclass, field +from datetime import datetime + +import numpy as np + + +@dataclass +class Resp: + is_valid: bool = False + + req_id: str = None + max_dec_len: int = None + min_dec_len: int = None + max_send_idx: int = None + + input_token_num: int = None + output_token_num: int = None + is_end: bool = None + + send_req_time: float = None + first_token_end2end_time: float = None + all_token_end2end_time: float = None + first_token_infer_time: float = None + all_token_infer_time: float = None + + http_received_cost_time: float = 0 + infer_received_cost_time: float = 0 + tokenizer_encode_cost_time: float = 0 + tokenizer_decode_cost_time: float = 0 + preprocess_cost_time: float = 0 + pending_cost_time: float = 0 + get_image_cost_time: float = 0 + process_image_cost_time: float = 0 + + input_text: str = None + output_list: list = field(default_factory=list) + + error_msg: str = "" + exception_msg: str = "" + + def auto_set_valid(self): + self.is_valid = True + names = ["req_id", "max_dec_len", "min_dec_len", "max_send_idx", "is_end", + "output_token_num", "send_req_time", "first_token_end2end_time", + "all_token_end2end_time", "first_token_infer_time", "all_token_infer_time"] + for name in names: + if getattr(self, name) is None: + self.is_valid = False + if self.error_msg != "" or self.exception_msg != "": + self.is_valid = False + + def is_error(self) -> bool: + return self.error_msg != "" + + def is_exception(self) -> bool: + return self.exception_msg != "" + + +def str_to_datetime(date_string): + if "." in date_string: + return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S.%f") + else: + return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S") + +def datetime_diff(datetime_start, datetime_end): + if isinstance(datetime_start, str): + datetime_start = str_to_datetime(datetime_start) + if isinstance(datetime_end, str): + datetime_end = str_to_datetime(datetime_end) + if datetime_end > datetime_start: + cost = datetime_end - datetime_start + else: + cost = datetime_start - datetime_end + return cost.total_seconds() + +def pp_print(name, input_list): + out_str = f"{name:<35}" + for item in input_list: + out_str += f"{item:<15}" + print(out_str) + +def pp_print_md(name, lst): + info = f"| {name:<35} |" + for i in lst: + info += f" {i:<15} |" + + info = f"| {name:<35} | " + print(info) + + +def collect_response(input_path): + result_dict = {} + start_time = None + end_time = None + log_step = 100000 + + print("\nstart read and collect response...") + with open(input_path, 'r', encoding='utf-8') as file: + for idx, line in enumerate(file): + try: + item = json.loads(line.rstrip('\n')) + except Exception as e: + print(f"error when parse line. idx: {idx}, line: {line} error:{e}") + + item_type = item['type'] + assert item_type in ["request", "response", "error", "exception"] + + req_id = item['content']['req_id'] + if req_id in result_dict: + resp = result_dict[req_id] + if resp.is_valid: + print("error: the req_id is already in result_dict") + continue + else: + resp = Resp(req_id=req_id) + result_dict[req_id] = resp + + if item_type == "request": + resp.max_dec_len = item['content']["max_dec_len"] + resp.min_dec_len = item['content']["min_dec_len"] + resp.input_text = item['content']["text"] + resp.send_req_time = str_to_datetime(item["now_time"]) + elif item_type == "response": + content = item['content'] + if content["send_idx"] == 0: + resp.input_token_num = content.get("input_ids_len", 0) + if content.get("http_received_time"): + resp.http_received_cost_time = datetime_diff(resp.send_req_time, content.get("http_received_time")) + if content.get("preprocess_start_time"): + resp.infer_received_cost_time = datetime_diff(resp.send_req_time, content.get("preprocess_start_time")) + + resp.first_token_infer_time = content["inference_time_cost"] + if content.get("preprocess_start_time") and content.get("preprocess_end_time"): + resp.preprocess_cost_time = datetime_diff(content.get("preprocess_start_time"), + content.get("preprocess_end_time")) + if content.get("preprocess_end_time") and content.get("schedule_start_time"): + resp.pending_cost_time = datetime_diff(content.get("preprocess_end_time"), + content.get("schedule_start_time")) + resp.get_image_cost_time = content.get("get_image_cost_time", 0) + resp.process_image_cost_time = content.get("process_image_cost_time", 0) + resp.tokenizer_encode_cost_time = content.get("tokenizer_encode_cost_time", 0) + resp.first_token_end2end_time = datetime_diff(resp.send_req_time, item["now_time"]) + if content["is_end"] == 1: + resp.is_end = True + resp.max_send_idx = content["send_idx"] + resp.output_token_num = content["tokens_all_num"] + resp.all_token_end2end_time = datetime_diff(resp.send_req_time, item["now_time"]) + resp.all_token_infer_time = content["inference_time_cost"] + resp.auto_set_valid() + resp.output_list.append({'idx': int(content['send_idx']), 'token':content['token']}) + resp.tokenizer_decode_cost_time += content.get("tokenizer_decode_cost_time", 0) + elif item_type == "error": + resp.error_msg += item['content']["error_msg"] + elif item_type == "exception": + resp.exception_msg += item['content']["exception_msg"] + + now_time = str_to_datetime(item["now_time"]) + if start_time is None: + start_time = resp.send_req_time + if end_time is None: + end_time = now_time + elif end_time < now_time: + end_time = now_time + + if idx % log_step == 0: + print(f"read {idx+1} chunks", end=', ', flush=True) + + result_list = result_dict.values() + cost_time = datetime_diff(start_time, end_time) + print(f"\nstart_time: {start_time}, end_time: {end_time}, " + f"cost_time: {cost_time}, result_list_num: {len(result_list)}") + return result_list, cost_time + +def save_output_text(result_list, input_path): + output_path = input_path.replace(".jsonl", "-out_msg.jsonl") + with open(output_path, "w", encoding='utf-8') as out_file: + for result in result_list: + if result.is_valid: + output_list = sorted(result.output_list, key=lambda d: d['idx']) + output_text = "" + for i in output_list: + output_text += i['token'] + #dict_obj = {'req_id': result.req_id, 'input_text': result.input_text, 'output_text': output_text} + dict_obj = {'input_text': result.input_text, 'output_text': output_text} + out_file.write(json.dumps(dict_obj, ensure_ascii=False) + "\n") + print(f"output save in {output_path}") + + +def stats_and_percentiles(lst, round_bit=3, multi=1): + lst = [item * multi for item in lst] + num = len(lst) + max_val = round(max(lst), round_bit) + min_val = round(min(lst), round_bit) + avg_val = round(sum(lst) / len(lst), round_bit) + + pct_50, pct_80, pct_95, pct_99 = np.percentile(lst, [50, 80, 95, 99]) + pct_50 = round(pct_50, round_bit) + pct_80 = round(pct_80, round_bit) + pct_95 = round(pct_95, round_bit) + pct_99 = round(pct_99, round_bit) + + return {"num": num, "max": max_val, "min": min_val, "avg": avg_val, + "pct_50": pct_50, "pct_80": pct_80, "pct_95": pct_95, "pct_99": pct_99} + +def analyse_single_key(result_list, key_name, round_bit=2, multi=1): + key_list = [] + for resp in result_list: + if not resp.is_valid: + continue + key_list.append(resp.__dict__[key_name]) + + return stats_and_percentiles(key_list, round_bit, multi) + +def analyse_response(result_list, cost_time): + print("\nstart anaylse response...") + valid_resp_num = 0 + error_num = 0 + exception_num = 0 + for resp in result_list: + if resp.is_valid: + valid_resp_num += 1 + elif resp.is_error(): + error_num += 1 + print(f"error resp: {resp}") + elif resp.is_exception(): + exception_num += 1 + print(f"exception resp: {resp}") + + print(f"total response num: {len(result_list)}, valid response num: {valid_resp_num}, " + f"error_num: {error_num}, exception_num: {exception_num}") + print(f"qps: {round(valid_resp_num / cost_time, 2)} \n") + + info_list = [{'key': 'output_token_num', 'multi': 1, 'msg': '生成token数'}, + {'key': 'first_token_infer_time', 'multi': 1000, 'msg': '首token推理耗时(ms)'}, + {'key': 'all_token_infer_time', 'multi': 1000, 'msg': '整句推理耗时(ms)'}, + {'key': 'first_token_end2end_time', 'multi': 1000, 'msg': '首token用户侧耗时(ms)'}, + {'key': 'all_token_end2end_time', 'multi': 1000, 'msg': '整句用户侧耗时(ms)'}, + {'key': 'infer_received_cost_time', 'multi': 1000, 'msg': '推理收到请求耗时(ms)'}, + {'key': 'http_received_cost_time', 'multi': 1000, 'msg': 'http收到请求耗时(ms)'}, + {'key': 'preprocess_cost_time', 'multi': 1000, 'msg': '预处理耗时(ms)'}, + {'key': 'pending_cost_time', 'multi': 1000, 'msg': '缓存等待推理耗时(ms)'}, + ] + print("| 指标 | 样本数 | 最大 | 最小 | 平均 | 50% | 80% | 95% | 99% |") + print("| ---- | ---- | ---- | ----| ---- | ---- | ---- | ---- | ---- |") + for info in info_list: + out = analyse_single_key(result_list, info['key'], multi=info['multi']) + print(f"| {info['msg']:<35} | {out['num']:<15} | {out['max']:<15} | {out['min']:<15} | {out['avg']:<15} " + f"| {out['pct_50']:<15} | {out['pct_80']:<15} | {out['pct_95']:<15} | {out['pct_99']:<15} |") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input_path", type=str, help="the jsonl result file generated by run_benchmark_xx.py") + args = parser.parse_args() + return args + +if __name__ == '__main__': + args = parse_args() + print(f"input_path: {args.input_path}") + + result_list, cost_time = collect_response(args.input_path) + analyse_response(result_list, cost_time) + save_output_text(result_list, args.input_path) diff --git a/llm/benchmark/benchmark.py b/llm/benchmark/benchmark.py new file mode 100644 index 0000000000..587201f88f --- /dev/null +++ b/llm/benchmark/benchmark.py @@ -0,0 +1,279 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import sys +import queue +import threading +import time +import uuid +from dataclasses import asdict, dataclass +from datetime import datetime +from functools import partial + +import httpx +import numpy as np + +import tritonclient.grpc as grpcclient +from tritonclient.utils import * + +from logger import get_logger + +def _http_send_worker(args, req_dict, result_queue): + is_error_resp = False + headers = {'Content-Type': 'application/json'} + with httpx.stream("POST", args.url, headers=headers, timeout=args.timeout, json=req_dict) as r: + for chunk in r.iter_lines(): + resp = json.loads(chunk) + if resp.get("error_msg") or resp.get("error_code"): + is_error_resp = True + content = {"error_msg": resp.get("error_msg"), "req_id": req_dict.get("req_id")} + result_queue.put({"type": "error", "now_time": str(datetime.now()), "content": content}) + else: + result_queue.put({"type": "response", "now_time": str(datetime.now()), "content": resp}) + return is_error_resp + +def _grpc_send_worker(args, req_dict, result_queue): + class OutputData: + def __init__(self): + self._completed_requests = queue.Queue() + + def triton_callback(output_data, result, error): + if error: + output_data._completed_requests.put(error) + else: + output_data._completed_requests.put(result) + + model_name = "model" + inputs = [grpcclient.InferInput("IN", [1], np_to_triton_dtype(np.object_))] + outputs = [grpcclient.InferRequestedOutput("OUT")] + output_data = OutputData() + is_error_resp = False + + with grpcclient.InferenceServerClient(url=args.url, verbose=False) as triton_client: + triton_client.start_stream(callback=partial(triton_callback, output_data)) + + input_data = json.dumps([req_dict]) + inputs[0].set_data_from_numpy(np.array([input_data], dtype=np.object_)) + + triton_client.async_stream_infer(model_name=model_name, + inputs=inputs, + request_id=str(uuid.uuid4()), + outputs=outputs) + + while True: + output_item = output_data._completed_requests.get(timeout=args.timeout) + if type(output_item) == InferenceServerException: + is_error_resp = True + error_msg = f"Exception: status is {output_item.status()}, msg is {output_item.message()}" + content = {"error_msg": error_msg, "req_id": req_dict.get("req_id")} + result_queue.put({"type": "error", "now_time": str(datetime.now()), "content": content}) + else: + result = json.loads(output_item.as_numpy("OUT")[0]) + result = result[0] if isinstance(result, list) else result + result_queue.put({"type": "response", "now_time": str(datetime.now()), "content": result}) + if result.get("is_end") == 1: + break + return is_error_resp + +def send_worker(args, data_queue, result_queue, worker_idx, logger): + """ + send requests and put response into result_queue + """ + logger.info(f"[send_worker {worker_idx}] start...") + + cur_idx = 0 + exception_num = 0 + exception_threshold = 10 + error_resp_num = 0 + log_step = 10 + + while not data_queue.empty(): + # read data + try: + input_data = data_queue.get(timeout=3) + remaining_num = data_queue.qsize() + cur_idx += 1 + except queue.Empty: + logger.info(f"[send_worker {worker_idx}] data queue is empty") + break + except Exception as e: + exception_num += 1 + logger.error(f"[send_worker {worker_idx}][fd_error] fetch data error: {e}") + continue + + result_queue.put({"type": "request", "now_time": str(datetime.now()), "content": input_data}) + + # send request + try: + if args.api_type == 'http': + is_error_resp = _http_send_worker(args, input_data, result_queue) + elif args.api_type == 'grpc': + is_error_resp = _grpc_send_worker(args, input_data, result_queue) + error_resp_num += 1 if is_error_resp else 0 + except Exception as e: + exception_num += 1 + content = {"exception_msg": str(e), "req_id": input_data.get("req_id")} + result_queue.put({"type": "exception", "now_time": str(datetime.now()), "content": content}) + if exception_num > exception_threshold: + logger.error(f"[send_worker {worker_idx}] exception num ({exception_num}) exceeds " + f"threshold, exit") + break + + # log + if cur_idx % log_step == 1: + logger.info(f"[send_worker {worker_idx}] processed_num: {cur_idx}, exception_num: {exception_num}, " + f"error_resp_num: {error_resp_num}, data queue remaining ({remaining_num}) tasks") + + logger.info(f"[send_worker {worker_idx}] exit, processed_num: {cur_idx}, exception_num: {exception_num}, " + f"error_resp_num: {error_resp_num}") + +def save_worker(result_path, result_queue, logger, timeout=50, log_step=10000): + """ + save the result to file + """ + logger.info("[save_worker] start...") + num = 0 + with open(result_path, "w", encoding='utf-8') as out_file: + while True: + try: + res_chunk = result_queue.get(timeout=timeout) + except queue.Empty: + logger.info("[save_worker] result queue is empty") + break + except Exception as e: + logger.error(f"[save_worker] Error retrieving data from queue: {e}") + break + + json_str = json.dumps(res_chunk, ensure_ascii=False) + out_file.write(json_str + "\n") + num += 1 + if num % log_step == 0: + logger.info(f"[save_worker] process {num} response chunks") + + logger.info("[save_worker] exit") + +def prepare_data(data_path, data_num, benchmark=True, stream=True, timeout=180): + """ + prepare data + """ + ''' + data_queue = queue.Queue() + with open(data_path, 'r', encoding='utf-8') as file: + for idx, line in enumerate(file): + raw_data = json.loads(line.rstrip('\n')) + input_data = { + "text": raw_data['text_before_process'], + "max_dec_len": raw_data["max_dec_len"], + "min_dec_len": raw_data["min_dec_len"], + "topp": raw_data["topp"], + "temperature": raw_data["temperature"], + "frequency_score": raw_data["frequency_score"], + "penalty_score": raw_data["penalty_score"], + "presence_score": raw_data["presence_score"], + "req_id": str(uuid.uuid4()), + "stream": stream, + "benchmark": benchmark, + "timeout": timeout, + } + if raw_data["history_QA"] != []: + input_data["history_qa"] = raw_data["history_QA"] + + data_queue.put(input_data) + if data_num > 0 and idx + 1 >= data_num: + break + return data_queue + ''' + data_queue = queue.Queue() + with open(data_path, 'r', encoding='utf-8') as file: + dataset = json.load(file) + dataset = [data for data in dataset if len(data['conversations']) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data['conversations'][0]['value'], + data['conversations'][1]['value']) for data in dataset] + prompts = [prompt for prompt, _ in dataset] + + for idx, text in enumerate(prompts): + input_data = { + "text": text, + "max_dec_len": 1024, + "min_dec_len": 1, + "topp": 0, + "temperature": 1, + "req_id": str(uuid.uuid4()), + "stream": stream, + "benchmark": benchmark, + "timeout": timeout, + } + data_queue.put(input_data) + if data_num > 0 and idx + 1 >= data_num: + break + return data_queue + + +def parse_args(): + """ + parse the arguments + """ + parser = argparse.ArgumentParser() + parser.add_argument("--api_type", default="http", type=str, help="grpc or http api") + parser.add_argument("--url", default="http://0.0.0.0:8894/v1/chat/completions", type=str, help="the url for model server") + parser.add_argument("--data_path", default="data.jsonl", type=str, help="the path of data with jsonl format") + parser.add_argument("--data_num", default=-1, type=int, help="-1 means all data") + parser.add_argument("--timeout", default=180, type=int, help="timeout for waiting repsonse") + parser.add_argument("--worker_num", default=1, type=int, help="the number of worker_num for sending requests") + parser.add_argument("--tag", default="test", type=str, help="identify the test case") + args = parser.parse_args() + return args + +if __name__ == "__main__": + args = parse_args() + + # prepare + data_queue = prepare_data(args.data_path, args.data_num, benchmark=False, timeout=args.timeout) + if args.data_num < 0: + args.data_num = data_queue.qsize() + print(f"data_queue size: {data_queue.qsize()}") + + test_tag = f"{args.tag}-{args.api_type}-wk{args.worker_num}-dn{args.data_num}" + logger = get_logger('benchmark', f'{test_tag}.log') + logger.info(f"args: {args}") + logger.info(f"test_tag: {test_tag}") + + result_path = f"output/{test_tag}.jsonl" + if os.path.exists(result_path): + logger.error(f"result file ({result_path}) already exists, overwrite it") + if not os.path.exists("output/"): + os.makedirs("output/") + logger.info(f"result_path: {result_path}") + + # save worker + worker_list = [] + result_queue = queue.Queue() + worker = threading.Thread(target=save_worker, args=(result_path, result_queue, logger, 20)) + worker.start() + worker_list.append(worker) + + # send worker + tic = time.time() + for idx in range(args.worker_num): + worker = threading.Thread(target=send_worker, args=(args, data_queue, result_queue, idx, logger)) + worker.start() + worker_list.append(worker) + for worker in worker_list: + worker.join() + + toc = time.time() + logger.info(f'Done, cost time: {round(toc - tic, 2)}s') diff --git a/llm/benchmark/logger.py b/llm/benchmark/logger.py new file mode 100644 index 0000000000..32c38777d2 --- /dev/null +++ b/llm/benchmark/logger.py @@ -0,0 +1,160 @@ + +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import codecs +import logging +import os +import pickle +import re +import subprocess +import time +from datetime import datetime +from enum import Enum +from logging.handlers import BaseRotatingHandler +from pathlib import Path + + +class DailyRotatingFileHandler(BaseRotatingHandler): + """ + like `logging.TimedRotatingFileHandler`, but this class support multi-process + """ + + def __init__( + self, + filename, + backupCount=0, + encoding="utf-8", + delay=False, + utc=False, + **kwargs + ): + self.backup_count = backupCount + self.utc = utc + self.suffix = "%Y-%m-%d" + self.base_log_path = Path(filename) + self.base_filename = self.base_log_path.name + self.current_filename = self._compute_fn() + self.current_log_path = self.base_log_path.with_name(self.current_filename) + BaseRotatingHandler.__init__(self, filename, "a", encoding, delay) + + def shouldRollover(self, record): + """ + check scroll through the log + """ + if self.current_filename != self._compute_fn(): + return True + return False + + def doRollover(self): + """ + scroll log + """ + if self.stream: + self.stream.close() + self.stream = None + + self.current_filename = self._compute_fn() + self.current_log_path = self.base_log_path.with_name(self.current_filename) + + if not self.delay: + self.stream = self._open() + + self.delete_expired_files() + + def _compute_fn(self): + """ + Calculate the log file name corresponding current time + """ + return self.base_filename + "." + time.strftime(self.suffix, time.localtime()) + + def _open(self): + """ + open new log file + """ + if self.encoding is None: + stream = open(str(self.current_log_path), self.mode) + else: + stream = codecs.open(str(self.current_log_path), self.mode, self.encoding) + + if self.base_log_path.exists(): + try: + if ( + not self.base_log_path.is_symlink() + or os.readlink(self.base_log_path) != self.current_filename + ): + os.remove(self.base_log_path) + except OSError: + pass + + try: + os.symlink(self.current_filename, str(self.base_log_path)) + except OSError: + pass + return stream + + def delete_expired_files(self): + """ + delete expired log files + """ + if self.backup_count <= 0: + return + + file_names = os.listdir(str(self.base_log_path.parent)) + result = [] + prefix = self.base_filename + "." + plen = len(prefix) + for file_name in file_names: + if file_name[:plen] == prefix: + suffix = file_name[plen:] + if re.match(r"^\d{4}-\d{2}-\d{2}(\.\w+)?$", suffix): + result.append(file_name) + if len(result) < self.backup_count: + result = [] + else: + result.sort() + result = result[: len(result) - self.backup_count] + + for file_name in result: + os.remove(str(self.base_log_path.with_name(file_name))) + + +def get_logger(name, file_name=None): + """ + 获取logger + """ + if file_name is None: + file_name = name + ".log" + log_dir = os.getenv("log_dir", default="output") + if not os.path.exists(log_dir): + os.mkdir(log_dir) + + logger = logging.getLogger(name) + is_debug = int(os.getenv("FD_DEBUG", default=0)) + if is_debug: + logger.setLevel(level=logging.DEBUG) + else: + logger.setLevel(level=logging.INFO) + + log_file = "{0}/{1}".format(log_dir, file_name) + handler = DailyRotatingFileHandler(log_file, backupCount=7) + + formatter = logging.Formatter( + "%(levelname)-8s %(asctime)s %(process)-5s %(filename)s[line:%(lineno)d] %(message)s" + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + handler.propagate = False + return logger diff --git a/llm/dockerfiles/Dockerfile_serving_cuda118_cudnn8 b/llm/dockerfiles/Dockerfile_serving_cuda118_cudnn8 index b7cd4205c4..c5b477fc22 100644 --- a/llm/dockerfiles/Dockerfile_serving_cuda118_cudnn8 +++ b/llm/dockerfiles/Dockerfile_serving_cuda118_cudnn8 @@ -6,10 +6,11 @@ COPY ./client/ /opt/output/client/ ENV LD_LIBRARY_PATH="/usr/local/cuda-11.8/compat/:$LD_LIBRARY_PATH" +RUN apt update && apt install net-tools + RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu118/ \ - && python3 -m pip install paddlenlp==3.0.0b0 \ - && python3 -m pip install --no-cache-dir sentencepiece pycryptodome tritonclient[all]==2.41.1 + && python3 -m pip install paddlenlp==3.0.0b0 RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc \ && python3 setup_cuda.py build && python3 setup_cuda.py install --user \ @@ -32,3 +33,4 @@ RUN cd /opt/output/Serving/ \ ENV http_proxy="" ENV https_proxy="" +ENV TZ=Asia/Shanghai diff --git a/llm/dockerfiles/Dockerfile_serving_cuda118_cudnn8_pure b/llm/dockerfiles/Dockerfile_serving_cuda118_cudnn8_pure new file mode 100644 index 0000000000..3ee6c7c1dc --- /dev/null +++ b/llm/dockerfiles/Dockerfile_serving_cuda118_cudnn8_pure @@ -0,0 +1,25 @@ +FROM registry.baidubce.com/paddlepaddle/fastdeploy:llm-base-gcc12.3-cuda11.8-cudnn8-nccl2.15.5 + +WORKDIR /opt/output/ + +ENV LD_LIBRARY_PATH="/usr/local/cuda-11.8/compat/:$LD_LIBRARY_PATH" + +RUN apt update && apt install net-tools + +RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple +RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu118/ \ + && python3 -m pip install paddlenlp==3.0.0b0 + +RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc \ + && python3 setup_cuda.py build && python3 setup_cuda.py install --user \ + && cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \ + && cp -r /root/.local/lib/python3.10/site-packages/* /usr/local/lib/python3.10/dist-packages/ \ + && rm -rf PaddleNLP + +RUN git clone https://github.com/PaddlePaddle/FastDeploy.git \ + && cd FastDeploy/llm \ + && python3 -m pip install -r server/requirements.txt + +ENV http_proxy="" +ENV https_proxy="" +ENV TZ=Asia/Shanghai diff --git a/llm/dockerfiles/Dockerfile_serving_cuda123_cudnn9 b/llm/dockerfiles/Dockerfile_serving_cuda123_cudnn9 index fabb7c1724..c7c62d4056 100644 --- a/llm/dockerfiles/Dockerfile_serving_cuda123_cudnn9 +++ b/llm/dockerfiles/Dockerfile_serving_cuda123_cudnn9 @@ -6,10 +6,11 @@ COPY ./client/ /opt/output/client/ ENV LD_LIBRARY_PATH="/usr/local/cuda-12.3/compat/:$LD_LIBRARY_PATH" +RUN apt update && apt install net-tools + RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ \ - && python3 -m pip install paddlenlp==3.0.0b0 \ - && python3 -m pip install --no-cache-dir sentencepiece pycryptodome tritonclient[all]==2.41.1 + && python3 -m pip install paddlenlp==3.0.0b0 RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc \ && python3 setup_cuda.py build && python3 setup_cuda.py install --user \ @@ -32,3 +33,4 @@ RUN cd /opt/output/Serving/ \ ENV http_proxy="" ENV https_proxy="" +ENV TZ=Asia/Shanghai diff --git a/llm/server/scripts/start_server.sh b/llm/server/scripts/start_server.sh index 12d9d35e0e..773436aea6 100644 --- a/llm/server/scripts/start_server.sh +++ b/llm/server/scripts/start_server.sh @@ -39,6 +39,19 @@ export METRICS_PORT=${METRICS_PORT:-"8722"} export INFER_QUEUE_PORT=${INFER_QUEUE_PORT:-"8813"} export PUSH_MODE_HTTP_PORT=${PUSH_MODE_HTTP_PORT:-"9965"} +ports=(${HTTP_PORT} ${GRPC_PORT} ${METRICS_PORT} ${INFER_QUEUE_PORT} ${PUSH_MODE_HTTP_PORT}) +for port in "${ports[@]}"; do + output=$(netstat -tuln | grep ":${port} ") + if [ -n "$output" ]; then + echo "${port} is already in use" + exit 1 + fi +done + +script_dir=$(realpath "$(dirname "${BASH_SOURCE[0]}")") +root_dir=$(dirname "$script_dir") +export PYTHONPATH=${root_dir}:${PYTHONPATH} + mkdir -p log rm -rf console.log log/* rm -rf /dev/shm/* @@ -48,7 +61,7 @@ echo "start serving ..." tritonserver --exit-timeout-secs 100 --cuda-memory-pool-byte-size 0:0 --cuda-memory-pool-byte-size 1:0 \ --cuda-memory-pool-byte-size 2:0 --cuda-memory-pool-byte-size 3:0 --cuda-memory-pool-byte-size 4:0 \ --cuda-memory-pool-byte-size 5:0 --cuda-memory-pool-byte-size 6:0 --cuda-memory-pool-byte-size 7:0 \ - --pinned-memory-pool-byte-size 0 --model-repository llm_model/ \ + --pinned-memory-pool-byte-size 0 --model-repository trition_server_model \ --allow-http false \ --grpc-port=${GRPC_PORT} \ --metrics-port=${METRICS_PORT} \ diff --git a/llm/server/server/common.py b/llm/server/server/common.py new file mode 100644 index 0000000000..2e0e601f07 --- /dev/null +++ b/llm/server/server/common.py @@ -0,0 +1,19 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue + +global_output_queue = queue.Queue() +def get_global_output_queue(): + return global_output_queue diff --git a/llm/server/server/engine/config.py b/llm/server/server/engine/config.py index 6f0e1964e2..f7ca557972 100644 --- a/llm/server/server/engine/config.py +++ b/llm/server/server/engine/config.py @@ -17,7 +17,9 @@ import sys from datetime import datetime +import paddle from paddlenlp.generation import GenerationConfig + from server.utils import model_server_logger @@ -28,23 +30,22 @@ class Config: def __init__(self): self.read_from_env() + self.read_from_config() + self.postprocess() + self.check() def read_from_env(self): """ get the configuration from environment """ env = os.environ - self.model_dir = env.get( - "MODEL_DIR", "/opt/output/Serving/models") + + self.model_dir = env.get("MODEL_DIR", "/opt/output/Serving/models") if not self.model_dir: raise Exception("The parameter MODEL_DIR is None.") self.mp_num = int(env.get("MP_NUM", 8)) - self.config_json_file = env.get("CONFIG_JSON_FILE", "config.json") - self.model_config_path = os.path.join(self.model_dir, self.config_json_file) - if env.get("FD_MODEL_CONFIG_PATH", None): - self.model_config_path = env.get("FD_MODEL_CONFIG_PATH") - - # distributed config + self.model_config_path = os.path.join(self.model_dir, + env.get("CONFIG_JSON_FILE", "config.json")) self.distributed_config_path = os.path.join(self.model_dir, "rank_mapping.csv") if os.getenv("DISTRIBUTED_CONFIG", None): self.distributed_config_path = os.getenv("DISTRIBUTED_CONFIG") @@ -64,15 +65,16 @@ def read_from_env(self): raise Exception(f"MAX_PREFILL_BATCH ({self.max_prefill_batch}) must be greater than 0") self.disable_streaming = int(os.getenv("DISABLE_STREAMING", 0)) + # server ports + self.grpc_port = int(os.getenv("GRPC_PORT", 8000)) + self.http_port = int(os.getenv("HTTP_PORT", 8001)) + self.metrics_port = int(os.getenv("METRICS_PORT", 8002)) + self.infer_queue_port = int(os.getenv("INFER_QUEUE_PORT", 8005)) + # if PUSH_MODE_HTTP_PORT is not configured, only GRPC service is enabled + self.push_mode_http_port = int(os.getenv("PUSH_MODE_HTTP_PORT", -1)) + # max cached task num self.max_cached_task_num = int(os.getenv("MAX_CACHED_TASK_NUM", "128")) - # if PUSH_MODE_HTTP_PORT is not configured, only GRPC service is enabled - self.push_mode_http_port = int(os.getenv("PUSH_MODE_HTTP_PORT", "-1")) - if self.push_mode_http_port > 0: - grpc_port = os.getenv("GRPC_PORT", None) - if grpc_port is None: - raise Exception("GRPC_PORT cannot be None, while PUSH_MODE_HTTP_PORT>0") - self.grpc_port = int(grpc_port) # http worker num self.push_mode_http_workers = int(os.getenv("PUSH_MODE_HTTP_WORKERS", "1")) @@ -80,12 +82,8 @@ def read_from_env(self): raise Exception(f"PUSH_MODE_HTTP_WORKERS ({self.push_mode_http_workers}) must be positive") # Padlle commit id - import paddle self.paddle_commit_id = paddle.version.commit - # time interval for detecting whether the engine loop is normal during probing - self.check_health_interval = int(os.getenv("CHECK_HEALTH_INTERVAL", 10)) - # model config self.dtype = env.get("DTYPE", "bfloat16") self.block_size = int(env.get("BLOCK_SIZE", 64)) @@ -102,21 +100,20 @@ def read_from_env(self): self.bad_tokens = str(env.get("BAD_TOKENS", "-1")) self.first_token_id = int(os.getenv("FIRST_TOKEN_ID", 1)) - # infer queue port - self.infer_port = int(os.getenv("INFER_QUEUE_PORT", 56666)) - - # whether to use custom health checker - self.use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1)) - # Check the legality of requests self.seq_len_limit = int(env.get("MAX_SEQ_LEN", 8192)) self.dec_len_limit = int(env.get("MAX_DEC_LEN", 1024)) + # whether to use custom health checker + self.use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1)) + # time interval for detecting whether the engine loop is normal during probing + self.check_health_interval = int(os.getenv("CHECK_HEALTH_INTERVAL", 10)) + # warmup self.use_warmup = int(os.getenv("USE_WARMUP", 0)) == 1 # uuid - self.shm_uuid = os.getenv("SHM_UUID", '') + self.shm_uuid = os.getenv("SHM_UUID", "") # use huggingface tokenizer self.use_hf_tokenizer = int(os.getenv("USE_HF_TOKENIZER", 0)) == 1 @@ -130,10 +127,6 @@ def read_from_env(self): ) self.generation_config = None - self.read_from_config() - self.postprocess() - self.check() - def postprocess(self): """ calculate some parameters @@ -234,3 +227,10 @@ def get_unique_name(self, name): def __str__(self) -> str: return json.dumps(self.__dict__, indent=4) + +cfg_inst = None +def get_global_config(): + global cfg_inst + if cfg_inst is None: + cfg_inst = Config() + return cfg_inst diff --git a/llm/server/server/engine/engine.py b/llm/server/server/engine/engine.py index 932404d9c0..eea5560f47 100644 --- a/llm/server/server/engine/engine.py +++ b/llm/server/server/engine/engine.py @@ -25,8 +25,8 @@ import numpy as np from server.engine.resource_manager import ResourceManager from server.engine.task_queue_manager import (TaskQueueManager, - launch_queue_service) -from server.engine.token_processor import TokenProcessor, WarmUpTokenProcessor + launch_task_queue_manager) +from server.engine.out_processor import OutProcessor from server.utils import model_server_logger @@ -34,81 +34,25 @@ class Engine(object): """ Engine Class """ - def __init__(self, cfg, token_processor): + def __init__(self, cfg): self.cfg = cfg self.resource_manager = ResourceManager(self.cfg) - self.token_processor = token_processor - self.token_processor.set_resource_manager(self.resource_manager) - self.is_started = False + self.out_processor = OutProcessor(self.cfg) + self.out_processor.set_resource_manager(self.resource_manager) self._init_engine_flags() - self._finalizer = weakref.finalize(self, self._exit_sub_services) - def start(self): - """ - initialize engine and start sub services - """ - assert not self.is_started, "The engine is already started.!" - start_time = time.time() - self.queue_service = self._start_tasks_queue_service() - self.tasks_queue = TaskQueueManager(mp_num=self.cfg.mp_num, port=self.cfg.infer_port) + self.tqm_proc = self._start_task_queue_manager() + self.task_queue_manager = TaskQueueManager(mp_num=self.cfg.mp_num, port=self.cfg.infer_queue_port) - self.token_processor.tasks_queue = self.tasks_queue - self.infer_proc = self._start_infer_service() + start_time = time.time() + self.infer_proc = self._start_infer_process() model_server_logger.info("Waitting infer processes ready...") while not self._infer_processes_ready(): time.sleep(1) - self.is_started = True - - # start warmup - if self.cfg.use_warmup: - model_server_logger.info("Start warmup") - self._set_warmup_token_processor() - self.warmup() - self._del_warmup_token_processor() - model_server_logger.info("Warmup finish") - - # start TokenProcessor thread - self.token_processor.run() model_server_logger.info("Infer processes are launched with {} seconds.".format(time.time() - start_time)) - def warmup(self): - """ - construct test tasks and avoid out of memory problem in the infer process - """ - # get eos_token_id - from server.data.processor import DataProcessor - eos_token_ids = DataProcessor().get_eos_tokens() - - # construct test tasks - res_task = [] - for j in range(2 * self.cfg.max_batch_size): - data = { - "input_ids": [5], - "req_id": j, - "max_dec_len": self.cfg.dec_len_limit, - "min_dec_len": int(self.cfg.dec_len_limit * 0.5) + 1, - "eos_token_ids": eos_token_ids - } - res_task.append(data) - for j in range(2 * self.cfg.max_prefill_batch): - data = { - "input_ids": [5] * self.cfg.seq_len_limit, - "req_id": j + 2 * self.cfg.max_batch_size, - "max_dec_len": 1, - "min_dec_len": 1, - "eos_token_ids": eos_token_ids - } - res_task.append(data) - - for x in res_task: - while self.available_batch() == 0 or not self.insert_tasks([x]): - time.sleep(0.0002) - - self.token_processor._is_blocking = False - # wait for all tasks finished - while not self.all_tasks_finished(): - time.sleep(1) + self._finalizer = weakref.finalize(self, self._exit_sub_services) def insert_tasks(self, tasks): """ @@ -158,13 +102,9 @@ def insert_tasks(self, tasks): if not tasks: return False - self.token_processor.number_of_tasks += len(tasks) - for i in range(len(tasks)): - self.token_processor.number_of_input_tokens += len(tasks[i]["input_ids"]) - req_ids = [t["req_id"] for t in tasks] model_server_logger.info(f"Tasks are sent to engine, req_ids={req_ids}") - self.tasks_queue.put((tasks, self.resource_manager.real_bsz)) + self.task_queue_manager.put((tasks, self.resource_manager.real_bsz)) return True def task_is_finished(self, index): @@ -187,7 +127,7 @@ def is_queue_empty(self): Returns: return: True if empty, False otherwise """ - return self.tasks_queue.empty() + return self.task_queue_manager.empty() def is_resource_sufficient(self, input_token_num): """ @@ -228,29 +168,6 @@ def available_block_num(self): """ return self.resource_manager.availabel_block_num() - def _set_warmup_token_processor(self): - """ - set token_processor for warmup - """ - self.token_processor_backup = self.token_processor - self.token_processor = WarmUpTokenProcessor(self.cfg) - self.token_processor.set_resource_manager(self.resource_manager) - self.token_processor.tasks_queue = self.tasks_queue - - # start TokenProcessor thread - self.token_processor.run() - - def _del_warmup_token_processor(self): - """ - delete token_processor for warmup - """ - self.token_processor.stop() - del self.token_processor - - # reset token_processor - self.token_processor = self.token_processor_backup - del self.token_processor_backup - def _infer_processes_ready(self): """ judge if all infer processes are ready @@ -341,34 +258,33 @@ def _exit_sub_services(self): """ exit sub services """ - if hasattr(self, "queue_service") and self.queue_service is not None: - self.queue_service.terminate() - self.queue_service.join() + if hasattr(self, "tqm_proc") and self.tqm_proc is not None: + self.tqm_proc.terminate() + self.tqm_proc.join() if hasattr(self, "infer_proc") and self.infer_proc is not None: os.killpg(self.infer_proc.pid, signal.SIGTERM) - def _start_tasks_queue_service(self): + def _start_task_queue_manager(self): """ start tasks queue service Returns: p: process handle """ - p = multiprocessing.Process(target=launch_queue_service, args=(self.cfg.infer_port, self.cfg.mp_num)) + p = multiprocessing.Process(target=launch_task_queue_manager, args=(self.cfg.infer_queue_port, self.cfg.mp_num)) p.start() - time.sleep(0.3) if p.is_alive(): model_server_logger.info("start tasks queue service successfully") else: - error_msg = "Failed to start tasks queue service, please check " \ + error_msg = "Failed to start task queue manager, please check " \ "the log/task_queue_manager.log for details" model_server_logger.info(error_msg) raise Exception(error_msg) return p - def _start_gpu_infer_service(self): + def _start_gpu_infer_process(self): """ - start gpu infer service + start gpu infer process Returns: p: process handle @@ -394,8 +310,8 @@ def _start_gpu_infer_service(self): ) return p - def _start_infer_service(self): + def _start_infer_process(self): """ - start infer service + start infer process """ - return self._start_gpu_infer_service() + return self._start_gpu_infer_process() diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py index 5d1f9bd33b..90b902cde1 100644 --- a/llm/server/server/engine/infer.py +++ b/llm/server/server/engine/infer.py @@ -27,10 +27,11 @@ import paddle.distributed.fleet as fleet from paddlenlp.utils.llm_utils import get_rotary_position_embedding from paddlenlp_ops import step_paddle + from server.data.processor import DataProcessor -from server.engine.config import Config +from server.engine.config import get_global_config from server.utils import get_logger -from task_queue_manager import TaskQueueManager +from server.engine.task_queue_manager import TaskQueueManager File_Path = os.path.realpath(sys.argv[0]) Dir_Path = os.path.dirname(File_Path) @@ -44,7 +45,7 @@ def __init__(self, args): # 2**63 - 1 self.MAX_INFER_SEED = 9223372036854775806 - self.config = Config() + self.config = get_global_config() self.model_cfg = self.config.get_model_config() self.format_print_configuration() @@ -62,7 +63,7 @@ def __init__(self, args): self.cache_kvs = {} self.init_inputs() - self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_port) + self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_queue_port) model_rank_path = os.path.join(self.args.model_dir, f"rank_{self.rank}") if not os.path.exists(model_rank_path): @@ -353,6 +354,14 @@ def run(self): """ run infer """ + use_custom_health_checker = self.config.use_custom_health_checker + if use_custom_health_checker: + shm_engine_ready_check_flag_array, engine_ready_check_flag_array = self.initialize_engine_ready_check_flag() + engine_ready_check_flag_array[0] = 1 + shm_engine_healthy_recorded_time_array, engine_healthy_recorded_time_array = self.initialize_engine_healthy_recorded_time_flag() + engine_healthy_recorded_time_array[0] = time.time() + infer_live_flag_shm = self.initialize_engine_live_flag() + flag_array = np.zeros([1], dtype=np.int32) shm_flag_broadcast = shared_memory.SharedMemory( name=self.config.get_unique_name("shm_pd_infer_flag_broadcast")) @@ -373,13 +382,6 @@ def run(self): dtype=flag_array.dtype, buffer=shm_flag_has_block_step.buf) - use_custom_health_checker = self.config.use_custom_health_checker - if use_custom_health_checker: - shm_engine_ready_check_flag_array, engine_ready_check_flag_array = self.initialize_engine_ready_check_flag() - engine_ready_check_flag_array[0] = 1 - shm_engine_healthy_recorded_time_array, engine_healthy_recorded_time_array = self.initialize_engine_healthy_recorded_time_flag() - engine_healthy_recorded_time_array[0] = time.time() - infer_live_flag_shm = self.initialize_engine_live_flag() infer_seed_increment = paddle.full(shape=[self.args.max_batch_size, 1], fill_value=4, dtype="int64") diff --git a/llm/server/server/engine/token_processor.py b/llm/server/server/engine/out_processor.py similarity index 76% rename from llm/server/server/engine/token_processor.py rename to llm/server/server/engine/out_processor.py index 507a3d43bd..612127bf11 100644 --- a/llm/server/server/engine/token_processor.py +++ b/llm/server/server/engine/out_processor.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import queue import threading import time import traceback @@ -22,9 +23,10 @@ import numpy as np from paddlenlp_ops import get_output from server.utils import datetime_diff, model_server_logger, monitor_logger +from server.common import get_global_output_queue -class TokenProcessor(object): +class OutProcessor(object): """ get Token/Score from Paddle inference engine """ @@ -32,20 +34,17 @@ def __init__(self, cfg): import paddle paddle.device.set_device("cpu") self.cfg = cfg + self.out_queue = get_global_output_queue() self.resource_manager = None # record all tokens for each request self.all_tokens = [[] for _ in range(self.cfg.max_batch_size)] self.tokens_counter = Counter() self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64") - self.worker = None - self.record_time_interval = int(os.getenv("RECORD_TIME_INTERVAL", "600")) - assert self.record_time_interval < 3600, "The RECORD_TIME_INTERVAL cannot exceed 3600." - self.statics_start_time = time.time() - self.number_of_tasks = 0 - self.number_of_input_tokens = 0 - self.number_of_output_tokens = 0 + self.worker = threading.Thread(target=self.process_sampling_results, args=()) + self.worker.daemon = True + self.worker.start() def set_resource_manager(self, resource_manager): """ @@ -57,18 +56,6 @@ def set_resource_manager(self, resource_manager): assert self.resource_manager is None, "The resource manager is not None, cannot set again." self.resource_manager = resource_manager - def run(self): - """ - start thread to get tokens - """ - assert self.resource_manager is not None, "The resource manager is None, cannot run." - if self.worker is not None: - raise Exception("Worker is already running!") - - self.worker = threading.Thread(target=self.process_sampling_results, args=()) - self.worker.daemon = True - self.worker.start() - def process_sampling_results(self): """ read tokens from paddle inference engine and process @@ -93,13 +80,7 @@ def postprocess(self, batch_result, exist_finished_task=False): batch_result (list): batch results exist_finished_task (bool): whether there is a finished task """ - result_dir = "./generate_token_results" - if not os.path.exists(result_dir): - os.makedirs(result_dir) - for result in batch_result: - result_file = os.path.join(result_dir, result["req_id"]) - with open(result_file, "a") as f: - f.write("{}\n".format(result)) + self.out_queue.put(batch_result) def _get_single_result(self, i, task_id, token_id, task): """ @@ -198,7 +179,6 @@ def _process_batch_output(self): if token_id not in task["eos_token_ids"]: self.all_tokens[i].append(token_id) - self.number_of_output_tokens += 1 if token_id in task["eos_token_ids"]: self._recycle_resources(task_id, i, task) model_server_logger.info("req_id: {0} finished".format(task_id)) @@ -207,40 +187,3 @@ def _process_batch_output(self): batch_result.append(result) self.postprocess(batch_result, exist_finished_task) - - -class WarmUpTokenProcessor(TokenProcessor): - """ - Warmup Processor - """ - def __init__(self, cfg): - super().__init__(cfg) - self._is_running = True - self._is_blocking = True - - def postprocess(self, batch_result, exist_finished_task=False): - pass - - def process_sampling_results(self): - """ - get output from model and process it - """ - while self._is_running: - try: - rank_id = 0 - get_output(self.output_tokens, rank_id, self._is_blocking) - - if self.output_tokens[0, 0] == -2: - continue - self._process_batch_output() - except Exception as e: - model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc()))) - - def stop(self): - """ - stop warm up thread - """ - self._is_running = False - self.worker.join() - model_server_logger.info("warm up thread stop") - del self.worker diff --git a/llm/server/server/engine/task_queue_manager.py b/llm/server/server/engine/task_queue_manager.py index 475365d47f..368cf96571 100644 --- a/llm/server/server/engine/task_queue_manager.py +++ b/llm/server/server/engine/task_queue_manager.py @@ -21,7 +21,7 @@ from server.utils import get_logger -logger = get_logger("infer_server", "task_queue_manager.log") +logger = get_logger("task_queue_manager", "task_queue_manager.log") class QueueManager(BaseManager): @@ -53,7 +53,22 @@ def __init__(self, rank=0, mp_num=8, port=56666): self.client_manager = QueueManager(address=('127.0.0.1', port), authkey=b'infer_queue' ) - self.client_manager.connect() + + retries = 10 + delay = 0.5 + for attempt in range(1, retries + 1): + try: + self.client_manager.connect() + logger.info(f"connect client manager success on attempt {attempt}") + break + except ConnectionRefusedError: + if attempt == retries: + logger.error(f"failed to connect after {retries} attempts.") + raise + else: + logger.warning(f"connection attempt {attempt} failed, retrying in {delay} seconds...") + time.sleep(delay) + self.list = self.client_manager.get_list() self.value = self.client_manager.get_value() self.lock = self.client_manager.get_lock() @@ -131,7 +146,7 @@ def get(self): return input_list, read_finish -def launch_queue_service(port, num_workers): +def launch_task_queue_manager(port, num_workers): """ Start the process communication queue service @@ -163,3 +178,4 @@ def launch_queue_service(port, num_workers): except Exception as e: logger.error(f"launch queue service failed, error_msg: {e}") raise e + logger.error("task queue manager exit") diff --git a/llm/server/server/triton_server_helper.py b/llm/server/server/health_checker.py similarity index 98% rename from llm/server/server/triton_server_helper.py rename to llm/server/server/health_checker.py index b299cd4204..558e7cb0b3 100644 --- a/llm/server/server/triton_server_helper.py +++ b/llm/server/server/health_checker.py @@ -25,11 +25,11 @@ import uvicorn from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, Response -from server.engine.config import Config +from server.engine.config import get_global_config from server.utils import get_logger app = FastAPI() -env_config = Config() +env_config = get_global_config() logger = get_logger("health_checker", "health_checker.log") diff --git a/llm/server/server/triton_server.py b/llm/server/server/triton_server.py index 12024c251a..4218a27acf 100644 --- a/llm/server/server/triton_server.py +++ b/llm/server/server/triton_server.py @@ -26,12 +26,14 @@ from datetime import datetime import numpy as np + +import server from server.checker import add_default_params, check_basic_params from server.engine import engine -from server.engine.config import Config +from server.engine.config import get_global_config from server.utils import error_logger, model_server_logger - -import server +from server.data.processor import DataProcessor +from server.common import get_global_output_queue try: import triton_python_backend_utils as pb_utils @@ -44,72 +46,6 @@ enc = os.environ["LANG"].split(".")[1] sys.stdout = codecs.getwriter(enc)(sys.stdout) - -class TritonConfig(Config): - """ - Triton Inference Server config - """ - def __init__(self, base_config): - super().__init__() - for k, v in base_config.__dict__.items(): - setattr(self, k, v) - - -class TritonTokenProcessor(engine.TokenProcessor): - """ - initialize Triton Processor - """ - def __init__(self, cfg, triton_server): - super().__init__(cfg) - self.triton_server = triton_server - self.cached_generated_tokens = queue.Queue() - self.token_buffer = dict() - self.score_buffer = dict() - - self.push_mode_sender_thread = threading.Thread(target=self._push_mode_sender_thread, args=()) - self.push_mode_sender_thread.daemon = True - self.push_mode_sender_thread.start() - - def _push_mode_sender_thread(self): - """ - push mode sender thread - """ - while True: - try: - batch_result = self.cached_generated_tokens.get() - for result in batch_result: - req_id = result["req_id"] - is_end = result.get("is_end", 0) - return_all_tokens = result.get("return_all_tokens", False) - if is_end == 0 and (return_all_tokens or self.cfg.disable_streaming): - continue - if return_all_tokens and "topk_tokens" in result: - del result["topk_tokens"] - result = self.triton_server.data_processor.process_response(result) - if "usage" in result: - result["usage"]["prompt_tokens"] = self.triton_server.task_info[req_id]["prompt_tokens"] - model_server_logger.debug(f"Send result to client under push mode: {result}") - with self.triton_server.thread_lock: - _send_result([result], self.triton_server.response_sender[req_id], is_end) - if is_end == 1: - del self.triton_server.response_sender[req_id] - del self.triton_server.task_info[req_id] - self.triton_server._update_metrics() - except Exception as e: - model_server_logger.error("Unexcepted error happend: {}, {}".format(e, str(traceback.format_exc()))) - - def postprocess(self, batch_result, exist_finished_task=False): - """ - single postprocess for triton - """ - try: - self.cached_generated_tokens.put(batch_result) - except Exception as e: - model_server_logger.info( - "Unexcepted problem happend while process output token: {}, {}" - .format(e, str(traceback.format_exc()))) - - class TritonServer(object): """ Triton Server @@ -119,21 +55,8 @@ def initialize(self, args): """ Triton initialization """ - # start health checker - use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1)) - # if set USE_CUSTOM_HEALTH_CHECKER=1, use custom health checker, need set --allow-http=false - # else use tritonserver's health checker, need set --http-port=${HTTP_PORT} - if use_custom_health_checker: - http_port = os.getenv("HTTP_PORT") - if http_port is None: - raise Exception("HTTP_PORT must be set") - from server.triton_server_helper import start_health_checker - multiprocessing.Process(target=start_health_checker, args=(int(http_port), )).start() - time.sleep(1) - model_config = json.loads(args["model_config"]) - using_decoupled = pb_utils.using_decoupled_model_transaction_policy( - model_config) + using_decoupled = pb_utils.using_decoupled_model_transaction_policy(model_config) if not using_decoupled: raise pb_utils.TritonModelException( """the model `{}` can generate any number of responses per request, @@ -148,235 +71,184 @@ def initialize(self, args): GAUGE, ) self.metrics = { - "batch_size": - self.metric_family.Metric(labels={"batch_size": "batch_size"}), - "block_num": - self.metric_family.Metric(labels={"block_num": "block_num"}), - "max_batch_size": - self.metric_family.Metric( - labels={"max_batch_size": "max_batch_size"}), - "max_block_num": - self.metric_family.Metric( - labels={"max_block_num": "max_block_num"}), - "available_resource": - self.metric_family.Metric( - labels={"available_resource": "available_resource"}), + "batch_size": self.metric_family.Metric(labels={"batch_size": "batch_size"}), + "block_num": self.metric_family.Metric(labels={"block_num": "block_num"}), + "max_batch_size": self.metric_family.Metric(labels={"max_batch_size": "max_batch_size"}), + "max_block_num": self.metric_family.Metric(labels={"max_block_num": "max_block_num"}), + "available_resource": self.metric_family.Metric(labels={"available_resource": "available_resource"}), } - # response_sender thread lock - self.thread_lock = threading.Lock() - - base_config = Config() - self.cfg = TritonConfig(base_config) + self.cfg = get_global_config() self.cfg.print(file="log/fastdeploy_init.info") + self.req_senders = dict() + self.cached_task_deque = deque() + self.is_stopping = False + + # if set USE_CUSTOM_HEALTH_CHECKER=1, use custom health checker, need set --allow-http=false + # else use tritonserver's health checker, need set --http-port=${HTTP_PORT} + if self.cfg.use_custom_health_checker: + from server.health_checker import start_health_checker + multiprocessing.Process(target=start_health_checker, args=(self.cfg.http_port, )).start() + time.sleep(1) + + self.engine = engine.Engine(self.cfg) + model_server_logger.info("create engine success") + + self.data_processor = DataProcessor() + model_server_logger.info("create data processor success") + + self.http_proc = None + self._launch_http_server() + model_server_logger.info("launch push server success") - # init engine - self.token_processor = TritonTokenProcessor(self.cfg, self) - self.engine = engine.Engine(self.cfg, self.token_processor) - model_server_logger.info("Creat engine...") - self.engine.start() - model_server_logger.info("Create engine success") + schedule_task_thread = threading.Thread(target=self._schedule_task, args=()) + schedule_task_thread.daemon = True + schedule_task_thread.start() + send_output_thread = threading.Thread(target=self._send_output, args=()) + send_output_thread.daemon = True + send_output_thread.start() - self._initialize_push_mode() - model_server_logger.info("Init triton server success") + model_server_logger.info("init triton server success") + def _launch_http_server(self): + """ + launch http server + """ + model_server_logger.info("launch http server...") + current_dir_path = os.path.split(os.path.abspath(__file__))[0] + http_py_file = "app.py" + http_py_path = os.path.join(current_dir_path, "http_server", http_py_file) + http_cmd = f"python3 {http_py_path} --port={self.cfg.push_mode_http_port} " \ + f"--workers={self.cfg.push_mode_http_workers} >log/launch_http.log 2>&1" + model_server_logger.info(f"launch HTTP server for push mode, command:{http_cmd}") + + self.http_proc = subprocess.Popen(http_cmd, shell=True, preexec_fn=os.setsid) + time.sleep(3) + exit_code = self.http_proc.poll() + if exit_code is None: + http_url = f"http://127.0.0.1:{self.cfg.push_mode_http_port}/v1/chat/completions" + model_server_logger.info(f"launch HTTP server for push mode success, http_url:{http_url}") + else: + error_msg = "\n Launch HTTP service for push mode failed in 3 seconds. " \ + "Please check log/launch_http.log file \n" + model_server_logger.error(error_msg) def execute(self, requests): """ Triton service main function, handling requests received by the Triton framework """ + # load request if len(requests) != 1: raise pb_utils.TritonModelException( "Only support batch=1, but now it's {}.".format(len(requests))) request = requests[0] - current_response_sender = request.get_response_sender() - request_tensor = pb_utils.get_input_tensor_by_name(request, "IN") - tasks = json.loads(request_tensor.as_numpy()[0]) - + sender = request.get_response_sender() + tasks = json.loads(pb_utils.get_input_tensor_by_name(request, "IN").as_numpy()[0]) model_server_logger.info(f"receive task: {tasks}") - self._process_task_push_mode(tasks, current_response_sender) - self._update_metrics() - def finalize(self): - """ - Triton service exit function - """ - model_server_logger.info("Triton service will be terminated...") - wait_time = 300 - while not self.engine.all_tasks_finished(): - if wait_time <= 0: - model_server_logger.warning(f"Ignore the unfinished tasks, force to stop.") - break - model_server_logger.info(f"There's unfinished tasks, wait {wait_time}...") - wait_time -= 5 - time.sleep(5) - model_server_logger.info("Terminate the engine now.") - self.enable_insert_task_push_mode = False - time.sleep(1) - del self.engine - if hasattr(self, "http_process"): - self.http_process.kill() - model_server_logger.info("Triton service is terminated!") - - def _initialize_push_mode(self): - from server.data.processor import DataProcessor - self.data_processor = DataProcessor() - model_server_logger.info("create data processor success") - - if self.cfg.push_mode_http_port < 0: - model_server_logger.info("HTTP server for push mode is disabled.") + # check request + tik = time.time() + task = tasks[0] + req_id = task["req_id"] + task["preprocess_start_time"] = datetime.now() + + if self.is_stopping: + _send_error("The server is stopping", sender, req_id=req_id) + return + cached_task_num = len(self.cached_task_deque) + if cached_task_num >= self.cfg.max_cached_task_num: + error_msg = f"cached task num ({cached_task_num}) exceeds " \ + f"the limit ({self.cfg.max_cached_task_num})" + _send_error(error_msg, sender, req_id=req_id) + return + + if len(tasks) != 1: + error_msg = f"request data should not be empty and query " \ + f"num {len(tasks)} should be 1" + _send_error(error_msg, sender, req_id=req_id) + return + if req_id in self.req_senders: + error_msg = f"The req_id {req_id} already exists in the current batch, " \ + f"the current request will be ignored." + _send_error(error_msg, sender, req_id=req_id) + return + + error_msg = check_basic_params(task) + if error_msg != []: + _send_error(error_msg, sender, req_id=req_id) + return + + # preprocess request + task = add_default_params(task) + + if int(task.get("enable_text_truncate", 1)): + real_seq_len = self.cfg.max_seq_len - task.get("max_dec_len", 800) + task = self.data_processor.process_request(task, max_seq_len=real_seq_len) else: - model_server_logger.info("launch http server...") - - current_dir_path = os.path.split(os.path.abspath(__file__))[0] - http_py_file = "app.py" - http_py_path = os.path.join(current_dir_path, "http_server", http_py_file) - http_cmd = f"python3 {http_py_path} --port={self.cfg.push_mode_http_port} " \ - f"--workers={self.cfg.push_mode_http_workers} >log/launch_http.log 2>&1" - - model_server_logger.info(f"Launch HTTP server for push mode, command:{http_cmd}") - self.http_process = subprocess.Popen(http_cmd, shell=True, preexec_fn=os.setsid) - time.sleep(3) - exit_code = self.http_process.poll() - if exit_code is None: - http_url = f"http://127.0.0.1:{self.cfg.push_mode_http_port}/v1/chat/completions" - model_server_logger.info(f"Launch HTTP server for push mode success, http_url:{http_url}") - else: - error_msg = "\n Launch HTTP service for push mode failed in 3 seconds. " \ - "Please check log/launch_http.log file \n" - model_server_logger.error(error_msg) - model_server_logger.info("init push server success") - - self.response_sender = dict() - self.task_info = dict() - self.cached_task_deque = deque() - self.enable_insert_task_push_mode = True - self.insert_task_to_engine_thread = threading.Thread( - target=self._insert_task_push_mode, args=()) - self.insert_task_to_engine_thread.daemon = True - self.insert_task_to_engine_thread.start() - - def _process_task_push_mode(self, tasks, current_response_sender): - """ - check request and insert into cached_task_deque + task = self.data_processor.process_request(task) + + # check token length + input_ids_len = len(task["input_ids"]) + if "max_dec_len" not in task: + task["max_dec_len"] = min(self.cfg.max_seq_len - input_ids_len, self.cfg.dec_len_limit) + min_dec_len = task["min_dec_len"] + if input_ids_len + min_dec_len >= self.cfg.max_seq_len: + error_msg = f"Input text is too long, input_ids_len ({input_ids_len}) " \ + f"+ min_dec_len ({min_dec_len}) >= max_seq_len " + _send_error(error_msg, sender, req_id=req_id) + return + if input_ids_len > self.cfg.seq_len_limit: + error_msg = f"Length of input token({input_ids_len}) exceeds the limit MAX_SEQ_LEN({self.cfg.seq_len_limit})." + _send_error(error_msg, sender, req_id=req_id) + return + if task["max_dec_len"] > self.cfg.dec_len_limit: + error_msg = f"The parameter max_dec_len({task['max_dec_len']}) exceeds the limit MAX_DEC_LEN({self.cfg.dec_len_limit})." + _send_error(error_msg, sender, req_id=req_id) + return + required_block_num = self.engine.resource_manager.get_required_block_number(input_ids_len) + if required_block_num > self.engine.resource_manager.total_block_number(): + error_msg = f"The input task required resources is exceed the limit, task={task}." + _send_error(error_msg, sender, req_id=req_id) + return + + # cache task + self.req_senders[req_id] = sender + task["preprocess_end_time"] = datetime.now() + self.cached_task_deque.appendleft(task) + tok = time.time() + model_server_logger.info(f"cache task with req_id ({req_id}), " + f"cost time: {tok-tik}s, cached_task_num: {len(self.cached_task_deque)}.") + model_server_logger.debug(f"cache task: {task}") - Args: - tasks (list): list of request - current_response_sender: response sender for current request - """ - try: - tik = time.time() - req_id = tasks[0]["req_id"] - cached_task_num = len(self.cached_task_deque) - if cached_task_num >= self.cfg.max_cached_task_num: - error_msg = f"cached task num ({cached_task_num}) exceeds " \ - f"the limit ({self.cfg.max_cached_task_num})" - _send_error(error_msg, current_response_sender, req_id=req_id) - return - - if not tasks or len(tasks) != 1 or not tasks[0]: - error_msg = f"request data should not be empty and query " \ - f"num {len(tasks)} should be 1" - _send_error(error_msg, current_response_sender, req_id=req_id) - return - - task = tasks[0] - task["preprocess_start_time"] = datetime.now() - - error_msg = check_basic_params(task) - if error_msg != []: - _send_error(error_msg, current_response_sender, req_id=req_id) - return - - task_id = task["req_id"] - with self.thread_lock: - if task_id in self.response_sender: - error_msg = f"The req_id {task_id} already exists in the current batch, " \ - f"the current request will be ignored." - _send_error(error_msg, current_response_sender, req_id=req_id) - return - - task = add_default_params(task) - - if int(task.get("enable_text_truncate", 1)): - real_seq_len = self.cfg.max_seq_len - task.get("max_dec_len", 800) - task = self.data_processor.process_request(task, max_seq_len=real_seq_len) - else: - task = self.data_processor.process_request(task) - - input_ids_len = len(task["input_ids"]) - if "max_dec_len" not in task: - task["max_dec_len"] = min(self.cfg.max_seq_len - input_ids_len, self.cfg.dec_len_limit) - min_dec_len = task["min_dec_len"] - if input_ids_len + min_dec_len >= self.cfg.max_seq_len: - error_msg = f"Input text is too long, input_ids_len ({input_ids_len}) " \ - f"+ min_dec_len ({min_dec_len}) >= max_seq_len " - _send_error(error_msg, current_response_sender, req_id=req_id) - return - - if input_ids_len > self.cfg.seq_len_limit: - error_msg = f"Length of input token({input_ids_len}) exceeds the limit MAX_SEQ_LEN({self.cfg.seq_len_limit})." - _send_error(error_msg, current_response_sender, req_id=req_id) - return - if task["max_dec_len"] > self.cfg.dec_len_limit: - error_msg = f"The parameter max_dec_len({task['max_dec_len']}) exceeds the limit MAX_DEC_LEN({self.cfg.dec_len_limit})." - _send_error(error_msg, current_response_sender, req_id=req_id) - return - - required_block_num = self.engine.resource_manager.get_required_block_number(input_ids_len) - if required_block_num > self.engine.resource_manager.total_block_number(): - error_msg = f"The input task required resources is exceed the limit, task={task}." - _send_error(error_msg, current_response_sender, req_id=req_id) - return - - with self.thread_lock: - self.response_sender[task_id] = current_response_sender - self.task_info[task_id] = {"prompt_tokens": input_ids_len} - - task["preprocess_end_time"] = datetime.now() - self.cached_task_deque.appendleft(task) - tok = time.time() - model_server_logger.info(f"cache task with req_id ({task_id}), " - f"cost time: {tok-tik}s, cached_task_num: {len(self.cached_task_deque)}.") - model_server_logger.debug(f"cache task: {task}") - except Exception as e: - error_msg = "Unexcepted promblem happend while insert new task to server task queue: {}, {}".format( - e, str(traceback.format_exc())) - _send_error(error_msg, current_response_sender) - - def _insert_task_push_mode(self): + self._update_metrics() + + def _schedule_task(self): """ Insert task to engine thread, monitor cached_task_deque. if the engine has resource, insert task to engine """ - try: - while self.enable_insert_task_push_mode: - if not hasattr(self, "engine") or self.engine is None: - time.sleep(0.1) - continue - if self.engine.available_batch() == 0: - time.sleep(0.001) - continue - if len(self.cached_task_deque) == 0: - time.sleep(0.001) - continue - if not self.engine.is_queue_empty(): + while True: + try: + if self.engine.available_batch() == 0 \ + or len(self.cached_task_deque) == 0 \ + or (not self.engine.is_queue_empty()): time.sleep(0.001) continue i_bs = 0 for _ in range(self.cfg.max_prefill_batch): - if len(self.cached_task_deque) == 0: - break - if self.engine.available_batch() == 0: + if len(self.cached_task_deque) == 0 \ + or self.engine.available_batch() == 0: break + while i_bs < self.cfg.max_batch_size: if self.engine.task_is_finished(i_bs): break i_bs += 1 if i_bs >= self.cfg.max_batch_size: break + input_token_num = len(self.cached_task_deque[-1]["input_ids"]) if not self.engine.is_resource_sufficient(input_token_num): break @@ -386,14 +258,38 @@ def _insert_task_push_mode(self): except Exception as e: err_msg = "Error happend while insert task to engine: {}, {}.".format( e, str(traceback.format_exc())) - with self.thread_lock: - _send_result({"error_msg": err_msg}, - self.response_sender[task["req_id"]], 1) - del self.response_sender[task["req_id"]] - model_server_logger.info("finish insert_task_push_mode thread") - except Exception as e: - model_server_logger.error("insert_task_push_mode thread exit " - f"unexpectedly, {e}. {str(traceback.format_exc())}") + _send_result({"error_msg": err_msg}, + self.req_senders[task["req_id"]], 1) + del self.req_senders[task["req_id"]] + except Exception as e: + model_server_logger.error(f"schedule task has error: {e}. {str(traceback.format_exc())}") + model_server_logger.info("schedule task thread exit") + + def _send_output(self): + """ + process output and send it to user + """ + while True: + try: + out_queue = get_global_output_queue() + batch_result = out_queue.get() + for result in batch_result: + req_id = result["req_id"] + is_end = result.get("is_end", 0) + return_all_tokens = result.get("return_all_tokens", False) + if is_end == 0 and (return_all_tokens or self.cfg.disable_streaming): + continue + if return_all_tokens and "topk_tokens" in result: + del result["topk_tokens"] + result = self.data_processor.process_response(result) + model_server_logger.debug(f"send result to client under push mode: {result}") + _send_result([result], self.req_senders[req_id], is_end) + if is_end == 1: + del self.req_senders[req_id] + self._update_metrics() + except Exception as e: + model_server_logger.error("unexcepted error happend: {}, {}".format(e, str(traceback.format_exc()))) + def _update_metrics(self): """ @@ -408,22 +304,26 @@ def _update_metrics(self): self.metrics["available_resource"].set(block_num * 1.0 / self.cfg.max_block_num) - def _get_current_server_info(self): + def finalize(self): """ - get server info + Triton service exit function """ - available_batch_size = min(self.cfg.max_prefill_batch, - self.engine.available_batch()) - available_block_num = self.engine.available_block_num() - server_info = { - "block_size": int(self.cfg.block_size), - "block_num": int(available_block_num), - "dec_token_num": int(self.cfg.dec_token_num), - "available_resource": - 1.0 * available_block_num / self.cfg.max_block_num, - "max_batch_size": int(available_batch_size), - } - return server_info + model_server_logger.info("Triton service will be terminated...") + self.is_stopping = True + wait_time = 300 + while not self.engine.all_tasks_finished(): + if wait_time <= 0: + model_server_logger.warning(f"Ignore the unfinished tasks, force to stop.") + break + model_server_logger.info(f"There's unfinished tasks, wait {wait_time}...") + wait_time -= 5 + time.sleep(5) + + del self.engine + if self.http_proc: + self.http_proc.kill() + model_server_logger.info("Triton service is terminated!") + def _send_result(result_dict, sender, end_flag=0): @@ -441,9 +341,8 @@ def _send_result(result_dict, sender, end_flag=0): end_output = pb_utils.Tensor("OUT", np.array([result_dict], dtype=np.object_)) response = pb_utils.InferenceResponse(output_tensors=[end_output]) - if response is None and end_flag == 0: - return - sender.send(response, flags=end_flag) + if response or end_flag != 0: + sender.send(response, flags=end_flag) def _send_error(error_msg, sender, error_code=200, req_id=None): """ @@ -457,7 +356,7 @@ def _send_error(error_msg, sender, error_code=200, req_id=None): """ if not isinstance(error_msg, str): error_msg = str(error_msg) - error_info = {"req_id": req_id, "error_msg": error_msg, "error_code": error_code, "version": "4.6", "timestamp": time.time()} + error_info = {"req_id": req_id, "error_msg": error_msg, "error_code": error_code, "timestamp": time.time()} error_logger.info(f"{error_info}") model_server_logger.error(error_msg) _send_result(error_info, sender, 1) diff --git a/llm/server/server/utils.py b/llm/server/server/utils.py index bb80f6b0a4..fc6e50cecf 100644 --- a/llm/server/server/utils.py +++ b/llm/server/server/utils.py @@ -135,6 +135,9 @@ def get_logger(name, file_name, without_formater=False): get logger """ log_dir = os.getenv("FD_LOG_DIR", default="log") + if not os.path.exists(log_dir): + os.mkdir(log_dir) + is_debug = int(os.getenv("FD_DEBUG", default=0)) logger = logging.getLogger(name) if is_debug: @@ -142,10 +145,8 @@ def get_logger(name, file_name, without_formater=False): else: logger.setLevel(level=logging.INFO) - LOG_FILE = "{0}/{1}".format(log_dir, file_name) backup_count = int(os.getenv("FD_LOG_BACKUP_COUNT", 7)) - handler = DailyRotatingFileHandler(LOG_FILE, backupCount=backup_count) - + handler = DailyRotatingFileHandler("{0}/{1}".format(log_dir, file_name), backupCount=backup_count) formatter = logging.Formatter( "%(levelname)-8s %(asctime)s %(process)-5s %(filename)s[line:%(lineno)d] %(message)s" ) diff --git a/llm/server/trition_server_model/model/1/model.py b/llm/server/trition_server_model/model/1/model.py new file mode 100644 index 0000000000..eb15091e23 --- /dev/null +++ b/llm/server/trition_server_model/model/1/model.py @@ -0,0 +1 @@ +from server.triton_server import TritonPythonModel diff --git a/llm/server/config/config.pbtxt b/llm/server/trition_server_model/model/config.pbtxt similarity index 100% rename from llm/server/config/config.pbtxt rename to llm/server/trition_server_model/model/config.pbtxt diff --git a/llm/tests/test_grpc.py b/llm/tests/test_grpc.py new file mode 100644 index 0000000000..153fd83ca4 --- /dev/null +++ b/llm/tests/test_grpc.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import queue +import sys +import uuid +from functools import partial + +import numpy as np +import tritonclient.grpc as grpcclient +from tritonclient.utils import * + + +class OutputData: + def __init__(self): + self._completed_requests = queue.Queue() + + +def triton_callback(output_data, result, error): + if error: + output_data._completed_requests.put(error) + else: + output_data._completed_requests.put(result) + +def test_base(grpc_url, input_data, test_iters=1, log_level="simple"): + if log_level not in ["simple", "verbose"]: + raise ValueError("log_level must be simple or verbose") + + model_name = "model" + inputs = [grpcclient.InferInput("IN", [1], np_to_triton_dtype(np.object_))] + outputs = [grpcclient.InferRequestedOutput("OUT")] + output_data = OutputData() + + with grpcclient.InferenceServerClient(url=grpc_url, verbose=False) as triton_client: + triton_client.start_stream(callback=partial(triton_callback, output_data)) + for i in range(test_iters): + input_data = json.dumps([input_data]) + inputs[0].set_data_from_numpy(np.array([input_data], dtype=np.object_)) + + triton_client.async_stream_infer(model_name=model_name, + inputs=inputs, + request_id="{}".format(i), + outputs=outputs) + + print("output_data:") + while True: + output_item = output_data._completed_requests.get(timeout=10) + if type(output_item) == InferenceServerException: + print(f"Exception: status is {output_item.status()}, msg is {output_item.message()}") + break + else: + result = json.loads(output_item.as_numpy("OUT")[0]) + result = result[0] if isinstance(result, list) else result + if result.get("is_end") == 1 or result.get("error_msg"): + print(f"\n {result} \n") + break + else: + if log_level == "simple": + print(result['token'] if 'token' in result else result['token_ids'][0], end="") + else: + print(result) + +if __name__ == "__main__": + input_data = { + "req_id": 0, + "text": "hello", + "seq_len": 1024, + "min_dec_len": 2, + "penalty_score": 1.0, + "temperature": 0.8, + "topp": 0.8, + "frequency_score": 0.1, + "presence_score": 0.0 + } + grpc_url = "0.0.0.0:8891" + test_base(grpc_url=grpc_url, input_data=input_data) diff --git a/llm/tests/test_http.py b/llm/tests/test_http.py new file mode 100644 index 0000000000..aac05293e4 --- /dev/null +++ b/llm/tests/test_http.py @@ -0,0 +1,88 @@ +import argparse +import json +import uuid +from datetime import datetime + +import httpx +import requests + + +def http_no_stream(url, data): + print("--http_no_stream--") + headers = {'Content-Type': 'application/json'} + #resp = httpx.post(url=url, headers=headers, timeout=300, json=data) + resp = requests.post(url, headers=headers, json=data) + print(resp.text) + +def http_stream(url, data, show_chunk=False): + print("--http_stream--") + headers = {'Content-Type': 'application/json'} + data = data.copy() + data["stream"] = True + #with httpx.stream("POST", url, headers=headers, timeout=300,json=data) as r: + with requests.post(url, json=data, headers=headers, timeout=300, stream=True) as r: + result = "" + for chunk in r.iter_lines(): + if chunk: + resp = json.loads(chunk) + if resp["error_msg"] != "" or resp["error_code"] != 0: + print(resp) + return + else: + result += resp.get("token", "") + if show_chunk: + print(resp) + print(f"Result: {result}") + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--http_host", default="10.95.147.146", type=str, help="host to the http server") + parser.add_argument("--http_port", default=8894, type=int, help="port to the http server") + parser.add_argument("-o", "--open_source_model", action="store_true", help="test eb_model or open_source_model") + args = parser.parse_args() + return args + +if __name__ == '__main__': + args = parse_args() + url = f"http://{args.http_host}:{args.http_port}/v1/chat/completions" + print(f"url: {url}") + + print("\n\n=====single round test=====") + data = { + "req_id": str(uuid.uuid4()), + "text": "hello", + "max_dec_len": 1024, + "min_dec_len": 2, + "penalty_score": 1.0, + "temperature": 0.8, + "topp": 0, + "frequency_score": 0.1, + "presence_score": 0.0, + "timeout": 600, + "benchmark": True, + } + http_no_stream(url, data) + http_stream(url, data) + + print("\n\n=====single round test with default params=====") + data = {"text": "hello"} + http_no_stream(url, data) + http_stream(url, data) + + + print("\n\n=====test error case=====") + data = { + "req_id": str(uuid.uuid4()), + "text": "hello", + "max_dec_len": 1024, + "min_dec_len": 2, + "penalty_score": 1.0, + "temperature": 0.8, + "topp": 2, # topp should be in [0, 1] + "frequency_score": 0.1, + "presence_score": 0.0, + "history_QA": [], + "benchmark": True, + "timeout": 600} + http_no_stream(url, data) + http_stream(url, data)