Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New feature] Add inference load balance controller for fastdeploy llm #2276

Open
wants to merge 9 commits into
base: llm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions llm_ic/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# 大模型服务的负载均衡组件

## 环境要求

- python >= 3.7
- 启动好的redis服务,用于作为负载均衡的数据库

## 环境变量
目前所支持的环境变量参考fastdeploy_ic里的config.py

| 环境变量 | 含义 |
| -------- | ------- |
| REDIS_HOST | redis服务的ip |
| REDIS_PORT | redis服务的port |
| REDIS_USERNAME | redis认证用户 |
| REDIS_PASSWORD | redis认证密码 |
| RESPONSE_TIMEOUT | 获取推理服务流式token的超时时间 |


## 启动示例

```shell
export REDIS_HOST="localhost"
export REDIS_PORT="6379"
python main.py
```

Empty file.
30 changes: 30 additions & 0 deletions llm_ic/fastdeploy_ic/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
import multiprocessing
import json

class GlobalConfig():
""" global config """

def __init__(self):
"""init
Args:
None
Returns:
None
"""
# Redis
self.redis_host = os.getenv('REDIS_HOST', default="localhost")
self.redis_port = int(os.getenv('REDIS_PORT', default="6379"))
self.redis_db = int(os.getenv('REDIS_DB', default="0"))
self.redis_username = os.getenv('REDIS_USERNAME', default=None)
self.redis_password = os.getenv('REDIS_PASSWORD', default=None)

# Response
self.resonpse_timeout = int(os.getenv('RESPONSE_TIMEOUT', default="120"))

# Server
self.num_process = int(os.getenv('NUM_PROCESS', default=multiprocessing.cpu_count()))

# Logger
self.log_dir = os.getenv('IC_LOG_DIR', default='ic_logs')

Empty file.
139 changes: 139 additions & 0 deletions llm_ic/fastdeploy_ic/data/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@

import json
import math
import asyncio

import aioredis

import fastdeploy_ic.proto.ic_pb2 as ic_pb2
from fastdeploy_ic.utils import get_logger

logger = get_logger("data_manager", "ic_data_manager.log")

__retry_times = 5 # redis client may have unexpected errors, we retry it with respect to some errors
def retry_wrapper(f):
async def wrapper(*args, **kwargs):
for i in range(__retry_times):
try:
return await f(*args, **kwargs)
except asyncio.CancelledError:
logger.info("{} occured asyncio.CancelledError, retry times: {}".format(f.__name__, i+1))
continue
except aioredis.ConnectionError:
args[0].renew_client()
logger.info("{} occured aioredis.ConnectionError, retry times: {}".format(f.__name__, i+1))
continue
except aioredis.TimeoutError:
args[0].renew_client()
logger.info("{} occured aioredis.TimeoutError, retry times: {}".format(f.__name__, i+1))
continue
return wrapper



class DataManager:
def __init__(self, redis_conf) -> None:
self.redis_conf = redis_conf
self.client = aioredis.Redis(**redis_conf)
self.internal_check_key_prefix = '__keymap_'

def renew_client(self):
self.client = aioredis.Redis(**self.redis_conf)

@retry_wrapper
async def check_req_id_exist(self, model_id, req_id):
key = '{}{}'.format(self.internal_check_key_prefix, model_id)
logger.info("check_req_id_exist: key: {} value: {}".format(key, req_id))
is_exist = await self.client.sismember(key, req_id)
return is_exist

@retry_wrapper
async def add_req_id_to_map(self, model_id, req_id):
key = '{}{}'.format(self.internal_check_key_prefix, model_id)
logger.info("add_req_id_to_map: key: {} value: {}".format(key, req_id))
await self.client.sadd(key, req_id)

@retry_wrapper
async def remove_req_id_from_map(self, model_id, req_id):
key = '{}{}'.format(self.internal_check_key_prefix, model_id)
logger.info("remove_req_id_from_map: key: {} value: {}".format(key, req_id))
await self.client.srem(key, req_id)

@retry_wrapper
async def enque_request(self, model_id, req, to_end=True):
serialized_req = req.SerializeToString()
# key = model_id
logger.info("enque_request: key: {} value: {}".format(model_id, req))
if to_end:
await self.client.rpush(model_id, serialized_req)
else:
await self.client.lpush(model_id, serialized_req)

@retry_wrapper
async def deque_request(self, model_id):
data = await self.client.lpop(model_id)
if data is not None:
data = ic_pb2.ModelInferRequest.FromString(data)
logger.info("deque_request: key: {} value: {}".format(model_id, data))
return data

@retry_wrapper
async def remove_request(self, model_id, req):
serialized_req = req.SerializeToString()
logger.info("remove_request: key: {} value: {}".format(model_id, req))
await self.client.lrem(model_id, 1, serialized_req)

@retry_wrapper
async def enque_response(self, model_id, req_id, res, to_end=True):
serialized_res = res.SerializeToString()
key = '{}/{}'.format(model_id, req_id)
logger.info("enque_response: key: {} value: {}".format(key, res))
if to_end:
await self.client.rpush(key, serialized_res)
else:
await self.client.lpush(key, serialized_res)

@retry_wrapper
async def deque_response(self, model_id, req_id):
key = '{}/{}'.format(model_id, req_id)
data = await self.client.lpop(key)
if data is not None:
data = ic_pb2.ModelInferResponse.FromString(data)
logger.info("deque_response: key: {} value: {}".format(key, data))
return data

@retry_wrapper
async def clear_response(self, model_id, req_id):
key = '{}/{}'.format(model_id, req_id)
logger.info("clear_response: key: {}".format(key))
await self.client.delete(key)

async def get_requests_by_number(self, model_id, max_request_num):
# return requests by ByRequest strategy
requests = []
for i in range(max_request_num):
request = await self.deque_request(model_id)
if request is not None:
requests.append(request)
else:
break
logger.info("get_requests_by_number: model_id: {} length: {}".format(model_id, len(requests)))
return requests

async def get_requests_by_block(self, model_id, max_request_num, block_num, block_size, dec_token_num):
# return requests by ByToken strategy
requests = []
left_block_num = block_num
for i in range(max_request_num):
request = await self.deque_request(model_id)
if request is not None:
text_words_num = json.loads(request.input)['text_words_num']
need_block_num = math.ceil((text_words_num + dec_token_num)/block_size)
if need_block_num < left_block_num:
requests.append(request)
left_block_num -= need_block_num
else:
await self.enque_request(model_id, request, to_end=False)
break
logger.info("get_requests_by_block: model_id: {} length: {}".format(model_id, len(requests)))
return requests
3 changes: 3 additions & 0 deletions llm_ic/fastdeploy_ic/proto/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import sys
import os
sys.path.append(os.path.dirname(__file__))
99 changes: 99 additions & 0 deletions llm_ic/fastdeploy_ic/proto/ic.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
syntax = "proto3";
package language_inference;

// Inference Server GRPC endpoints.
service GRPCInferenceService
{
// 模型推理请求入口,给上层dispatch调用
// 输入一个请求,流式返回多个response
rpc ModelStreamInfer(ModelInferRequest) returns (stream ModelInferResponse) {}

// 拉取一个请求,给inference server调用
rpc ModelFetchRequest(ModelFetchRequestParams) returns (ModelFetchRequestResult) {}

// 发送请求的返回结果,给inference server调用
// response是流式的发送
rpc ModelSendResponse(stream ModelInferResponse) returns (ModelSendResponseResult) {}

// 批量发送请求的返回结果,给inference server调用
// response是流式的发送
rpc ModelSendResponseList(stream ModelInferResponseList) returns (ModelSendResponseResult) {}
}

message ModelFetchRequestParams
{
// 模型全局唯一id
repeated string model_id = 1;

// 一次返回的最大请求数
int32 max_request_num = 2;

FetchStrategy strategy = 3;

ByTokenParams by_token_params = 4;
}
// 根据 token 数量拉取请求的计算公式:
// 每个query需要的block数量: block_num = ceil((text_words_num + dec_token_num)/block_size)

enum FetchStrategy {
// 根据 request 数量拉取请求
ByRequest = 0; // 默认值

// 根据 token 数量拉取请求
ByToken = 1;
}

message ByTokenParams
{
// 可用的 block 数量
int32 block_num = 1;

// 每个 block 能支持的 token 数量
int32 block_size = 2;

// 每个 query 需要给输出预留的 token 数量
int32 dec_token_num = 3;
}

message ModelFetchRequestResult
{
// 获取到的请求数组
repeated ModelInferRequest requests = 1;
}

// 无需关心SendResponse的返回值
message ModelSendResponseResult {
}

message ModelInferRequest
{
// 模型唯一id
string model_id = 1;

// 请求唯一id,必须全局唯一
string request_id = 2;

// 串联上下游日志的id,用于定位问题
string trace_id = 3;

// 语言模型输入
string input = 4;
}

message ModelInferResponseList{
repeated ModelInferResponse response_list = 1;
}

message ModelInferResponse
{
// 请求唯一id
string request_id = 1;

// 返回的句子id,表示第几句,用于去重和排序
int32 sentence_id = 2;

// 语言模型输出
string output = 3;
}


41 changes: 41 additions & 0 deletions llm_ic/fastdeploy_ic/proto/ic_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading