From 6a96a614757807709a2bf68729b16a9f9308dfb2 Mon Sep 17 00:00:00 2001 From: xjasonlyu Date: Thu, 29 Aug 2024 10:30:47 -0400 Subject: [PATCH] up --- engine/engine.go | 2 +- tunnel/global.go | 23 +++++------------------ tunnel/tcp.go | 2 +- tunnel/tunnel.go | 44 +++++++++++++++++++++++++++++++++++++------- tunnel/udp.go | 8 ++------ 5 files changed, 46 insertions(+), 33 deletions(-) diff --git a/engine/engine.go b/engine/engine.go index ec8fd78e..2ad28fc8 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -191,7 +191,7 @@ func netstack(k *Key) (err error) { if _defaultProxy, err = parseProxy(k.Proxy); err != nil { return } - proxy.SetDialer(_defaultProxy) + tunnel.T().SetDialer(_defaultProxy) if _defaultDevice, err = parseDevice(k.Device, uint32(k.MTU)); err != nil { return diff --git a/tunnel/global.go b/tunnel/global.go index 591872e9..00851f2d 100644 --- a/tunnel/global.go +++ b/tunnel/global.go @@ -1,11 +1,8 @@ package tunnel import ( - "context" - "net" "sync" - M "github.com/xjasonlyu/tun2socks/v2/metadata" "github.com/xjasonlyu/tun2socks/v2/proxy" "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" ) @@ -16,22 +13,12 @@ var ( ) func init() { - ReplaceGlobals(New(wrapper{}, statistic.DefaultManager)) + ReplaceGlobal(New(&proxy.Base{}, statistic.DefaultManager)) go T().Process() } -type wrapper struct{} - -func (wrapper) DialContext(ctx context.Context, metadata *M.Metadata) (net.Conn, error) { - return proxy.DialContext(ctx, metadata) -} - -func (wrapper) DialUDP(metadata *M.Metadata) (net.PacketConn, error) { - return proxy.DialUDP(metadata) -} - // T returns the global Tunnel, which can be reconfigured with -// ReplaceGlobals. It's safe for concurrent use. +// ReplaceGlobal. It's safe for concurrent use. func T() *Tunnel { _globalMu.RLock() t := _globalT @@ -39,12 +26,12 @@ func T() *Tunnel { return t } -// ReplaceGlobals replaces the global Tunnel, and returns a function +// ReplaceGlobal replaces the global Tunnel, and returns a function // to restore the original values. It's safe for concurrent use. -func ReplaceGlobals(t *Tunnel) func() { +func ReplaceGlobal(t *Tunnel) func() { _globalMu.Lock() prev := _globalT _globalT = t _globalMu.Unlock() - return func() { ReplaceGlobals(prev) } + return func() { ReplaceGlobal(prev) } } diff --git a/tunnel/tcp.go b/tunnel/tcp.go index 990ea46f..86f0e75e 100644 --- a/tunnel/tcp.go +++ b/tunnel/tcp.go @@ -36,7 +36,7 @@ func (t *Tunnel) handleTCPConn(originConn adapter.TCPConn) { ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout) defer cancel() - remoteConn, err := t.dialer.DialContext(ctx, metadata) + remoteConn, err := t.Dialer().DialContext(ctx, metadata) if err != nil { log.Warnf("[TCP] dial %s: %v", metadata.DestinationAddress(), err) return diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 2aba74bd..f1830a26 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -22,11 +22,15 @@ type Tunnel struct { // UDP session timeout. udpTimeout time.Duration - dialer proxy.Dialer + mu sync.RWMutex + dialer proxy.Dialer + + // Internal statistic.Manager for Tunnel. manager *statistic.Manager - once sync.Once - cancel context.CancelFunc + // Process controls. + procOnce sync.Once + procCancel context.CancelFunc } func New(dialer proxy.Dialer, manager *statistic.Manager) *Tunnel { @@ -36,7 +40,7 @@ func New(dialer proxy.Dialer, manager *statistic.Manager) *Tunnel { udpTimeout: udpSessionTimeout, dialer: dialer, manager: manager, - cancel: func() {}, + procCancel: func() {}, } } @@ -58,6 +62,32 @@ func (t *Tunnel) HandleUDP(conn adapter.UDPConn) { t.UDPIn() <- conn } +func (t *Tunnel) Dialer() proxy.Dialer { + t.mu.RLock() + d := t.dialer + t.mu.RUnlock() + return d +} + +func (t *Tunnel) SetDialer(dialer proxy.Dialer) { + t.mu.Lock() + t.dialer = dialer + t.mu.Unlock() +} + +func (t *Tunnel) UDPTimeout() time.Duration { + t.mu.RLock() + timeout := t.udpTimeout + t.mu.RUnlock() + return timeout +} + +func (t *Tunnel) SetUDPTimeout(timeout time.Duration) { + t.mu.Lock() + t.udpTimeout = timeout + t.mu.Unlock() +} + func (t *Tunnel) process(ctx context.Context) { for { select { @@ -72,13 +102,13 @@ func (t *Tunnel) process(ctx context.Context) { } func (t *Tunnel) Process() { - t.once.Do(func() { + t.procOnce.Do(func() { ctx, cancel := context.WithCancel(context.Background()) - t.cancel = cancel + t.procCancel = cancel t.process(ctx) }) } func (t *Tunnel) Close() { - t.cancel() + t.procCancel() } diff --git a/tunnel/udp.go b/tunnel/udp.go index a78d111d..870b824b 100644 --- a/tunnel/udp.go +++ b/tunnel/udp.go @@ -13,10 +13,6 @@ import ( "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" ) -func (t *Tunnel) SetUDPTimeout(timeout time.Duration) { - t.udpTimeout = timeout -} - // TODO: Port Restricted NAT support. func (t *Tunnel) handleUDPConn(uc adapter.UDPConn) { defer uc.Close() @@ -30,7 +26,7 @@ func (t *Tunnel) handleUDPConn(uc adapter.UDPConn) { DstPort: id.LocalPort, } - pc, err := t.dialer.DialUDP(metadata) + pc, err := t.Dialer().DialUDP(metadata) if err != nil { log.Warnf("[UDP] dial %s: %v", metadata.DestinationAddress(), err) return @@ -49,7 +45,7 @@ func (t *Tunnel) handleUDPConn(uc adapter.UDPConn) { pc = newSymmetricNATPacketConn(pc, metadata) log.Infof("[UDP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress()) - pipePacket(uc, pc, remote, t.udpTimeout) + pipePacket(uc, pc, remote, t.UDPTimeout()) } func pipePacket(origin, remote net.PacketConn, to net.Addr, timeout time.Duration) {