From 4b482fcce3f040a0ee6c3a2cf777424568516be3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8F=A9=E8=BD=A9?= Date: Tue, 4 Jun 2024 17:23:29 +0800 Subject: [PATCH] 1. Optimize the reconcile process, use the ipsec status command to obtain the current VPN link status, and compare the expected status with the current status for configuration. (Wireguard has the same principle) 2. Add a regular synchronization mechanism --- charts/raven-agent/templates/daemonset.yaml | 2 + charts/raven-agent/values.yaml | 5 +- cmd/agent/app/config/config.go | 8 +- cmd/agent/app/options/options.go | 17 +- cmd/agent/app/start.go | 13 +- cmd/agent/main.go | 2 +- pkg/engine/engine.go | 84 +++-- pkg/engine/proxy.go | 21 +- pkg/engine/tunnel.go | 48 ++- pkg/networkengine/util/ipset/ipset.go | 32 +- pkg/networkengine/util/netlink/netlink.go | 11 + pkg/networkengine/util/utils.go | 14 +- pkg/networkengine/vpndriver/ipset/ipset.go | 129 ++++++++ .../vpndriver/libreswan/libreswan.go | 300 +++++++++++------- .../vpndriver/libreswan/libreswan_test.go | 8 +- .../vpndriver/wireguard/wireguard.go | 256 +++++++-------- 16 files changed, 608 insertions(+), 342 deletions(-) create mode 100644 pkg/networkengine/vpndriver/ipset/ipset.go diff --git a/charts/raven-agent/templates/daemonset.yaml b/charts/raven-agent/templates/daemonset.yaml index 0b4653b..e0b5f7a 100644 --- a/charts/raven-agent/templates/daemonset.yaml +++ b/charts/raven-agent/templates/daemonset.yaml @@ -56,6 +56,8 @@ spec: - --vpn-bind-port={{.Values.vpn.tunnelAddr}} - --keep-alive-interval={{.Values.vpn.keepAliveInterval}} - --keep-alive-timeout={{.Values.vpn.keepAliveTimeout}} + - --sync-raven-rules={{.Values.sync.syncRule}} + - --sync-raven-rules-period={{.Values.sync.syncPeriod}} - --proxy-metric-bind-addr={{.Values.proxy.metricsBindAddr}} - --proxy-internal-secure-addr={{.Values.proxy.internalSecureAddr}} - --proxy-internal-insecure-addr={{.Values.proxy.internalInsecureAddr}} diff --git a/charts/raven-agent/values.yaml b/charts/raven-agent/values.yaml index ddd5287..e794cbe 100644 --- a/charts/raven-agent/values.yaml +++ b/charts/raven-agent/values.yaml @@ -59,6 +59,9 @@ containerEnv: secretKeyRef: key: vpn-connection-psk name: raven-agent-secret +sync: + syncRule: true + syncPeriod: 30m vpn: driver: libreswan @@ -86,4 +89,4 @@ proxy: metricsBindAddr: ":10266" rollingUpdate: - maxUnavailable: 5% \ No newline at end of file + maxUnavailable: 20% \ No newline at end of file diff --git a/cmd/agent/app/config/config.go b/cmd/agent/app/config/config.go index b9b4a3b..7667418 100644 --- a/cmd/agent/app/config/config.go +++ b/cmd/agent/app/config/config.go @@ -17,14 +17,18 @@ package config import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/rest" "sigs.k8s.io/controller-runtime/pkg/manager" ) // Config is the main context object for raven agent type Config struct { - NodeName string - NodeIP string + NodeName string + NodeIP string + SyncRules bool + SyncPeriod metav1.Duration + MetricsBindAddress string HealthProbeAddr string diff --git a/cmd/agent/app/options/options.go b/cmd/agent/app/options/options.go index a3c1f78..016dbd0 100644 --- a/cmd/agent/app/options/options.go +++ b/cmd/agent/app/options/options.go @@ -8,6 +8,7 @@ import ( "regexp" "strconv" "strings" + "time" "github.com/spf13/pflag" v1 "k8s.io/api/core/v1" @@ -50,6 +51,8 @@ type AgentOptions struct { Kubeconfig string MetricsBindAddress string HealthProbeAddr string + SyncRules bool + SyncPeriod metav1.Duration } type TunnelOptions struct { @@ -91,6 +94,12 @@ func (o *AgentOptions) Validate() error { } } } + if o.SyncPeriod.Duration < time.Minute { + o.SyncPeriod.Duration = time.Minute + } + if o.SyncPeriod.Duration > 24*time.Hour { + o.SyncPeriod.Duration = 24 * time.Hour + } return nil } @@ -103,6 +112,8 @@ func (o *AgentOptions) AddFlags(fs *pflag.FlagSet) { fs.StringVar(&o.RouteDriver, "route-driver", o.RouteDriver, `The Route driver name. (default "vxlan")`) fs.StringVar(&o.MetricsBindAddress, "metric-bind-addr", o.MetricsBindAddress, `Binding address of tunnel metrics. (default ":10265")`) fs.StringVar(&o.HealthProbeAddr, "health-probe-addr", o.HealthProbeAddr, `The address the healthz/readyz endpoint binds to.. (default ":10275")`) + fs.BoolVar(&o.SyncRules, "sync-raven-rules", true, "Whether to synchronize raven rules regularly") + fs.DurationVar(&o.SyncPeriod.Duration, "sync-raven-rules-period", 10*time.Minute, "The period for reconciling routes created for nodes by cloud provider. The minimum value is 1 minute and the maximum value is 24 hour") fs.StringVar(&o.VPNPort, "vpn-bind-port", o.VPNPort, `Binding port of vpn. (default ":4500")`) fs.BoolVar(&o.NATTraversal, "nat-traversal", o.NATTraversal, `Enable NAT Traversal or not. (default "false")`) @@ -141,8 +152,10 @@ func (o *AgentOptions) Config() (*config.Config, error) { } cfg = restclient.AddUserAgent(cfg, "raven-agent-ds") c := &config.Config{ - NodeName: o.NodeName, - NodeIP: o.NodeIP, + NodeName: o.NodeName, + NodeIP: o.NodeIP, + SyncRules: o.SyncRules, + SyncPeriod: o.SyncPeriod, } c.KubeConfig = cfg c.MetricsBindAddress = resolveAddress(c.MetricsBindAddress, resolveLocalHost(), strconv.Itoa(DefaultTunnelMetricsPort)) diff --git a/cmd/agent/app/start.go b/cmd/agent/app/start.go index e542bab..fe23038 100644 --- a/cmd/agent/app/start.go +++ b/cmd/agent/app/start.go @@ -19,15 +19,17 @@ package app import ( "context" "fmt" + "sync" + "time" "github.com/lorenzosaino/go-sysctl" + "github.com/spf13/cobra" "k8s.io/klog/v2" "github.com/openyurtio/raven/cmd/agent/app/config" "github.com/openyurtio/raven/cmd/agent/app/options" ravenengine "github.com/openyurtio/raven/pkg/engine" "github.com/openyurtio/raven/pkg/features" - "github.com/spf13/cobra" ) // NewRavenAgentCommand creates a new raven agent command @@ -70,6 +72,15 @@ func Run(ctx context.Context, cfg *config.CompletedConfig) error { } klog.Info("engine successfully start") engine.Start() + var wg sync.WaitGroup + wg.Add(1) + go func() { + <-ctx.Done() + time.Sleep(time.Second) + engine.Cleanup() + wg.Done() + }() + wg.Wait() return nil } diff --git a/cmd/agent/main.go b/cmd/agent/main.go index e1b2733..4a37a8e 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -32,7 +32,7 @@ var GitCommit string func main() { klog.InitFlags(nil) defer klog.Flush() - rand.Seed(time.Now().UnixNano()) + rand.NewSource(time.Now().UnixNano()) klog.Infof("component: %s, git commit: %s\n", "raven-agent-ds", GitCommit) cmd := app.NewRavenAgentCommand(server.SetupSignalContext()) cmd.Flags().AddGoFlagSet(flag.CommandLine) diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index c06a8a9..70acbfe 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -4,6 +4,7 @@ import ( "context" "time" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" utilruntime "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/util/workqueue" @@ -22,13 +23,16 @@ import ( ) type Engine struct { - nodeName string - nodeIP string - context context.Context - manager manager.Manager - client client.Client - option *Option - queue workqueue.RateLimitingInterface + nodeName string + nodeIP string + syncRules bool + syncPeriod metav1.Duration + + context context.Context + manager manager.Manager + client client.Client + option *Option + queue workqueue.RateLimitingInterface tunnel *TunnelEngine proxy *ProxyEngine @@ -36,12 +40,14 @@ type Engine struct { func NewEngine(ctx context.Context, cfg *config.Config) (*Engine, error) { engine := &Engine{ - nodeName: cfg.NodeName, - nodeIP: cfg.NodeIP, - manager: cfg.Manager, - context: ctx, - option: NewEngineOption(), - queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "raven"), + nodeName: cfg.NodeName, + nodeIP: cfg.NodeIP, + syncRules: cfg.SyncRules, + syncPeriod: cfg.SyncPeriod, + manager: cfg.Manager, + context: ctx, + option: NewEngineOption(), + queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "raven"), } err := ctrl.NewControllerManagedBy(engine.manager). For(&v1beta1.Gateway{}, builder.WithPredicates(predicate.Funcs{ @@ -53,7 +59,7 @@ func NewEngine(ctx context.Context, cfg *config.Config) (*Engine, error) { return reconcile.Result{}, nil })) if err != nil { - klog.Errorf(utils.FormatRavenEngine("fail to new controller with manager, error %s", err.Error())) + klog.Errorf("fail to new controller with manager, error %s", err.Error()) return engine, err } engine.client = engine.manager.GetClient() @@ -66,7 +72,7 @@ func NewEngine(ctx context.Context, cfg *config.Config) (*Engine, error) { } err = engine.tunnel.InitDriver() if err != nil { - klog.Errorf(utils.FormatRavenEngine("fail to init tunnel driver, error %s", err.Error())) + klog.Errorf("fail to init tunnel driver, error %s", err.Error()) return engine, err } @@ -90,9 +96,12 @@ func (e *Engine) Start() { klog.ErrorS(err, "failed to start engine controller") } }() + go wait.Until(e.worker, time.Second, e.context.Done()) - <-e.context.Done() - e.cleanup() + + if e.syncRules { + go wait.Until(e.regularSync, e.syncPeriod.Duration, e.context.Done()) + } } func (e *Engine) worker() { @@ -110,19 +119,29 @@ func (e *Engine) processNextWorkItem() bool { return false } defer e.queue.Done(gw) - e.findLocalGateway() - err := e.tunnel.Handler() + err := e.sync() if err != nil { e.handleEventErr(err, gw) } - e.option.SetTunnelStatus(e.tunnel.Status()) + return true +} - err = e.proxy.Handler() +func (e *Engine) sync() error { + e.findLocalGateway() + err := e.proxy.Handler() if err != nil { - e.handleEventErr(err, gw) + return err + } + err = e.tunnel.Handler() + if err != nil { + return err } + e.option.SetTunnelStatus(e.tunnel.Status()) + return nil +} - return true +func (e *Engine) regularSync() { + e.queue.Add(&v1beta1.Gateway{ObjectMeta: metav1.ObjectMeta{Name: "gw-sync"}}) } func (e *Engine) findLocalGateway() { @@ -144,12 +163,9 @@ func (e *Engine) findLocalGateway() { } } -func (e *Engine) cleanup() { +func (e *Engine) Cleanup() { if e.option.GetTunnelStatus() { - err := e.tunnel.CleanupDriver() - if err != nil { - klog.Errorf(utils.FormatRavenEngine("failed to cleanup tunnel driver, error %s", err.Error())) - } + e.tunnel.CleanupDriver() } if e.option.GetProxyStatus() { e.proxy.stop() @@ -163,18 +179,18 @@ func (e *Engine) handleEventErr(err error, gw *v1beta1.Gateway) { } if e.queue.NumRequeues(gw) < utils.MaxRetries { - klog.Info(utils.FormatRavenEngine("error syncing event %s: %s", gw.GetName(), err.Error())) + klog.Infof("error syncing event %s: %s", gw.GetName(), err.Error()) e.queue.AddRateLimited(gw) return } - klog.Info(utils.FormatRavenEngine("dropping event %s out of the queue: %s", gw.GetName(), err.Error())) + klog.Infof("dropping event %s out of the queue: %s", gw.GetName(), err.Error()) e.queue.Forget(gw) } func (e *Engine) addGateway(evt event.CreateEvent) bool { gw, ok := evt.Object.(*v1beta1.Gateway) if ok { - klog.InfoS(utils.FormatRavenEngine("adding gateway %s", gw.GetName())) + klog.Infof("adding gateway %s", gw.GetName()) e.queue.Add(gw.DeepCopy()) } return ok @@ -187,10 +203,8 @@ func (e *Engine) updateGateway(evt event.UpdateEvent) bool { if ok1 && ok2 { if oldGw.ResourceVersion != newGw.ResourceVersion { update = true - klog.InfoS(utils.FormatRavenEngine("updating gateway, %s", newGw.GetName())) + klog.Infof("updating gateway, %s", newGw.GetName()) e.queue.Add(newGw.DeepCopy()) - } else { - klog.InfoS(utils.FormatRavenEngine("skip handle update gateway"), klog.KObj(newGw)) } } return update @@ -199,7 +213,7 @@ func (e *Engine) updateGateway(evt event.UpdateEvent) bool { func (e *Engine) deleteGateway(evt event.DeleteEvent) bool { gw, ok := evt.Object.(*v1beta1.Gateway) if ok { - klog.InfoS(utils.FormatRavenEngine("deleting gateway, %s", gw.GetName())) + klog.Infof("deleting gateway, %s", gw.GetName()) e.queue.Add(gw.DeepCopy()) } return ok diff --git a/pkg/engine/proxy.go b/pkg/engine/proxy.go index 87058b3..8f3f016 100644 --- a/pkg/engine/proxy.go +++ b/pkg/engine/proxy.go @@ -79,7 +79,7 @@ func (p *ProxyEngine) Handler() error { srcAddr := getSrcAddressForProxyServer(p.client, p.nodeName) err = p.startProxyServer() if err != nil { - klog.Errorf(utils.FormatProxyServer("failed to start proxy server, error %s", err.Error())) + klog.Errorf("failed to start proxy server, error %s", err.Error()) return err } p.serverLocalEndpoints = srcAddr @@ -93,7 +93,7 @@ func (p *ProxyEngine) Handler() error { time.Sleep(2 * time.Second) err = p.startProxyServer() if err != nil { - klog.Errorf(utils.FormatProxyServer("failed to start proxy server, error %s", err.Error())) + klog.Errorf("failed to start proxy server, error %s", err.Error()) return err } p.serverLocalEndpoints = srcAddr @@ -106,7 +106,7 @@ func (p *ProxyEngine) Handler() error { case StartType: err = p.startProxyClient() if err != nil { - klog.Errorf(utils.FormatProxyServer("failed to start proxy client, error %s", err.Error())) + klog.Errorf("failed to start proxy client, error %s", err.Error()) return err } case StopType: @@ -114,7 +114,7 @@ func (p *ProxyEngine) Handler() error { case RestartType: dstAddr := getDestAddressForProxyClient(p.client, p.localGateway) if len(dstAddr) < 1 { - klog.Infoln(utils.FormatProxyClient("dest address is empty, will not connected it")) + klog.Infoln("dest address is empty, will not connected it") return nil } if strings.Join(p.clientRemoteEndpoints, ",") != strings.Join(dstAddr, ",") { @@ -122,7 +122,7 @@ func (p *ProxyEngine) Handler() error { time.Sleep(2 * time.Second) err = p.startProxyClient() if err != nil { - klog.Errorf(utils.FormatProxyServer("failed to start proxy client, error %s", err.Error())) + klog.Errorf("failed to start proxy client, error %s", err.Error()) return err } } @@ -133,7 +133,7 @@ func (p *ProxyEngine) Handler() error { } func (p *ProxyEngine) startProxyServer() error { - klog.Infoln(utils.FormatProxyServer("start raven l7 proxy server")) + klog.Infoln("start raven l7 proxy server") if p.localGateway == nil { return fmt.Errorf("unknown gateway for node %s, can not start proxy server", p.nodeName) } @@ -164,7 +164,7 @@ func (p *ProxyEngine) startProxyServer() error { } func (p *ProxyEngine) stopProxyServer() { - klog.Infoln(utils.FormatProxyServer("Stop raven l7 proxy server")) + klog.Infoln("Stop raven l7 proxy server") cancel := p.proxyCtx.GetServerCancelFunc() cancel() p.proxyOption.SetServerStatus(false) @@ -172,11 +172,11 @@ func (p *ProxyEngine) stopProxyServer() { } func (p *ProxyEngine) startProxyClient() error { - klog.Infoln(utils.FormatProxyClient("start raven l7 proxy client")) + klog.Infoln("start raven l7 proxy client") var err error dstAddr := getDestAddressForProxyClient(p.client, p.localGateway) if len(dstAddr) < 1 { - klog.Infoln(utils.FormatProxyClient("dest address is empty, will not connected it")) + klog.Infoln("dest address is empty, will not connected it") return nil } p.clientRemoteEndpoints = dstAddr @@ -195,13 +195,14 @@ func (p *ProxyEngine) startProxyClient() error { err = pc.Start(ctx) if err != nil { klog.Errorf("failed to start proxy client, error %s", err.Error()) + return err } p.proxyOption.SetClientStatus(true) return nil } func (p *ProxyEngine) stopProxyClient() { - klog.Infoln(utils.FormatProxyClient("stop raven l7 proxy client")) + klog.Infoln("stop raven l7 proxy client") cancel := p.proxyCtx.GetClientCancelFunc() cancel() p.proxyOption.SetClientStatus(false) diff --git a/pkg/engine/tunnel.go b/pkg/engine/tunnel.go index 6bae19c..ac92d75 100644 --- a/pkg/engine/tunnel.go +++ b/pkg/engine/tunnel.go @@ -19,9 +19,10 @@ package engine import ( "context" "fmt" + "k8s.io/apimachinery/pkg/util/wait" "net" - "reflect" "strconv" + "time" "github.com/EvilSuperstars/go-cidrman" v1 "k8s.io/api/core/v1" @@ -51,9 +52,8 @@ type TunnelEngine struct { routeDriver routedriver.Driver vpnDriver vpndriver.Driver - nodeInfos map[types.NodeName]*v1beta1.NodeInfo - network *types.Network - lastSeenNetwork *types.Network + nodeInfos map[types.NodeName]*v1beta1.NodeInfo + network *types.Network } func (c *TunnelEngine) InitDriver() error { @@ -74,20 +74,24 @@ func (c *TunnelEngine) InitDriver() error { if err != nil { return fmt.Errorf("fail to initialize vpn driver: %s, %s", c.config.Tunnel.VPNDriver, err) } - klog.Info(utils.FormatTunnel("route driver %s and vpn driver %s are initialized", c.config.Tunnel.RouteDriver, c.config.Tunnel.VPNDriver)) + klog.Infof("route driver %s and vpn driver %s are initialized", c.config.Tunnel.RouteDriver, c.config.Tunnel.VPNDriver) return nil } -func (c *TunnelEngine) CleanupDriver() error { - err := c.routeDriver.Cleanup() - if err != nil { - return fmt.Errorf("fail to cleanup route driver: %s", err.Error()) - } - err = c.vpnDriver.Cleanup() - if err != nil { - return fmt.Errorf("fail to cleanup vpn driver: %s", err.Error()) - } - return nil +func (c *TunnelEngine) CleanupDriver() { + _ = wait.PollImmediate(time.Second, 5*time.Second, func() (done bool, err error) { + err = c.vpnDriver.Cleanup() + if err != nil { + klog.Errorf("fail to cleanup vpn driver: %s", err.Error()) + return false, nil + } + err = c.routeDriver.Cleanup() + if err != nil { + klog.Errorf("fail to cleanup route driver: %s", err.Error()) + return false, nil + } + return true, nil + }) } func (c *TunnelEngine) Status() bool { @@ -105,7 +109,7 @@ func (c *TunnelEngine) Status() bool { func (c *TunnelEngine) Handler() error { if c.config.Tunnel.NATTraversal { if err := c.checkNatCapability(); err != nil { - klog.Errorf(utils.FormatTunnel("fail to check the capability of NAT, error %s", err.Error())) + klog.Errorf("fail to check the capability of NAT, error %s", err.Error()) return err } } @@ -154,26 +158,18 @@ func (c *TunnelEngine) Handler() error { } c.syncGateway(gw) } - if reflect.DeepEqual(c.network, c.lastSeenNetwork) { - klog.Info("network not changed, skip to process") - return nil - } nw := c.network.Copy() klog.InfoS("applying network", "localEndpoint", nw.LocalEndpoint, "remoteEndpoint", nw.RemoteEndpoints) err = c.vpnDriver.Apply(nw, c.routeDriver.MTU) if err != nil { - klog.ErrorS(err, "error apply vpn driver") + klog.Errorf("error apply vpn driver, error %s", err.Error()) return err } err = c.routeDriver.Apply(nw, c.vpnDriver.MTU) if err != nil { - klog.ErrorS(err, "error apply route driver") + klog.Errorf("error apply route driver, error %s", err.Error()) return err } - - // Only update lastSeenNetwork when all operations succeeded. - c.lastSeenNetwork = c.network - return nil } diff --git a/pkg/networkengine/util/ipset/ipset.go b/pkg/networkengine/util/ipset/ipset.go index 66e34e6..ee844eb 100644 --- a/pkg/networkengine/util/ipset/ipset.go +++ b/pkg/networkengine/util/ipset/ipset.go @@ -33,14 +33,26 @@ type IPSetInterface interface { Del(entry *netlink.IPSetEntry) error Flush() error Destroy() error + Key(entry *netlink.IPSetEntry) string } +var DefaultKeyFunc = EntryKey + type ipSetWrapper struct { setName string + setType string + keyFunc func(setEntry *netlink.IPSetEntry) string +} + +type IpsetWrapperOption struct { + KeyFunc func(setEntry *netlink.IPSetEntry) string } -func New(setName string) (IPSetInterface, error) { - err := netlink.IpsetCreate(setName, "hash:net", netlink.IpsetCreateOptions{ +func New(setName, setTypeName string, options IpsetWrapperOption) (IPSetInterface, error) { + if options.KeyFunc == nil { + options.KeyFunc = DefaultKeyFunc + } + err := netlink.IpsetCreate(setName, setTypeName, netlink.IpsetCreateOptions{ Replace: true, }) if err != nil { @@ -50,7 +62,7 @@ func New(setName string) (IPSetInterface, error) { if klog.V(5).Enabled() { klog.V(5).InfoS("netlink.IpsetCreate succeeded", "setName", setName) } - return &ipSetWrapper{setName}, nil + return &ipSetWrapper{setName, setTypeName, options.KeyFunc}, nil } func (i *ipSetWrapper) List() (*netlink.IPSetResult, error) { @@ -72,11 +84,11 @@ func (i *ipSetWrapper) Name() string { func (i *ipSetWrapper) Add(entry *netlink.IPSetEntry) (err error) { err = netlink.IpsetAdd(i.Name(), entry) if err != nil { - klog.ErrorS(err, "error on netlink.IpsetAdd", "setName", i.Name(), "entry", SetEntryKey(entry)) + klog.ErrorS(err, "error on netlink.IpsetAdd", "setName", i.Name(), "entry", i.Key(entry)) return } if klog.V(5).Enabled() { - klog.V(5).InfoS("netlink.IpsetAdd succeeded", "setName", i.Name(), "entry", SetEntryKey(entry)) + klog.V(5).InfoS("netlink.IpsetAdd succeeded", "setName", i.Name(), "entry", i.Key(entry)) } return } @@ -84,11 +96,11 @@ func (i *ipSetWrapper) Add(entry *netlink.IPSetEntry) (err error) { func (i *ipSetWrapper) Del(entry *netlink.IPSetEntry) (err error) { err = netlink.IpsetDel(i.Name(), entry) if err != nil { - klog.ErrorS(err, "error on netlink.IpsetDel", "setName", i.Name(), "entry", SetEntryKey(entry)) + klog.ErrorS(err, "error on netlink.IpsetDel", "setName", i.Name(), "entry", i.Key) return } if klog.V(5).Enabled() { - klog.V(5).InfoS("netlink.IpsetDel succeeded", "setName", i.Name(), "entry", SetEntryKey(entry)) + klog.V(5).InfoS("netlink.IpsetDel succeeded", "setName", i.Name(), "entry", i.Key(entry)) } return } @@ -117,6 +129,10 @@ func (i *ipSetWrapper) Destroy() (err error) { return } -func SetEntryKey(setEntry *netlink.IPSetEntry) string { +func (i *ipSetWrapper) Key(entry *netlink.IPSetEntry) string { + return i.keyFunc(entry) +} + +func EntryKey(setEntry *netlink.IPSetEntry) string { return fmt.Sprintf("%s/%d", setEntry.IP.String(), setEntry.CIDR) } diff --git a/pkg/networkengine/util/netlink/netlink.go b/pkg/networkengine/util/netlink/netlink.go index 2c05366..a503f81 100644 --- a/pkg/networkengine/util/netlink/netlink.go +++ b/pkg/networkengine/util/netlink/netlink.go @@ -41,6 +41,7 @@ var ( RuleDel = ruleDel XfrmPolicyFlush = xfrmPolicyFlush + XfrmStateFlush = xfrmStateFlush NeighAdd = neighAdd NeighReplace = neighReplace @@ -127,6 +128,16 @@ func xfrmPolicyFlush() (err error) { return nil } +func xfrmStateFlush() (err error) { + err = netlink.XfrmStateFlush(0) + if err != nil { + klog.ErrorS(err, "error on netlink.XfrmStateFlush") + return + } + klog.V(5).InfoS("netlink.XfrmStateFlush succeeded") + return nil +} + func ruleListFiltered(family int, filter *netlink.Rule, filterMask uint64) (rules []netlink.Rule, err error) { rules, err = netlink.RuleListFiltered(family, filter, filterMask) if err != nil { diff --git a/pkg/networkengine/util/utils.go b/pkg/networkengine/util/utils.go index 23b0add..1b9c834 100644 --- a/pkg/networkengine/util/utils.go +++ b/pkg/networkengine/util/utils.go @@ -21,7 +21,6 @@ package networkutil import ( "fmt" - "net" "syscall" "github.com/vdobler/ht/errorlist" @@ -32,11 +31,6 @@ import ( netlinkutil "github.com/openyurtio/raven/pkg/networkengine/util/netlink" ) -var ( - AllZeroMAC = net.HardwareAddr{0, 0, 0, 0, 0, 0} - AllZeroAddress = "0.0.0.0/0" -) - func NewRavenRule(rulePriority int, routeTableID int) *netlink.Rule { rule := netlink.NewRule() rule.Priority = rulePriority @@ -94,7 +88,7 @@ func ListIPSetOnNode(set ipsetutil.IPSetInterface) (map[string]*netlink.IPSetEnt } ro := make(map[string]*netlink.IPSetEntry) for i := range info.Entries { - ro[ipsetutil.SetEntryKey(&info.Entries[i])] = &info.Entries[i] + ro[set.Key(&info.Entries[i])] = &info.Entries[i] } return ro, nil } @@ -114,7 +108,11 @@ func ApplyRules(current, desired map[string]*netlink.Rule) (err error) { } } // add expect ip rules - for _, v := range desired { + for k, v := range desired { + _, ok := current[k] + if ok { + continue + } klog.InfoS("adding rule", "src", v.Src, "lookup", v.Table) err = netlinkutil.RuleAdd(v) errList = errList.Append(err) diff --git a/pkg/networkengine/vpndriver/ipset/ipset.go b/pkg/networkengine/vpndriver/ipset/ipset.go new file mode 100644 index 0000000..f3fdfd2 --- /dev/null +++ b/pkg/networkengine/vpndriver/ipset/ipset.go @@ -0,0 +1,129 @@ +/* +Copyright 2023 The OpenYurt Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package ipset + +import ( + "fmt" + "net" + + "github.com/EvilSuperstars/go-cidrman" + "github.com/vdobler/ht/errorlist" + "github.com/vishvananda/netlink" + "k8s.io/klog/v2" + + ipsetutil "github.com/openyurtio/raven/pkg/networkengine/util/ipset" + "github.com/openyurtio/raven/pkg/types" +) + +const ( + RavenSkipNatSet = "raven-skip-nat-set" + RavenSkipNatSetType = "hash:net,net" +) + +var KeyFunc = func(entry *netlink.IPSetEntry) string { + return fmt.Sprintf("%s/%d-%s/%d", entry.IP.String(), entry.CIDR, entry.IP2.String(), entry.CIDR2) +} + +func IsGatewayRole(network *types.Network, nodeName types.NodeName) bool { + return network != nil && + network.LocalEndpoint != nil && + network.LocalEndpoint.NodeName == nodeName +} + +func IsCentreGatewayRole(centralGw *types.Endpoint, localNodeName types.NodeName) bool { + return centralGw != nil && centralGw.NodeName == localNodeName +} + +func CalIPSetOnNode(network *types.Network, centralGw *types.Endpoint, nodeName types.NodeName, ipset ipsetutil.IPSetInterface) map[string]*netlink.IPSetEntry { + set := make(map[string]*netlink.IPSetEntry) + subnets := make([]string, 0) + for _, v := range network.RemoteNodeInfo { + nodeInfo := network.RemoteNodeInfo[types.NodeName(v.NodeName)] + if nodeInfo == nil { + klog.Errorf("node %s not found in RemoteNodeInfo", v.NodeName) + continue + } + subnets = append(subnets, nodeInfo.Subnets...) + } + var err error + subnets, err = cidrman.MergeCIDRs(subnets) + if err != nil { + return set + } + if IsCentreGatewayRole(centralGw, nodeName) { + subnets = append(subnets, network.LocalEndpoint.Subnets...) + for _, srcCIDR := range subnets { + _, ipNet, err := net.ParseCIDR(srcCIDR) + if err != nil { + klog.Errorf("parse node subnet %s error %s", srcCIDR, err.Error()) + continue + } + ones, _ := ipNet.Mask.Size() + entry := &netlink.IPSetEntry{ + IP: ipNet.IP, + CIDR: uint8(ones), + IP2: ipNet.IP, + CIDR2: uint8(ones), + Replace: true, + } + set[ipset.Key(entry)] = entry + } + } else { + for _, localCIDR := range network.LocalEndpoint.Subnets { + _, localIPNet, err := net.ParseCIDR(localCIDR) + if err != nil { + klog.Errorf("parse node subnet %s error %s", localCIDR, err.Error()) + continue + } + localOnes, _ := localIPNet.Mask.Size() + for _, remoteCIDR := range subnets { + _, remoteIPNet, err := net.ParseCIDR(remoteCIDR) + if err != nil { + klog.Errorf("parse node subnet %s error %s", remoteCIDR, err.Error()) + continue + } + remoteOnes, _ := remoteIPNet.Mask.Size() + entry := &netlink.IPSetEntry{ + IP: localIPNet.IP, + CIDR: uint8(localOnes), + IP2: remoteIPNet.IP, + CIDR2: uint8(remoteOnes), + Replace: true, + } + set[ipset.Key(entry)] = entry + } + } + } + return set +} + +func CleanupRavenSkipNATIPSet() error { + errList := errorlist.List{} + ipset, err := ipsetutil.New(RavenSkipNatSet, RavenSkipNatSetType, ipsetutil.IpsetWrapperOption{}) + if err != nil { + errList = errList.Append(fmt.Errorf("error ensure ip set %s: %s", RavenSkipNatSet, err)) + } + err = ipset.Flush() + if err != nil { + errList = errList.Append(fmt.Errorf("error flushing ipset: %s", err)) + } + err = ipset.Destroy() + if err != nil { + errList = errList.Append(fmt.Errorf("error destroying ipset: %s", err)) + } + return errList.AsError() +} diff --git a/pkg/networkengine/vpndriver/libreswan/libreswan.go b/pkg/networkengine/vpndriver/libreswan/libreswan.go index 50395b0..e94db6c 100644 --- a/pkg/networkengine/vpndriver/libreswan/libreswan.go +++ b/pkg/networkengine/vpndriver/libreswan/libreswan.go @@ -17,10 +17,14 @@ package libreswan import ( + "bufio" + "bytes" "fmt" "os" "os/exec" + "regexp" "strconv" + "strings" "syscall" "time" @@ -28,9 +32,12 @@ import ( "k8s.io/klog/v2" "github.com/openyurtio/raven/cmd/agent/app/config" + networkutil "github.com/openyurtio/raven/pkg/networkengine/util" + ipsetutil "github.com/openyurtio/raven/pkg/networkengine/util/ipset" iptablesutil "github.com/openyurtio/raven/pkg/networkengine/util/iptables" netlinkutil "github.com/openyurtio/raven/pkg/networkengine/util/netlink" "github.com/openyurtio/raven/pkg/networkengine/vpndriver" + vpndriveripset "github.com/openyurtio/raven/pkg/networkengine/vpndriver/ipset" "github.com/openyurtio/raven/pkg/types" "github.com/openyurtio/raven/pkg/utils" ) @@ -40,12 +47,16 @@ const ( // DriverName specifies name of libreswan VPN backend driver. DriverName = "libreswan" + + IKESAESTABLISHED = "STATE_V2_ESTABLISHED_IKE_SA" + ChILDSAESTABLISHED = "STATE_V2_ESTABLISHED_CHILD_SA" ) var _ vpndriver.Driver = (*libreswan)(nil) // can be modified for testing. var whackCmd = whackCmdFn +var ipsecCmd = ipsecCmdFn var findCentralGw = vpndriver.FindCentralGwFn var enableCreateEdgeConnection = vpndriver.EnableCreateEdgeConnection @@ -58,11 +69,11 @@ const ( ) type libreswan struct { - relayConnections map[string]*vpndriver.Connection - edgeConnections map[string]*vpndriver.Connection + connections map[string]bool nodeName types.NodeName centralGw *types.Endpoint iptables iptablesutil.IPTablesInterface + ipset ipsetutil.IPSetInterface listenPort string keepaliveInterval int keepaliveTimeout int @@ -73,6 +84,11 @@ func (l *libreswan) Init() (err error) { if err != nil { return err } + l.ipset, err = ipsetutil.New(vpndriveripset.RavenSkipNatSet, vpndriveripset.RavenSkipNatSetType, ipsetutil.IpsetWrapperOption{}) + if err != nil { + return err + } + // Ensure secrets file _, err = os.Stat(SecretFile) if err == nil { @@ -95,8 +111,7 @@ func (l *libreswan) Init() (err error) { func New(cfg *config.Config) (vpndriver.Driver, error) { return &libreswan{ - relayConnections: make(map[string]*vpndriver.Connection), - edgeConnections: make(map[string]*vpndriver.Connection), + connections: make(map[string]bool), nodeName: types.NodeName(cfg.NodeName), listenPort: cfg.Tunnel.VPNPort, keepaliveInterval: cfg.Tunnel.KeepAliveInterval, @@ -110,12 +125,13 @@ func (l *libreswan) Apply(network *types.Network, routeDriverMTUFn func(*types.N return l.Cleanup() } if network.LocalEndpoint.NodeName != l.nodeName { - klog.Infof(utils.FormatTunnel("the current node is not gateway node, cleaning vpn connections")) + klog.Infof("the current node is not gateway node, cleaning vpn connections") return l.Cleanup() } - if err := l.createConnections(network); err != nil { - return fmt.Errorf("error create VPN tunnels: %v", err) + l.centralGw = findCentralGw(network) + if err := l.ensureConnections(network); err != nil { + return fmt.Errorf("error ensure VPN tunnels: %s", err.Error()) } return nil @@ -176,136 +192,181 @@ func (l *libreswan) getEndpointResolver(network *types.Network) func(centralGw, } } -func (l *libreswan) createConnections(network *types.Network) error { - l.centralGw = findCentralGw(network) +func (l *libreswan) ensureConnections(network *types.Network) error { + defer func() { + // wait connection is established + time.Sleep(5 * time.Second) + }() + + l.connections = currentConnections() + if err := l.deleteUnavailableConn(); err != nil { + return fmt.Errorf("delete unavailabel connections error %s", err.Error()) + } desiredEdgeConns, desiredRelayConns := l.computeDesiredConnections(network) - if len(desiredEdgeConns) == 0 && len(desiredRelayConns) == 0 { - klog.Infof(utils.FormatTunnel("no desired connections, cleaning vpn connections")) - return l.Cleanup() + klog.Infof("desired edge connections: %+v, desired relay connections: %+v", desiredEdgeConns, desiredRelayConns) + + if err := l.deleteUndesiredConn(desiredEdgeConns, desiredRelayConns); err != nil { + return fmt.Errorf("ensure delete undesired connections error %s", err.Error()) } - klog.Infof(utils.FormatTunnel("desired edge connections: %+v, desired relay connections: %+v", desiredEdgeConns, desiredRelayConns)) + if err := l.ensureEdgeConnections(desiredEdgeConns); err != nil { + return fmt.Errorf("ensure delete edge-edge connections error %s", err.Error()) + } - if err := l.createEdgeConnections(desiredEdgeConns); err != nil { - return err + if err := l.ensureRelayConnections(desiredRelayConns); err != nil { + return fmt.Errorf("ensure delete cloud-edge connections error %s", err.Error()) } - if err := l.createRelayConnections(desiredRelayConns); err != nil { - return err + + if err := l.ensureRavenSkipNAT(network); err != nil { + return fmt.Errorf("ensure raven skip nat error %s", err.Error()) } return nil } -func (l *libreswan) createEdgeConnections(desiredEdgeConns map[string]*vpndriver.Connection) error { - if len(desiredEdgeConns) == 0 { - klog.Infof("no desired edge connections") - return nil +func currentConnections() map[string]bool { + connections := make(map[string]bool) + reg := regexp.MustCompile(`"([^"]+)"`) + out, err := ipsecCmd("auto", "--status") + if err != nil { + return connections + } + foundConnectionList := false + scanner := bufio.NewScanner(out) + for scanner.Scan() { + line := scanner.Text() + if strings.Contains(line, "Connection list") { + foundConnectionList = true + continue + } + if foundConnectionList { + matches := reg.FindAllStringSubmatch(line, -1) + for _, match := range matches { + if len(match) > 1 { + connections[match[1]] = false + } + } + } } + for k := range connections { + out, err = ipsecCmd("whack", "--showstates") + if err != nil { + continue + } + foundIKESAEstablished := false + foundChildSAEstablished := false + scanner = bufio.NewScanner(out) + for scanner.Scan() { + line := scanner.Text() + if !strings.Contains(line, k) { + continue + } + if strings.Contains(line, IKESAESTABLISHED) { + foundIKESAEstablished = true + } + if strings.Contains(line, ChILDSAESTABLISHED) { + foundChildSAEstablished = true + } + } + if foundIKESAEstablished && foundChildSAEstablished { + connections[k] = true + } + } + return connections +} +func (l *libreswan) deleteUnavailableConn() error { errList := errorlist.List{} - - // remove unwanted connections - for connName := range l.edgeConnections { - if _, ok := desiredEdgeConns[connName]; !ok { + for connName, established := range l.connections { + if !established { err := l.whackDelConnection(connName) if err != nil { errList = errList.Append(err) - klog.ErrorS(err, "error disconnecting endpoint", "connectionName", connName) + klog.ErrorS(err, "error delete unavailable connection", "connectionName", connName) continue } - delete(l.edgeConnections, connName) + delete(l.connections, connName) } } - - // add new connections - for name, connection := range desiredEdgeConns { - err := l.connectToEdgeEndpoint(name, connection) - errList = errList.Append(err) - } - return errList.AsError() } -func (l *libreswan) createRelayConnections(desiredRelayConns map[string]*vpndriver.Connection) error { - if len(desiredRelayConns) == 0 { - klog.Infof("no desired relay connections") - return nil - } - +func (l *libreswan) deleteUndesiredConn(desiredEdgeConns, desiredRelayConns map[string]*vpndriver.Connection) error { errList := errorlist.List{} - - // remove unwanted connections - for connName := range l.relayConnections { - if _, ok := desiredRelayConns[connName]; !ok { + desireConn := make(map[string]struct{}) + for k := range desiredEdgeConns { + desireConn[k] = struct{}{} + } + for k := range desiredRelayConns { + desireConn[k] = struct{}{} + } + for connName := range l.connections { + if _, ok := desireConn[connName]; !ok { err := l.whackDelConnection(connName) if err != nil { errList = errList.Append(err) - klog.ErrorS(err, "error disconnecting endpoint", "connectionName", connName) + klog.ErrorS(err, "error delete undesired connection", "connectionName", connName) continue } - if l.centralGw.NodeName == l.nodeName { - if conn, ok := l.relayConnections[connName]; ok && conn != nil { - err := l.deleteRavenSkipNAT(conn) - if err != nil { - errList = errList.Append(err) - } - } - } - delete(l.relayConnections, connName) + delete(l.connections, connName) } } + return errList.AsError() +} + +func (l *libreswan) ensureEdgeConnections(desiredEdgeConns map[string]*vpndriver.Connection) error { + errList := errorlist.List{} + for name, connection := range desiredEdgeConns { + err := l.connectToEdgeEndpoint(name, connection) + errList = errList.Append(err) + } + return errList.AsError() +} - // add new connections +func (l *libreswan) ensureRelayConnections(desiredRelayConns map[string]*vpndriver.Connection) error { + errList := errorlist.List{} for name, connection := range desiredRelayConns { err := l.connectToEndpoint(name, connection) errList = errList.Append(err) - if l.centralGw.NodeName == l.nodeName { - err = l.ensureRavenSkipNAT(connection) - if err != nil { - errList = errList.Append(err) - } - } } - return errList.AsError() } -func (l *libreswan) ensureRavenSkipNAT(connection *vpndriver.Connection) errorlist.List { - errList := errorlist.List{} - for _, subnet := range l.centralGw.Subnets { - if connection.LocalSubnet == subnet || connection.RemoteSubnet == subnet { - return errList - } +func (l *libreswan) ensureRavenSkipNAT(network *types.Network) error { + if !vpndriveripset.IsGatewayRole(network, l.nodeName) { + klog.Infof("node %s is not gateway, skip add skip nat", l.nodeName) + return nil } - // for raven skip nat - if err := l.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { - errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.RavenPostRoutingChain, err)) + + // The desired and current ipset entries calculated from given network. + // The key is ip set entry + var err error + l.ipset, err = ipsetutil.New(vpndriveripset.RavenSkipNatSet, vpndriveripset.RavenSkipNatSetType, ipsetutil.IpsetWrapperOption{KeyFunc: vpndriveripset.KeyFunc}) + if err != nil { + return fmt.Errorf("error ensure ipset %s, type %s", vpndriveripset.RavenSkipNatSet, vpndriveripset.RavenSkipNatSetType) } - if err := l.iptables.InsertIfNotExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, 1, "-m", "comment", "--comment", "raven traffic should skip NAT", "-j", iptablesutil.RavenPostRoutingChain); err != nil { - errList = errList.Append(fmt.Errorf("error adding chain %s rule: %s", iptablesutil.PostRoutingChain, err)) + currentSet, err := networkutil.ListIPSetOnNode(l.ipset) + if err != nil { + return fmt.Errorf("error listing ip set %s on node: %s", l.ipset.Name(), err.Error()) } - if err := l.iptables.AppendIfNotExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-s", connection.LocalSubnet, "-d", connection.RemoteSubnet, "-j", "ACCEPT"); err != nil { - errList = errList.Append(fmt.Errorf("error adding chain %s rule: %s", iptablesutil.RavenPostRoutingChain, err)) + desiredSet := vpndriveripset.CalIPSetOnNode(network, l.centralGw, l.nodeName, l.ipset) + err = networkutil.ApplyIPSet(l.ipset, currentSet, desiredSet) + if err != nil { + return fmt.Errorf("error applying ip set: %s", err) } - return errList -} -func (l *libreswan) deleteRavenSkipNAT(connection *vpndriver.Connection) errorlist.List { - errList := errorlist.List{} - err := l.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) - if err != nil { - errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err)) + // for raven skip nat + if err = l.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { + return fmt.Errorf("error create %s chain: %s", iptablesutil.RavenPostRoutingChain, err) } - for _, subnet := range l.centralGw.Subnets { - if connection.LocalSubnet == subnet || connection.RemoteSubnet == subnet { - return errList - } + if err = l.iptables.InsertIfNotExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, 1, "-m", "comment", "--comment", "raven traffic should skip NAT", "-j", iptablesutil.RavenPostRoutingChain); err != nil { + return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.PostRoutingChain, err) } - err = l.iptables.DeleteIfExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-s", connection.LocalSubnet, "-d", connection.RemoteSubnet, "-j", "ACCEPT") - if err != nil { - errList = errList.Append(fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.RavenPostRoutingChain, err)) + if err = l.iptables.AppendIfNotExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-m", "set", "--match-set", vpndriveripset.RavenSkipNatSet, "src,dst", "-j", "ACCEPT"); err != nil { + return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.RavenPostRoutingChain, err) } - return errList + + return nil } func (l *libreswan) computeDesiredConnections(network *types.Network) (map[string]*vpndriver.Connection, map[string]*vpndriver.Connection) { @@ -456,6 +517,24 @@ func whackCmdFn(args ...string) error { return nil } +func ipsecCmdFn(args ...string) (*bytes.Buffer, error) { + var err error + var output bytes.Buffer + for i := 0; i < 5; i++ { + cmd := exec.Command("ipsec", args...) + cmd.Stdout = &output + err = cmd.Run() + if err == nil { + break + } + time.Sleep(1 * time.Second) + } + if err != nil { + return nil, fmt.Errorf("error ipsec with %v, error %s", args, err.Error()) + } + return &output, nil +} + func (l *libreswan) whackDelConnection(conn string) error { return whackCmd("--delete", "--name", conn) } @@ -466,43 +545,36 @@ func connectionName(localID, remoteID, leftSubnet, rightSubnet string) string { func (l *libreswan) Cleanup() error { errList := errorlist.List{} - for name := range l.relayConnections { + connections := currentConnections() + for name := range connections { if err := l.whackDelConnection(name); err != nil { errList = errList.Append(err) klog.ErrorS(err, "fail to delete connection", "connectionName", name) } - if l.centralGw != nil && l.centralGw.NodeName == l.nodeName { - if conn, ok := l.relayConnections[name]; ok && conn != nil { - err := l.deleteRavenSkipNAT(conn) - if err != nil { - errList = errList.Append(err) - } - } - } } - for name := range l.edgeConnections { - if err := l.whackDelConnection(name); err != nil { - errList = errList.Append(err) - klog.ErrorS(err, "fail to delete connection", "connectionName", name) - } - } - l.relayConnections = make(map[string]*vpndriver.Connection) - l.edgeConnections = make(map[string]*vpndriver.Connection) err := netlinkutil.XfrmPolicyFlush() errList = errList.Append(err) + err = netlinkutil.XfrmStateFlush() + errList = errList.Append(err) + + err = vpndriveripset.CleanupRavenSkipNATIPSet() + if err != nil { + errList = errList.Append(fmt.Errorf("error cleanup ipset %s, %s", vpndriveripset.RavenSkipNatSet, err.Error())) + } err = l.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) if err != nil { - errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err)) + errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err.Error())) } err = l.iptables.DeleteIfExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, "-m", "comment", "--comment", "raven traffic should skip NAT", "-j", iptablesutil.RavenPostRoutingChain) if err != nil { - errList = errList.Append(fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.PostRoutingChain, err)) + errList = errList.Append(fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.PostRoutingChain, err.Error())) } err = l.iptables.ClearAndDeleteChain(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) if err != nil { - errList = errList.Append(fmt.Errorf("error deleting %s chain %s", iptablesutil.RavenPostRoutingChain, err)) + errList = errList.Append(fmt.Errorf("error deleting %s chain %s", iptablesutil.RavenPostRoutingChain, err.Error())) } + return errList.AsError() } @@ -542,8 +614,7 @@ func (l *libreswan) runPluto() error { func (l *libreswan) connectToEndpoint(name string, connection *vpndriver.Connection) errorlist.List { errList := errorlist.List{} - if _, ok := l.relayConnections[name]; ok { - klog.InfoS("skipping connect because connection already exists", "connectionName", name) + if _, ok := l.connections[name]; ok { return errList } err := l.whackConnectToEndpoint(name, connection) @@ -552,14 +623,12 @@ func (l *libreswan) connectToEndpoint(name string, connection *vpndriver.Connect klog.ErrorS(err, "error connect connection", "connectionName", name) return errList } - l.relayConnections[name] = connection return errList } func (l *libreswan) connectToEdgeEndpoint(name string, connection *vpndriver.Connection) errorlist.List { errList := errorlist.List{} - if _, ok := l.edgeConnections[name]; ok { - klog.InfoS("skipping connect because connection already exists", "connectionName", name) + if _, ok := l.connections[name]; ok { return errList } err := l.whackConnectToEdgeEndpoint(name, connection) @@ -568,6 +637,5 @@ func (l *libreswan) connectToEdgeEndpoint(name string, connection *vpndriver.Con klog.ErrorS(err, "error connect connection", "connectionName", name) return errList } - l.edgeConnections[name] = connection return errList } diff --git a/pkg/networkengine/vpndriver/libreswan/libreswan_test.go b/pkg/networkengine/vpndriver/libreswan/libreswan_test.go index 529963d..f35e4cd 100644 --- a/pkg/networkengine/vpndriver/libreswan/libreswan_test.go +++ b/pkg/networkengine/vpndriver/libreswan/libreswan_test.go @@ -25,7 +25,6 @@ import ( iptablesutil "github.com/openyurtio/raven/pkg/networkengine/util/iptables" netlinkutil "github.com/openyurtio/raven/pkg/networkengine/util/netlink" - "github.com/openyurtio/raven/pkg/networkengine/vpndriver" "github.com/openyurtio/raven/pkg/types" ) @@ -130,7 +129,7 @@ func TestLibreswan_Apply(t *testing.T) { nodeName: "localGwNode", // It is unable to set up any vpn connections in such case and should clean up vpn connections expectedConnName: map[string]struct{}{}, - shouldCleanup: true, + shouldCleanup: false, network: &types.Network{ LocalEndpoint: &types.Endpoint{ GatewayName: "localGw", @@ -372,9 +371,8 @@ func TestLibreswan_Apply(t *testing.T) { whackCmd = w.whackCmd a := assert.New(t) l := &libreswan{ - relayConnections: make(map[string]*vpndriver.Connection), - edgeConnections: make(map[string]*vpndriver.Connection), - nodeName: types.NodeName(v.nodeName), + connections: make(map[string]bool), + nodeName: types.NodeName(v.nodeName), } var err error l.iptables, err = iptablesutil.New() diff --git a/pkg/networkengine/vpndriver/wireguard/wireguard.go b/pkg/networkengine/vpndriver/wireguard/wireguard.go index e17b24c..eb01238 100644 --- a/pkg/networkengine/vpndriver/wireguard/wireguard.go +++ b/pkg/networkengine/vpndriver/wireguard/wireguard.go @@ -26,7 +26,6 @@ import ( "strconv" "time" - "github.com/openyurtio/api/raven/v1beta1" "github.com/pkg/errors" "github.com/vdobler/ht/errorlist" "github.com/vishvananda/netlink" @@ -36,10 +35,13 @@ import ( "k8s.io/klog/v2" "sigs.k8s.io/controller-runtime/pkg/client" + "github.com/openyurtio/api/raven/v1beta1" "github.com/openyurtio/raven/cmd/agent/app/config" networkutil "github.com/openyurtio/raven/pkg/networkengine/util" + ipsetutil "github.com/openyurtio/raven/pkg/networkengine/util/ipset" iptablesutil "github.com/openyurtio/raven/pkg/networkengine/util/iptables" "github.com/openyurtio/raven/pkg/networkengine/vpndriver" + vpnipset "github.com/openyurtio/raven/pkg/networkengine/vpndriver/ipset" "github.com/openyurtio/raven/pkg/types" "github.com/openyurtio/raven/pkg/utils" ) @@ -61,6 +63,9 @@ const ( DeviceName = "raven-wg0" // DefaultListenPort specifies port of WireGuard listened. DefaultListenPort = 4500 + + ravenSkipNatSet = "raven-skip-nat-set" + ravenSkipNatSetType = "hash:net,net" ) var findCentralGw = vpndriver.FindCentralGwFn @@ -78,10 +83,10 @@ type wireguard struct { psk wgtypes.Key wgLink netlink.Link - relayConnections map[string]*vpndriver.Connection - edgeConnections map[string]*vpndriver.Connection iptables iptablesutil.IPTablesInterface + ipset ipsetutil.IPSetInterface nodeName types.NodeName + centralGw *types.Endpoint ravenClient client.Client listenPort int keepaliveInterval int @@ -93,8 +98,6 @@ func New(cfg *config.Config) (vpndriver.Driver, error) { port = DefaultListenPort } return &wireguard{ - relayConnections: make(map[string]*vpndriver.Connection), - edgeConnections: make(map[string]*vpndriver.Connection), nodeName: types.NodeName(cfg.NodeName), ravenClient: cfg.Manager.GetClient(), listenPort: port, @@ -217,56 +220,78 @@ func (w *wireguard) ensureWgLink(network *types.Network, routeDriverMTUFn func(* return nil } -func (w *wireguard) createConnections(network *types.Network) error { +func (w *wireguard) ensureConnections(network *types.Network) error { desiredEdgeConns, desiredRelayConns, centralAllowedIPs := w.computeDesiredConnections(network) if len(desiredEdgeConns) == 0 && len(desiredRelayConns) == 0 { klog.Infof("no desired connections, cleaning vpn connections") return w.Cleanup() } - klog.Infof("desired edge connections: %+v, desired relay connections: %+v", desiredEdgeConns, desiredRelayConns) - centralGw := findCentralGw(network) - if err := w.createEdgeConnections(desiredEdgeConns); err != nil { - return err + var err error + + peers := w.currentPeers() + klog.Infof("current peers: %v", peers) + + if err = w.deleteUndesiredPeers(peers, desiredEdgeConns, desiredRelayConns); err != nil { + return fmt.Errorf("ensure edge-edge peers error %s", err.Error()) } - if err := w.createRelayConnections(desiredRelayConns, centralAllowedIPs, centralGw); err != nil { - return err + + if err = w.ensureEdgePeers(desiredEdgeConns); err != nil { + return fmt.Errorf("ensure edge-edge peers error %s", err.Error()) + } + if err = w.ensureRelayPeers(desiredRelayConns, centralAllowedIPs); err != nil { + return fmt.Errorf("ensure cloud-edge peers error %s", err.Error()) + } + + if err = w.ensureRavenSkipNAT(network); err != nil { + return fmt.Errorf("ensure raven skip nat error %s", err.Error()) } return nil } -func (w *wireguard) createEdgeConnections(desiredEdgeConns map[string]*vpndriver.Connection) error { - if len(desiredEdgeConns) == 0 { - klog.Infof("no desired edge connections") - return nil +func (w *wireguard) currentPeers() map[string]wgtypes.Peer { + set := make(map[string]wgtypes.Peer) + dev, err := w.wgClient.Device(DeviceName) + if err != nil { + klog.Errorf("can not found wireguard device %s, error %s", DeviceName, err.Error()) + return set } + for _, peer := range dev.Peers { + set[peer.PublicKey.String()] = peer + } + return set +} - for connName, connection := range w.edgeConnections { - if _, ok := desiredEdgeConns[connName]; !ok { - remoteKey := keyFromEndpoint(connection.RemoteEndpoint) - if err := w.removePeer(remoteKey); err == nil { - delete(w.edgeConnections, connName) - } +func (w *wireguard) deleteUndesiredPeers(currentConns map[string]wgtypes.Peer, desiredEdgeConns, desiredRelayConns map[string]*vpndriver.Connection) error { + errList := errorlist.List{} + desiredPeers := make(map[string]struct{}) + for _, connection := range desiredEdgeConns { + desiredPeers[keyFromEndpoint(connection.RemoteEndpoint).String()] = struct{}{} + } + for _, connection := range desiredRelayConns { + desiredPeers[keyFromEndpoint(connection.RemoteEndpoint).String()] = struct{}{} + } + var err error + for key, peer := range currentConns { + if _, ok := desiredPeers[key]; !ok { + err = w.removePeer(&peer.PublicKey) + errList = errList.Append(err) } } + return errList.AsError() +} +func (w *wireguard) ensureEdgePeers(desiredEdgeConns map[string]*vpndriver.Connection) error { + if len(desiredEdgeConns) == 0 { + klog.Infof("no desired edge connections") + return nil + } peerConfigs := make([]wgtypes.PeerConfig, 0) - for name, newConn := range desiredEdgeConns { - newKey := keyFromEndpoint(newConn.RemoteEndpoint) - - if oldConn, ok := w.edgeConnections[name]; ok { - oldKey := keyFromEndpoint(oldConn.RemoteEndpoint) - if oldKey.String() != newKey.String() { - if err := w.removePeer(oldKey); err == nil { - delete(w.edgeConnections, name) - } - } - } - + for _, newConn := range desiredEdgeConns { klog.InfoS("create edge-to-edge connection", "c", newConn) - + newKey := keyFromEndpoint(newConn.RemoteEndpoint) allowedIPs := parseSubnets(newConn.RemoteEndpoint.Subnets) ka := time.Duration(w.keepaliveInterval) var remotePort int @@ -284,59 +309,29 @@ func (w *wireguard) createEdgeConnections(desiredEdgeConns map[string]*vpndriver IP: net.ParseIP(newConn.RemoteEndpoint.PublicIP), Port: remotePort, }, - PersistentKeepaliveInterval: &ka, ReplaceAllowedIPs: true, AllowedIPs: allowedIPs, }) } - - if err := w.wgClient.ConfigureDevice(DeviceName, wgtypes.Config{ + return w.wgClient.ConfigureDevice(DeviceName, wgtypes.Config{ ReplacePeers: true, Peers: peerConfigs, - }); err != nil { - return fmt.Errorf("error add peers: %v", err) - } - - w.edgeConnections = desiredEdgeConns - - return nil + }) } -func (w *wireguard) createRelayConnections(desiredRelayConns map[string]*vpndriver.Connection, centralAllowedIPs []string, centralGw *types.Endpoint) error { +func (w *wireguard) ensureRelayPeers(desiredRelayConns map[string]*vpndriver.Connection, centralAllowedIPs []string) error { if len(desiredRelayConns) == 0 { klog.Infof("no desired relay connections") return nil } - - // delete unwanted connections - for connName, connection := range w.relayConnections { - if _, ok := desiredRelayConns[connName]; !ok { - remoteKey := keyFromEndpoint(connection.RemoteEndpoint) - if err := w.removePeer(remoteKey); err == nil { - delete(w.relayConnections, connName) - } - } - } - // add or update connections peerConfigs := make([]wgtypes.PeerConfig, 0) - for name, newConn := range desiredRelayConns { - newKey := keyFromEndpoint(newConn.RemoteEndpoint) - - if oldConn, ok := w.relayConnections[name]; ok { - oldKey := keyFromEndpoint(oldConn.RemoteEndpoint) - if oldKey.String() != newKey.String() { - if err := w.removePeer(oldKey); err == nil { - delete(w.relayConnections, name) - } - } - } - + for _, newConn := range desiredRelayConns { klog.InfoS("create connection", "c", newConn) - + newKey := keyFromEndpoint(newConn.RemoteEndpoint) allowedIPs := parseSubnets(newConn.RemoteEndpoint.Subnets) - if newConn.RemoteEndpoint.NodeName == centralGw.NodeName { + if w.centralGw != nil && newConn.RemoteEndpoint.NodeName == w.centralGw.NodeName { allowedIPs = append(allowedIPs, parseSubnets(centralAllowedIPs)...) } @@ -357,16 +352,10 @@ func (w *wireguard) createRelayConnections(desiredRelayConns map[string]*vpndriv }) } - if err := w.wgClient.ConfigureDevice(DeviceName, wgtypes.Config{ + return w.wgClient.ConfigureDevice(DeviceName, wgtypes.Config{ ReplacePeers: false, Peers: peerConfigs, - }); err != nil { - return fmt.Errorf("error add peers: %v", err) - } - - w.relayConnections = desiredRelayConns - - return nil + }) } func (w *wireguard) Apply(network *types.Network, routeDriverMTUFn func(*types.Network) (int, error)) error { @@ -378,7 +367,7 @@ func (w *wireguard) Apply(network *types.Network, routeDriverMTUFn func(*types.N klog.Infof("the current node is not gateway node, cleaning vpn connections") return w.Cleanup() } - + w.centralGw = findCentralGw(network) if _, ok := network.LocalEndpoint.Config[PublicKey]; !ok || network.LocalEndpoint.Config[PublicKey] != w.privateKey.PublicKey().String() { err := w.configGatewayPublicKey(string(network.LocalEndpoint.GatewayName), string(network.LocalEndpoint.NodeName)) if err != nil { @@ -387,24 +376,17 @@ func (w *wireguard) Apply(network *types.Network, routeDriverMTUFn func(*types.N return errors.New("retry to config public key") } - centralGw := findCentralGw(network) - if centralGw.NodeName == w.nodeName { - if err := w.ensureRavenSkipNAT(); err != nil { - return fmt.Errorf("error ensure raven skip nat: %s", err) - } - } - if err := w.ensureWgLink(network, routeDriverMTUFn); err != nil { - return fmt.Errorf("fail to ensure wireguar link: %v", err) + return fmt.Errorf("fail to ensure wireguar link: %s", err.Error()) } // 3. Config device route and rules currentRoutes, err := networkutil.ListRoutesOnNode(wgRouteTableID) if err != nil { - return fmt.Errorf("error listing wireguard routes on node: %s", err) + return fmt.Errorf("error listing wireguard routes on node: %s", err.Error()) } currentRules, err := networkutil.ListRulesOnNode(wgRouteTableID) if err != nil { - return fmt.Errorf("error listing wireguard rules on node: %s", err) + return fmt.Errorf("error listing wireguard rules on node: %s", err.Error()) } desiredRoutes := w.calWgRoutes(network) @@ -412,15 +394,52 @@ func (w *wireguard) Apply(network *types.Network, routeDriverMTUFn func(*types.N err = networkutil.ApplyRoutes(currentRoutes, desiredRoutes) if err != nil { - return fmt.Errorf("error applying wireguard routes: %s", err) + return fmt.Errorf("error applying wireguard routes: %s", err.Error()) } err = networkutil.ApplyRules(currentRules, desiredRules) if err != nil { - return fmt.Errorf("error applying wireguard rules: %s", err) + return fmt.Errorf("error applying wireguard rules: %s", err.Error()) + } + + if err = w.ensureConnections(network); err != nil { + return fmt.Errorf("error ensure VPN tunnels: %s", err.Error()) } - if err := w.createConnections(network); err != nil { - return fmt.Errorf("error create VPN tunnels: %v", err) + return nil +} + +func (w *wireguard) ensureRavenSkipNAT(network *types.Network) error { + if !vpnipset.IsGatewayRole(network, w.nodeName) { + klog.Infof("node %s is not gateway, skip add skip nat", w.nodeName) + return nil + } + + // The desired and current ipset entries calculated from given network. + // The key is ip set entry + var err error + w.ipset, err = ipsetutil.New(ravenSkipNatSet, ravenSkipNatSetType, ipsetutil.IpsetWrapperOption{KeyFunc: vpnipset.KeyFunc}) + if err != nil { + return fmt.Errorf("error new ipset %s, type %s", vpnipset.RavenSkipNatSet, vpnipset.RavenSkipNatSetType) + } + currentSet, err := networkutil.ListIPSetOnNode(w.ipset) + if err != nil { + return fmt.Errorf("error listing ip set %s on node: %s", w.ipset.Name(), err.Error()) + } + desiredSet := vpnipset.CalIPSetOnNode(network, w.centralGw, w.nodeName, w.ipset) + err = networkutil.ApplyIPSet(w.ipset, currentSet, desiredSet) + if err != nil { + return fmt.Errorf("error applying ip set: %s", err) + } + + // for raven skip nat + if err = w.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { + return fmt.Errorf("error create %s chain: %s", iptablesutil.RavenPostRoutingChain, err) + } + if err = w.iptables.InsertIfNotExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, 1, "-m", "comment", "--comment", "raven traffic should skip NAT", "-o", DeviceName, "-j", iptablesutil.RavenPostRoutingChain); err != nil { + return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.PostRoutingChain, err) + } + if err = w.iptables.AppendIfNotExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-m", "set", "--match-set", vpnipset.RavenSkipNatSet, "src,dst", "-j", "ACCEPT"); err != nil { + return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.RavenPostRoutingChain, err) } return nil @@ -457,12 +476,24 @@ func (w *wireguard) Cleanup() error { errList = errList.Append(fmt.Errorf("error delete existing wireguard device %q: %v", DeviceName, err)) } - if err = w.deleteRavenSkipNAT(); err != nil { - errList = errList.Append(err) + err = vpnipset.CleanupRavenSkipNATIPSet() + if err != nil { + errList = errList.Append(fmt.Errorf("error cleanup ipset %s, %s", vpnipset.RavenSkipNatSet, err.Error())) + } + + err = w.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) + if err != nil { + errList = errList.Append(fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err)) + } + err = w.iptables.DeleteIfExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, "-m", "comment", "--comment", "raven traffic should skip NAT", "-o", DeviceName, "-j", iptablesutil.RavenPostRoutingChain) + if err != nil { + errList = errList.Append(fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.PostRoutingChain, err)) + } + err = w.iptables.ClearAndDeleteChain(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain) + if err != nil { + errList = errList.Append(fmt.Errorf("error deleting %s chain %s", iptablesutil.RavenPostRoutingChain, err)) } - w.relayConnections = make(map[string]*vpndriver.Connection) - w.edgeConnections = make(map[string]*vpndriver.Connection) return errList.AsError() } @@ -604,32 +635,3 @@ func parseSubnets(subnets []string) []net.IPNet { } return nets } - -func (w *wireguard) ensureRavenSkipNAT() error { - // for raven skip nat - if err := w.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { - return fmt.Errorf("error create %s chain: %s", iptablesutil.RavenPostRoutingChain, err) - } - if err := w.iptables.InsertIfNotExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, 1, "-m", "comment", "--comment", "raven traffic should skip NAT", "-o", "raven-wg0", "-j", iptablesutil.RavenPostRoutingChain); err != nil { - return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.PostRoutingChain, err) - } - if err := w.iptables.AppendIfNotExists(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain, "-j", "ACCEPT"); err != nil { - return fmt.Errorf("error adding chain %s rule: %s", iptablesutil.RavenPostRoutingChain, err) - } - - return nil -} - -func (w *wireguard) deleteRavenSkipNAT() error { - if err := w.iptables.NewChainIfNotExist(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { - return fmt.Errorf("error create %s chain: %s", iptablesutil.PostRoutingChain, err) - } - if err := w.iptables.DeleteIfExists(iptablesutil.NatTable, iptablesutil.PostRoutingChain, "-m", "comment", "--comment", "raven traffic should skip NAT", "-o", "raven-wg0", "-j", iptablesutil.RavenPostRoutingChain); err != nil { - return fmt.Errorf("error deleting %s chain rule: %s", iptablesutil.PostRoutingChain, err) - } - if err := w.iptables.ClearAndDeleteChain(iptablesutil.NatTable, iptablesutil.RavenPostRoutingChain); err != nil { - return fmt.Errorf("error deleting %s chain %s", iptablesutil.RavenPostRoutingChain, err) - } - - return nil -}