Skip to content

Commit

Permalink
Refactor(tunnel): modularize tunnel pkg
Browse files Browse the repository at this point in the history
  • Loading branch information
xjasonlyu committed Aug 29, 2024
1 parent 601601a commit 86a863d
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 74 deletions.
5 changes: 2 additions & 3 deletions engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/xjasonlyu/tun2socks/v2/core/device"
"github.com/xjasonlyu/tun2socks/v2/core/option"
"github.com/xjasonlyu/tun2socks/v2/dialer"
"github.com/xjasonlyu/tun2socks/v2/engine/mirror"
"github.com/xjasonlyu/tun2socks/v2/log"
"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/restapi"
Expand Down Expand Up @@ -130,7 +129,7 @@ func general(k *Key) error {
if k.UDPTimeout < time.Second {
return errors.New("invalid udp timeout value")
}
tunnel.SetUDPTimeout(k.UDPTimeout)
tunnel.T().SetUDPTimeout(k.UDPTimeout)
}
return nil
}
Expand Down Expand Up @@ -226,7 +225,7 @@ func netstack(k *Key) (err error) {

if _defaultStack, err = core.CreateStack(&core.Config{
LinkEndpoint: _defaultDevice,
TransportHandler: &mirror.Tunnel{},
TransportHandler: tunnel.T(),
MulticastGroups: multicastGroups,
Options: opts,
}); err != nil {
Expand Down
18 changes: 0 additions & 18 deletions engine/mirror/tunnel.go

This file was deleted.

7 changes: 0 additions & 7 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ func SetDialer(d Dialer) {
_defaultDialer = d
}

// Dial uses default Dialer to dial TCP.
func Dial(metadata *M.Metadata) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
defer cancel()
return _defaultDialer.DialContext(ctx, metadata)
}

// DialContext uses default Dialer to dial TCP with context.
func DialContext(ctx context.Context, metadata *M.Metadata) (net.Conn, error) {
return _defaultDialer.DialContext(ctx, metadata)
Expand Down
50 changes: 50 additions & 0 deletions tunnel/global.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package tunnel

import (
"context"
"net"
"sync"

M "github.com/xjasonlyu/tun2socks/v2/metadata"
"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
)

var (
_globalMu sync.RWMutex
_globalT *Tunnel
)

func init() {
ReplaceGlobals(New(wrapper{}, statistic.DefaultManager))
go T().Process()
}

type wrapper struct{}

func (wrapper) DialContext(ctx context.Context, metadata *M.Metadata) (net.Conn, error) {
return proxy.DialContext(ctx, metadata)
}

func (wrapper) DialUDP(metadata *M.Metadata) (net.PacketConn, error) {
return proxy.DialUDP(metadata)
}

// T returns the global Tunnel, which can be reconfigured with
// ReplaceGlobals. It's safe for concurrent use.
func T() *Tunnel {
_globalMu.RLock()
t := _globalT
_globalMu.RUnlock()
return t
}

// ReplaceGlobals replaces the global Tunnel, and returns a function
// to restore the original values. It's safe for concurrent use.
func ReplaceGlobals(t *Tunnel) func() {
_globalMu.Lock()
prev := _globalT
_globalT = t
_globalMu.Unlock()
return func() { ReplaceGlobals(prev) }
}
10 changes: 0 additions & 10 deletions tunnel/statistic/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ func NewTCPTracker(conn net.Conn, metadata *M.Metadata, manager *Manager) net.Co
return tt
}

// DefaultTCPTracker returns a new net.Conn(*tcpTacker) with default manager.
func DefaultTCPTracker(conn net.Conn, metadata *M.Metadata) net.Conn {
return NewTCPTracker(conn, metadata, DefaultManager)
}

func (tt *tcpTracker) ID() string {
return tt.UUID.String()
}
Expand Down Expand Up @@ -120,11 +115,6 @@ func NewUDPTracker(conn net.PacketConn, metadata *M.Metadata, manager *Manager)
return ut
}

// DefaultUDPTracker returns a new net.PacketConn(*udpTacker) with default manager.
func DefaultUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn {
return NewUDPTracker(conn, metadata, DefaultManager)
}

func (ut *udpTracker) ID() string {
return ut.UUID.String()
}
Expand Down
15 changes: 10 additions & 5 deletions tunnel/tcp.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tunnel

import (
"context"
"io"
"net"
"sync"
Expand All @@ -10,16 +11,17 @@ import (
"github.com/xjasonlyu/tun2socks/v2/core/adapter"
"github.com/xjasonlyu/tun2socks/v2/log"
M "github.com/xjasonlyu/tun2socks/v2/metadata"
"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
)

const (
// tcpConnectTimeout is the default timeout for TCP handshake.
tcpConnectTimeout = 5 * time.Second
// tcpWaitTimeout implements a TCP half-close timeout.
tcpWaitTimeout = 60 * time.Second
)

func handleTCPConn(originConn adapter.TCPConn) {
func (t *Tunnel) handleTCPConn(originConn adapter.TCPConn) {
defer originConn.Close()

id := originConn.ID()
Expand All @@ -31,21 +33,24 @@ func handleTCPConn(originConn adapter.TCPConn) {
DstPort: id.LocalPort,
}

remoteConn, err := proxy.Dial(metadata)
ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
defer cancel()

remoteConn, err := t.dialer.DialContext(ctx, metadata)
if err != nil {
log.Warnf("[TCP] dial %s: %v", metadata.DestinationAddress(), err)
return
}
metadata.MidIP, metadata.MidPort = parseAddr(remoteConn.LocalAddr())

remoteConn = statistic.DefaultTCPTracker(remoteConn, metadata)
remoteConn = statistic.NewTCPTracker(remoteConn, metadata, t.manager)
defer remoteConn.Close()

log.Infof("[TCP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress())
pipe(originConn, remoteConn)
}

// pipe copies copy data to & from provided net.Conn(s) bidirectionally.
// pipe copies data to & from provided net.Conn(s) bidirectionally.
func pipe(origin, remote net.Conn) {
wg := sync.WaitGroup{}
wg.Add(2)
Expand Down
80 changes: 64 additions & 16 deletions tunnel/tunnel.go
Original file line number Diff line number Diff line change
@@ -1,36 +1,84 @@
package tunnel

import (
"context"
"sync"
"time"

"github.com/xjasonlyu/tun2socks/v2/core/adapter"
"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
)

// Unbuffered TCP/UDP queues.
var (
_tcpQueue = make(chan adapter.TCPConn)
_udpQueue = make(chan adapter.UDPConn)
)
const udpSessionTimeout = 60 * time.Second

var _ adapter.TransportHandler = (*Tunnel)(nil)

type Tunnel struct {
// Unbuffered TCP/UDP queues.
tcpQueue chan adapter.TCPConn
udpQueue chan adapter.UDPConn

// UDP session timeout.
udpTimeout time.Duration

func init() {
go process()
dialer proxy.Dialer
manager *statistic.Manager

once sync.Once
cancel context.CancelFunc
}

func New(dialer proxy.Dialer, manager *statistic.Manager) *Tunnel {
return &Tunnel{
tcpQueue: make(chan adapter.TCPConn),
udpQueue: make(chan adapter.UDPConn),
udpTimeout: udpSessionTimeout,
dialer: dialer,
manager: manager,
cancel: func() {},
}
}

// TCPIn return fan-in TCP queue.
func TCPIn() chan<- adapter.TCPConn {
return _tcpQueue
func (t *Tunnel) TCPIn() chan<- adapter.TCPConn {
return t.tcpQueue
}

// UDPIn return fan-in UDP queue.
func UDPIn() chan<- adapter.UDPConn {
return _udpQueue
func (t *Tunnel) UDPIn() chan<- adapter.UDPConn {
return t.udpQueue
}

func process() {
func (t *Tunnel) HandleTCP(conn adapter.TCPConn) {
t.TCPIn() <- conn
}

func (t *Tunnel) HandleUDP(conn adapter.UDPConn) {
t.UDPIn() <- conn
}

func (t *Tunnel) process(ctx context.Context) {
for {
select {
case conn := <-_tcpQueue:
go handleTCPConn(conn)
case conn := <-_udpQueue:
go handleUDPConn(conn)
case conn := <-t.tcpQueue:
go t.handleTCPConn(conn)
case conn := <-t.udpQueue:
go t.handleUDPConn(conn)
case <-ctx.Done():
return
}
}
}

func (t *Tunnel) Process() {
t.once.Do(func() {
ctx, cancel := context.WithCancel(context.Background())
t.cancel = cancel
t.process(ctx)
})
}

func (t *Tunnel) Close() {
t.cancel()
}
26 changes: 11 additions & 15 deletions tunnel/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,15 @@ import (
"github.com/xjasonlyu/tun2socks/v2/core/adapter"
"github.com/xjasonlyu/tun2socks/v2/log"
M "github.com/xjasonlyu/tun2socks/v2/metadata"
"github.com/xjasonlyu/tun2socks/v2/proxy"
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
)

// _udpSessionTimeout is the default timeout for each UDP session.
var _udpSessionTimeout = 60 * time.Second

func SetUDPTimeout(t time.Duration) {
_udpSessionTimeout = t
func (t *Tunnel) SetUDPTimeout(timeout time.Duration) {
t.udpTimeout = timeout
}

// TODO: Port Restricted NAT support.
func handleUDPConn(uc adapter.UDPConn) {
func (t *Tunnel) handleUDPConn(uc adapter.UDPConn) {
defer uc.Close()

id := uc.ID()
Expand All @@ -34,14 +30,14 @@ func handleUDPConn(uc adapter.UDPConn) {
DstPort: id.LocalPort,
}

pc, err := proxy.DialUDP(metadata)
pc, err := t.dialer.DialUDP(metadata)
if err != nil {
log.Warnf("[UDP] dial %s: %v", metadata.DestinationAddress(), err)
return
}
metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr())

pc = statistic.DefaultUDPTracker(pc, metadata)
pc = statistic.NewUDPTracker(pc, metadata, t.manager)
defer pc.Close()

var remote net.Addr
Expand All @@ -53,22 +49,22 @@ func handleUDPConn(uc adapter.UDPConn) {
pc = newSymmetricNATPacketConn(pc, metadata)

log.Infof("[UDP] %s <-> %s", metadata.SourceAddress(), metadata.DestinationAddress())
pipePacket(uc, pc, remote)
pipePacket(uc, pc, remote, t.udpTimeout)
}

func pipePacket(origin, remote net.PacketConn, to net.Addr) {
func pipePacket(origin, remote net.PacketConn, to net.Addr, timeout time.Duration) {
wg := sync.WaitGroup{}
wg.Add(2)

go unidirectionalPacketStream(remote, origin, to, "origin->remote", &wg)
go unidirectionalPacketStream(origin, remote, nil, "remote->origin", &wg)
go unidirectionalPacketStream(remote, origin, to, "origin->remote", &wg, timeout)
go unidirectionalPacketStream(origin, remote, nil, "remote->origin", &wg, timeout)

wg.Wait()
}

func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string, wg *sync.WaitGroup) {
func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string, wg *sync.WaitGroup, timeout time.Duration) {
defer wg.Done()
if err := copyPacketData(dst, src, to, _udpSessionTimeout); err != nil {
if err := copyPacketData(dst, src, to, timeout); err != nil {
log.Debugf("[UDP] copy data for %s: %v", dir, err)
}
}
Expand Down

0 comments on commit 86a863d

Please sign in to comment.