From af8d966663b251cb1a1f2cfb78ae407ba7822820 Mon Sep 17 00:00:00 2001 From: OrangeWolf Date: Wed, 10 Feb 2021 10:34:53 +0800 Subject: [PATCH] feat(inbound): refactor fastestA() --- core/utils/mock/mocker.go | 5 ++ core/utils/ping.go | 35 ++++++++++ core/utils/ping_test.go | 33 +++++++++ inbound/server.go | 3 +- inbound/tools.go | 137 +++++++++++++++++++------------------- inbound/tools_test.go | 94 +++++++++++++------------- 6 files changed, 188 insertions(+), 119 deletions(-) create mode 100644 core/utils/ping.go create mode 100644 core/utils/ping_test.go diff --git a/core/utils/mock/mocker.go b/core/utils/mock/mocker.go index 61b16e8..0ef6309 100644 --- a/core/utils/mock/mocker.go +++ b/core/utils/mock/mocker.go @@ -20,6 +20,11 @@ func (m *Mocker) FuncSeq(target interface{}, outputs []gomonkey.Params) { m.patches = append(m.patches, gomonkey.ApplyFuncSeq(target, cells)) } +// Func gomonkey.ApplyFunc的封装 +func (m *Mocker) Func(target interface{}, double interface{}) { + m.patches = append(m.patches, gomonkey.ApplyFunc(target, double)) +} + // MethodSeq gomonkey.ApplyMethodSeq的封装 func (m *Mocker) MethodSeq(target interface{}, method string, outputs []gomonkey.Params) { var cells []gomonkey.OutputCell diff --git a/core/utils/ping.go b/core/utils/ping.go new file mode 100644 index 0000000..6ec1096 --- /dev/null +++ b/core/utils/ping.go @@ -0,0 +1,35 @@ +package utils + +import ( + "errors" + "net" + "strconv" + "time" + + "github.com/sparrc/go-ping" +) + +// PingIP 向指定ip地址发起icmp ping/tcp ping(如tcpPort大于0),返回值为nil代表ping成功 +func PingIP(ipAddr string, tcpPort int, timeout time.Duration) error { + if tcpPort > 0 { // tcp ping + addr := ipAddr + ":" + strconv.Itoa(tcpPort) + conn, err := net.DialTimeout("tcp", addr, timeout) + if err != nil { + return err + } + _ = conn.Close() + return nil + } + // icmp ping + task, err := ping.NewPinger(ipAddr) + if err != nil { + return err + } + task.Count, task.Timeout = 1, timeout + task.SetPrivileged(true) + task.Run() + if stat := task.Statistics(); stat.PacketsRecv >= 1 { + return nil + } + return errors.New("package loss") +} diff --git a/core/utils/ping_test.go b/core/utils/ping_test.go new file mode 100644 index 0000000..d2889d7 --- /dev/null +++ b/core/utils/ping_test.go @@ -0,0 +1,33 @@ +package utils + +import ( + "fmt" + "net" + "testing" + "time" + + "github.com/agiledragon/gomonkey" + "github.com/sparrc/go-ping" + "github.com/stretchr/testify/assert" + "github.com/wolf-joe/ts-dns/core/utils/mock" +) + +func TestPingIP(t *testing.T) { + // icmp ping + assert.NotNil(t, PingIP("299.299.299.299", -1, time.Second)) + assert.NotNil(t, PingIP("111", -1, time.Second)) + mocker := mock.Mocker{} + defer mocker.Reset() + mocker.MethodSeq(&ping.Pinger{}, "Statistics", []gomonkey.Params{ + {&ping.Statistics{PacketsRecv: 1, AvgRtt: 100}}, + }) + assert.Nil(t, PingIP("1.1.1.1", -1, time.Second)) + + // tcp ping + mocker.FuncSeq(net.DialTimeout, []gomonkey.Params{ + {nil, fmt.Errorf("err")}, {&net.TCPConn{}, nil}, + }) + mocker.MethodSeq(&net.TCPConn{}, "Close", []gomonkey.Params{{nil}}) + assert.NotNil(t, PingIP("1.1.1.1", 80, time.Second)) + assert.Nil(t, PingIP("1.1.1.1", 80, time.Second)) +} diff --git a/inbound/server.go b/inbound/server.go index 1a00fa1..0ea55ba 100644 --- a/inbound/server.go +++ b/inbound/server.go @@ -51,7 +51,6 @@ func (group *Group) CallDNS(ctx context.Context, request *dns.Msg) *dns.Msg { } // 遍历DNS服务器 for _, caller := range group.Callers { - utils.CtxDebug(ctx, "forward question %v to %s", request.Question, caller) if group.Concurrent || group.FastestV4 { go call(caller, request) } else if r := call(caller, request); r != nil { @@ -66,7 +65,7 @@ func (group *Group) CallDNS(ctx context.Context, request *dns.Msg) *dns.Msg { } } } else if group.FastestV4 { // 选择ping值最低的IPv4地址作为返回值 - return fastestA(ch, len(group.Callers), group.TCPPingPort) + return fastestA(ctx, ch, len(group.Callers), group.TCPPingPort) } return nil } diff --git a/inbound/tools.go b/inbound/tools.go index 1dcfc90..5957f9d 100644 --- a/inbound/tools.go +++ b/inbound/tools.go @@ -1,19 +1,19 @@ package inbound import ( - log "github.com/Sirupsen/logrus" + "context" + "net" + "time" + "github.com/miekg/dns" - "github.com/sparrc/go-ping" "github.com/wolf-joe/ts-dns/cache" "github.com/wolf-joe/ts-dns/core/common" - "math" - "net" - "strconv" - "sync" - "time" + "github.com/wolf-joe/ts-dns/core/utils" ) -const maxRtt = 500 +const ( + pingTimeout = 500 * time.Millisecond +) // 如dns响应中所有ipv4地址都在目标范围内(或没有ipv4地址)返回true,否则返回False func allInRange(r *dns.Msg, ipRange *cache.RamSet) bool { @@ -25,76 +25,73 @@ func allInRange(r *dns.Msg, ipRange *cache.RamSet) bool { return true } -// 获取到目标ip的ping值(毫秒),当tcpPort大于0时使用tcp ping,否则使用icmp ping -func pingRtt(ip string, tcpPort int) (rtt int64) { - if tcpPort > 0 { // 使用tcp ping - begin, addr := time.Now(), ip+":"+strconv.Itoa(tcpPort) - conn, err := net.DialTimeout("tcp", addr, time.Millisecond*maxRtt) - if err != nil { - return maxRtt + 1 - } - defer func() { _ = conn.Close() }() - rtt = time.Now().Sub(begin).Milliseconds() - return rtt - } - // 使用icmp ping - task, err := ping.NewPinger(ip) - if err != nil { - return maxRtt + 1 +func fastestA(ctx context.Context, ch <-chan *dns.Msg, chLen int, tcpPort int) *dns.Msg { + if chLen == 0 { + return nil } - task.Count, task.Timeout = 1, time.Millisecond*maxRtt - task.SetPrivileged(true) - task.Run() - stat := task.Statistics() - if stat.PacketsRecv >= 1 { - return stat.AvgRtt.Milliseconds() - } - return maxRtt + 1 -} - -// 从dns msg chan中找出ping值最低的ipv4地址并将其所属的A记录打包返回 -func fastestA(ch chan *dns.Msg, chLen int, tcpPort int) (res *dns.Msg) { - aLock, rttLock, wg := new(sync.Mutex), new(sync.Mutex), new(sync.WaitGroup) - aMap, rttMap := map[string]dns.A{}, map[string]int64{} + const maxGoNum = 15 // 最大并发数量 + // 从msg ch中提取所有IPv4地址,并建立IPv4地址到msg的映射 + allIP := make([]string, 0, maxGoNum) + msgMap := make(map[string]*dns.Msg, maxGoNum) + var fastestMsg *dns.Msg // 最早抵达的msg,当测速失败时使用该响应返回 for i := 0; i < chLen; i++ { - msg := <-ch // 从chan中取出一个msg - if msg != nil { - res = msg // 防止被最后出现的nil覆盖 + msg := <-ch + if len(msgMap) >= maxGoNum { + continue // 消费chLen内剩余msg + } + if fastestMsg == nil { + fastestMsg = msg } for _, a := range common.ExtractA(msg) { - ipv4, aObj := a.A.String(), *a // 用aObj实体变量来防止aMap的键值不一致 - wg.Add(1) - go func() { - defer wg.Done() - aLock.Lock() - if _, ok := aMap[ipv4]; ok { // 防止重复ping - aLock.Unlock() - return + ipV4 := a.A.String() + if _, exists := msgMap[ipV4]; !exists { + allIP = append(allIP, ipV4) + msgMap[ipV4] = msg + if len(msgMap) >= maxGoNum { + break } - aMap[ipv4] = aObj - aLock.Unlock() - // 并发测速 - rtt := pingRtt(ipv4, tcpPort) - rttLock.Lock() - rttMap[ipv4] = rtt - rttLock.Unlock() - }() + } } } - wg.Wait() - // 查找ping最小的ipv4地址 - lowestRtt, fastestIP := int64(math.MaxInt64), "" - for ipv4, rtt := range rttMap { - if rtt < maxRtt && rtt < lowestRtt { - lowestRtt, fastestIP = rtt, ipv4 + switch len(msgMap) { + case 0: // 没有任何IPv4地址 + return fastestMsg + case 1: // 只有一个IPv4地址 + for _, msg := range msgMap { + return msg } } - // 用ping最小的ipv4地址覆盖msg - if aObj := aMap[fastestIP]; fastestIP != "" && res != nil { - common.RemoveA(res) - res.Answer = append(res.Answer, &aObj) - } else { - log.Error("find fastest ipv4 failed") + // 并发测速 + doneCh := make(chan interface{}, 0) + resCh := make(chan string, 1) + for ipV4 := range msgMap { + go func(addr string) { + if err := utils.PingIP(addr, tcpPort, pingTimeout); err == nil { + select { + case resCh <- addr: + case <-doneCh: + } + } + }(ipV4) + } + var fastestIP string // 第一个从resCh返回的地址就是ping值最低的地址 + begin := time.Now() + select { + case fastestIP = <-resCh: + case <-time.After(pingTimeout): + } + cost := time.Now().Sub(begin).Milliseconds() + close(doneCh) + utils.CtxDebug(ctx, "fastest ip of %s: %s(%dms)", allIP, fastestIP, cost) + if msg, exists := msgMap[fastestIP]; exists && fastestIP != "" { + // 删除msg内除fastestIP之外的其它A记录 + for i := 0; i < len(msg.Answer); i++ { + if a, ok := msg.Answer[i].(*dns.A); ok && a.A.String() != fastestIP { + msg.Answer = append(msg.Answer[:i], msg.Answer[i+1:]...) + i-- + } + } + return msg } - return + return fastestMsg } diff --git a/inbound/tools_test.go b/inbound/tools_test.go index bbd0a80..ce1d69e 100644 --- a/inbound/tools_test.go +++ b/inbound/tools_test.go @@ -1,74 +1,74 @@ package inbound import ( - "fmt" + "errors" + "time" - "github.com/agiledragon/gomonkey" + "github.com/Sirupsen/logrus" "github.com/miekg/dns" - "github.com/sparrc/go-ping" "github.com/stretchr/testify/assert" "github.com/wolf-joe/ts-dns/cache" + "github.com/wolf-joe/ts-dns/core/utils" "github.com/wolf-joe/ts-dns/core/utils/mock" "net" "testing" ) -func TestTools(t *testing.T) { +func TestAllInRange(t *testing.T) { resp := &dns.Msg{Answer: []dns.RR{&dns.A{A: net.IPv4(1, 1, 1, 1)}}} assert.False(t, allInRange(resp, cache.NewRamSetByText(""))) assert.True(t, allInRange(resp, cache.NewRamSetByText("1.1.1.1"))) +} + +func TestFastestA(t *testing.T) { + logrus.SetLevel(logrus.DebugLevel) + ctx := utils.NewCtx(nil, 0xffff) + tcpPort := -1 + chLen := 3 + ch := make(chan *dns.Msg, chLen) + emptyMsg := &dns.Msg{} - assert.True(t, pingRtt("299.299.299.299", -1) > maxRtt) - assert.True(t, pingRtt("111", -1) > maxRtt) mocker := mock.Mocker{} defer mocker.Reset() - mocker.MethodSeq(&ping.Pinger{}, "Statistics", []gomonkey.Params{ - {&ping.Statistics{PacketsRecv: 1, AvgRtt: maxRtt - 1}}, + mocker.Func(utils.PingIP, func(string, int, time.Duration) error { + return errors.New("cannot ping now") }) - assert.True(t, pingRtt("1.1.1.1", -1) < maxRtt) - // 测试tcp ping - mocker.FuncSeq(net.DialTimeout, []gomonkey.Params{ - {nil, fmt.Errorf("err")}, {&net.TCPConn{}, nil}, - }) - mocker.MethodSeq(&net.TCPConn{}, "Close", []gomonkey.Params{{nil}}) - assert.True(t, pingRtt("1.1.1.1", 80) > maxRtt) - assert.True(t, pingRtt("1.1.1.1", 80) < maxRtt) -} -func TestTools_FastestA(t *testing.T) { - // 预设ping rtt值 - gomonkey.ApplyFunc(pingRtt, func(ip string, _ int) int64 { - if ip == "1.1.1.1" { - return 100 - } - return 200 - }) + ch <- emptyMsg + assert.Nil(t, fastestA(ctx, ch, 0, tcpPort)) + assert.Equal(t, emptyMsg, fastestA(ctx, ch, 1, tcpPort)) - chLen := 4 - ch := make(chan *dns.Msg, chLen) - ch <- &dns.Msg{Answer: []dns.RR{&dns.A{A: net.IPv4(1, 1, 1, 1)}}} - ch <- &dns.Msg{Answer: []dns.RR{&dns.A{A: net.IPv4(1, 1, 1, 2)}}} - ch <- &dns.Msg{Answer: []dns.RR{&dns.A{A: net.IPv4(1, 1, 1, 2)}}} - ch <- nil - msg := fastestA(ch, chLen, -1) - assert.NotNil(t, msg) - assert.Equal(t, msg.Answer[0].(*dns.A).A.String(), "1.1.1.1") - - chLen = 0 - ch = make(chan *dns.Msg, chLen) - msg = fastestA(ch, chLen, -1) - assert.Nil(t, msg) + ch <- &dns.Msg{Answer: []dns.RR{&dns.A{A: []byte{1, 1, 1, 1}}}} + ch <- &dns.Msg{Answer: []dns.RR{&dns.A{A: []byte{1, 1, 1, 1}}}} + assert.NotNil(t, fastestA(ctx, ch, 2, tcpPort)) - chLen = 1 - ch = make(chan *dns.Msg, chLen) + makeMsg := func() *dns.Msg { + msg := &dns.Msg{} + for i := byte(1); i < 255; i++ { + msg.Answer = append(msg.Answer, &dns.A{A: []byte{1, 1, 1, i}}) + } + return msg + } + msg := makeMsg() + ch <- nil + ch <- msg ch <- nil - msg = fastestA(ch, chLen, -1) - assert.Nil(t, msg) + assert.Equal(t, msg, fastestA(ctx, ch, chLen, tcpPort)) - chLen = 1 - ch = make(chan *dns.Msg, chLen) - ch <- &dns.Msg{Answer: []dns.RR{&dns.AAAA{}}} - msg = fastestA(ch, chLen, -1) + mocker.Func(utils.PingIP, func(addr string, _ int, _ time.Duration) error { + switch addr { + case "1.1.1.10": + return nil + default: + return errors.New("timeout") + } + }) + ch <- nil + ch <- makeMsg() + ch <- nil + msg = fastestA(ctx, ch, chLen, tcpPort) assert.NotNil(t, msg) + assert.Equal(t, 1, len(msg.Answer)) + assert.Equal(t, "1.1.1.10", msg.Answer[0].(*dns.A).A.String()) }