diff --git a/README.md b/README.md index 91c0bb3..b730c40 100644 --- a/README.md +++ b/README.md @@ -172,7 +172,7 @@ $ kubectl apply -f k8s/agent.yaml Egress filter feature --- -alvd has an egress filter feature (filtering, sorting, translating, etc...) that is extensible by using Lua scripts. +alvd has an egress filter (= post filter) feature (filtering, sorting, translating, etc...) that is extensible by using Lua scripts. To enable it, run alvd server by passing a path to the Lua scripts. diff --git a/examples/egress-filter/retryable.lua b/examples/egress-filter/retryable.lua new file mode 100644 index 0000000..1aed919 --- /dev/null +++ b/examples/egress-filter/retryable.lua @@ -0,0 +1,20 @@ +-- if `retry.Enabled` is true, retry ANN search when the number of `results` is lower than the required number. +retry.Enabled = true +-- `retry.MaxRetries` represents maximum number of retries. +retry.MaxRetries = 3 +-- `retry.NextNumMultiplier` represents how to increase number of internal search results. +retry.NextNumMultiplier = 2 + +local remains = {} +for i, r in results() do + -- remove elements if ID lengths is lower than 3 + if string.len(r.Id) >= 3 then + remains[#remains+1] = r + end + + results[i] = nil +end + +for i, r in pairs(remains) do + results[i] = r +end diff --git a/pkg/alvd/extension/lua/filter/filter.go b/pkg/alvd/extension/lua/filter/filter.go index c4357d1..a1b4395 100644 --- a/pkg/alvd/extension/lua/filter/filter.go +++ b/pkg/alvd/extension/lua/filter/filter.go @@ -23,8 +23,14 @@ type filter struct { proto *lua.FunctionProto } +type RetryConfig struct { + Enabled bool + MaxRetries int + NextNumMultiplier int +} + type EgressFilter interface { - Do([]*payload.Object_Distance) ([]*payload.Object_Distance, error) + Do(origin []*payload.Object_Distance) (results []*payload.Object_Distance, retry *RetryConfig, err error) } func NewEgressFilter(filePath string) (EgressFilter, error) { @@ -60,22 +66,28 @@ func CompileLua(filePath string) (*lua.FunctionProto, error) { return proto, nil } -func (f *filter) Do(origin []*payload.Object_Distance) (results []*payload.Object_Distance, err error) { +func (f *filter) Do(origin []*payload.Object_Distance) (results []*payload.Object_Distance, retry *RetryConfig, err error) { state := lua.NewState() defer state.Close() libs.Preload(state) results = origin + retry = &RetryConfig{ + Enabled: false, + MaxRetries: 3, + NextNumMultiplier: 2, + } state.SetGlobal("results", luar.New(state, results)) + state.SetGlobal("retry", luar.New(state, retry)) fn := state.NewFunctionFromProto(f.proto) state.Push(fn) err = state.PCall(0, lua.MultRet, nil) if err != nil { - return origin, err + return origin, retry, err } - return results, nil + return results, retry, nil } diff --git a/pkg/alvd/server/service/gateway/handler/handler.go b/pkg/alvd/server/service/gateway/handler/handler.go index 6b56632..469b652 100644 --- a/pkg/alvd/server/service/gateway/handler/handler.go +++ b/pkg/alvd/server/service/gateway/handler/handler.go @@ -10,8 +10,10 @@ import ( "time" "github.com/rinx/alvd/internal/errors" + "github.com/rinx/alvd/internal/log" "github.com/rinx/alvd/internal/net/grpc/codes" "github.com/rinx/alvd/internal/net/grpc/status" + "github.com/rinx/alvd/pkg/alvd/extension/lua/filter" "github.com/rinx/alvd/pkg/alvd/server/service/manager" "github.com/vdaas/vald/apis/grpc/v1/payload" "github.com/vdaas/vald/apis/grpc/v1/vald" @@ -21,7 +23,7 @@ const ( defaultTimeout = 3 * time.Second ) -type EgressFilter = func([]*payload.Object_Distance) ([]*payload.Object_Distance, error) +type EgressFilter = func([]*payload.Object_Distance) ([]*payload.Object_Distance, *filter.RetryConfig, error) type server struct { manager manager.Manager @@ -83,6 +85,59 @@ func (s *server) Search( req *payload.Search_Request, ) (res *payload.Search_Response, err error) { cfg := req.GetConfig() + num := int(cfg.GetNum()) + + for i := 0; i < 50; i++ { + res, err = s.search(ctx, req) + if err != nil { + return nil, err + } + + if s.egressFilter == nil { + break + } + + filtered, retry, err := s.egressFilter(res.Results) + if err != nil { + log.Warnf("an error occurred while egress filtering: %s", err) + break + } + + res.Results = make([]*payload.Object_Distance, 0, len(filtered)) + for _, r := range filtered { + if r != nil { + res.Results = append(res.Results, r) + } + } + + if !retry.Enabled || i >= retry.MaxRetries || len(res.Results) >= num { + break + } + + req = &payload.Search_Request{ + Vector: req.GetVector(), + Config: &payload.Search_Config{ + RequestId: req.GetConfig().GetRequestId(), + Num: req.GetConfig().GetNum() * uint32(retry.NextNumMultiplier), + Radius: req.GetConfig().GetRadius(), + Epsilon: req.GetConfig().GetEpsilon(), + Timeout: req.GetConfig().GetTimeout(), + }, + } + } + + if num != 0 && len(res.GetResults()) > num { + res.Results = res.Results[:num] + } + + return res, err +} + +func (s *server) search( + ctx context.Context, + req *payload.Search_Request, +) (res *payload.Search_Response, err error) { + cfg := req.GetConfig() timeout := getTimeout(cfg) num := int(cfg.GetNum()) @@ -145,18 +200,6 @@ func (s *server) Search( res.Results = res.Results[:num] } - if s.egressFilter != nil { - filtered, err := s.egressFilter(res.Results) - if err == nil { - res.Results = make([]*payload.Object_Distance, 0, len(filtered)) - for _, r := range filtered { - if r != nil { - res.Results = append(res.Results, r) - } - } - } - } - return res, nil case dist := <-dch: nres := len(res.GetResults())