diff --git a/Makefile b/Makefile index b2d2da4..ae8b221 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,7 @@ IMG ?= openyurt/raven-agent:latest VPN_DRIVER ?= libreswan FORWARD_NODE_IP ?= false +NAT_TRAVERSAL ?= false METRIC_BIND_ADDR ?= ":8080" BUILDPLATFORM ?= linux/amd64 @@ -65,7 +66,7 @@ docker-push: ## Push docker image with the agent. ##@ Deploy gen-deploy-yaml: - bash hack/gen-yaml.sh ${IMG} ${VPN_DRIVER} ${FORWARD_NODE_IP} ${METRIC_BIND_ADDR} + bash hack/gen-yaml.sh ${IMG} ${VPN_DRIVER} ${FORWARD_NODE_IP} ${METRIC_BIND_ADDR} ${NAT_TRAVERSAL} deploy: gen-deploy-yaml ## Deploy agent daemon. kubectl apply -f _output/yamls/raven-agent.yaml diff --git a/charts/raven-agent/templates/config.yaml b/charts/raven-agent/templates/config.yaml index 6108bdc..d6b6cc1 100644 --- a/charts/raven-agent/templates/config.yaml +++ b/charts/raven-agent/templates/config.yaml @@ -11,6 +11,7 @@ apiVersion: v1 data: vpn-driver: {{ .Values.vpn.driver }} forward-node-ip: {{ .Values.vpn.forwardNodeIP | quote }} + nat-traversal: {{ .Values.vpn.natTraversal | quote }} metric-bind-addr: {{ .Values.vpn.metricBindAddr }} tunnel-bind-addr: {{ .Values.vpn.tunnelAddr }} proxy-external-addr: {{ .Values.proxy.externalAddr }} diff --git a/charts/raven-agent/templates/daemonset.yaml b/charts/raven-agent/templates/daemonset.yaml index d12b966..c44998c 100644 --- a/charts/raven-agent/templates/daemonset.yaml +++ b/charts/raven-agent/templates/daemonset.yaml @@ -30,6 +30,7 @@ spec: - --v=2 - --vpn-driver={{.Values.vpn.driver}} - --forward-node-ip={{.Values.vpn.forwardNodeIP}} + - --nat-traversal={{.Values.vpn.natTraversal}} - --metric-bind-addr={{.Values.vpn.metricBindAddr}} - --vpn-bind-port={{.Values.vpn.tunnelAddr}} - --proxy-metric-bind-addr={{.Values.proxy.metricsBindAddr}} diff --git a/charts/raven-agent/values.yaml b/charts/raven-agent/values.yaml index 96a362b..fc776a3 100644 --- a/charts/raven-agent/values.yaml +++ b/charts/raven-agent/values.yaml @@ -62,6 +62,7 @@ containerEnv: vpn: driver: libreswan forwardNodeIP: false + natTraversal: false # raven-agent requires a unique vpn psk # You can generate one with the command: # 'openssl rand -hex 64' diff --git a/cmd/agent/app/config/config.go b/cmd/agent/app/config/config.go index f27cd3b..fefcda7 100644 --- a/cmd/agent/app/config/config.go +++ b/cmd/agent/app/config/config.go @@ -37,6 +37,7 @@ type TunnelConfig struct { VPNPort string RouteDriver string ForwardNodeIP bool + NATTraversal bool } type ProxyConfig struct { diff --git a/cmd/agent/app/options/options.go b/cmd/agent/app/options/options.go index 3446a8d..942b9ac 100644 --- a/cmd/agent/app/options/options.go +++ b/cmd/agent/app/options/options.go @@ -45,6 +45,7 @@ type TunnelOptions struct { VPNPort string RouteDriver string ForwardNodeIP bool + NATTraversal bool } type ProxyOptions struct { @@ -84,6 +85,7 @@ func (o *AgentOptions) AddFlags(fs *pflag.FlagSet) { fs.StringVar(&o.VPNDriver, "vpn-driver", o.VPNDriver, `The VPN driver name. (default "libreswan")`) fs.StringVar(&o.RouteDriver, "route-driver", o.RouteDriver, `The Route driver name. (default "vxlan")`) fs.BoolVar(&o.ForwardNodeIP, "forward-node-ip", o.ForwardNodeIP, `Forward node IP or not. (default "false")`) + fs.BoolVar(&o.NATTraversal, "nat-traversal", o.NATTraversal, `Enable NAT Traversal or not. (default "false")`) fs.StringVar(&o.MetricsBindAddress, "metric-bind-addr", o.MetricsBindAddress, `Binding address of tunnel metrics. (default ":10265")`) fs.StringVar(&o.VPNPort, "vpn-bind-port", o.VPNPort, `Binding port of vpn. (default ":4500")`) fs.StringVar(&o.ProxyMetricsAddress, "proxy-metric-bind-addr", o.ProxyMetricsAddress, `Binding address of proxy metrics. (default ":10266")`) @@ -125,6 +127,7 @@ func (o *AgentOptions) Config() (*config.Config, error) { VPNDriver: o.VPNDriver, RouteDriver: o.RouteDriver, ForwardNodeIP: o.ForwardNodeIP, + NATTraversal: o.NATTraversal, } c.Proxy = &config.ProxyConfig{ ProxyMetricsAddress: o.ProxyMetricsAddress, diff --git a/config/raven-agent/agent/agent.yaml b/config/raven-agent/agent/agent.yaml index aa03dc2..887864f 100644 --- a/config/raven-agent/agent/agent.yaml +++ b/config/raven-agent/agent/agent.yaml @@ -43,6 +43,11 @@ spec: configMapKeyRef: name: agent-config key: forward-node-ip + - name: NAT_TRAVERSAL + valueFrom: + configMapKeyRef: + name: agent-config + key: nat-traversal - name: METRIC_BIND_ADDR valueFrom: configMapKeyRef: diff --git a/go.mod b/go.mod index 33ed13c..8924a11 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.18 require ( github.com/EvilSuperstars/go-cidrman v0.0.0-20190607145828-28e79e32899a + github.com/ccding/go-stun v0.1.5-0.20230908213042-0f417a9a4966 github.com/coreos/go-iptables v0.6.0 github.com/gorilla/mux v1.8.0 github.com/lorenzosaino/go-sysctl v0.3.1 diff --git a/go.sum b/go.sum index 1e47f96..f35dfa0 100644 --- a/go.sum +++ b/go.sum @@ -96,6 +96,8 @@ github.com/bketelsen/crypt v0.0.4/go.mod h1:aI6NrJ0pMGgvZKL1iVgXLnfIFJtfV+bKCoqO github.com/blang/semver v3.5.0+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= github.com/blang/semver v3.5.1+incompatible h1:cQNTCjp13qL8KC3Nbxr/y2Bqb63oX6wdnnjpJbkM4JQ= github.com/blang/semver v3.5.1+incompatible/go.mod h1:kRBLl5iJ+tD4TcOOxsy/0fnwebNt5EWlYSAyrTnjyyk= +github.com/ccding/go-stun v0.1.5-0.20230908213042-0f417a9a4966 h1:ugTbop8ITMmnyZRFFQZ0LDnEi+m28dDU7Jxf6cYoA5M= +github.com/ccding/go-stun v0.1.5-0.20230908213042-0f417a9a4966/go.mod h1:cCZjJ1J3WFSJV6Wj8Y9Di8JMTsEXh6uv2eNmLzKaUeM= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/certifi/gocertifi v0.0.0-20191021191039-0944d244cd40/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= github.com/certifi/gocertifi v0.0.0-20200922220541-2c3bb06c6054/go.mod h1:sGbDF6GwGcLpkNXPUTkMRoywsNa/ol15pxFe6ERfguA= diff --git a/hack/gen-yaml.sh b/hack/gen-yaml.sh index dee1979..ace6f25 100644 --- a/hack/gen-yaml.sh +++ b/hack/gen-yaml.sh @@ -22,6 +22,7 @@ gen_yaml() { local VPN_DRIVER=$2 local FORWARD_NODE_IP=$3 local METRIC_BIND_ADDR=$4 + local NAT_TRAVERSAL=$5 local OUT_YAML_DIR=${YURT_ROOT}/_output/yamls local BUILD_YAML_DIR=${OUT_YAML_DIR}/build [ -f "${BUILD_YAML_DIR}" ] || mkdir -p "${BUILD_YAML_DIR}" @@ -38,6 +39,7 @@ gen_yaml() { [ -f "${BUILD_YAML_DIR}"/default/config.env ] || echo "vpn-driver=${VPN_DRIVER}" > "${BUILD_YAML_DIR}"/default/config.env echo "forward-node-ip=${FORWARD_NODE_IP}" >> "${BUILD_YAML_DIR}"/default/config.env echo "metric-bind-addr=${METRIC_BIND_ADDR}" >> "${BUILD_YAML_DIR}"/default/config.env + echo "nat-traversal=${NAT_TRAVERSAL}" >> "${BUILD_YAML_DIR}"/default/config.env kustomize build "${BUILD_YAML_DIR}"/default > "${OUT_YAML_DIR}"/raven-agent.yaml rm -Rf "${BUILD_YAML_DIR}" } diff --git a/pkg/engine/tunnel.go b/pkg/engine/tunnel.go index afade72..1ac0896 100644 --- a/pkg/engine/tunnel.go +++ b/pkg/engine/tunnel.go @@ -53,6 +53,9 @@ func (t *TunnelEngine) processNextWorkItem() bool { func (t *TunnelEngine) handler(gw *v1beta1.Gateway) error { klog.Info(utils.FormatRavenEngine("update raven l3 tunnel config for gateway %s", gw.GetName())) + if err := t.checkNatCapability(); err != nil { + return err + } if t.routeDriver == nil || t.vpnDriver == nil { err := t.initDriver() if err != nil { @@ -104,6 +107,24 @@ func (t *TunnelEngine) clearDriver() error { return nil } +func (t *TunnelEngine) checkNatCapability() error { + natType, err := utils.GetNATType() + if err != nil { + return err + } + + if natType == utils.NATSymmetric { + return nil + } + + _, err = utils.GetPublicPort() + if err != nil { + return err + } + + return nil +} + func (t *TunnelEngine) handleEventErr(err error, event interface{}) { if err == nil { t.queue.Forget(event) diff --git a/pkg/networkengine/vpndriver/driver.go b/pkg/networkengine/vpndriver/driver.go index 7eca5e7..59bada2 100644 --- a/pkg/networkengine/vpndriver/driver.go +++ b/pkg/networkengine/vpndriver/driver.go @@ -28,6 +28,7 @@ import ( "github.com/openyurtio/raven/cmd/agent/app/config" netlinkutil "github.com/openyurtio/raven/pkg/networkengine/util/netlink" "github.com/openyurtio/raven/pkg/types" + "github.com/openyurtio/raven/pkg/utils" ) const ( @@ -65,6 +66,7 @@ type Factory func(cfg *config.Config) (Driver, error) var ( driversMutex sync.Mutex drivers = make(map[string]Factory) + natTraversal bool ) func RegisterDriver(name string, factory Factory) { @@ -78,6 +80,9 @@ func RegisterDriver(name string, factory Factory) { } func New(name string, cfg *config.Config) (Driver, error) { + if cfg.Tunnel != nil { + natTraversal = cfg.Tunnel.NATTraversal + } driversMutex.Lock() defer driversMutex.Unlock() if _, found := drivers[name]; !found { @@ -110,6 +115,22 @@ func FindCentralGwFn(network *types.Network) *types.Endpoint { return central } +// EnableCreateEdgeConnection determine whether VPN tunnels can be established between edges. +func EnableCreateEdgeConnection(localEndpoint *types.Endpoint, remoteEndpoint *types.Endpoint) bool { + if !natTraversal { + return false + } + if localEndpoint.NATType == utils.NATUndefined || remoteEndpoint.NATType == utils.NATUndefined { + return false + } + if !localEndpoint.UnderNAT || !remoteEndpoint.UnderNAT { + return false + } + return !((localEndpoint.NATType == utils.NATSymmetric && remoteEndpoint.NATType == utils.NATSymmetric) || + (localEndpoint.NATType == utils.NATSymmetric && remoteEndpoint.NATType == utils.NATPortRestricted) || + (localEndpoint.NATType == utils.NATPortRestricted && remoteEndpoint.NATType == utils.NATSymmetric)) +} + func DefaultMTU() (int, error) { routes, err := netlinkutil.RouteListFiltered( netlink.FAMILY_V4, diff --git a/pkg/networkengine/vpndriver/libreswan/libreswan.go b/pkg/networkengine/vpndriver/libreswan/libreswan.go index 9db1802..aab5e95 100644 --- a/pkg/networkengine/vpndriver/libreswan/libreswan.go +++ b/pkg/networkengine/vpndriver/libreswan/libreswan.go @@ -20,6 +20,7 @@ import ( "fmt" "os" "os/exec" + "strconv" "syscall" "time" @@ -46,6 +47,7 @@ var _ vpndriver.Driver = (*libreswan)(nil) // can be modified for testing. var whackCmd = whackCmdFn var findCentralGw = vpndriver.FindCentralGwFn +var enableCreateEdgeConnection = vpndriver.EnableCreateEdgeConnection func init() { vpndriver.RegisterDriver(DriverName, New) @@ -56,9 +58,11 @@ const ( ) type libreswan struct { - connections map[string]*vpndriver.Connection - nodeName types.NodeName - listenPort string + relayConnections map[string]*vpndriver.Connection + edgeConnections map[string]*vpndriver.Connection + nodeName types.NodeName + listenPort string + centralGw *types.Endpoint iptables iptablesutil.IPTablesInterface } @@ -90,58 +94,28 @@ func (l *libreswan) Init() (err error) { func New(cfg *config.Config) (vpndriver.Driver, error) { return &libreswan{ - connections: make(map[string]*vpndriver.Connection), - nodeName: types.NodeName(cfg.NodeName), - listenPort: cfg.Tunnel.VPNPort, + relayConnections: make(map[string]*vpndriver.Connection), + edgeConnections: make(map[string]*vpndriver.Connection), + nodeName: types.NodeName(cfg.NodeName), + listenPort: cfg.Tunnel.VPNPort, }, nil } func (l *libreswan) Apply(network *types.Network, routeDriverMTUFn func(*types.Network) (int, error)) (err error) { - errList := errorlist.List{} if network.LocalEndpoint == nil || len(network.RemoteEndpoints) == 0 { - klog.Info(utils.FormatTunnel("no local gateway or remote gateway is found, cleaning vpn connections")) + klog.Info("no local gateway or remote gateway is found, cleaning vpn connections") return l.Cleanup() } if network.LocalEndpoint.NodeName != l.nodeName { - klog.Info(utils.FormatTunnel("the current node is not gateway node, cleaning vpn connections")) - return l.Cleanup() - } - - desiredConnections := l.computeDesiredConnections(network) - if len(desiredConnections) == 0 { - klog.Info(utils.FormatTunnel("no desired connections, cleaning vpn connections")) + klog.Infof(utils.FormatTunnel("the current node is not gateway node, cleaning vpn connections")) return l.Cleanup() } - centralGw := findCentralGw(network) - - // remove unwanted connections - for connName := range l.connections { - if _, ok := desiredConnections[connName]; !ok { - err := l.whackDelConnection(connName) - if err != nil { - errList = errList.Append(err) - klog.ErrorS(err, "error disconnecting endpoint", "connectionName", connName) - continue - } - delete(l.connections, connName) - if centralGw.NodeName == l.nodeName { - errList = errList.Append(l.deleteRavenSkipNAT(centralGw, l.connections[connName])) - } - } + if err := l.createConnections(network); err != nil { + return fmt.Errorf("error create VPN tunnels: %v", err) } - // add new connections - for name, connection := range desiredConnections { - err := l.connectToEndpoint(name, connection) - errList = errList.Append(err) - if centralGw.NodeName == l.nodeName { - err = l.ensureRavenSkipNAT(centralGw, connection) - errList = errList.Append(err) - } - } - - return errList.AsError() + return nil } func (l *libreswan) MTU() (int, error) { @@ -154,10 +128,10 @@ func (l *libreswan) MTU() (int, error) { // getEndpointResolver returns a function that resolve the left subnets and the Endpoint that should connect to. func (l *libreswan) getEndpointResolver(network *types.Network) func(centralGw, remoteGw *types.Endpoint) (leftSubnets []string, connectTo *types.Endpoint) { - snUnderNAT := make(map[types.GatewayName][]string) + snUnderNAT := make(map[types.GatewayName]*types.Endpoint) for _, v := range network.RemoteEndpoints { if v.UnderNAT { - snUnderNAT[v.GatewayName] = v.Subnets + snUnderNAT[v.GatewayName] = v } } return func(centralGw, remoteGw *types.Endpoint) (leftSubnets []string, connectTo *types.Endpoint) { @@ -178,7 +152,9 @@ func (l *libreswan) getEndpointResolver(network *types.Network) func(centralGw, // append all subnets of other NATed gateways into left subnets. for gwName, v := range snUnderNAT { if gwName != remoteGw.GatewayName { - leftSubnets = append(leftSubnets, v...) + if !enableCreateEdgeConnection(v, remoteGw) { + leftSubnets = append(leftSubnets, v.Subnets...) + } } } } @@ -186,15 +162,171 @@ func (l *libreswan) getEndpointResolver(network *types.Network) func(centralGw, } // If both local and remote are NATed, and the local gateway is not the central gateway, - // connects to central gateway to forward traffic. + // and can't create edge to edge tunnel, connects to central gateway to forward traffic. if network.LocalEndpoint.UnderNAT && remoteGw.UnderNAT { - return leftSubnets, centralGw + if !enableCreateEdgeConnection(network.LocalEndpoint, remoteGw) { + return leftSubnets, centralGw + } } return leftSubnets, remoteGw } } +func (l *libreswan) createConnections(network *types.Network) error { + l.centralGw = findCentralGw(network) + desiredEdgeConns, desiredRelayConns := l.computeDesiredConnections(network) + if len(desiredEdgeConns) == 0 && len(desiredRelayConns) == 0 { + klog.Infof(utils.FormatTunnel("no desired connections, cleaning vpn connections")) + return l.Cleanup() + } + + klog.Infof(utils.FormatTunnel("desired edge connections: %+v, desired relay connections: %+v", desiredEdgeConns, desiredRelayConns)) + + if err := l.createEdgeConnections(desiredEdgeConns); err != nil { + return err + } + if err := l.createRelayConnections(desiredRelayConns); err != nil { + return err + } + + return nil +} + +func (l *libreswan) createEdgeConnections(desiredEdgeConns map[string]*vpndriver.Connection) error { + if len(desiredEdgeConns) == 0 { + klog.Infof("no desired edge connections") + return nil + } + + errList := errorlist.List{} + + // remove unwanted connections + for connName := range l.edgeConnections { + if _, ok := desiredEdgeConns[connName]; !ok { + err := l.whackDelConnection(connName) + if err != nil { + errList = errList.Append(err) + klog.ErrorS(err, "error disconnecting endpoint", "connectionName", connName) + continue + } + delete(l.edgeConnections, connName) + } + } + + // add new connections + for name, connection := range desiredEdgeConns { + err := l.connectToEdgeEndpoint(name, connection) + errList = errList.Append(err) + } + + return errList.AsError() +} + +func (l *libreswan) createRelayConnections(desiredRelayConns map[string]*vpndriver.Connection) error { + if len(desiredRelayConns) == 0 { + klog.Infof("no desired relay connections") + return nil + } + + errList := errorlist.List{} + + // remove unwanted connections + for connName := range l.relayConnections { + if _, ok := desiredRelayConns[connName]; !ok { + err := l.whackDelConnection(connName) + if err != nil { + errList = errList.Append(err) + klog.ErrorS(err, "error disconnecting endpoint", "connectionName", connName) + continue + } + if l.centralGw.NodeName == l.nodeName { + errList = errList.Append(l.deleteRavenSkipNAT(l.relayConnections[connName])) + } + delete(l.relayConnections, connName) + } + } + + // add new connections + for name, connection := range desiredRelayConns { + err := l.connectToEndpoint(name, connection) + errList = errList.Append(err) + if l.centralGw.NodeName == l.nodeName { + err = l.ensureRavenSkipNAT(connection) + errList = errList.Append(err) + } + } + + return errList.AsError() +} + +func (l *libreswan) ensureRavenSkipNAT(connection *vpndriver.Connection) errorlist.List { + errList := errorlist.List{} + for _, subnet := range l.centralGw.Subnets { + if connection.LocalSubnet == subnet || connection.RemoteSubnet == subnet { + return errList + } + } + // for raven skip nat + if err := l.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { + errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.RavenPostRoutingChain, err)) + } + if err := l.iptables.InsertIfNotExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, 1, "-m", "comment", "--comment", "raven traffic should skip NAT", "-j", iptablesutil.RavenPostRoutingChain); err != nil { + errList = errList.Append(fmt.Errorf("error adding chain %s rule: %s", iptablesutil.PostRoutingChain, err)) + } + if err := l.iptables.AppendIfNotExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-s", connection.LocalSubnet, "-d", connection.RemoteSubnet, "-j", "ACCEPT"); err != nil { + errList = errList.Append(fmt.Errorf("error adding chain %s rule: %s", iptablesutil.RavenPostRoutingChain, err)) + } + return errList +} + +func (l *libreswan) deleteRavenSkipNAT(connection *vpndriver.Connection) errorlist.List { + errList := errorlist.List{} + err := l.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) + if err != nil { + errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err)) + } + for _, subnet := range l.centralGw.Subnets { + if connection.LocalSubnet == subnet || connection.RemoteSubnet == subnet { + return errList + } + } + err = l.iptables.DeleteIfExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-s", connection.LocalSubnet, "-d", connection.RemoteSubnet, "-j", "ACCEPT") + if err != nil { + errList = errList.Append(fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.RavenPostRoutingChain, err)) + } + return errList +} + +func (l *libreswan) computeDesiredConnections(network *types.Network) (map[string]*vpndriver.Connection, map[string]*vpndriver.Connection) { + desiredEdgeConns := make(map[string]*vpndriver.Connection) + desiredRelayConns := make(map[string]*vpndriver.Connection) + resolveEndpoint := l.getEndpointResolver(network) + + leftEndpoint := network.LocalEndpoint + for _, remoteGw := range network.RemoteEndpoints { + leftSubnets, connectTo := resolveEndpoint(l.centralGw, remoteGw) + for _, leftSubnet := range leftSubnets { + for _, rightSubnet := range remoteGw.Subnets { + name := connectionName(leftEndpoint.PrivateIP, connectTo.PrivateIP, leftSubnet, rightSubnet) + connect := &vpndriver.Connection{ + LocalEndpoint: leftEndpoint.Copy(), + RemoteEndpoint: connectTo.Copy(), + LocalSubnet: leftSubnet, + RemoteSubnet: rightSubnet, + } + if enableCreateEdgeConnection(leftEndpoint.Copy(), connectTo.Copy()) { + desiredEdgeConns[name] = connect + } else { + desiredRelayConns[name] = connect + } + } + } + } + + return desiredEdgeConns, desiredRelayConns +} + func (l *libreswan) whackConnectToEndpoint(connectionName string, connection *vpndriver.Connection) error { args := make([]string, 0) leftID := fmt.Sprintf("@%s-%s-%s", connection.LocalEndpoint.PrivateIP, connection.LocalSubnet, connection.RemoteSubnet) @@ -244,66 +376,48 @@ func (l *libreswan) whackConnectToEndpoint(connectionName string, connection *vp return nil } -func (l *libreswan) ensureRavenSkipNAT(centralGw *types.Endpoint, connection *vpndriver.Connection) errorlist.List { - errList := errorlist.List{} - // for raven skip nat - if err := l.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { - errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.RavenPostRoutingChain, err)) - } - if err := l.iptables.InsertIfNotExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, 1, "-m", "comment", "--comment", "raven traffic should skip NAT", "-j", iptablesutil.RavenPostRoutingChain); err != nil { - errList = errList.Append(fmt.Errorf("error adding chain %s rule: %s", iptablesutil.PostRoutingChain, err)) +func (l *libreswan) whackConnectToEdgeEndpoint(connectionName string, connection *vpndriver.Connection) error { + args := make([]string, 0) + leftID := fmt.Sprintf("@%s-%s-%s", connection.LocalEndpoint.PrivateIP, connection.LocalSubnet, connection.RemoteSubnet) + rightID := fmt.Sprintf("@%s-%s-%s", connection.RemoteEndpoint.PrivateIP, connection.RemoteSubnet, connection.LocalSubnet) + + if err := whackCmd("--delete", "--name", connectionName); err != nil { + return err } - for _, subnet := range centralGw.Subnets { - if connection.LocalSubnet == subnet || connection.RemoteSubnet == subnet { - return errList + // local + args = append(args, "--psk", "--encrypt", "--forceencaps", "--name", connectionName, + "--id", leftID, + "--host", connection.LocalEndpoint.String(), + "--client", connection.LocalSubnet) + // remote + if connection.RemoteEndpoint.NATType == utils.NATSymmetric { + args = append(args, "--to", + "--id", rightID, + "--host", "%any", + "--client", connection.RemoteSubnet) + if err := whackCmd(args...); err != nil { + return err } + return nil } - if err := l.iptables.AppendIfNotExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-s", connection.LocalSubnet, "-d", connection.RemoteSubnet, "-j", "ACCEPT"); err != nil { - errList = errList.Append(fmt.Errorf("error adding chain %s rule: %s", iptablesutil.RavenPostRoutingChain, err)) - } - return errList -} -func (l *libreswan) deleteRavenSkipNAT(centralGw *types.Endpoint, connection *vpndriver.Connection) errorlist.List { - errList := errorlist.List{} - err := l.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) - if err != nil { - errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err)) + args = append(args, "--to", + "--id", rightID, + "--host", connection.RemoteEndpoint.PublicIP, + "--client", connection.RemoteSubnet, + "--ikeport", strconv.Itoa(connection.RemoteEndpoint.PublicPort)) + + if err := whackCmd(args...); err != nil { + return err } - for _, subnet := range centralGw.Subnets { - if connection.LocalSubnet == subnet || connection.RemoteSubnet == subnet { - return errList - } + if err := whackCmd("--route", "--name", connectionName); err != nil { + return err } - err = l.iptables.DeleteIfExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-s", connection.LocalSubnet, "-d", connection.RemoteSubnet, "-j", "ACCEPT") - if err != nil { - errList = errList.Append(fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.RavenPostRoutingChain, err)) + if err := whackCmd("--initiate", "--asynchronous", "--name", connectionName); err != nil { + return err } - return errList -} - -func (l *libreswan) computeDesiredConnections(network *types.Network) map[string]*vpndriver.Connection { - centralGw := findCentralGw(network) - resolveEndpoint := l.getEndpointResolver(network) - // This is the desired connection calculated from given *types.Network - desiredConns := make(map[string]*vpndriver.Connection) - leftEndpoint := network.LocalEndpoint - for _, remoteGw := range network.RemoteEndpoints { - leftSubnets, connectTo := resolveEndpoint(centralGw, remoteGw) - for _, leftSubnet := range leftSubnets { - for _, rightSubnet := range remoteGw.Subnets { - name := connectionName(leftEndpoint.PrivateIP, connectTo.PrivateIP, leftSubnet, rightSubnet) - desiredConns[name] = &vpndriver.Connection{ - LocalEndpoint: leftEndpoint.Copy(), - RemoteEndpoint: connectTo.Copy(), - LocalSubnet: leftSubnet, - RemoteSubnet: rightSubnet, - } - } - } - } - return desiredConns + return nil } func whackCmdFn(args ...string) error { @@ -319,7 +433,7 @@ func whackCmdFn(args ...string) error { time.Sleep(1 * time.Second) } if err != nil { - return fmt.Errorf("error whacking with %v: status code %v, error %s", args, err, string(output)) + return fmt.Errorf("error whacking with %v: %v", args, err) } return nil } @@ -334,13 +448,23 @@ func connectionName(localID, remoteID, leftSubnet, rightSubnet string) string { func (l *libreswan) Cleanup() error { errList := errorlist.List{} - for name := range l.connections { + for name := range l.relayConnections { + if err := l.whackDelConnection(name); err != nil { + errList = errList.Append(err) + klog.ErrorS(err, "fail to delete connection", "connectionName", name) + } + if l.centralGw != nil && l.centralGw.NodeName == l.nodeName { + errList = errList.Append(l.deleteRavenSkipNAT(l.relayConnections[name])) + } + } + for name := range l.edgeConnections { if err := l.whackDelConnection(name); err != nil { errList = errList.Append(err) klog.ErrorS(err, "fail to delete connection", "connectionName", name) } } - l.connections = make(map[string]*vpndriver.Connection) + l.relayConnections = make(map[string]*vpndriver.Connection) + l.edgeConnections = make(map[string]*vpndriver.Connection) err := netlinkutil.XfrmPolicyFlush() errList = errList.Append(err) @@ -395,7 +519,7 @@ func (l *libreswan) runPluto() error { func (l *libreswan) connectToEndpoint(name string, connection *vpndriver.Connection) errorlist.List { errList := errorlist.List{} - if _, ok := l.connections[name]; ok { + if _, ok := l.relayConnections[name]; ok { klog.InfoS("skipping connect because connection already exists", "connectionName", name) return errList } @@ -405,6 +529,22 @@ func (l *libreswan) connectToEndpoint(name string, connection *vpndriver.Connect klog.ErrorS(err, "error connect connection", "connectionName", name) return errList } - l.connections[name] = connection + l.relayConnections[name] = connection + return errList +} + +func (l *libreswan) connectToEdgeEndpoint(name string, connection *vpndriver.Connection) errorlist.List { + errList := errorlist.List{} + if _, ok := l.edgeConnections[name]; ok { + klog.InfoS("skipping connect because connection already exists", "connectionName", name) + return errList + } + err := l.whackConnectToEdgeEndpoint(name, connection) + if err != nil { + errList = errList.Append(err) + klog.ErrorS(err, "error connect connection", "connectionName", name) + return errList + } + l.edgeConnections[name] = connection return errList } diff --git a/pkg/networkengine/vpndriver/libreswan/libreswan_test.go b/pkg/networkengine/vpndriver/libreswan/libreswan_test.go index 5817baa..529963d 100644 --- a/pkg/networkengine/vpndriver/libreswan/libreswan_test.go +++ b/pkg/networkengine/vpndriver/libreswan/libreswan_test.go @@ -372,8 +372,9 @@ func TestLibreswan_Apply(t *testing.T) { whackCmd = w.whackCmd a := assert.New(t) l := &libreswan{ - connections: make(map[string]*vpndriver.Connection), - nodeName: types.NodeName(v.nodeName), + relayConnections: make(map[string]*vpndriver.Connection), + edgeConnections: make(map[string]*vpndriver.Connection), + nodeName: types.NodeName(v.nodeName), } var err error l.iptables, err = iptablesutil.New() diff --git a/pkg/networkengine/vpndriver/wireguard/wireguard.go b/pkg/networkengine/vpndriver/wireguard/wireguard.go index 9cdadb8..0b95bc7 100644 --- a/pkg/networkengine/vpndriver/wireguard/wireguard.go +++ b/pkg/networkengine/vpndriver/wireguard/wireguard.go @@ -41,6 +41,7 @@ import ( iptablesutil "github.com/openyurtio/raven/pkg/networkengine/util/iptables" "github.com/openyurtio/raven/pkg/networkengine/vpndriver" "github.com/openyurtio/raven/pkg/types" + "github.com/openyurtio/raven/pkg/utils" ) const ( @@ -63,6 +64,7 @@ const ( ) var findCentralGw = vpndriver.FindCentralGwFn +var enableCreateEdgeConnection = vpndriver.EnableCreateEdgeConnection var _ vpndriver.Driver = (*wireguard)(nil) @@ -76,12 +78,12 @@ type wireguard struct { psk wgtypes.Key wgLink netlink.Link - connections map[string]*vpndriver.Connection - crossEdgeConnections map[string]*vpndriver.Connection - iptables iptablesutil.IPTablesInterface - nodeName types.NodeName - ravenClient client.Client - listenPort int + relayConnections map[string]*vpndriver.Connection + edgeConnections map[string]*vpndriver.Connection + iptables iptablesutil.IPTablesInterface + nodeName types.NodeName + ravenClient client.Client + listenPort int } func New(cfg *config.Config) (vpndriver.Driver, error) { @@ -90,10 +92,11 @@ func New(cfg *config.Config) (vpndriver.Driver, error) { port = DefaultListenPort } return &wireguard{ - connections: make(map[string]*vpndriver.Connection), - nodeName: types.NodeName(cfg.NodeName), - ravenClient: cfg.Manager.GetClient(), - listenPort: port, + relayConnections: make(map[string]*vpndriver.Connection), + edgeConnections: make(map[string]*vpndriver.Connection), + nodeName: types.NodeName(cfg.NodeName), + ravenClient: cfg.Manager.GetClient(), + listenPort: port, }, nil } @@ -212,88 +215,117 @@ func (w *wireguard) ensureWgLink(network *types.Network, routeDriverMTUFn func(* return nil } -func (w *wireguard) Apply(network *types.Network, routeDriverMTUFn func(*types.Network) (int, error)) error { - if network.LocalEndpoint == nil || len(network.RemoteEndpoints) == 0 { - klog.Info("no local gateway or remote gateway is found, cleaning vpn connections") - return w.Cleanup() - } - if network.LocalEndpoint.NodeName != w.nodeName { - klog.Infof("the current node is not gateway node, cleaning vpn connections") - return w.Cleanup() - } - - if _, ok := network.LocalEndpoint.Config[PublicKey]; !ok || network.LocalEndpoint.Config[PublicKey] != w.privateKey.PublicKey().String() { - err := w.configGatewayPublicKey(string(network.LocalEndpoint.GatewayName), string(network.LocalEndpoint.NodeName)) - if err != nil { - klog.ErrorS(err, "error config gateway public key", "gateway", network.LocalEndpoint.GatewayName) - } - return errors.New("retry to config public key") - } - // 1. Compute desiredConnections - centralGw := findCentralGw(network) - desiredConnections, centralAllowedIPs, desiredCrossEdgeConns := w.computeDesiredConnections(network) - if len(desiredConnections) == 0 { +func (w *wireguard) createConnections(network *types.Network) error { + desiredEdgeConns, desiredRelayConns, centralAllowedIPs := w.computeDesiredConnections(network) + if len(desiredEdgeConns) == 0 && len(desiredRelayConns) == 0 { klog.Infof("no desired connections, cleaning vpn connections") return w.Cleanup() } - // 2. Ensure WireGuard link - if err := w.ensureWgLink(network, routeDriverMTUFn); err != nil { - return fmt.Errorf("fail to ensure wireguar link: %v", err) - } + klog.Infof("desired edge connections: %+v, desired relay connections: %+v", desiredEdgeConns, desiredRelayConns) - // 3. Config device route and rules - currentRoutes, err := networkutil.ListRoutesOnNode(wgRouteTableID) - if err != nil { - return fmt.Errorf("error listing wireguard routes on node: %s", err) + centralGw := findCentralGw(network) + if err := w.createEdgeConnections(desiredEdgeConns); err != nil { + return err } - currentRules, err := networkutil.ListRulesOnNode(wgRouteTableID) - if err != nil { - return fmt.Errorf("error listing wireguard rules on node: %s", err) + if err := w.createRelayConnections(desiredRelayConns, centralAllowedIPs, centralGw); err != nil { + return err } - desiredRoutes := w.calWgRoutes(network) - desiredRules := w.calWgRules() + return nil +} - err = networkutil.ApplyRoutes(currentRoutes, desiredRoutes) - if err != nil { - return fmt.Errorf("error applying wireguard routes: %s", err) - } - err = networkutil.ApplyRules(currentRules, desiredRules) - if err != nil { - return fmt.Errorf("error applying wireguard rules: %s", err) +func (w *wireguard) createEdgeConnections(desiredEdgeConns map[string]*vpndriver.Connection) error { + if len(desiredEdgeConns) == 0 { + klog.Infof("no desired edge connections") + return nil } - // 4. delete unwanted connections - for connName, connection := range w.connections { - if _, ok := desiredConnections[connName]; !ok { + for connName, connection := range w.edgeConnections { + if _, ok := desiredEdgeConns[connName]; !ok { remoteKey := keyFromEndpoint(connection.RemoteEndpoint) if err := w.removePeer(remoteKey); err == nil { - delete(w.connections, connName) + delete(w.edgeConnections, connName) } } } - if centralGw.NodeName == w.nodeName { - for connName, connection := range w.crossEdgeConnections { - if _, ok := desiredCrossEdgeConns[connName]; !ok { - delete(w.crossEdgeConnections, connName) - if err := w.deleteRavenSkipNAT(centralGw, connection); err != nil { - return err + + peerConfigs := make([]wgtypes.PeerConfig, 0) + for name, newConn := range desiredEdgeConns { + newKey := keyFromEndpoint(newConn.RemoteEndpoint) + + if oldConn, ok := w.edgeConnections[name]; ok { + oldKey := keyFromEndpoint(oldConn.RemoteEndpoint) + if oldKey.String() != newKey.String() { + if err := w.removePeer(oldKey); err == nil { + delete(w.edgeConnections, name) } } } + + klog.InfoS("create edge-to-edge connection", "c", newConn) + + allowedIPs := parseSubnets(newConn.RemoteEndpoint.Subnets) + var remotePort int + if newConn.RemoteEndpoint.NATType == utils.NATSymmetric { + remotePort = w.listenPort + } else { + remotePort = newConn.RemoteEndpoint.PublicPort + } + ka := KeepAliveInterval + peerConfigs = append(peerConfigs, wgtypes.PeerConfig{ + PublicKey: *newKey, + Remove: false, + UpdateOnly: false, + PresharedKey: &w.psk, + Endpoint: &net.UDPAddr{ + IP: net.ParseIP(newConn.RemoteEndpoint.PublicIP), + Port: remotePort, + }, + PersistentKeepaliveInterval: &ka, + ReplaceAllowedIPs: true, + AllowedIPs: allowedIPs, + }) + } + + if err := w.wgClient.ConfigureDevice(DeviceName, wgtypes.Config{ + ReplacePeers: true, + Peers: peerConfigs, + }); err != nil { + return fmt.Errorf("error add peers: %v", err) + } + + w.edgeConnections = desiredEdgeConns + + return nil +} + +func (w *wireguard) createRelayConnections(desiredRelayConns map[string]*vpndriver.Connection, centralAllowedIPs []string, centralGw *types.Endpoint) error { + if len(desiredRelayConns) == 0 { + klog.Infof("no desired relay connections") + return nil } - // 5. add or update connections + // delete unwanted connections + for connName, connection := range w.relayConnections { + if _, ok := desiredRelayConns[connName]; !ok { + remoteKey := keyFromEndpoint(connection.RemoteEndpoint) + if err := w.removePeer(remoteKey); err == nil { + delete(w.relayConnections, connName) + } + } + } + + // add or update connections peerConfigs := make([]wgtypes.PeerConfig, 0) - for name, newConn := range desiredConnections { + for name, newConn := range desiredRelayConns { newKey := keyFromEndpoint(newConn.RemoteEndpoint) - if oldConn, ok := w.connections[name]; ok { + if oldConn, ok := w.relayConnections[name]; ok { oldKey := keyFromEndpoint(oldConn.RemoteEndpoint) if oldKey.String() != newKey.String() { if err := w.removePeer(oldKey); err == nil { - delete(w.connections, name) + delete(w.relayConnections, name) } } } @@ -321,15 +353,6 @@ func (w *wireguard) Apply(network *types.Network, routeDriverMTUFn func(*types.N AllowedIPs: allowedIPs, }) } - if centralGw.NodeName == w.nodeName { - for name, newConn := range desiredCrossEdgeConns { - if _, ok := w.crossEdgeConnections[name]; !ok { - if err := w.ensureRavenSkipNAT(centralGw, newConn); err != nil { - return err - } - } - } - } if err := w.wgClient.ConfigureDevice(DeviceName, wgtypes.Config{ ReplacePeers: false, @@ -338,8 +361,64 @@ func (w *wireguard) Apply(network *types.Network, routeDriverMTUFn func(*types.N return fmt.Errorf("error add peers: %v", err) } - w.connections = desiredConnections - w.crossEdgeConnections = desiredCrossEdgeConns + w.relayConnections = desiredRelayConns + + return nil +} + +func (w *wireguard) Apply(network *types.Network, routeDriverMTUFn func(*types.Network) (int, error)) error { + if network.LocalEndpoint == nil || len(network.RemoteEndpoints) == 0 { + klog.Info("no local gateway or remote gateway is found, cleaning vpn connections") + return w.Cleanup() + } + if network.LocalEndpoint.NodeName != w.nodeName { + klog.Infof("the current node is not gateway node, cleaning vpn connections") + return w.Cleanup() + } + + if _, ok := network.LocalEndpoint.Config[PublicKey]; !ok || network.LocalEndpoint.Config[PublicKey] != w.privateKey.PublicKey().String() { + err := w.configGatewayPublicKey(string(network.LocalEndpoint.GatewayName), string(network.LocalEndpoint.NodeName)) + if err != nil { + klog.ErrorS(err, "error config gateway public key", "gateway", network.LocalEndpoint.GatewayName) + } + return errors.New("retry to config public key") + } + + centralGw := findCentralGw(network) + if centralGw.NodeName == w.nodeName { + if err := w.ensureRavenSkipNAT(); err != nil { + return fmt.Errorf("error ensure raven skip nat: %s", err) + } + } + + if err := w.ensureWgLink(network, routeDriverMTUFn); err != nil { + return fmt.Errorf("fail to ensure wireguar link: %v", err) + } + // 3. Config device route and rules + currentRoutes, err := networkutil.ListRoutesOnNode(wgRouteTableID) + if err != nil { + return fmt.Errorf("error listing wireguard routes on node: %s", err) + } + currentRules, err := networkutil.ListRulesOnNode(wgRouteTableID) + if err != nil { + return fmt.Errorf("error listing wireguard rules on node: %s", err) + } + + desiredRoutes := w.calWgRoutes(network) + desiredRules := w.calWgRules() + + err = networkutil.ApplyRoutes(currentRoutes, desiredRoutes) + if err != nil { + return fmt.Errorf("error applying wireguard routes: %s", err) + } + err = networkutil.ApplyRules(currentRules, desiredRules) + if err != nil { + return fmt.Errorf("error applying wireguard rules: %s", err) + } + + if err := w.createConnections(network); err != nil { + return fmt.Errorf("error create VPN tunnels: %v", err) + } return nil } @@ -374,90 +453,44 @@ func (w *wireguard) Cleanup() error { if err = netlink.LinkDel(link); err != nil { errList = errList.Append(fmt.Errorf("error delete existing wireguard device %q: %v", DeviceName, err)) } - w.connections = make(map[string]*vpndriver.Connection) - err = w.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) - if err != nil { - errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err)) - } - err = w.iptables.DeleteIfExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, "-m", "comment", "--comment", "raven traffic should skip NAT", "-j", iptablesutil.RavenPostRoutingChain) - if err != nil { - errList = errList.Append(fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.PostRoutingChain, err)) - } - err = w.iptables.ClearAndDeleteChain(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) - if err != nil { - errList = errList.Append(fmt.Errorf("error deleting %s chain %s", iptablesutil.RavenPostRoutingChain, err)) + if err = w.deleteRavenSkipNAT(); err != nil { + errList = errList.Append(err) } - w.crossEdgeConnections = make(map[string]*vpndriver.Connection) + w.relayConnections = make(map[string]*vpndriver.Connection) + w.edgeConnections = make(map[string]*vpndriver.Connection) return errList.AsError() } -// getSubnetResolver returns a function that resolve the left subnets. -func (w *wireguard) getSubnetResolver(network *types.Network) func(remoteGw *types.Endpoint) (leftSubnets []string) { - snUnderNAT := make(map[types.GatewayName][]string) - for _, v := range network.RemoteEndpoints { - if v.UnderNAT { - snUnderNAT[v.GatewayName] = v.Subnets - } - } - return func(remoteGw *types.Endpoint) (leftSubnets []string) { - if remoteGw.UnderNAT { - // In order to forward traffic from other NATed gateway to the NATed remoteGw, - // append all subnets of other NATed gateways into left subnets. - for gwName, v := range snUnderNAT { - if gwName != remoteGw.GatewayName { - leftSubnets = append(leftSubnets, v...) - } - } - } - return leftSubnets - } -} - -func (w *wireguard) computeDesiredConnections(network *types.Network) (map[string]*vpndriver.Connection, []string, map[string]*vpndriver.Connection) { - - // This is the desired connection calculated from given *types.Network - desiredConns := make(map[string]*vpndriver.Connection) - centralGw := findCentralGw(network) - desiredCrossEdgeConns := make(map[string]*vpndriver.Connection) +func (w *wireguard) computeDesiredConnections(network *types.Network) (map[string]*vpndriver.Connection, map[string]*vpndriver.Connection, []string) { + // This is the desired edge connections and relay connections calculated from given *types.Network + desiredEdgeConns := make(map[string]*vpndriver.Connection) + desiredRelayConns := make(map[string]*vpndriver.Connection) centralAllowedIPs := make([]string, 0) for _, remote := range network.RemoteEndpoints { if _, ok := remote.Config[PublicKey]; !ok { continue } - - // if local gateway is not central gateway and remote endpoint is NATed - // append all subnets of remote gateway into central allowed IPs. - if network.LocalEndpoint.UnderNAT && remote.UnderNAT { - centralAllowedIPs = append(centralAllowedIPs, remote.Subnets...) - continue - } - name := connectionName(string(network.LocalEndpoint.NodeName), string(remote.NodeName)) - desiredConns[name] = &vpndriver.Connection{ + connect := &vpndriver.Connection{ LocalEndpoint: network.LocalEndpoint.Copy(), RemoteEndpoint: remote.Copy(), } - } - - if centralGw.NodeName == w.nodeName { - resolveSubnet := w.getSubnetResolver(network) - for _, remoteGw := range network.RemoteEndpoints { - leftSubnets := resolveSubnet(remoteGw) - for _, leftSubnet := range leftSubnets { - for _, rightSubnet := range remoteGw.Subnets { - name := connectionName(leftSubnet, rightSubnet) - desiredCrossEdgeConns[name] = &vpndriver.Connection{ - LocalSubnet: leftSubnet, - RemoteSubnet: rightSubnet, - } - } + if enableCreateEdgeConnection(network.LocalEndpoint, remote) { + desiredEdgeConns[name] = connect + } else { + // if local gateway is not central gateway and remote endpoint is NATed + // append all subnets of remote gateway into central allowed IPs. + if network.LocalEndpoint.UnderNAT && remote.UnderNAT { + centralAllowedIPs = append(centralAllowedIPs, remote.Subnets...) + continue } + desiredRelayConns[name] = connect } } - return desiredConns, centralAllowedIPs, desiredCrossEdgeConns + return desiredEdgeConns, desiredRelayConns, centralAllowedIPs } func (w *wireguard) removePeer(key *wgtypes.Key) error { @@ -569,28 +602,31 @@ func parseSubnets(subnets []string) []net.IPNet { return nets } -func (w *wireguard) ensureRavenSkipNAT(centralGw *types.Endpoint, connection *vpndriver.Connection) error { +func (w *wireguard) ensureRavenSkipNAT() error { // for raven skip nat if err := w.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { return fmt.Errorf("error create %s chain: %s", iptablesutil.RavenPostRoutingChain, err) } - if err := w.iptables.InsertIfNotExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, 1, "-m", "comment", "--comment", "raven traffic should skip NAT", "-j", iptablesutil.RavenPostRoutingChain); err != nil { + if err := w.iptables.InsertIfNotExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, 1, "-m", "comment", "--comment", "raven traffic should skip NAT", "-o", "raven-wg0", "-j", iptablesutil.RavenPostRoutingChain); err != nil { return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.PostRoutingChain, err) } - if err := w.iptables.AppendIfNotExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-s", connection.LocalSubnet, "-d", connection.RemoteSubnet, "-j", "ACCEPT"); err != nil { + if err := w.iptables.AppendIfNotExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-j", "ACCEPT"); err != nil { return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.RavenPostRoutingChain, err) } + return nil } -func (w *wireguard) deleteRavenSkipNAT(centralGw *types.Endpoint, connection *vpndriver.Connection) error { - err := w.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) - if err != nil { +func (w *wireguard) deleteRavenSkipNAT() error { + if err := w.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { return fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err) } - err = w.iptables.DeleteIfExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-s", connection.LocalSubnet, "-d", connection.RemoteSubnet, "-j", "ACCEPT") - if err != nil { - return fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.RavenPostRoutingChain, err) + if err := w.iptables.DeleteIfExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, "-m", "comment", "--comment", "raven traffic should skip NAT", "-o", "raven-wg0", "-j", iptablesutil.RavenPostRoutingChain); err != nil { + return fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.PostRoutingChain, err) } + if err := w.iptables.ClearAndDeleteChain(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { + return fmt.Errorf("error deleting %s chain %s", iptablesutil.RavenPostRoutingChain, err) + } + return nil } diff --git a/pkg/tunnelengine/tunnelagent.go b/pkg/tunnelengine/tunnelagent.go index 4c84c5d..c859193 100644 --- a/pkg/tunnelengine/tunnelagent.go +++ b/pkg/tunnelengine/tunnelagent.go @@ -86,12 +86,20 @@ func (c *TunnelHandler) Handler() error { for i := range gws.Items { // try to update public IP if empty. gw := &gws.Items[i] - if ep := getTunnelActiveEndpoints(gw); ep != nil && ep.PublicIP == "" { - err := c.configGatewayPublicIP(gw) - if err != nil { - klog.ErrorS(err, "error config gateway public ip", "gateway", klog.KObj(gw)) + if ep := getTunnelActiveEndpoints(gw); ep != nil { + if ep.PublicIP == "" || ep.NATType == "" || ep.PublicPort == 0 && ep.NATType != utils.NATSymmetric { + if ep.PublicIP == "" { + if err := c.configGatewayPublicIP(gw); err != nil { + klog.ErrorS(err, "error config gateway public ip", "gateway", klog.KObj(gw)) + } + } + if ep.NATType == "" || ep.PublicPort == 0 && ep.NATType != utils.NATSymmetric { + if err := c.configGatewayStunInfo(gw); err != nil { + klog.ErrorS(err, "error config gateway stun info", "gateway", klog.KObj(gw)) + } + } + continue } - continue } if !c.shouldHandleGateway(gw) { continue @@ -170,8 +178,10 @@ func (c *TunnelHandler) syncGateway(gw *v1beta1.Gateway) { NodeName: types.NodeName(aep.NodeName), Subnets: subnets, PrivateIP: nodeInfo.PrivateIP, + PublicPort: aep.PublicPort, PublicIP: aep.PublicIP, UnderNAT: aep.UnderNAT, + NATType: aep.NATType, Config: cfg, } var isLocalGateway bool @@ -195,14 +205,23 @@ func (c *TunnelHandler) syncGateway(gw *v1beta1.Gateway) { } func (c *TunnelHandler) shouldHandleGateway(gateway *v1beta1.Gateway) bool { - if getTunnelActiveEndpoints(gateway) == nil { + ep := getTunnelActiveEndpoints(gateway) + if ep == nil { klog.InfoS("no active endpoint , waiting for sync", "gateway", klog.KObj(gateway)) return false } - if getTunnelActiveEndpoints(gateway).PublicIP == "" { + if ep.PublicIP == "" { klog.InfoS("no public IP for gateway, waiting for sync", "gateway", klog.KObj(gateway)) return false } + if ep.NATType == "" { + klog.InfoS("no nat type for gateway, waiting for sync", "gateway", klog.KObj(gateway)) + return false + } + if ep.NATType != utils.NATSymmetric && ep.PublicPort == 0 { + klog.InfoS("no public port for gateway, waiting for sync", "gateway", klog.KObj(gateway)) + return false + } if c.ownGateway == nil { klog.InfoS(fmt.Sprintf("no own gateway for node %s, skip it", c.nodeName), "gateway", klog.KObj(gateway)) return false @@ -250,6 +269,49 @@ func (c *TunnelHandler) configGatewayPublicIP(gateway *v1beta1.Gateway) error { return err } +func (c *TunnelHandler) configGatewayStunInfo(gateway *v1beta1.Gateway) error { + if getTunnelActiveEndpoints(gateway).NodeName != c.nodeName { + return nil + } + + natType, err := utils.GetNATType() + if err != nil { + return err + } + + var publicPort int + if natType != utils.NATSymmetric { + publicPort, err = utils.GetPublicPort() + if err != nil { + return err + } + } + + // retry to update nat type of localGateway + err = retry.RetryOnConflict(retry.DefaultBackoff, func() error { + // get localGateway from api server + var apiGw v1beta1.Gateway + err := c.ravenClient.Get(context.Background(), client.ObjectKey{ + Name: gateway.Name, + }, &apiGw) + if err != nil { + return err + } + for k, v := range apiGw.Spec.Endpoints { + if v.NodeName == c.nodeName { + apiGw.Spec.Endpoints[k].NATType = natType + if natType != utils.NATSymmetric { + apiGw.Spec.Endpoints[k].PublicPort = publicPort + } + err = c.ravenClient.Update(context.Background(), &apiGw) + return err + } + } + return nil + }) + return err +} + func (c *TunnelHandler) getLoadBalancerPublicIP(gwName string) (string, error) { var svcList v1.ServiceList err := c.ravenClient.List(context.TODO(), &svcList, &client.ListOptions{ diff --git a/pkg/types/network.go b/pkg/types/network.go index cc80538..e508695 100644 --- a/pkg/types/network.go +++ b/pkg/types/network.go @@ -32,11 +32,13 @@ type Endpoint struct { // NodeName is the name of the Node holding this Endpoint. NodeName NodeName // Subnets stores subnets of the nodes managed by the gateway. - Subnets []string - PrivateIP string - PublicIP string - UnderNAT bool - Config map[string]string + Subnets []string + PrivateIP string + PublicIP string + PublicPort int + UnderNAT bool + NATType string + Config map[string]string } func (e *Endpoint) String() string { diff --git a/pkg/utils/constants.go b/pkg/utils/constants.go index a4b59fd..d2ae8f6 100644 --- a/pkg/utils/constants.go +++ b/pkg/utils/constants.go @@ -51,4 +51,8 @@ const ( GatewayProxyInternalService = "x-raven-proxy-internal-svc" LabelCurrentGatewayEndpoints = "raven.openyurt.io/endpoints-name" LabelCurrentGatewayType = "raven.openyurt.io/gateway-type" + + NATSymmetric = "Symmetric NAT" + NATPortRestricted = "Port Restricted cone NAT" + NATUndefined = "Undefined" ) diff --git a/pkg/utils/stun.go b/pkg/utils/stun.go new file mode 100644 index 0000000..d78f26c --- /dev/null +++ b/pkg/utils/stun.go @@ -0,0 +1,73 @@ +/* + * Copyright 2023 The OpenYurt Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package utils + +import ( + "fmt" + + "github.com/ccding/go-stun/stun" + "github.com/vdobler/ht/errorlist" +) + +var ( + stunAPIs = [...]string{ + "stun.qq.com:3478", + "stun.miwifi.com:3478", + } + stunClient *stun.Client + NATType string + PublicPort int +) + +func init() { + stunClient = stun.NewClient() + stunClient.SetLocalPort(4500) +} + +func GetNATType() (string, error) { + if NATType != "" { + return NATType, nil + } + errList := errorlist.List{} + for _, api := range stunAPIs { + stunClient.SetServerAddr(api) + natBehavior, err := stunClient.BehaviorTest() + if err == nil { + NATType = natBehavior.NormalType() + return NATType, nil + } + errList = errList.Append(err) + } + return "", fmt.Errorf("error get nat type by any of the apis[%v]: %s", stunAPIs, errList.AsError()) +} + +func GetPublicPort() (int, error) { + if PublicPort != 0 { + return PublicPort, nil + } + errList := errorlist.List{} + for _, api := range stunAPIs { + stunClient.SetServerAddr(api) + _, host, err := stunClient.Discover() + if err == nil { + PublicPort = int(host.Port()) + return PublicPort, nil + } + errList = errList.Append(err) + } + return 0, fmt.Errorf("error get public port by any of the apis[%v]: %s", stunAPIs, errList.AsError()) +}