From 4a5cd82abadf9bb7aae4630b6d4fe00c4ab4dd35 Mon Sep 17 00:00:00 2001 From: Liang Deng <283304489@qq.com> Date: Tue, 7 Nov 2023 17:00:23 +0800 Subject: [PATCH] fix: fix some nat traversal bugs Signed-off-by: Liang Deng <283304489@qq.com> --- go.mod | 2 +- go.sum | 2 + pkg/engine/tunnel.go | 75 +++++ .../vpndriver/libreswan/libreswan.go | 289 ++++++++++++++---- .../vpndriver/libreswan/libreswan_test.go | 10 +- .../vpndriver/wireguard/wireguard.go | 67 ++-- pkg/tunnelengine/tunnelagent.go | 64 +--- pkg/utils/stun.go | 89 +++--- 8 files changed, 412 insertions(+), 186 deletions(-) diff --git a/go.mod b/go.mod index 1de7d16..e70d901 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.18 require ( github.com/EvilSuperstars/go-cidrman v0.0.0-20190607145828-28e79e32899a - github.com/ccding/go-stun/stun v0.0.0-20200514191101-4dc67bcdb029 github.com/coreos/go-iptables v0.6.0 github.com/gorilla/mux v1.8.0 github.com/lorenzosaino/go-sysctl v0.3.1 @@ -37,6 +36,7 @@ require ( github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver v3.5.1+incompatible // indirect + github.com/ccding/go-stun v0.1.5-0.20230908213042-0f417a9a4966 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/coreos/go-semver v0.3.0 // indirect github.com/coreos/go-systemd/v22 v22.3.2 // indirect diff --git a/go.sum b/go.sum index 9f75057..53b33dd 100644 --- a/go.sum +++ b/go.sum @@ -96,6 +96,8 @@ github.com/bketelsen/crypt v0.0.4/go.mod h1:aI6NrJ0pMGgvZKL1iVgXLnfIFJtfV+bKCoqO github.com/blang/semver v3.5.0+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ= github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= +github.com/ccding/go-stun v0.1.5-0.20230908213042-0f417a9a4966 h1:ugTbop8ITMmnyZRFFQZ0LDnEi+m28dDU7Jxf6cYoA5M= +github.com/ccding/go-stun v0.1.5-0.20230908213042-0f417a9a4966/go.mod h1:cCZjJ1J3WFSJV6Wj8Y9Di8JMTsEXh6uv2eNmLzKaUeM= github.com/ccding/go-stun/stun v0.0.0-20200514191101-4dc67bcdb029 h1:POmUHfxXdeyM8Aomg4tKDcwATCFuW+cYLkj6pwsw9pc= github.com/ccding/go-stun/stun v0.0.0-20200514191101-4dc67bcdb029/go.mod h1:Rpr5n9cGHYdM3S3IK8ROSUUUYjQOu+MSUCZDcJbYWi8= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= diff --git a/pkg/engine/tunnel.go b/pkg/engine/tunnel.go index 9b39558..595e849 100644 --- a/pkg/engine/tunnel.go +++ b/pkg/engine/tunnel.go @@ -1,8 +1,10 @@ package engine import ( + "context" "fmt" + "k8s.io/client-go/util/retry" "k8s.io/client-go/util/workqueue" "k8s.io/klog/v2" "sigs.k8s.io/controller-runtime/pkg/client" @@ -97,7 +99,71 @@ func (t *TunnelEngine) clearDriver() error { return nil } +func (t *TunnelEngine) configGatewayListStunInfo() error { + var gws v1beta1.GatewayList + if err := t.client.List(context.Background(), &gws); err != nil { + return err + } + for i := range gws.Items { + // try to update info required by nat traversal + gw := &gws.Items[i] + if ep := getTunnelActiveEndpoints(gw); ep != nil { + if ep.NATType == "" || ep.NATType != utils.NATSymmetric && ep.PublicPort == 0 { + err := t.configGatewayStunInfo(gw) + if err != nil { + klog.ErrorS(err, "error config gateway nat type", "gateway", klog.KObj(gw)) + } + } + + } + } + return nil +} + +func (t *TunnelEngine) configGatewayStunInfo(gateway *v1beta1.Gateway) error { + if getTunnelActiveEndpoints(gateway).NodeName != t.nodeName { + return nil + } + + natType, err := utils.GetNATType() + if err != nil { + return err + } + + publicPort, err := utils.GetPublicPort() + if err != nil { + return err + } + + // retry to update nat type of localGateway + err = retry.RetryOnConflict(retry.DefaultBackoff, func() error { + // get localGateway from api server + var apiGw v1beta1.Gateway + err := t.client.Get(context.Background(), client.ObjectKey{ + Name: gateway.Name, + }, &apiGw) + if err != nil { + return err + } + for k, v := range apiGw.Spec.Endpoints { + if v.NodeName == t.nodeName { + apiGw.Spec.Endpoints[k].NATType = natType + if natType != utils.NATSymmetric { + apiGw.Spec.Endpoints[k].PublicPort = publicPort + } + err = t.client.Update(context.Background(), &apiGw) + return err + } + } + return nil + }) + return err +} + func (t *TunnelEngine) reconcile() error { + if err := t.configGatewayListStunInfo(); err != nil { + return err + } if t.routeDriver == nil || t.vpnDriver == nil { err := t.initDriver() if err != nil { @@ -124,3 +190,12 @@ func (t *TunnelEngine) handleEventErr(err error, event interface{}) { klog.Info(utils.FormatRavenEngine("dropping event %q out of the queue: %v", event, err)) t.queue.Forget(event) } + +func getTunnelActiveEndpoints(gw *v1beta1.Gateway) *v1beta1.Endpoint { + for _, aep := range gw.Status.ActiveEndpoints { + if aep.Type == v1beta1.Tunnel { + return aep.DeepCopy() + } + } + return nil +} diff --git a/pkg/networkengine/vpndriver/libreswan/libreswan.go b/pkg/networkengine/vpndriver/libreswan/libreswan.go index 13d8d31..e909685 100644 --- a/pkg/networkengine/vpndriver/libreswan/libreswan.go +++ b/pkg/networkengine/vpndriver/libreswan/libreswan.go @@ -20,6 +20,7 @@ import ( "fmt" "os" "os/exec" + "strconv" "syscall" "time" @@ -57,9 +58,10 @@ const ( ) type libreswan struct { - connections map[string]*vpndriver.Connection - nodeName types.NodeName - listenPort string + relayConnections map[string]*vpndriver.Connection + edgeConnections map[string]*vpndriver.Connection + nodeName types.NodeName + listenPort string iptables iptablesutil.IPTablesInterface } @@ -91,9 +93,10 @@ func (l *libreswan) Init() (err error) { func New(cfg *config.Config) (vpndriver.Driver, error) { return &libreswan{ - connections: make(map[string]*vpndriver.Connection), - nodeName: types.NodeName(cfg.NodeName), - listenPort: cfg.Tunnel.VPNPort, + relayConnections: make(map[string]*vpndriver.Connection), + edgeConnections: make(map[string]*vpndriver.Connection), + nodeName: types.NodeName(cfg.NodeName), + listenPort: cfg.Tunnel.VPNPort, }, nil } @@ -122,97 +125,205 @@ func (l *libreswan) MTU() (int, error) { return mtu - IPSecEncapLen, nil } +// getEndpointResolver returns a function that resolve the left subnets and the Endpoint that should connect to. +func (l *libreswan) getEndpointResolver(network *types.Network) func(centralGw, remoteGw *types.Endpoint) (leftSubnets []string, connectTo *types.Endpoint) { + snUnderNAT := make(map[types.GatewayName]*types.Endpoint) + for _, v := range network.RemoteEndpoints { + if v.UnderNAT { + snUnderNAT[v.GatewayName] = v + } + } + return func(centralGw, remoteGw *types.Endpoint) (leftSubnets []string, connectTo *types.Endpoint) { + leftSubnets = network.LocalEndpoint.Subnets + if centralGw == nil { + // If both local and remote gateway are NATed but no central gateway found, + // we cannot set up vpn connections between the local and remote gateway. + if network.LocalEndpoint.UnderNAT && remoteGw.UnderNAT { + return nil, nil + } + return leftSubnets, remoteGw + } + + if centralGw.NodeName == l.nodeName { + if remoteGw.UnderNAT { + // If the local gateway is the central gateway, + // in order to forward traffic from other NATed gateway to the NATed remoteGw, + // append all subnets of other NATed gateways into left subnets. + for gwName, v := range snUnderNAT { + if gwName != remoteGw.GatewayName { + if !enableCreateEdgeConnection(v, remoteGw) { + leftSubnets = append(leftSubnets, v.Subnets...) + } + } + } + } + return leftSubnets, remoteGw + } + + // If both local and remote are NATed, and the local gateway is not the central gateway, + // and can't create edge to edge tunnel, connects to central gateway to forward traffic. + if network.LocalEndpoint.UnderNAT && remoteGw.UnderNAT { + if !enableCreateEdgeConnection(network.LocalEndpoint, remoteGw) { + return leftSubnets, centralGw + } + } + + return leftSubnets, remoteGw + } +} + func (l *libreswan) createConnections(network *types.Network) error { - errList := errorlist.List{} - desiredConnections := l.computeDesiredConnections(network) - if len(desiredConnections) == 0 { + desiredEdgeConns, desiredRelayConns := l.computeDesiredConnections(network) + if len(desiredEdgeConns) == 0 && len(desiredRelayConns) == 0 { klog.Infof(utils.FormatTunnel("no desired connections, cleaning vpn connections")) l.Cleanup() return nil } + centralGw := findCentralGw(network) + if err := l.createEdgeConnections(desiredEdgeConns); err != nil { + return err + } + if err := l.createRelayConnections(desiredRelayConns, centralGw); err != nil { + return err + } + + return nil +} + +func (l *libreswan) createEdgeConnections(desiredEdgeConns map[string]*vpndriver.Connection) error { + if len(desiredEdgeConns) == 0 { + klog.Infof("no desired edge connections") + return nil + } + + errList := errorlist.List{} + // remove unwanted connections - for connName := range l.connections { - if _, ok := desiredConnections[connName]; !ok { + for connName := range l.edgeConnections { + if _, ok := desiredEdgeConns[connName]; !ok { err := l.whackDelConnection(connName) if err != nil { errList = errList.Append(err) klog.ErrorS(err, "error disconnecting endpoint", "connectionName", connName) continue } - delete(l.connections, connName) + delete(l.edgeConnections, connName) } } // add new connections - for name, connection := range desiredConnections { - err := l.connectToEndpoint(name, connection) + for name, connection := range desiredEdgeConns { + err := l.connectToEdgeEndpoint(name, connection) errList = errList.Append(err) } return errList.AsError() } -func (l *libreswan) computeDesiredConnections(network *types.Network) map[string]*vpndriver.Connection { - centralGw := findCentralGw(network) - // This is the desired connection calculated from given *types.Network - desiredConns := make(map[string]*vpndriver.Connection) +func (l *libreswan) createRelayConnections(desiredRelayConns map[string]*vpndriver.Connection, centralGw *types.Endpoint) error { + if len(desiredRelayConns) == 0 { + klog.Infof("no desired relay connections") + return nil + } - leftEndpoint := network.LocalEndpoint - for _, remote := range network.RemoteEndpoints { - leftSubnets, connectTo := l.resolveEndpoint(network, centralGw, remote) - for _, leftSubnet := range leftSubnets { - for _, rightSubnet := range remote.Subnets { - name := connectionName(leftEndpoint.PrivateIP, remote.PrivateIP, leftSubnet, rightSubnet) - desiredConns[name] = &vpndriver.Connection{ - LocalEndpoint: leftEndpoint.Copy(), - RemoteEndpoint: connectTo.Copy(), - LocalSubnet: leftSubnet, - RemoteSubnet: rightSubnet, - } + errList := errorlist.List{} + + // remove unwanted connections + for connName := range l.relayConnections { + if _, ok := desiredRelayConns[connName]; !ok { + err := l.whackDelConnection(connName) + if err != nil { + errList = errList.Append(err) + klog.ErrorS(err, "error disconnecting endpoint", "connectionName", connName) + continue } + delete(l.relayConnections, connName) + if centralGw.NodeName == l.nodeName { + errList = errList.Append(l.deleteRavenSkipNAT(centralGw, l.relayConnections[connName])) + } + } + } + + // add new connections + for name, connection := range desiredRelayConns { + err := l.connectToEndpoint(name, connection) + errList = errList.Append(err) + if centralGw.NodeName == l.nodeName { + err = l.ensureRavenSkipNAT(centralGw, connection) + errList = errList.Append(err) } } - return desiredConns + return errList.AsError() } -func (l *libreswan) resolveEndpoint(network *types.Network, centralGw, remoteGw *types.Endpoint) (leftSubnets []string, connectTo *types.Endpoint) { - snUnderNAT := make(map[types.GatewayName][]string) - for _, v := range network.RemoteEndpoints { - if v.UnderNAT && !enableCreateEdgeConnection(v, remoteGw) { - snUnderNAT[v.GatewayName] = v.Subnets +func (l *libreswan) ensureRavenSkipNAT(centralGw *types.Endpoint, connection *vpndriver.Connection) errorlist.List { + errList := errorlist.List{} + // for raven skip nat + if err := l.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { + errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.RavenPostRoutingChain, err)) + } + if err := l.iptables.InsertIfNotExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, 1, "-m", "comment", "--comment", "raven traffic should skip NAT", "-j", iptablesutil.RavenPostRoutingChain); err != nil { + errList = errList.Append(fmt.Errorf("error adding chain %s rule: %s", iptablesutil.PostRoutingChain, err)) + } + for _, subnet := range centralGw.Subnets { + if connection.LocalSubnet == subnet || connection.RemoteSubnet == subnet { + return errList } } - leftSubnets = network.LocalEndpoint.Subnets - if centralGw == nil { - // If both local and remote gateway are NATed but no central gateway found, - // we cannot set up vpn connections between the local and remote gateway. - if network.LocalEndpoint.UnderNAT && remoteGw.UnderNAT { - return nil, nil + if err := l.iptables.AppendIfNotExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-s", connection.LocalSubnet, "-d", connection.RemoteSubnet, "-j", "ACCEPT"); err != nil { + errList = errList.Append(fmt.Errorf("error adding chain %s rule: %s", iptablesutil.RavenPostRoutingChain, err)) + } + return errList +} + +func (l *libreswan) deleteRavenSkipNAT(centralGw *types.Endpoint, connection *vpndriver.Connection) errorlist.List { + errList := errorlist.List{} + err := l.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) + if err != nil { + errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err)) + } + for _, subnet := range centralGw.Subnets { + if connection.LocalSubnet == subnet || connection.RemoteSubnet == subnet { + return errList } - return leftSubnets, remoteGw } + err = l.iptables.DeleteIfExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-s", connection.LocalSubnet, "-d", connection.RemoteSubnet, "-j", "ACCEPT") + if err != nil { + errList = errList.Append(fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.RavenPostRoutingChain, err)) + } + return errList +} + +func (l *libreswan) computeDesiredConnections(network *types.Network) (map[string]*vpndriver.Connection, map[string]*vpndriver.Connection) { + centralGw := findCentralGw(network) + desiredEdgeConns := make(map[string]*vpndriver.Connection) + desiredRelayConns := make(map[string]*vpndriver.Connection) + resolveEndpoint := l.getEndpointResolver(network) - if centralGw.NodeName == l.nodeName { - if remoteGw.UnderNAT { - // If the local gateway is the central gateway, - // in order to forward traffic from other NATed gateway to the NATed remoteGw, - // append all subnets of other NATed gateways into left subnets. - for gwName, v := range snUnderNAT { - if gwName != remoteGw.GatewayName { - leftSubnets = append(leftSubnets, v...) + leftEndpoint := network.LocalEndpoint + for _, remoteGw := range network.RemoteEndpoints { + leftSubnets, connectTo := resolveEndpoint(centralGw, remoteGw) + for _, leftSubnet := range leftSubnets { + for _, rightSubnet := range remoteGw.Subnets { + name := connectionName(leftEndpoint.PrivateIP, connectTo.PrivateIP, leftSubnet, rightSubnet) + connect := &vpndriver.Connection{ + LocalEndpoint: leftEndpoint.Copy(), + RemoteEndpoint: connectTo.Copy(), + LocalSubnet: leftSubnet, + RemoteSubnet: rightSubnet, + } + if enableCreateEdgeConnection(leftEndpoint.Copy(), connectTo.Copy()) { + desiredEdgeConns[name] = connect + } else { + desiredRelayConns[name] = connect } } } - return leftSubnets, remoteGw - } - - if !enableCreateEdgeConnection(network.LocalEndpoint, remoteGw) { - return leftSubnets, centralGw } - return leftSubnets, remoteGw + return desiredEdgeConns, desiredRelayConns } func (l *libreswan) whackConnectToEndpoint(connectionName string, connection *vpndriver.Connection) error { @@ -264,6 +375,39 @@ func (l *libreswan) whackConnectToEndpoint(connectionName string, connection *vp return nil } +func (l *libreswan) whackConnectToEdgeEndpoint(connectionName string, connection *vpndriver.Connection) error { + args := make([]string, 0) + leftID := fmt.Sprintf("@%s-%s-%s", connection.LocalEndpoint.PrivateIP, connection.LocalSubnet, connection.RemoteSubnet) + rightID := fmt.Sprintf("@%s-%s-%s", connection.RemoteEndpoint.PrivateIP, connection.RemoteSubnet, connection.LocalSubnet) + + if err := whackCmd("--delete", "--name", connectionName); err != nil { + return err + } + // local + args = append(args, "--psk", "--encrypt", "--forceencaps", "--name", connectionName, + "--id", leftID, + "--host", connection.LocalEndpoint.String(), + "--client", connection.LocalSubnet) + // remote + args = append(args, "--to", + "--id", rightID, + "--host", connection.RemoteEndpoint.PublicIP, + "--client", connection.RemoteSubnet, + "--ikeport", strconv.Itoa(connection.RemoteEndpoint.PublicPort)) + + if err := whackCmd(args...); err != nil { + return err + } + if err := whackCmd("--route", "--name", connectionName); err != nil { + return err + } + if err := whackCmd("--initiate", "--asynchronous", "--name", connectionName); err != nil { + return err + } + + return nil +} + func whackCmdFn(args ...string) error { var err error var output []byte @@ -292,13 +436,20 @@ func connectionName(localID, remoteID, leftSubnet, rightSubnet string) string { func (l *libreswan) Cleanup() error { errList := errorlist.List{} - for name := range l.connections { + for name := range l.relayConnections { + if err := l.whackDelConnection(name); err != nil { + errList = errList.Append(err) + klog.ErrorS(err, "fail to delete connection", "connectionName", name) + } + } + for name := range l.edgeConnections { if err := l.whackDelConnection(name); err != nil { errList = errList.Append(err) klog.ErrorS(err, "fail to delete connection", "connectionName", name) } } - l.connections = make(map[string]*vpndriver.Connection) + l.relayConnections = make(map[string]*vpndriver.Connection) + l.edgeConnections = make(map[string]*vpndriver.Connection) err := netlinkutil.XfrmPolicyFlush() errList = errList.Append(err) @@ -353,7 +504,7 @@ func (l *libreswan) runPluto() error { func (l *libreswan) connectToEndpoint(name string, connection *vpndriver.Connection) errorlist.List { errList := errorlist.List{} - if _, ok := l.connections[name]; ok { + if _, ok := l.relayConnections[name]; ok { klog.InfoS("skipping connect because connection already exists", "connectionName", name) return errList } @@ -363,6 +514,22 @@ func (l *libreswan) connectToEndpoint(name string, connection *vpndriver.Connect klog.ErrorS(err, "error connect connection", "connectionName", name) return errList } - l.connections[name] = connection + l.relayConnections[name] = connection + return errList +} + +func (l *libreswan) connectToEdgeEndpoint(name string, connection *vpndriver.Connection) errorlist.List { + errList := errorlist.List{} + if _, ok := l.edgeConnections[name]; ok { + klog.InfoS("skipping connect because connection already exists", "connectionName", name) + return errList + } + err := l.whackConnectToEdgeEndpoint(name, connection) + if err != nil { + errList = errList.Append(err) + klog.ErrorS(err, "error connect connection", "connectionName", name) + return errList + } + l.edgeConnections[name] = connection return errList } diff --git a/pkg/networkengine/vpndriver/libreswan/libreswan_test.go b/pkg/networkengine/vpndriver/libreswan/libreswan_test.go index 5817baa..e46e9ec 100644 --- a/pkg/networkengine/vpndriver/libreswan/libreswan_test.go +++ b/pkg/networkengine/vpndriver/libreswan/libreswan_test.go @@ -27,6 +27,7 @@ import ( netlinkutil "github.com/openyurtio/raven/pkg/networkengine/util/netlink" "github.com/openyurtio/raven/pkg/networkengine/vpndriver" "github.com/openyurtio/raven/pkg/types" + "github.com/openyurtio/raven/pkg/utils" ) type whackMock struct { @@ -185,6 +186,7 @@ func TestLibreswan_Apply(t *testing.T) { PrivateIP: localGwIP, PublicIP: "1.1.1.1", UnderNAT: true, + NATType: utils.NATSymmetric, }, RemoteEndpoints: map[types.GatewayName]*types.Endpoint{ "centralGw": { @@ -202,6 +204,7 @@ func TestLibreswan_Apply(t *testing.T) { PrivateIP: remoteGw2IP, PublicIP: "1.1.1.3", UnderNAT: true, + NATType: utils.NATSymmetric, }, }, }, @@ -299,6 +302,7 @@ func TestLibreswan_Apply(t *testing.T) { PrivateIP: remoteGw1IP, PublicIP: "1.1.1.2", UnderNAT: true, + NATType: utils.NATSymmetric, }, "remoteGw2": { GatewayName: "remoteGw2", @@ -307,6 +311,7 @@ func TestLibreswan_Apply(t *testing.T) { PrivateIP: remoteGw2IP, PublicIP: "1.1.1.3", UnderNAT: true, + NATType: utils.NATSymmetric, }, }, }, @@ -372,8 +377,9 @@ func TestLibreswan_Apply(t *testing.T) { whackCmd = w.whackCmd a := assert.New(t) l := &libreswan{ - connections: make(map[string]*vpndriver.Connection), - nodeName: types.NodeName(v.nodeName), + relayConnections: make(map[string]*vpndriver.Connection), + edgeConnections: make(map[string]*vpndriver.Connection), + nodeName: types.NodeName(v.nodeName), } var err error l.iptables, err = iptablesutil.New() diff --git a/pkg/networkengine/vpndriver/wireguard/wireguard.go b/pkg/networkengine/vpndriver/wireguard/wireguard.go index 2674ca6..6f75eb9 100644 --- a/pkg/networkengine/vpndriver/wireguard/wireguard.go +++ b/pkg/networkengine/vpndriver/wireguard/wireguard.go @@ -95,7 +95,7 @@ func New(cfg *config.Config) (vpndriver.Driver, error) { edgeConnections: make(map[string]*vpndriver.Connection), nodeName: types.NodeName(cfg.NodeName), ravenClient: cfg.Manager.GetClient(), - listenPort: port, + listenPort: port, }, nil } @@ -377,6 +377,13 @@ func (w *wireguard) Apply(network *types.Network, routeDriverMTUFn func(*types.N return errors.New("retry to config public key") } + centralGw := findCentralGw(network) + if centralGw.NodeName == w.nodeName { + if err := w.ensureRavenSkipNAT(); err != nil { + return fmt.Errorf("error ensure raven skip nat: %s", err) + } + } + if err := w.ensureWgLink(network, routeDriverMTUFn); err != nil { return fmt.Errorf("fail to ensure wireguar link: %v", err) } @@ -439,33 +446,16 @@ func (w *wireguard) Cleanup() error { if err = netlink.LinkDel(link); err != nil { errList = errList.Append(fmt.Errorf("error delete existing wireguard device %q: %v", DeviceName, err)) } + + if err = w.deleteRavenSkipNAT(); err != nil { + errList = errList.Append(err) + } + w.relayConnections = make(map[string]*vpndriver.Connection) w.edgeConnections = make(map[string]*vpndriver.Connection) return errList.AsError() } -// getSubnetResolver returns a function that resolve the left subnets. -func (w *wireguard) getSubnetResolver(network *types.Network) func(remoteGw *types.Endpoint) (leftSubnets []string) { - snUnderNAT := make(map[types.GatewayName][]string) - for _, v := range network.RemoteEndpoints { - if v.UnderNAT { - snUnderNAT[v.GatewayName] = v.Subnets - } - } - return func(remoteGw *types.Endpoint) (leftSubnets []string) { - if remoteGw.UnderNAT { - // In order to forward traffic from other NATed gateway to the NATed remoteGw, - // append all subnets of other NATed gateways into left subnets. - for gwName, v := range snUnderNAT { - if gwName != remoteGw.GatewayName { - leftSubnets = append(leftSubnets, v...) - } - } - } - return leftSubnets - } -} - func (w *wireguard) computeDesiredConnections(network *types.Network) (map[string]*vpndriver.Connection, map[string]*vpndriver.Connection, []string) { // This is the desired edge connections and relay connections calculated from given *types.Network desiredEdgeConns := make(map[string]*vpndriver.Connection) @@ -529,7 +519,7 @@ func (w *wireguard) configGatewayPublicKey(gwName string, nodeName string) error return err } for k, v := range apiGw.Spec.Endpoints { - if v.NodeName == nodeName && v.Type == v1beta1.Tunnel { + if v.NodeName == nodeName && v.Type == v1beta1.Tunnel { if apiGw.Spec.Endpoints[k].Config == nil { apiGw.Spec.Endpoints[k].Config = make(map[string]string) } @@ -604,3 +594,32 @@ func parseSubnets(subnets []string) []net.IPNet { } return nets } + +func (w *wireguard) ensureRavenSkipNAT() error { + // for raven skip nat + if err := w.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { + return fmt.Errorf("error create %s chain: %s", iptablesutil.RavenPostRoutingChain, err) + } + if err := w.iptables.InsertIfNotExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, 1, "-m", "comment", "--comment", "raven traffic should skip NAT", "-o", "raven-wg0", "-j", iptablesutil.RavenPostRoutingChain); err != nil { + return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.PostRoutingChain, err) + } + if err := w.iptables.AppendIfNotExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-j", "ACCEPT"); err != nil { + return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.RavenPostRoutingChain, err) + } + + return nil +} + +func (w *wireguard) deleteRavenSkipNAT() error { + if err := w.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { + return fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err) + } + if err := w.iptables.DeleteIfExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, "-m", "comment", "--comment", "raven traffic should skip NAT", "-o", "raven-wg0", "-j", iptablesutil.RavenPostRoutingChain); err != nil { + return fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.PostRoutingChain, err) + } + if err := w.iptables.ClearAndDeleteChain(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { + return fmt.Errorf("error deleting %s chain %s", iptablesutil.RavenPostRoutingChain, err) + } + + return nil +} diff --git a/pkg/tunnelengine/tunnelagent.go b/pkg/tunnelengine/tunnelagent.go index 73404fc..00b08d1 100644 --- a/pkg/tunnelengine/tunnelagent.go +++ b/pkg/tunnelengine/tunnelagent.go @@ -87,24 +87,12 @@ func (c *TunnelHandler) Handler() error { for i := range gws.Items { // try to update public IP if empty. gw := &gws.Items[i] - if ep := getTunnelActiveEndpoints(gw); ep != nil { - if ep.PublicIP == "" || ep.NATType == "" || ep.NATType != utils.NATSymmetric && ep.PublicPort == 0 { - // try to update public IP if empty. - if ep.PublicIP == "" { - err := c.configGatewayPublicIP(gw) - if err != nil { - klog.ErrorS(err, "error config gateway public ip", "gateway", klog.KObj(gw)) - } - } - // try to update NAT type if empty - if ep.NATType == "" || ep.NATType != utils.NATSymmetric && ep.PublicPort == 0 { - err := c.configGatewayStun(gw) - if err != nil { - klog.ErrorS(err, "error config gateway nat type", "gateway", klog.KObj(gw)) - } - } - continue + if ep := getTunnelActiveEndpoints(gw); ep != nil && ep.PublicIP == "" { + err := c.configGatewayPublicIP(gw) + if err != nil { + klog.ErrorS(err, "error config gateway public ip", "gateway", klog.KObj(gw)) } + continue } if !c.shouldHandleGateway(gw) { continue @@ -220,9 +208,11 @@ func (c *TunnelHandler) shouldHandleGateway(gateway *v1beta1.Gateway) bool { } if getTunnelActiveEndpoints(gateway).NATType == "" { klog.InfoS("no nat type for gateway, waiting for sync", "gateway", klog.KObj(gateway)) + return false } if getTunnelActiveEndpoints(gateway).NATType != utils.NATSymmetric && getTunnelActiveEndpoints(gateway).PublicPort == 0 { klog.InfoS("no public port for gateway, waiting for sync", "gateway", klog.KObj(gateway)) + return false } if c.ownGateway == nil { klog.InfoS(fmt.Sprintf("no own gateway for node %s, skip it", c.nodeName), "gateway", klog.KObj(gateway)) @@ -271,46 +261,6 @@ func (c *TunnelHandler) configGatewayPublicIP(gateway *v1beta1.Gateway) error { return err } -func (c *TunnelHandler) configGatewayStun(gateway *v1beta1.Gateway) error { - if getTunnelActiveEndpoints(gateway).NodeName != c.nodeName { - return nil - } - - natType, err := utils.GetNATType() - if err != nil { - return err - } - - publicPort, err := utils.GetPublicPort() - if err != nil { - return err - } - - // retry to update nat type of localGateway - err = retry.RetryOnConflict(retry.DefaultBackoff, func() error { - // get localGateway from api server - var apiGw v1beta1.Gateway - err := c.ravenClient.Get(context.Background(), client.ObjectKey{ - Name: gateway.Name, - }, &apiGw) - if err != nil { - return err - } - for k, v := range apiGw.Spec.Endpoints { - if v.NodeName == c.nodeName { - apiGw.Spec.Endpoints[k].NATType = natType - if natType != utils.NATSymmetric { - apiGw.Spec.Endpoints[k].PublicPort = publicPort - } - err = c.ravenClient.Update(context.Background(), &apiGw) - return err - } - } - return nil - }) - return err -} - func (c *TunnelHandler) getLoadBalancerPublicIP(gwName string) (string, error) { var svcList v1.ServiceList err := c.ravenClient.List(context.TODO(), &svcList, &client.ListOptions{ diff --git a/pkg/utils/stun.go b/pkg/utils/stun.go index d4bc8d6..32954e8 100644 --- a/pkg/utils/stun.go +++ b/pkg/utils/stun.go @@ -14,45 +14,52 @@ * limitations under the License. */ - package utils +package utils - import ( - "fmt" - - "github.com/ccding/go-stun/stun" - ) - - var ( - stunAPIs = [...]string{ - "stun.qq.com:3478", - "stun.miwifi.com:3478", - } - stunClient *stun.Client - ) - - func init() { - stunClient = stun.NewClient() - stunClient.SetLocalPort(DefaultVPNPort) - } - - func GetNATType() (string, error) { - for _, api := range stunAPIs { - stunClient.SetServerAddr(api) - natBehavior, err := stunClient.BehaviorTest() - if err == nil { - return natBehavior.NormalType(), nil - } - } - return "", fmt.Errorf("error get nat type by any of the apis: %v", stunAPIs) - } - - func GetPublicPort() (int, error) { - for _, api := range stunAPIs { - stunClient.SetServerAddr(api) - _, host, err := stunClient.Discover() - if err == nil { - return int(host.Port()), nil - } - } - return 0, fmt.Errorf("error get public port by any of the apis: %v", stunAPIs) - } \ No newline at end of file +import ( + "fmt" + + "github.com/ccding/go-stun/stun" + "github.com/vdobler/ht/errorlist" +) + +var ( + stunAPIs = [...]string{ + "stun.qq.com:3478", + "stun.miwifi.com:3478", + } + stunClient *stun.Client +) + +func init() { + stunClient = stun.NewClient() + stunClient.SetLocalPort(4500) +} + +func GetNATType() (string, error) { + errList := errorlist.List{} + for _, api := range stunAPIs { + stunClient.SetServerAddr(api) + natBehavior, err := stunClient.BehaviorTest() + if err == nil { + return natBehavior.NormalType(), nil + } else { + errList = errList.Append(err) + } + } + return "", fmt.Errorf("error get nat type by any of the apis[%v]: %s", stunAPIs, errList.AsError()) +} + +func GetPublicPort() (int, error) { + errList := errorlist.List{} + for _, api := range stunAPIs { + stunClient.SetServerAddr(api) + _, host, err := stunClient.Discover() + if err == nil { + return int(host.Port()), nil + } else { + errList.Append(err) + } + } + return 0, fmt.Errorf("error get public port by any of the apis[%v]: %s", stunAPIs, errList.AsError()) +}