Skip to content

Commit

Permalink
udp: impl endpoint-independent filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
ignoramous committed Aug 22, 2024
1 parent 07ccf30 commit 0e8917b
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 65 deletions.
128 changes: 86 additions & 42 deletions intra/netstack/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,20 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)

var errMissingEp = errors.New("not connected to any endpoint")
var (
errMissingEp = errors.New("not connected to any endpoint")
errMissingReq = errors.New("missing forwarder request")

Check failure on line 28 in intra/netstack/udp.go

View workflow job for this annotation

GitHub Actions / 🧭 Lint

var errMissingReq is unused (U1000)
errFilteredOut = errors.New("no eif; filtered out")
)

type DemuxerFn func(dst netip.AddrPort) error

type GUDPConnHandler interface {
// Proxy proxies data between conn (src) and dst.
Proxy(conn *GUDPConn, src, dst netip.AddrPort) bool
// ProxyMux proxies data between conn and multiple destinations.
ProxyMux(conn *GUDPConn, src, dst netip.AddrPort) bool
// ProxyMux proxies data between conn and multiple destinations
// (endpoint-independent mapping).
ProxyMux(conn *GUDPConn, src, dst netip.AddrPort, dmx DemuxerFn) bool
// Error notes the error in connecting src to dst.
Error(conn *GUDPConn, src, dst netip.AddrPort, err error)
// CloseConns closes conns by ids, or all if ids is empty.
Expand All @@ -42,10 +49,17 @@ var _ core.UDPConn = (*GUDPConn)(nil)

type GUDPConn struct {
stack *stack.Stack
c *core.Volatile[*gonet.UDPConn] // conn exposes UDP semantics atop endpoint
src netip.AddrPort // local addr (remote addr in netstack)
dst netip.AddrPort // remote addr (local addr in netstack)
req *udp.ForwarderRequest // egress request as UDP

// conn exposes UDP semantics atop endpoint
c *core.Volatile[*gonet.UDPConn]
// local addr (remote addr in netstack)
// ex: 10.111.222.1:20716; same as endpoint.GetRemoteAddress
src netip.AddrPort
// remote addr (local addr in netstack)
// ex: 10.111.222.3:53; same as endpoint.GetLocalAddress
dst netip.AddrPort

req *udp.ForwarderRequest // egress request as UDP

eim bool // endpoint is muxed
eif bool // endpoint is transparent
Expand Down Expand Up @@ -85,6 +99,21 @@ func udpForwarder(s *stack.Stack, h GUDPConnHandler) *udp.Forwarder {
log.E("ns: udp: forwarder: nil request")
return
}

// owner app tun ns h
// repr socket packet endpoint socket
// type udp fd gudpconn core.minconn
//
// (src, dst) :1111, :53 :1111, :53 :53, :1111 :9999, :53
//
// write :1111 => :53 :1111, :53 :53 => :1111 :9999 => :53
// \ /
// \ /
// (pipe) \ /
// / \
// / \
// / \
// read :1111 <= :53 :1111, :53 :53 <= :1111 :9999 <= :53
id := req.ID()
// src 10.111.222.1:20716; same as endpoint.GetRemoteAddress
src := remoteAddrPort(id)
Expand All @@ -105,10 +134,30 @@ func udpForwarder(s *stack.Stack, h GUDPConnHandler) *udp.Forwarder {
}
}

demux := func(newdst netip.AddrPort) error {
if newdst == dst {
log.D("ns: udp: demuxer: no-op; src(%v) same as dst(%v)", src, newdst)
return nil
}
if !gc.eif {
return errFilteredOut
}
newgc := makeGUDPConn(s, nil /*not a forwarder req*/, src, newdst)
if !settings.SingleThreaded.Load() {
if err := newgc.Establish(); err != nil {
log.E("ns: udp: demuxer: dial: %v; src(%v) dst(%v)", err, src, newdst)
go h.Error(newgc, src, newdst, err)
return err
}
}
go h.Proxy(newgc, src, newdst)
return nil
}

// proxy in a separate gorountine; return immediately
// why? netstack/dispatcher.go:newReadvDispatcher
if gc.eim {
go h.ProxyMux(gc, src, dst)
go h.ProxyMux(gc, src, dst, demux)
} else {
go h.Proxy(gc, src, dst)
}
Expand All @@ -124,47 +173,35 @@ func (g *GUDPConn) conn() *gonet.UDPConn {
}

func (g *GUDPConn) StatefulTeardown() (fin bool) {
_ = g.tryConnect() // establish circuit then teardown
_ = g.Close() // then shutdown
return true // always fin
_ = g.Establish() // establish circuit then teardown
_ = g.Close() // then shutdown
return true // always fin
}

func (g *GUDPConn) Establish() error {
if g.eif {
return g.tryBind()
}
return g.tryConnect()
}

func (g *GUDPConn) tryConnect() error {
if g.ok() { // already setup
return nil
}

wq := new(waiter.Queue)
if endpoint, err := g.req.CreateEndpoint(wq); err != nil {
// ex: CONNECT endpoint for [fd66:f83a:c650::1]:15753 => [fd66:f83a:c650::3]:53; err(no route to host)
log.E("ns: udp: connect: endpoint for %v => %v; err(%v)", g.src, g.dst, err)
return e(err)
} else {
g.c.Store(gonet.NewUDPConn(wq, endpoint))
}
return nil
}

func (g *GUDPConn) tryBind() error {
if g.ok() { // already setup
return nil
}

src, proto := addrport2nsaddr(g.src)
// unconnected socket w/ gonet.DialUDP
if conn, err := gonet.DialUDP(g.stack, &src, nil, proto); err != nil {
log.E("ns: udp: bind: endpoint for %v [=> %v]; err(%v)", g.src, g.dst, err)
return err
if g.req == nil {
src, proto := addrport2nsaddr(g.dst) // remote addr is local addr in netstack
dst, _ := addrport2nsaddr(g.src) // local addr is remote addr in netstack
// ingress socket w/ gonet.DialUDP
if conn, err := gonet.DialUDP(g.stack, &src, &dst, proto); err != nil {
log.E("ns: udp: dial: endpoint for %v => %v; err(%v)", g.src, g.dst, err)
return err
} else {
g.c.Store(conn)
}
} else {
// todo: handle the first pkt like in g.req.CreateEndpoint
g.c.Store(conn)
wq := new(waiter.Queue)
if endpoint, err := g.req.CreateEndpoint(wq); err != nil {
// ex: CONNECT endpoint for [fd66:f83a:c650::1]:15753 => [fd66:f83a:c650::3]:53; err(no route to host)
log.E("ns: udp: connect: endpoint for %v => %v; err(%v)", g.src, g.dst, err)
return e(err)
} else {
g.c.Store(gonet.NewUDPConn(wq, endpoint))
}
}
return nil
}
Expand Down Expand Up @@ -196,7 +233,14 @@ func (g *GUDPConn) Write(data []byte) (int, error) {
// ep(state 3 / info &{2048 17 {53 10.111.222.3 17711 10.111.222.1} 1 10.111.222.3 1} / stats &{{{1}} {{0}} {{{0}} {{0}} {{0}} {{0}}} {{{0}} {{0}} {{0}}} {{{0}} {{0}}} {{{0}} {{0}} {{0}}}})
// 3: status:datagram-connected / {2048=>proto, 17=>transport, {53=>local-port localip 17711=>remote-port remoteip}=>endpoint-id, 1=>bind-nic-id, ip=>bind-addr, 1=>registered-nic-id}
// g.ep may be nil: log.V("ns: writeFrom: from(%v) / ep(state %v / info %v / stats %v)", addr, g.ep.State(), g.ep.Info(), g.ep.Stats())
return c.Write(data)
if g.eif {
// unexpected except in cases of DNS override;
// forward the packet to the dst as got from the first pkt
log.W("ns: udp: Write(To): unexpected; %s <= %s; sz: %d", g.src, g.dst, len(data))
return c.WriteTo(data, net.UDPAddrFromAddrPort(g.dst))
} else {
return c.Write(data)
}
}
return 0, netError(g, "udp", "write", io.ErrClosedPipe)
}
Expand Down
2 changes: 1 addition & 1 deletion intra/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ const (
)

const (
retrytimeout = 1 * time.Minute
retrytimeout = 15 * time.Second
onFlowTimeout = 5 * time.Second
)

Expand Down
57 changes: 36 additions & 21 deletions intra/udp.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ package intra

import (
"errors"
"io"
"net"
"net/netip"
"sync"
Expand Down Expand Up @@ -176,32 +175,48 @@ func (h *udpHandler) onFlow(localaddr, target netip.AddrPort, realips, domains,
}

// ProxyMux implements netstack.GUDPConnHandler
func (h *udpHandler) ProxyMux(gconn *netstack.GUDPConn, src, dst netip.AddrPort) (ok bool) {
func (h *udpHandler) ProxyMux(gconn *netstack.GUDPConn, src, dst netip.AddrPort, dmx netstack.DemuxerFn) (ok bool) {
defer core.Recover(core.Exit11, "udp.ProxyMux")
return h.proxy(gconn, src, dst, true)
return h.proxy(gconn, src, dst, dmx)
}

// Error implements netstack.GUDPConnHandler.
// Must be called from a goroutine.
func (h *udpHandler) Error(gconn *netstack.GUDPConn, src, dst netip.AddrPort, err error) {
ok := h.proxy(gconn, src, dst, false)
log.I("udp: proxy: %v -> %v; err %v; recovered? %t", src, dst, err, ok)
func (h *udpHandler) Error(gconn *netstack.GUDPConn, src, target netip.AddrPort, err error) {
log.W("udp: proxy: %v -> %v; err %v", src, target, err)
if !src.IsValid() || !target.IsValid() {
return
}

realips, domains, probableDomains, blocklists := undoAlg(h.resolver, target.Addr())

// flow is alg/nat-aware, do not change target or any addrs
res := h.onFlow(src, target, realips, domains, probableDomains, blocklists)
cid, pid, uid := splitCidPidUid(res)
smm := udpSummary(cid, pid, uid, target.Addr())

if h.status.Load() == UDPEND {
err = errUdpEnd
} else if pid == ipn.Block {
err = errUdpFirewalled
}
smm.done(err)
}

// Proxy implements netstack.GUDPConnHandler; thread-safe.
// Must be called from a goroutine.
func (h *udpHandler) Proxy(gconn *netstack.GUDPConn, src, dst netip.AddrPort) (ok bool) {
defer core.Recover(core.Exit11, "udp.Proxy")
return h.proxy(gconn, src, dst, false)
return h.proxy(gconn, src, dst, nil)
}

// proxy connects src to dst over a proxy; thread-safe.
func (h *udpHandler) proxy(gconn *netstack.GUDPConn, src, dst netip.AddrPort, mux bool) (ok bool) {

remote, smm, ct, err := h.Connect(gconn, src, dst, mux) // remote may be nil; smm is never nil
func (h *udpHandler) proxy(gconn *netstack.GUDPConn, src, dst netip.AddrPort, dmx netstack.DemuxerFn) (ok bool) {
mux := dmx != nil
remote, smm, err := h.Connect(gconn, src, dst, dmx) // remote may be nil; smm is never nil

if err != nil {
clos(gconn, remote)
core.Close(gconn, remote)
queueSummary(h.smmch, h.done, smm.done(err)) // smm may be nil
log.W("udp: proxy: mux? %t, unexpected %s -> %s; err: %v", mux, src, dst, err)
// dst addrs no longer tracked in h.Connect: h.conntracker.Untrack(ct.CID)
Expand All @@ -217,23 +232,23 @@ func (h *udpHandler) proxy(gconn *netstack.GUDPConn, src, dst netip.AddrPort, mu
cid = smm.ID
}

h.conntracker.Track(ct, gconn, remote)
h.conntracker.Track(cid, gconn, remote)
core.Go("udp.forward: "+cid, func() {
defer h.conntracker.Untrack(ct.CID)
defer h.conntracker.Untrack(cid)
forward(gconn, &rwext{remote}, h.smmch, h.done, smm)
})
return true // ok
}

// Connect connects the proxy server; thread-safe.
func (h *udpHandler) Connect(gconn *netstack.GUDPConn, src, target netip.AddrPort, mux bool) (dst core.UDPConn, smm *SocketSummary, ct core.ConnTuple, err error) {
var px ipn.Proxy = nil
var pc io.Closer = nil

// connect gconn right away, since we assume a duplex-stream from here on
// see: h.Connect -> dnsOverride
if err = gconn.Establish(); err != nil {
log.W("udp: %s gconn connect, mux? %t, err %s => %s", src, target, mux, err)
func (h *udpHandler) Connect(gconn *netstack.GUDPConn, src, target netip.AddrPort, dmx netstack.DemuxerFn) (pc net.Conn, smm *SocketSummary, err error) {
mux := dmx != nil

if !target.IsValid() { // must call h.Bind
err = errUdpUnconnected
} else { // connect gconn right away, since we assume a duplex-stream from here on
// see: h.Connect -> dnsOverride
err = gconn.Establish()
} // err handled after onFlow, so that the listener knows about this gconn/flow

realips, domains, probableDomains, blocklists := undoAlg(h.resolver, target.Addr())
Expand Down
13 changes: 12 additions & 1 deletion intra/udpmux.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

"github.com/celzero/firestack/intra/core"
"github.com/celzero/firestack/intra/log"
"github.com/celzero/firestack/intra/netstack"
)

// from: github.com/pion/transport/blob/03c807b/udp/conn.go
Expand Down Expand Up @@ -61,7 +62,8 @@ type muxer struct {
dxconns chan *demuxconn // never closed
doneCh chan struct{} // stop vending, reading, and routing
once sync.Once
cb func() // muxer.stop() callback (new goroutine)
cb func() // muxer.stop() callback (in a new goroutine)
vnd netstack.DemuxerFn // for new routes in netstack

rmu sync.Mutex // protects routes
routes map[string]*demuxconn // remote addr -> demuxed conn
Expand Down Expand Up @@ -249,6 +251,11 @@ func (x *muxer) route(raddr net.Addr) (*demuxconn, error) {
case x.dxconns <- conn:
x.stats.dxcount.Add(1)
x.routes[addr] = conn
if dst, err := addr2netip(raddr); err == nil && dst.IsValid() {
go x.vnd(dst)
} else { // should never happen
log.E("udp: mux: %s route: invalid addr %s; err: %v", x.cid, raddr, err)
}
log.I("udp: mux: %s route: new for %s; stats: %d",
x.cid, raddr, x.stats)
}
Expand Down Expand Up @@ -488,3 +495,7 @@ func (e *muxTable) dissociate(id string, src netip.AddrPort) {
defer e.Unlock()
delete(e.t, src)
}

func addr2netip(addr net.Addr) (netip.AddrPort, error) {
return netip.ParseAddrPort(addr.String())
}

1 comment on commit 0e8917b

@ignoramous
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#77

Please sign in to comment.