Skip to content

Commit

Permalink
feat(inbound): refactor fastestA()
Browse files Browse the repository at this point in the history
  • Loading branch information
wolf-joe committed Feb 10, 2021
1 parent bc02269 commit af8d966
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 119 deletions.
5 changes: 5 additions & 0 deletions core/utils/mock/mocker.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ func (m *Mocker) FuncSeq(target interface{}, outputs []gomonkey.Params) {
m.patches = append(m.patches, gomonkey.ApplyFuncSeq(target, cells))
}

// Func gomonkey.ApplyFunc的封装
func (m *Mocker) Func(target interface{}, double interface{}) {
m.patches = append(m.patches, gomonkey.ApplyFunc(target, double))
}

// MethodSeq gomonkey.ApplyMethodSeq的封装
func (m *Mocker) MethodSeq(target interface{}, method string, outputs []gomonkey.Params) {
var cells []gomonkey.OutputCell
Expand Down
35 changes: 35 additions & 0 deletions core/utils/ping.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package utils

import (
"errors"
"net"
"strconv"
"time"

"github.com/sparrc/go-ping"
)

// PingIP 向指定ip地址发起icmp ping/tcp ping(如tcpPort大于0),返回值为nil代表ping成功
func PingIP(ipAddr string, tcpPort int, timeout time.Duration) error {
if tcpPort > 0 { // tcp ping
addr := ipAddr + ":" + strconv.Itoa(tcpPort)
conn, err := net.DialTimeout("tcp", addr, timeout)
if err != nil {
return err
}
_ = conn.Close()
return nil
}
// icmp ping
task, err := ping.NewPinger(ipAddr)
if err != nil {
return err
}
task.Count, task.Timeout = 1, timeout
task.SetPrivileged(true)
task.Run()
if stat := task.Statistics(); stat.PacketsRecv >= 1 {
return nil
}
return errors.New("package loss")
}
33 changes: 33 additions & 0 deletions core/utils/ping_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package utils

import (
"fmt"
"net"
"testing"
"time"

"github.com/agiledragon/gomonkey"
"github.com/sparrc/go-ping"
"github.com/stretchr/testify/assert"
"github.com/wolf-joe/ts-dns/core/utils/mock"
)

func TestPingIP(t *testing.T) {
// icmp ping
assert.NotNil(t, PingIP("299.299.299.299", -1, time.Second))
assert.NotNil(t, PingIP("111", -1, time.Second))
mocker := mock.Mocker{}
defer mocker.Reset()
mocker.MethodSeq(&ping.Pinger{}, "Statistics", []gomonkey.Params{
{&ping.Statistics{PacketsRecv: 1, AvgRtt: 100}},
})
assert.Nil(t, PingIP("1.1.1.1", -1, time.Second))

// tcp ping
mocker.FuncSeq(net.DialTimeout, []gomonkey.Params{
{nil, fmt.Errorf("err")}, {&net.TCPConn{}, nil},
})
mocker.MethodSeq(&net.TCPConn{}, "Close", []gomonkey.Params{{nil}})
assert.NotNil(t, PingIP("1.1.1.1", 80, time.Second))
assert.Nil(t, PingIP("1.1.1.1", 80, time.Second))
}
3 changes: 1 addition & 2 deletions inbound/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ func (group *Group) CallDNS(ctx context.Context, request *dns.Msg) *dns.Msg {
}
// 遍历DNS服务器
for _, caller := range group.Callers {
utils.CtxDebug(ctx, "forward question %v to %s", request.Question, caller)
if group.Concurrent || group.FastestV4 {
go call(caller, request)
} else if r := call(caller, request); r != nil {
Expand All @@ -66,7 +65,7 @@ func (group *Group) CallDNS(ctx context.Context, request *dns.Msg) *dns.Msg {
}
}
} else if group.FastestV4 { // 选择ping值最低的IPv4地址作为返回值
return fastestA(ch, len(group.Callers), group.TCPPingPort)
return fastestA(ctx, ch, len(group.Callers), group.TCPPingPort)
}
return nil
}
Expand Down
137 changes: 67 additions & 70 deletions inbound/tools.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
package inbound

import (
log "github.com/Sirupsen/logrus"
"context"
"net"
"time"

"github.com/miekg/dns"
"github.com/sparrc/go-ping"
"github.com/wolf-joe/ts-dns/cache"
"github.com/wolf-joe/ts-dns/core/common"
"math"
"net"
"strconv"
"sync"
"time"
"github.com/wolf-joe/ts-dns/core/utils"
)

const maxRtt = 500
const (
pingTimeout = 500 * time.Millisecond
)

// 如dns响应中所有ipv4地址都在目标范围内(或没有ipv4地址)返回true,否则返回False
func allInRange(r *dns.Msg, ipRange *cache.RamSet) bool {
Expand All @@ -25,76 +25,73 @@ func allInRange(r *dns.Msg, ipRange *cache.RamSet) bool {
return true
}

// 获取到目标ip的ping值(毫秒),当tcpPort大于0时使用tcp ping,否则使用icmp ping
func pingRtt(ip string, tcpPort int) (rtt int64) {
if tcpPort > 0 { // 使用tcp ping
begin, addr := time.Now(), ip+":"+strconv.Itoa(tcpPort)
conn, err := net.DialTimeout("tcp", addr, time.Millisecond*maxRtt)
if err != nil {
return maxRtt + 1
}
defer func() { _ = conn.Close() }()
rtt = time.Now().Sub(begin).Milliseconds()
return rtt
}
// 使用icmp ping
task, err := ping.NewPinger(ip)
if err != nil {
return maxRtt + 1
func fastestA(ctx context.Context, ch <-chan *dns.Msg, chLen int, tcpPort int) *dns.Msg {
if chLen == 0 {
return nil
}
task.Count, task.Timeout = 1, time.Millisecond*maxRtt
task.SetPrivileged(true)
task.Run()
stat := task.Statistics()
if stat.PacketsRecv >= 1 {
return stat.AvgRtt.Milliseconds()
}
return maxRtt + 1
}

// 从dns msg chan中找出ping值最低的ipv4地址并将其所属的A记录打包返回
func fastestA(ch chan *dns.Msg, chLen int, tcpPort int) (res *dns.Msg) {
aLock, rttLock, wg := new(sync.Mutex), new(sync.Mutex), new(sync.WaitGroup)
aMap, rttMap := map[string]dns.A{}, map[string]int64{}
const maxGoNum = 15 // 最大并发数量
// 从msg ch中提取所有IPv4地址,并建立IPv4地址到msg的映射
allIP := make([]string, 0, maxGoNum)
msgMap := make(map[string]*dns.Msg, maxGoNum)
var fastestMsg *dns.Msg // 最早抵达的msg,当测速失败时使用该响应返回
for i := 0; i < chLen; i++ {
msg := <-ch // 从chan中取出一个msg
if msg != nil {
res = msg // 防止被最后出现的nil覆盖
msg := <-ch
if len(msgMap) >= maxGoNum {
continue // 消费chLen内剩余msg
}
if fastestMsg == nil {
fastestMsg = msg
}
for _, a := range common.ExtractA(msg) {
ipv4, aObj := a.A.String(), *a // 用aObj实体变量来防止aMap的键值不一致
wg.Add(1)
go func() {
defer wg.Done()
aLock.Lock()
if _, ok := aMap[ipv4]; ok { // 防止重复ping
aLock.Unlock()
return
ipV4 := a.A.String()
if _, exists := msgMap[ipV4]; !exists {
allIP = append(allIP, ipV4)
msgMap[ipV4] = msg
if len(msgMap) >= maxGoNum {
break
}
aMap[ipv4] = aObj
aLock.Unlock()
// 并发测速
rtt := pingRtt(ipv4, tcpPort)
rttLock.Lock()
rttMap[ipv4] = rtt
rttLock.Unlock()
}()
}
}
}
wg.Wait()
// 查找ping最小的ipv4地址
lowestRtt, fastestIP := int64(math.MaxInt64), ""
for ipv4, rtt := range rttMap {
if rtt < maxRtt && rtt < lowestRtt {
lowestRtt, fastestIP = rtt, ipv4
switch len(msgMap) {
case 0: // 没有任何IPv4地址
return fastestMsg
case 1: // 只有一个IPv4地址
for _, msg := range msgMap {
return msg
}
}
// 用ping最小的ipv4地址覆盖msg
if aObj := aMap[fastestIP]; fastestIP != "" && res != nil {
common.RemoveA(res)
res.Answer = append(res.Answer, &aObj)
} else {
log.Error("find fastest ipv4 failed")
// 并发测速
doneCh := make(chan interface{}, 0)
resCh := make(chan string, 1)
for ipV4 := range msgMap {
go func(addr string) {
if err := utils.PingIP(addr, tcpPort, pingTimeout); err == nil {
select {
case resCh <- addr:
case <-doneCh:
}
}
}(ipV4)
}
var fastestIP string // 第一个从resCh返回的地址就是ping值最低的地址
begin := time.Now()
select {
case fastestIP = <-resCh:
case <-time.After(pingTimeout):
}
cost := time.Now().Sub(begin).Milliseconds()
close(doneCh)
utils.CtxDebug(ctx, "fastest ip of %s: %s(%dms)", allIP, fastestIP, cost)
if msg, exists := msgMap[fastestIP]; exists && fastestIP != "" {
// 删除msg内除fastestIP之外的其它A记录
for i := 0; i < len(msg.Answer); i++ {
if a, ok := msg.Answer[i].(*dns.A); ok && a.A.String() != fastestIP {
msg.Answer = append(msg.Answer[:i], msg.Answer[i+1:]...)
i--
}
}
return msg
}
return
return fastestMsg
}
94 changes: 47 additions & 47 deletions inbound/tools_test.go
Original file line number Diff line number Diff line change
@@ -1,74 +1,74 @@
package inbound

import (
"fmt"
"errors"
"time"

"github.com/agiledragon/gomonkey"
"github.com/Sirupsen/logrus"
"github.com/miekg/dns"
"github.com/sparrc/go-ping"
"github.com/stretchr/testify/assert"
"github.com/wolf-joe/ts-dns/cache"
"github.com/wolf-joe/ts-dns/core/utils"
"github.com/wolf-joe/ts-dns/core/utils/mock"

"net"
"testing"
)

func TestTools(t *testing.T) {
func TestAllInRange(t *testing.T) {
resp := &dns.Msg{Answer: []dns.RR{&dns.A{A: net.IPv4(1, 1, 1, 1)}}}
assert.False(t, allInRange(resp, cache.NewRamSetByText("")))
assert.True(t, allInRange(resp, cache.NewRamSetByText("1.1.1.1")))
}

func TestFastestA(t *testing.T) {
logrus.SetLevel(logrus.DebugLevel)
ctx := utils.NewCtx(nil, 0xffff)
tcpPort := -1
chLen := 3
ch := make(chan *dns.Msg, chLen)
emptyMsg := &dns.Msg{}

assert.True(t, pingRtt("299.299.299.299", -1) > maxRtt)
assert.True(t, pingRtt("111", -1) > maxRtt)
mocker := mock.Mocker{}
defer mocker.Reset()
mocker.MethodSeq(&ping.Pinger{}, "Statistics", []gomonkey.Params{
{&ping.Statistics{PacketsRecv: 1, AvgRtt: maxRtt - 1}},
mocker.Func(utils.PingIP, func(string, int, time.Duration) error {
return errors.New("cannot ping now")
})
assert.True(t, pingRtt("1.1.1.1", -1) < maxRtt)
// 测试tcp ping
mocker.FuncSeq(net.DialTimeout, []gomonkey.Params{
{nil, fmt.Errorf("err")}, {&net.TCPConn{}, nil},
})
mocker.MethodSeq(&net.TCPConn{}, "Close", []gomonkey.Params{{nil}})
assert.True(t, pingRtt("1.1.1.1", 80) > maxRtt)
assert.True(t, pingRtt("1.1.1.1", 80) < maxRtt)
}

func TestTools_FastestA(t *testing.T) {
// 预设ping rtt值
gomonkey.ApplyFunc(pingRtt, func(ip string, _ int) int64 {
if ip == "1.1.1.1" {
return 100
}
return 200
})
ch <- emptyMsg
assert.Nil(t, fastestA(ctx, ch, 0, tcpPort))
assert.Equal(t, emptyMsg, fastestA(ctx, ch, 1, tcpPort))

chLen := 4
ch := make(chan *dns.Msg, chLen)
ch <- &dns.Msg{Answer: []dns.RR{&dns.A{A: net.IPv4(1, 1, 1, 1)}}}
ch <- &dns.Msg{Answer: []dns.RR{&dns.A{A: net.IPv4(1, 1, 1, 2)}}}
ch <- &dns.Msg{Answer: []dns.RR{&dns.A{A: net.IPv4(1, 1, 1, 2)}}}
ch <- nil
msg := fastestA(ch, chLen, -1)
assert.NotNil(t, msg)
assert.Equal(t, msg.Answer[0].(*dns.A).A.String(), "1.1.1.1")

chLen = 0
ch = make(chan *dns.Msg, chLen)
msg = fastestA(ch, chLen, -1)
assert.Nil(t, msg)
ch <- &dns.Msg{Answer: []dns.RR{&dns.A{A: []byte{1, 1, 1, 1}}}}
ch <- &dns.Msg{Answer: []dns.RR{&dns.A{A: []byte{1, 1, 1, 1}}}}
assert.NotNil(t, fastestA(ctx, ch, 2, tcpPort))

chLen = 1
ch = make(chan *dns.Msg, chLen)
makeMsg := func() *dns.Msg {
msg := &dns.Msg{}
for i := byte(1); i < 255; i++ {
msg.Answer = append(msg.Answer, &dns.A{A: []byte{1, 1, 1, i}})
}
return msg
}
msg := makeMsg()
ch <- nil
ch <- msg
ch <- nil
msg = fastestA(ch, chLen, -1)
assert.Nil(t, msg)
assert.Equal(t, msg, fastestA(ctx, ch, chLen, tcpPort))

chLen = 1
ch = make(chan *dns.Msg, chLen)
ch <- &dns.Msg{Answer: []dns.RR{&dns.AAAA{}}}
msg = fastestA(ch, chLen, -1)
mocker.Func(utils.PingIP, func(addr string, _ int, _ time.Duration) error {
switch addr {
case "1.1.1.10":
return nil
default:
return errors.New("timeout")
}
})
ch <- nil
ch <- makeMsg()
ch <- nil
msg = fastestA(ctx, ch, chLen, tcpPort)
assert.NotNil(t, msg)
assert.Equal(t, 1, len(msg.Answer))
assert.Equal(t, "1.1.1.10", msg.Answer[0].(*dns.A).A.String())
}

0 comments on commit af8d966

Please sign in to comment.