Skip to content

Commit

Permalink
[patch] ✨ Add egress filter retrying feature
Browse files Browse the repository at this point in the history
Signed-off-by: Rintaro Okamura <[email protected]>
  • Loading branch information
rinx committed May 7, 2021
1 parent 3881d56 commit 27da6d4
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ $ kubectl apply -f k8s/agent.yaml
Egress filter feature
---
alvd has an egress filter feature (filtering, sorting, translating, etc...) that is extensible by using Lua scripts.
alvd has an egress filter (= post filter) feature (filtering, sorting, translating, etc...) that is extensible by using Lua scripts.
To enable it, run alvd server by passing a path to the Lua scripts.
Expand Down
20 changes: 20 additions & 0 deletions examples/egress-filter/retryable.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
-- if `retry.Enabled` is true, retry ANN search when the number of `results` is lower than the required number.
retry.Enabled = true
-- `retry.MaxRetries` represents maximum number of retries.
retry.MaxRetries = 3
-- `retry.NextNumMultiplier` represents how to increase number of internal search results.
retry.NextNumMultiplier = 2

local remains = {}
for i, r in results() do
-- remove elements if ID lengths is lower than 3
if string.len(r.Id) >= 3 then
remains[#remains+1] = r
end

results[i] = nil
end

for i, r in pairs(remains) do
results[i] = r
end
20 changes: 16 additions & 4 deletions pkg/alvd/extension/lua/filter/filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@ type filter struct {
proto *lua.FunctionProto
}

type RetryConfig struct {
Enabled bool
MaxRetries int
NextNumMultiplier int
}

type EgressFilter interface {
Do([]*payload.Object_Distance) ([]*payload.Object_Distance, error)
Do(origin []*payload.Object_Distance) (results []*payload.Object_Distance, retry *RetryConfig, err error)
}

func NewEgressFilter(filePath string) (EgressFilter, error) {
Expand Down Expand Up @@ -60,22 +66,28 @@ func CompileLua(filePath string) (*lua.FunctionProto, error) {
return proto, nil
}

func (f *filter) Do(origin []*payload.Object_Distance) (results []*payload.Object_Distance, err error) {
func (f *filter) Do(origin []*payload.Object_Distance) (results []*payload.Object_Distance, retry *RetryConfig, err error) {
state := lua.NewState()
defer state.Close()

libs.Preload(state)

results = origin
retry = &RetryConfig{
Enabled: false,
MaxRetries: 3,
NextNumMultiplier: 2,
}

state.SetGlobal("results", luar.New(state, results))
state.SetGlobal("retry", luar.New(state, retry))

fn := state.NewFunctionFromProto(f.proto)
state.Push(fn)
err = state.PCall(0, lua.MultRet, nil)
if err != nil {
return origin, err
return origin, retry, err
}

return results, nil
return results, retry, nil
}
69 changes: 56 additions & 13 deletions pkg/alvd/server/service/gateway/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ import (
"time"

"github.com/rinx/alvd/internal/errors"
"github.com/rinx/alvd/internal/log"
"github.com/rinx/alvd/internal/net/grpc/codes"
"github.com/rinx/alvd/internal/net/grpc/status"
"github.com/rinx/alvd/pkg/alvd/extension/lua/filter"
"github.com/rinx/alvd/pkg/alvd/server/service/manager"
"github.com/vdaas/vald/apis/grpc/v1/payload"
"github.com/vdaas/vald/apis/grpc/v1/vald"
Expand All @@ -21,7 +23,7 @@ const (
defaultTimeout = 3 * time.Second
)

type EgressFilter = func([]*payload.Object_Distance) ([]*payload.Object_Distance, error)
type EgressFilter = func([]*payload.Object_Distance) ([]*payload.Object_Distance, *filter.RetryConfig, error)

type server struct {
manager manager.Manager
Expand Down Expand Up @@ -83,6 +85,59 @@ func (s *server) Search(
req *payload.Search_Request,
) (res *payload.Search_Response, err error) {
cfg := req.GetConfig()
num := int(cfg.GetNum())

for i := 0; i < 50; i++ {
res, err = s.search(ctx, req)
if err != nil {
return nil, err
}

if s.egressFilter == nil {
break
}

filtered, retry, err := s.egressFilter(res.Results)
if err != nil {
log.Warnf("an error occurred while egress filtering: %s", err)
break
}

res.Results = make([]*payload.Object_Distance, 0, len(filtered))
for _, r := range filtered {
if r != nil {
res.Results = append(res.Results, r)
}
}

if !retry.Enabled || i >= retry.MaxRetries || len(res.Results) >= num {
break
}

req = &payload.Search_Request{
Vector: req.GetVector(),
Config: &payload.Search_Config{
RequestId: req.GetConfig().GetRequestId(),
Num: req.GetConfig().GetNum() * uint32(retry.NextNumMultiplier),
Radius: req.GetConfig().GetRadius(),
Epsilon: req.GetConfig().GetEpsilon(),
Timeout: req.GetConfig().GetTimeout(),
},
}
}

if num != 0 && len(res.GetResults()) > num {
res.Results = res.Results[:num]
}

return res, err
}

func (s *server) search(
ctx context.Context,
req *payload.Search_Request,
) (res *payload.Search_Response, err error) {
cfg := req.GetConfig()

timeout := getTimeout(cfg)
num := int(cfg.GetNum())
Expand Down Expand Up @@ -145,18 +200,6 @@ func (s *server) Search(
res.Results = res.Results[:num]
}

if s.egressFilter != nil {
filtered, err := s.egressFilter(res.Results)
if err == nil {
res.Results = make([]*payload.Object_Distance, 0, len(filtered))
for _, r := range filtered {
if r != nil {
res.Results = append(res.Results, r)
}
}
}
}

return res, nil
case dist := <-dch:
nres := len(res.GetResults())
Expand Down

0 comments on commit 27da6d4

Please sign in to comment.