diff --git a/common/observable/iterable.go b/common/observable/iterable.go deleted file mode 100644 index 2ac38b40..00000000 --- a/common/observable/iterable.go +++ /dev/null @@ -1,3 +0,0 @@ -package observable - -type Iterable <-chan any diff --git a/common/observable/observable.go b/common/observable/observable.go deleted file mode 100644 index 5caaf690..00000000 --- a/common/observable/observable.go +++ /dev/null @@ -1,65 +0,0 @@ -package observable - -import ( - "errors" - "sync" -) - -type Observable struct { - iterable Iterable - listener map[Subscription]*Subscriber - mux sync.Mutex - done bool -} - -func (o *Observable) process() { - for item := range o.iterable { - o.mux.Lock() - for _, sub := range o.listener { - sub.Emit(item) - } - o.mux.Unlock() - } - o.close() -} - -func (o *Observable) close() { - o.mux.Lock() - defer o.mux.Unlock() - - o.done = true - for _, sub := range o.listener { - sub.Close() - } -} - -func (o *Observable) Subscribe() (Subscription, error) { - o.mux.Lock() - defer o.mux.Unlock() - if o.done { - return nil, errors.New("observable is closed") - } - subscriber := newSubscriber() - o.listener[subscriber.Out()] = subscriber - return subscriber.Out(), nil -} - -func (o *Observable) UnSubscribe(sub Subscription) { - o.mux.Lock() - defer o.mux.Unlock() - subscriber, exist := o.listener[sub] - if !exist { - return - } - delete(o.listener, sub) - subscriber.Close() -} - -func NewObservable(any Iterable) *Observable { - observable := &Observable{ - iterable: any, - listener: map[Subscription]*Subscriber{}, - } - go observable.process() - return observable -} diff --git a/common/observable/observable_test.go b/common/observable/observable_test.go deleted file mode 100644 index f34b81fc..00000000 --- a/common/observable/observable_test.go +++ /dev/null @@ -1,148 +0,0 @@ -package observable - -import ( - "sync" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "go.uber.org/atomic" -) - -func iterator(item []any) chan any { - ch := make(chan any) - go func() { - time.Sleep(100 * time.Millisecond) - for _, elm := range item { - ch <- elm - } - close(ch) - }() - return ch -} - -func TestObservable(t *testing.T) { - iter := iterator([]any{1, 2, 3, 4, 5}) - src := NewObservable(iter) - data, err := src.Subscribe() - assert.Nil(t, err) - count := 0 - for range data { - count++ - } - assert.Equal(t, count, 5) -} - -func TestObservable_MultiSubscribe(t *testing.T) { - iter := iterator([]any{1, 2, 3, 4, 5}) - src := NewObservable(iter) - ch1, _ := src.Subscribe() - ch2, _ := src.Subscribe() - count := atomic.NewInt32(0) - - var wg sync.WaitGroup - wg.Add(2) - waitCh := func(ch <-chan any) { - for range ch { - count.Inc() - } - wg.Done() - } - go waitCh(ch1) - go waitCh(ch2) - wg.Wait() - assert.Equal(t, int32(10), count.Load()) -} - -func TestObservable_UnSubscribe(t *testing.T) { - iter := iterator([]any{1, 2, 3, 4, 5}) - src := NewObservable(iter) - data, err := src.Subscribe() - assert.Nil(t, err) - src.UnSubscribe(data) - _, open := <-data - assert.False(t, open) -} - -func TestObservable_SubscribeClosedSource(t *testing.T) { - iter := iterator([]any{1}) - src := NewObservable(iter) - data, _ := src.Subscribe() - <-data - - _, closed := src.Subscribe() - assert.NotNil(t, closed) -} - -func TestObservable_UnSubscribeWithNotExistSubscription(t *testing.T) { - sub := Subscription(make(chan any)) - iter := iterator([]any{1}) - src := NewObservable(iter) - src.UnSubscribe(sub) -} - -func TestObservable_SubscribeGoroutineLeak(t *testing.T) { - iter := iterator([]any{1, 2, 3, 4, 5}) - src := NewObservable(iter) - max := 100 - - var list []Subscription - for i := 0; i < max; i++ { - ch, _ := src.Subscribe() - list = append(list, ch) - } - - var wg sync.WaitGroup - wg.Add(max) - waitCh := func(ch <-chan any) { - for range ch { - } - wg.Done() - } - - for _, ch := range list { - go waitCh(ch) - } - wg.Wait() - - for _, sub := range list { - _, more := <-sub - assert.False(t, more) - } - - if len(list) > 0 { - _, more := <-list[0] - assert.False(t, more) - } -} - -func Benchmark_Observable_1000(b *testing.B) { - ch := make(chan any) - o := NewObservable(ch) - num := 1000 - - var subs []Subscription - for i := 0; i < num; i++ { - sub, _ := o.Subscribe() - subs = append(subs, sub) - } - - wg := sync.WaitGroup{} - wg.Add(num) - - b.ResetTimer() - for _, sub := range subs { - go func(s Subscription) { - for range s { - } - wg.Done() - }(sub) - } - - for i := 0; i < b.N; i++ { - ch <- i - } - - close(ch) - wg.Wait() -} diff --git a/common/observable/subscriber.go b/common/observable/subscriber.go deleted file mode 100644 index 0d8559bc..00000000 --- a/common/observable/subscriber.go +++ /dev/null @@ -1,33 +0,0 @@ -package observable - -import ( - "sync" -) - -type Subscription <-chan any - -type Subscriber struct { - buffer chan any - once sync.Once -} - -func (s *Subscriber) Emit(item any) { - s.buffer <- item -} - -func (s *Subscriber) Out() Subscription { - return s.buffer -} - -func (s *Subscriber) Close() { - s.once.Do(func() { - close(s.buffer) - }) -} - -func newSubscriber() *Subscriber { - sub := &Subscriber{ - buffer: make(chan any, 200), - } - return sub -} diff --git a/log/event.go b/log/event.go deleted file mode 100644 index d3fb68aa..00000000 --- a/log/event.go +++ /dev/null @@ -1,39 +0,0 @@ -package log - -import ( - "fmt" - "time" - - "github.com/xjasonlyu/tun2socks/v2/common/observable" -) - -var ( - _logCh = make(chan any) - _source = observable.NewObservable(_logCh) -) - -type Event struct { - Level Level `json:"level"` - Message string `json:"msg"` - Time time.Time `json:"time"` -} - -func newEvent(level Level, format string, args ...any) *Event { - event := &Event{ - Level: level, - Time: time.Now(), - Message: fmt.Sprintf(format, args...), - } - _logCh <- event /* send all events to logCh */ - - return event -} - -func Subscribe() observable.Subscription { - sub, _ := _source.Subscribe() - return sub -} - -func UnSubscribe(sub observable.Subscription) { - _source.UnSubscribe(sub) -} diff --git a/log/log.go b/log/log.go index f1336328..5f98b399 100644 --- a/log/log.go +++ b/log/log.go @@ -3,6 +3,7 @@ package log import ( "io" "os" + "time" "github.com/sirupsen/logrus" "go.uber.org/atomic" @@ -45,19 +46,16 @@ func Fatalf(format string, args ...any) { } func logf(level Level, format string, args ...any) { - event := newEvent(level, format, args...) - if uint32(event.Level) > _defaultLevel.Load() { - return - } - switch level { case DebugLevel: - logrus.WithTime(event.Time).Debugln(event.Message) + logrus.WithTime(time.Now()).Debugf(format, args...) case InfoLevel: - logrus.WithTime(event.Time).Infoln(event.Message) + logrus.WithTime(time.Now()).Infof(format, args...) case WarnLevel: - logrus.WithTime(event.Time).Warnln(event.Message) + logrus.WithTime(time.Now()).Warnf(format, args...) case ErrorLevel: - logrus.WithTime(event.Time).Errorln(event.Message) + logrus.WithTime(time.Now()).Errorf(format, args...) + default: + // nop } } diff --git a/restapi/connections.go b/restapi/connections.go index b8c9c336..6e98215f 100644 --- a/restapi/connections.go +++ b/restapi/connections.go @@ -17,7 +17,7 @@ import ( const defaultInterval = 1000 func init() { - registerMountPoint("/connections", connectionRouter()) + registerEndpoint("/connections", connectionRouter()) } func connectionRouter() http.Handler { diff --git a/restapi/debug.go b/restapi/debug.go index 8a3e780a..7ac02858 100644 --- a/restapi/debug.go +++ b/restapi/debug.go @@ -10,7 +10,7 @@ import ( ) func init() { - registerMountPoint("/debug/pprof/", pprofRouter()) + registerEndpoint("/debug/pprof/", pprofRouter()) } func pprofRouter() http.Handler { diff --git a/restapi/netstats.go b/restapi/netstats.go index e62d9bf8..13dd6870 100644 --- a/restapi/netstats.go +++ b/restapi/netstats.go @@ -19,7 +19,7 @@ func SetStatsFunc(s func() tcpip.Stats) { } func init() { - registerMountPoint("/netstats", http.HandlerFunc(getNetStats)) + registerEndpoint("/netstats", http.HandlerFunc(getNetStats)) } func getNetStats(w http.ResponseWriter, r *http.Request) { diff --git a/restapi/server.go b/restapi/server.go index b55462a4..fdd9b41c 100644 --- a/restapi/server.go +++ b/restapi/server.go @@ -14,7 +14,6 @@ import ( "github.com/gorilla/websocket" V "github.com/xjasonlyu/tun2socks/v2/internal/version" - "github.com/xjasonlyu/tun2socks/v2/log" "github.com/xjasonlyu/tun2socks/v2/tunnel/statistic" ) @@ -25,11 +24,11 @@ var ( }, } - _mountPoints = make(map[string]http.Handler) + _endpoints = make(map[string]http.Handler) ) -func registerMountPoint(pattern string, handler http.Handler) { - _mountPoints[pattern] = handler +func registerEndpoint(pattern string, handler http.Handler) { + _endpoints[pattern] = handler } func Start(addr, token string) error { @@ -46,11 +45,10 @@ func Start(addr, token string) error { r.Group(func(r chi.Router) { r.Use(authenticator(token)) r.Get("/", hello) - r.Get("/logs", getLogs) r.Get("/traffic", traffic) r.Get("/version", version) // attach HTTP handlers - for pattern, handler := range _mountPoints { + for pattern, handler := range _endpoints { r.Mount(pattern, handler) } }) @@ -103,61 +101,6 @@ func authenticator(token string) func(http.Handler) http.Handler { } } -func getLogs(w http.ResponseWriter, r *http.Request) { - lvl := r.URL.Query().Get("level") - if lvl == "" { - lvl = "info" /* default */ - } - - level, err := log.ParseLevel(lvl) - if err != nil { - render.Status(r, http.StatusBadRequest) - render.JSON(w, r, ErrBadRequest) - return - } - - var wsConn *websocket.Conn - if websocket.IsWebSocketUpgrade(r) { - wsConn, err = _upgrader.Upgrade(w, r, nil) - if err != nil { - return - } - } - - if wsConn == nil { - w.Header().Set("Content-Type", "application/json") - render.Status(r, http.StatusOK) - } - - sub := log.Subscribe() - defer log.UnSubscribe(sub) - - buf := &bytes.Buffer{} - for elm := range sub { - buf.Reset() - - e := elm.(*log.Event) - if e.Level > level { - continue - } - - if err = json.NewEncoder(buf).Encode(e); err != nil { - break - } - - if wsConn == nil { - _, err = w.Write(buf.Bytes()) - w.(http.Flusher).Flush() - } else { - err = wsConn.WriteMessage(websocket.TextMessage, buf.Bytes()) - } - - if err != nil { - break - } - } -} - func traffic(w http.ResponseWriter, r *http.Request) { var ( err error