diff --git a/changelogs/current.yaml b/changelogs/current.yaml index c864088fa37f..3b67f584d979 100644 --- a/changelogs/current.yaml +++ b/changelogs/current.yaml @@ -265,6 +265,10 @@ new_features: change: | Added support in SNI dynamic forward proxy for saving the resolved upstream address in the filter state. The state is saved with the key ``envoy.stream.upstream_address``. +- area: lua + change: | + Added a new ``setUpstreamOverrideHost()`` which could be used to set the given host as the upstream host for the + current request. deprecated: - area: rbac diff --git a/docs/root/configuration/http/http_filters/lua_filter.rst b/docs/root/configuration/http/http_filters/lua_filter.rst index 71c8b502bb8f..49bd10473304 100644 --- a/docs/root/configuration/http/http_filters/lua_filter.rst +++ b/docs/root/configuration/http/http_filters/lua_filter.rst @@ -558,6 +558,36 @@ Returns connection-level :repo:`information ` r Returns a connection-level :ref:`stream info object `. +``setUpstreamOverrideHost()`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: lua + + handle:setUpstreamOverrideHost(host, strict) + +Sets an upstream address override for the request. When the overridden host is available and can be selected directly, +the load balancer bypasses its algorithm and routes traffic directly to the specified host. The strict flag determines +whether the HTTP request must strictly use the overridden destination. If the destination is unavailable and strict is +set to true, Envoy responds with a 503 Service Unavailable error. + +The function takes two arguments: + +* ``host`` (string): The host address to be used for upstream requests. +* ``strict`` (boolean, optional): If set to true, the request is strictly routed to the overridden host. If the host is + unavailable, Envoy returns a 503 error. Defaults to false. + +Example: + +.. code-block:: lua + + function envoy_on_request(request_handle) + -- Override upstream host without strict mode + request_handle:setUpstreamOverrideHost("192.168.21.13", false) + + -- Override upstream host with strict mode + request_handle:setUpstreamOverrideHost("192.168.21.13", true) + end + ``importPublicKey()`` ^^^^^^^^^^^^^^^^^^^^^ diff --git a/source/extensions/filters/http/lua/lua_filter.cc b/source/extensions/filters/http/lua/lua_filter.cc index d32ddf0cfa32..26b14351f0bf 100644 --- a/source/extensions/filters/http/lua/lua_filter.cc +++ b/source/extensions/filters/http/lua/lua_filter.cc @@ -963,6 +963,32 @@ void Filter::scriptLog(spdlog::level::level_enum level, absl::string_view messag } } +int StreamHandleWrapper::luaSetUpstreamOverrideHost(lua_State* state) { + // Get the host address argument + size_t len; + const char* host = luaL_checklstring(state, 2, &len); + + // Validate that host is not null and is an IP address + if (host == nullptr) { + luaL_error(state, "host argument is required"); + } + if (!Http::Utility::parseAuthority(host).is_ip_address_) { + luaL_error(state, "host is not a valid IP address"); + } + + // Get the optional strict flag (defaults to false) + bool strict = false; + if (lua_gettop(state) >= 3) { + luaL_checktype(state, 3, LUA_TBOOLEAN); + strict = lua_toboolean(state, 3); + } + + // Set the upstream override host + callbacks_.setUpstreamOverrideHost(std::make_pair(std::string(host, len), strict)); + + return 0; +} + void Filter::DecoderCallbacks::respond(Http::ResponseHeaderMapPtr&& headers, Buffer::Instance* body, lua_State*) { uint64_t status = Http::Utility::getResponseStatus(*headers); diff --git a/source/extensions/filters/http/lua/lua_filter.h b/source/extensions/filters/http/lua/lua_filter.h index 305189172d92..154c94b92398 100644 --- a/source/extensions/filters/http/lua/lua_filter.h +++ b/source/extensions/filters/http/lua/lua_filter.h @@ -111,6 +111,12 @@ class FilterCallbacks { * @return const Tracing::Span& the current tracing active span. */ virtual Tracing::Span& activeSpan() PURE; + + /** + * Set the upstream host override. + * @param host_and_strict supplies the host and whether the host should be treated as strict. + */ + virtual void setUpstreamOverrideHost(std::pair host_and_strict) PURE; }; class Filter; @@ -187,7 +193,8 @@ class StreamHandleWrapper : public Filters::Common::Lua::BaseLuaObject { return callbacks_->connection().ptr(); } Tracing::Span& activeSpan() override { return callbacks_->activeSpan(); } + void setUpstreamOverrideHost(std::pair host_and_strict) override { + callbacks_->setUpstreamOverrideHost(std::move(host_and_strict)); + } Filter& parent_; Http::StreamDecoderFilterCallbacks* callbacks_{}; @@ -595,6 +612,9 @@ class Filter : public Http::StreamFilter, Logger::Loggable { return callbacks_->connection().ptr(); } Tracing::Span& activeSpan() override { return callbacks_->activeSpan(); } + void setUpstreamOverrideHost(std::pair host_and_strict) override { + UNREFERENCED_PARAMETER(host_and_strict); + } Filter& parent_; Http::StreamEncoderFilterCallbacks* callbacks_{}; diff --git a/test/extensions/filters/http/lua/lua_filter_test.cc b/test/extensions/filters/http/lua/lua_filter_test.cc index 77a578610252..cde58e704de8 100644 --- a/test/extensions/filters/http/lua/lua_filter_test.cc +++ b/test/extensions/filters/http/lua/lua_filter_test.cc @@ -3033,6 +3033,144 @@ TEST_F(LuaHttpFilterTest, StatsWithPerFilterPrefix) { EXPECT_EQ(2, stats_store_.counter("test.lua.my_script.errors").value()); } +// Test successful upstream host override +TEST_F(LuaHttpFilterTest, SetUpstreamOverrideHost) { + const std::string SCRIPT{R"EOF( + function envoy_on_request(request_handle) + request_handle:setUpstreamOverrideHost("192.168.21.11", false) + end + )EOF"}; + + InSequence s; + setup(SCRIPT); + + Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}}; + EXPECT_CALL(decoder_callbacks_, + setUpstreamOverrideHost(testing::Pair(testing::Eq("192.168.21.11"), false))); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true)); +} + +// Test upstream host override with strict flag set to true +TEST_F(LuaHttpFilterTest, SetUpstreamOverrideHostStrict) { + const std::string SCRIPT{R"EOF( + function envoy_on_request(request_handle) + request_handle:setUpstreamOverrideHost("192.168.21.11", true) + end + )EOF"}; + + InSequence s; + setup(SCRIPT); + + Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}}; + EXPECT_CALL(decoder_callbacks_, + setUpstreamOverrideHost(testing::Pair(testing::Eq("192.168.21.11"), true))); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true)); +} + +// Test that setUpstreamOverrideHost requires a host argument +TEST_F(LuaHttpFilterTest, SetUpstreamOverrideHostNoArgument) { + const std::string SCRIPT{R"EOF( + function envoy_on_request(request_handle) + request_handle:setUpstreamOverrideHost() + end + )EOF"}; + + InSequence s; + setup(SCRIPT); + + Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}}; + EXPECT_CALL(*filter_, + scriptLog(spdlog::level::err, + StrEq("[string \"...\"]:3: bad argument #1 to 'setUpstreamOverrideHost' " + "(string expected, got no value)"))); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true)); + EXPECT_EQ(1, stats_store_.counter("test.lua.errors").value()); +} + +// Test that setUpstreamOverrideHost validates the argument type for strict flag +TEST_F(LuaHttpFilterTest, SetUpstreamOverrideHostInvalidStrictType) { + const std::string SCRIPT{R"EOF( + function envoy_on_request(request_handle) + request_handle:setUpstreamOverrideHost("192.168.21.11", "not_a_boolean") + end + )EOF"}; + + InSequence s; + setup(SCRIPT); + + Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}}; + EXPECT_CALL(*filter_, + scriptLog(spdlog::level::err, + StrEq("[string \"...\"]:3: bad argument #2 to 'setUpstreamOverrideHost' " + "(boolean expected, got string)"))); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true)); + EXPECT_EQ(1, stats_store_.counter("test.lua.errors").value()); +} + +// Test that setUpstreamOverrideHost can be called on different paths +TEST_F(LuaHttpFilterTest, SetUpstreamOverrideHostDifferentPaths) { + const std::string SCRIPT{R"EOF( + function envoy_on_request(request_handle) + request_handle:setUpstreamOverrideHost("192.168.21.11", true) + end + )EOF"}; + + InSequence s; + setup(SCRIPT); + + { + Http::TestRequestHeaderMapImpl request_headers{{":path", "/path1"}}; + EXPECT_CALL(decoder_callbacks_, + setUpstreamOverrideHost(testing::Pair(testing::Eq("192.168.21.11"), true))); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true)); + } + + setupFilter(); + + { + Http::TestRequestHeaderMapImpl request_headers{{":path", "/path2"}}; + EXPECT_CALL(decoder_callbacks_, + setUpstreamOverrideHost(testing::Pair(testing::Eq("192.168.21.11"), true))); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true)); + } +} + +// Test empty host argument +TEST_F(LuaHttpFilterTest, SetUpstreamOverrideHostEmptyHost) { + const std::string SCRIPT{R"EOF( + function envoy_on_request(request_handle) + request_handle:setUpstreamOverrideHost("", false) + end + )EOF"}; + + InSequence s; + setup(SCRIPT); + + Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}}; + EXPECT_CALL(*filter_, scriptLog(spdlog::level::err, + StrEq("[string \"...\"]:3: host is not a valid IP address"))); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true)); + EXPECT_EQ(1, stats_store_.counter("test.lua.errors").value()); +} + +// Test that setUpstreamOverrideHost rejects non-IP hosts +TEST_F(LuaHttpFilterTest, SetUpstreamOverrideHostNonIpHost) { + const std::string SCRIPT{R"EOF( + function envoy_on_request(request_handle) + request_handle:setUpstreamOverrideHost("example.com", false) + end + )EOF"}; + + InSequence s; + setup(SCRIPT); + + Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}}; + EXPECT_CALL(*filter_, scriptLog(spdlog::level::err, + StrEq("[string \"...\"]:3: host is not a valid IP address"))); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, true)); + EXPECT_EQ(1, stats_store_.counter("test.lua.errors").value()); +} + } // namespace } // namespace Lua } // namespace HttpFilters