From 86a863d237b276458732e77188f1352f3d575124 Mon Sep 17 00:00:00 2001 From: xjasonlyu Date: Wed, 28 Aug 2024 23:10:20 -0400 Subject: [PATCH] Refactor(tunnel): modularize tunnel pkg --- engine/engine.go | 5 +-- engine/mirror/tunnel.go | 18 --------- proxy/proxy.go | 7 ---- tunnel/global.go | 50 +++++++++++++++++++++++ tunnel/statistic/tracker.go | 10 ----- tunnel/tcp.go | 15 ++++--- tunnel/tunnel.go | 80 +++++++++++++++++++++++++++++-------- tunnel/udp.go | 26 +++++------- 8 files changed, 137 insertions(+), 74 deletions(-) delete mode 100644 engine/mirror/tunnel.go create mode 100644 tunnel/global.go diff --git a/engine/engine.go b/engine/engine.go index f513fac3..ec8fd78e 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -16,7 +16,6 @@ import ( "github.com/xjasonlyu/tun2socks/v2/core/device" "github.com/xjasonlyu/tun2socks/v2/core/option" "github.com/xjasonlyu/tun2socks/v2/dialer" - "github.com/xjasonlyu/tun2socks/v2/engine/mirror" "github.com/xjasonlyu/tun2socks/v2/log" "github.com/xjasonlyu/tun2socks/v2/proxy" "github.com/xjasonlyu/tun2socks/v2/restapi" @@ -130,7 +129,7 @@ func general(k *Key) error { if k.UDPTimeout < time.Second { return errors.New("invalid udp timeout value") } - tunnel.SetUDPTimeout(k.UDPTimeout) + tunnel.T().SetUDPTimeout(k.UDPTimeout) } return nil } @@ -226,7 +225,7 @@ func netstack(k *Key) (err error) { if _defaultStack, err = core.CreateStack(&core.Config{ LinkEndpoint: _defaultDevice, - TransportHandler: &mirror.Tunnel{}, + TransportHandler: tunnel.T(), MulticastGroups: multicastGroups, Options: opts, }); err != nil { diff --git a/engine/mirror/tunnel.go b/engine/mirror/tunnel.go deleted file mode 100644 index b753ac32..00000000 --- a/engine/mirror/tunnel.go +++ /dev/null @@ -1,18 +0,0 @@ -package mirror - -import ( - "github.com/xjasonlyu/tun2socks/v2/core/adapter" - "github.com/xjasonlyu/tun2socks/v2/tunnel" -) - -var _ adapter.TransportHandler = (*Tunnel)(nil) - -type Tunnel struct{} - -func (*Tunnel) HandleTCP(conn adapter.TCPConn) { - tunnel.TCPIn() <- conn -} - -func (*Tunnel) HandleUDP(conn adapter.UDPConn) { - tunnel.UDPIn() <- conn -} diff --git a/proxy/proxy.go b/proxy/proxy.go index 27cafc7b..5fbccb97 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -32,13 +32,6 @@ func SetDialer(d Dialer) { _defaultDialer = d } -// Dial uses default Dialer to dial TCP. -func Dial(metadata *M.Metadata) (net.Conn, error) { - ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout) - defer cancel() - return _defaultDialer.DialContext(ctx, metadata) -} - // DialContext uses default Dialer to dial TCP with context. func DialContext(ctx context.Context, metadata *M.Metadata) (net.Conn, error) { return _defaultDialer.DialContext(ctx, metadata) diff --git a/tunnel/global.go b/tunnel/global.go new file mode 100644 index 00000000..591872e9 --- /dev/null +++ b/tunnel/global.go @@ -0,0 +1,50 @@ +package tunnel + +import ( + "context" + "net" + "sync" + + M "github.com/xjasonlyu/tun2socks/v2/metadata" + "github.com/xjasonlyu/tun2socks/v2/proxy" + "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" +) + +var ( + _globalMu sync.RWMutex + _globalT *Tunnel +) + +func init() { + ReplaceGlobals(New(wrapper{}, statistic.DefaultManager)) + go T().Process() +} + +type wrapper struct{} + +func (wrapper) DialContext(ctx context.Context, metadata *M.Metadata) (net.Conn, error) { + return proxy.DialContext(ctx, metadata) +} + +func (wrapper) DialUDP(metadata *M.Metadata) (net.PacketConn, error) { + return proxy.DialUDP(metadata) +} + +// T returns the global Tunnel, which can be reconfigured with +// ReplaceGlobals. It's safe for concurrent use. +func T() *Tunnel { + _globalMu.RLock() + t := _globalT + _globalMu.RUnlock() + return t +} + +// ReplaceGlobals replaces the global Tunnel, and returns a function +// to restore the original values. It's safe for concurrent use. +func ReplaceGlobals(t *Tunnel) func() { + _globalMu.Lock() + prev := _globalT + _globalT = t + _globalMu.Unlock() + return func() { ReplaceGlobals(prev) } +} diff --git a/tunnel/statistic/tracker.go b/tunnel/statistic/tracker.go index 393c200a..ee377613 100644 --- a/tunnel/statistic/tracker.go +++ b/tunnel/statistic/tracker.go @@ -50,11 +50,6 @@ func NewTCPTracker(conn net.Conn, metadata *M.Metadata, manager *Manager) net.Co return tt } -// DefaultTCPTracker returns a new net.Conn(*tcpTacker) with default manager. -func DefaultTCPTracker(conn net.Conn, metadata *M.Metadata) net.Conn { - return NewTCPTracker(conn, metadata, DefaultManager) -} - func (tt *tcpTracker) ID() string { return tt.UUID.String() } @@ -120,11 +115,6 @@ func NewUDPTracker(conn net.PacketConn, metadata *M.Metadata, manager *Manager) return ut } -// DefaultUDPTracker returns a new net.PacketConn(*udpTacker) with default manager. -func DefaultUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn { - return NewUDPTracker(conn, metadata, DefaultManager) -} - func (ut *udpTracker) ID() string { return ut.UUID.String() } diff --git a/tunnel/tcp.go b/tunnel/tcp.go index 03cebab5..5214318b 100644 --- a/tunnel/tcp.go +++ b/tunnel/tcp.go @@ -1,6 +1,7 @@ package tunnel import ( + "context" "io" "net" "sync" @@ -10,16 +11,17 @@ import ( "github.com/xjasonlyu/tun2socks/v2/core/adapter" "github.com/xjasonlyu/tun2socks/v2/log" M "github.com/xjasonlyu/tun2socks/v2/metadata" - "github.com/xjasonlyu/tun2socks/v2/proxy" "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" ) const ( + // tcpConnectTimeout is the default timeout for TCP handshake. + tcpConnectTimeout = 5 * time.Second // tcpWaitTimeout implements a TCP half-close timeout. tcpWaitTimeout = 60 * time.Second ) -func handleTCPConn(originConn adapter.TCPConn) { +func (t *Tunnel) handleTCPConn(originConn adapter.TCPConn) { defer originConn.Close() id := originConn.ID() @@ -31,21 +33,24 @@ func handleTCPConn(originConn adapter.TCPConn) { DstPort: id.LocalPort, } - remoteConn, err := proxy.Dial(metadata) + ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout) + defer cancel() + + remoteConn, err := t.dialer.DialContext(ctx, metadata) if err != nil { log.Warnf("[TCP] dial %s: %v", metadata.DestinationAddress(), err) return } metadata.MidIP, metadata.MidPort = parseAddr(remoteConn.LocalAddr()) - remoteConn = statistic.DefaultTCPTracker(remoteConn, metadata) + remoteConn = statistic.NewTCPTracker(remoteConn, metadata, t.manager) defer remoteConn.Close() log.Infof("[TCP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress()) pipe(originConn, remoteConn) } -// pipe copies copy data to & from provided net.Conn(s) bidirectionally. +// pipe copies data to & from provided net.Conn(s) bidirectionally. func pipe(origin, remote net.Conn) { wg := sync.WaitGroup{} wg.Add(2) diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 8ced53f1..2aba74bd 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -1,36 +1,84 @@ package tunnel import ( + "context" + "sync" + "time" + "github.com/xjasonlyu/tun2socks/v2/core/adapter" + "github.com/xjasonlyu/tun2socks/v2/proxy" + "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" ) -// Unbuffered TCP/UDP queues. -var ( - _tcpQueue = make(chan adapter.TCPConn) - _udpQueue = make(chan adapter.UDPConn) -) +const udpSessionTimeout = 60 * time.Second + +var _ adapter.TransportHandler = (*Tunnel)(nil) + +type Tunnel struct { + // Unbuffered TCP/UDP queues. + tcpQueue chan adapter.TCPConn + udpQueue chan adapter.UDPConn + + // UDP session timeout. + udpTimeout time.Duration -func init() { - go process() + dialer proxy.Dialer + manager *statistic.Manager + + once sync.Once + cancel context.CancelFunc +} + +func New(dialer proxy.Dialer, manager *statistic.Manager) *Tunnel { + return &Tunnel{ + tcpQueue: make(chan adapter.TCPConn), + udpQueue: make(chan adapter.UDPConn), + udpTimeout: udpSessionTimeout, + dialer: dialer, + manager: manager, + cancel: func() {}, + } } // TCPIn return fan-in TCP queue. -func TCPIn() chan<- adapter.TCPConn { - return _tcpQueue +func (t *Tunnel) TCPIn() chan<- adapter.TCPConn { + return t.tcpQueue } // UDPIn return fan-in UDP queue. -func UDPIn() chan<- adapter.UDPConn { - return _udpQueue +func (t *Tunnel) UDPIn() chan<- adapter.UDPConn { + return t.udpQueue } -func process() { +func (t *Tunnel) HandleTCP(conn adapter.TCPConn) { + t.TCPIn() <- conn +} + +func (t *Tunnel) HandleUDP(conn adapter.UDPConn) { + t.UDPIn() <- conn +} + +func (t *Tunnel) process(ctx context.Context) { for { select { - case conn := <-_tcpQueue: - go handleTCPConn(conn) - case conn := <-_udpQueue: - go handleUDPConn(conn) + case conn := <-t.tcpQueue: + go t.handleTCPConn(conn) + case conn := <-t.udpQueue: + go t.handleUDPConn(conn) + case <-ctx.Done(): + return } } } + +func (t *Tunnel) Process() { + t.once.Do(func() { + ctx, cancel := context.WithCancel(context.Background()) + t.cancel = cancel + t.process(ctx) + }) +} + +func (t *Tunnel) Close() { + t.cancel() +} diff --git a/tunnel/udp.go b/tunnel/udp.go index a10e2d47..a78d111d 100644 --- a/tunnel/udp.go +++ b/tunnel/udp.go @@ -10,19 +10,15 @@ import ( "github.com/xjasonlyu/tun2socks/v2/core/adapter" "github.com/xjasonlyu/tun2socks/v2/log" M "github.com/xjasonlyu/tun2socks/v2/metadata" - "github.com/xjasonlyu/tun2socks/v2/proxy" "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" ) -// _udpSessionTimeout is the default timeout for each UDP session. -var _udpSessionTimeout = 60 * time.Second - -func SetUDPTimeout(t time.Duration) { - _udpSessionTimeout = t +func (t *Tunnel) SetUDPTimeout(timeout time.Duration) { + t.udpTimeout = timeout } // TODO: Port Restricted NAT support. -func handleUDPConn(uc adapter.UDPConn) { +func (t *Tunnel) handleUDPConn(uc adapter.UDPConn) { defer uc.Close() id := uc.ID() @@ -34,14 +30,14 @@ func handleUDPConn(uc adapter.UDPConn) { DstPort: id.LocalPort, } - pc, err := proxy.DialUDP(metadata) + pc, err := t.dialer.DialUDP(metadata) if err != nil { log.Warnf("[UDP] dial %s: %v", metadata.DestinationAddress(), err) return } metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr()) - pc = statistic.DefaultUDPTracker(pc, metadata) + pc = statistic.NewUDPTracker(pc, metadata, t.manager) defer pc.Close() var remote net.Addr @@ -53,22 +49,22 @@ func handleUDPConn(uc adapter.UDPConn) { pc = newSymmetricNATPacketConn(pc, metadata) log.Infof("[UDP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress()) - pipePacket(uc, pc, remote) + pipePacket(uc, pc, remote, t.udpTimeout) } -func pipePacket(origin, remote net.PacketConn, to net.Addr) { +func pipePacket(origin, remote net.PacketConn, to net.Addr, timeout time.Duration) { wg := sync.WaitGroup{} wg.Add(2) - go unidirectionalPacketStream(remote, origin, to, "origin->remote", &wg) - go unidirectionalPacketStream(origin, remote, nil, "remote->origin", &wg) + go unidirectionalPacketStream(remote, origin, to, "origin->remote", &wg, timeout) + go unidirectionalPacketStream(origin, remote, nil, "remote->origin", &wg, timeout) wg.Wait() } -func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string, wg *sync.WaitGroup) { +func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string, wg *sync.WaitGroup, timeout time.Duration) { defer wg.Done() - if err := copyPacketData(dst, src, to, _udpSessionTimeout); err != nil { + if err := copyPacketData(dst, src, to, timeout); err != nil { log.Debugf("[UDP] copy data for %s: %v", dir, err) } }