Skip to content

Commit

Permalink
✨ Add Search Query Interceptor
Browse files Browse the repository at this point in the history
Signed-off-by: Rintaro Okamura <[email protected]>
  • Loading branch information
rinx committed Jun 4, 2021
1 parent 8b00b3e commit d303fbf
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 2 deletions.
8 changes: 7 additions & 1 deletion examples/config/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ server = {
end
end,

-- server-side Search Query Interceptor
search_query_interceptor = function (request)
-- print(string.format("Searching top %d neighbors", request.Config.Num))
end,

-- server-side Insert Data Interceptor
insert_data_interceptor = function (request)
print(string.format("Inserting ID: %s", request.Vector.Id))
-- print(string.format("Inserting ID: %s", request.Vector.Id))
end,
}
4 changes: 3 additions & 1 deletion pkg/alvd/cli/server/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ type Opts struct {
CreateIndexThreshold uint

SearchResultInterceptor *lua.LFunction
InsertDataInterceptor *lua.LFunction
SearchQueryInterceptor *lua.LFunction
InsertDataInterceptor *lua.LFunction
}

var Flags = []cli.Flag{
Expand Down Expand Up @@ -155,6 +156,7 @@ func ToConfig(opts *Opts, agentOpts *agent.Opts) (*config.Config, error) {
config.WithCheckIndexInterval(opts.CheckIndexInterval),
config.WithCreateIndexThreshold(opts.CreateIndexThreshold),
config.WithSearchResultInterceptor(opts.SearchResultInterceptor),
config.WithSearchQueryInterceptor(opts.SearchQueryInterceptor),
config.WithInsertDataInterceptor(opts.InsertDataInterceptor),
)
if err != nil {
Expand Down
33 changes: 33 additions & 0 deletions pkg/alvd/extension/lua/lua.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,39 @@ func NewSearchResultInterceptorFn(sri *LFunction) SearchResultInterceptor {
}
}

type SearchQueryInterceptor = func(*payload.Search_Request) (
*payload.Search_Request,
error,
)

func NewSearchQueryInterceptorFn(sqi *LFunction) SearchQueryInterceptor {
return func(origin *payload.Search_Request) (
req *payload.Search_Request,
err error,
) {
state := lua.NewState()
defer state.Close()

libs.Preload(state)

req = origin

err = state.CallByParam(
lua.P{
Fn: sqi,
NRet: 0,
Protect: true,
},
luar.New(state, req),
)
if err != nil {
return origin, err
}

return req, nil
}
}

type InsertDataInterceptor = func(*payload.Insert_Request) (
*payload.Insert_Request,
error,
Expand Down
1 change: 1 addition & 0 deletions pkg/alvd/server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Config struct {
CreateIndexThreshold int

SearchResultInterceptor *lua.LFunction
SearchQueryInterceptor *lua.LFunction
InsertDataInterceptor *lua.LFunction
}

Expand Down
8 changes: 8 additions & 0 deletions pkg/alvd/server/config/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ func WithSearchResultInterceptor(sri *lua.LFunction) OptionFunc {
}
}

func WithSearchQueryInterceptor(sqi *lua.LFunction) OptionFunc {
return func(c *Config) error {
c.SearchQueryInterceptor = sqi

return nil
}
}

func WithInsertDataInterceptor(idi *lua.LFunction) OptionFunc {
return func(c *Config) error {
c.InsertDataInterceptor = idi
Expand Down
6 changes: 6 additions & 0 deletions pkg/alvd/server/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ func New(cfg *config.Config) (Daemon, error) {
)
}

if cfg.SearchQueryInterceptor != nil {
h.RegisterSearchQueryInterceptor(
lua.NewSearchQueryInterceptorFn(cfg.SearchQueryInterceptor),
)
}

if cfg.InsertDataInterceptor != nil {
h.RegisterInsertDataInterceptor(
lua.NewInsertDataInterceptorFn(cfg.InsertDataInterceptor),
Expand Down
13 changes: 13 additions & 0 deletions pkg/alvd/server/service/gateway/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ type server struct {
numReplica int

searchResultInterceptor lua.SearchResultInterceptor
searchQueryInterceptor lua.SearchQueryInterceptor
insertDataInterceptor lua.InsertDataInterceptor
}

type Server interface {
vald.Server
RegisterSearchResultInterceptor(sri lua.SearchResultInterceptor)
RegisterSearchQueryInterceptor(sri lua.SearchQueryInterceptor)
RegisterInsertDataInterceptor(idi lua.InsertDataInterceptor)
}

Expand All @@ -48,6 +50,10 @@ func (s *server) RegisterSearchResultInterceptor(sri lua.SearchResultInterceptor
s.searchResultInterceptor = sri
}

func (s *server) RegisterSearchQueryInterceptor(sqi lua.SearchQueryInterceptor) {
s.searchQueryInterceptor = sqi
}

func (s *server) RegisterInsertDataInterceptor(idi lua.InsertDataInterceptor) {
s.insertDataInterceptor = idi
}
Expand Down Expand Up @@ -88,6 +94,13 @@ func (s *server) Search(
ctx context.Context,
req *payload.Search_Request,
) (res *payload.Search_Response, err error) {
if s.searchQueryInterceptor != nil {
req, err = s.searchQueryInterceptor(req)
if err != nil {
log.Warnf("an error occurred while using search query interceptor: %s", err)
}
}

cfg := req.GetConfig()
num := int(cfg.GetNum())

Expand Down

0 comments on commit d303fbf

Please sign in to comment.