Skip to content

Commit

Permalink
允许非admin用户登录
Browse files Browse the repository at this point in the history
  • Loading branch information
Jrohy committed Mar 29, 2020
1 parent e4b2fe7 commit 37522c2
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 8 deletions.
23 changes: 23 additions & 0 deletions core/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,29 @@ func (mysql *Mysql) CleanData(id uint) error {
return nil
}

// GetUserByName 通过用户名来获取用户
func (mysql *Mysql) GetUserByName(name string) *User {
db := mysql.GetDB()
if db == nil {
return nil
}
defer db.Close()
var (
username string
originPass string
download uint64
upload uint64
quota int64
id uint
)
row := db.QueryRow(fmt.Sprintf("SELECT * FROM users WHERE username='%s'", name))
if err := row.Scan(&id, &username, &originPass, &quota, &download, &upload); err != nil {
fmt.Println(err)
return nil
}
return &User{ID: id, Username: username, Password: originPass, Download: download, Upload: upload, Quota: quota}
}

// GetData 获取用户记录
func (mysql *Mysql) GetData(ids ...string) []*User {
var dataList []*User
Expand Down
4 changes: 4 additions & 0 deletions trojan/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ func AddUser() {
fmt.Println(util.Yellow("不能新建用户名为'admin'的用户!"))
return
}
if _, err := core.GetValue(inputUser + "_pass"); err == nil {
fmt.Println(util.Yellow("已存在用户名为: " + inputUser + " 的用户!"))
return
}
inputPass := util.Input(fmt.Sprintf("生成随机密码: %s, 使用直接回车, 否则输入自定义密码: ", randomPass), randomPass)
mysql := core.GetMysql()
if mysql.CreateUser(inputUser, inputPass) == nil {
Expand Down
31 changes: 25 additions & 6 deletions web/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ func init() {
}
},
Authenticator: func(c *gin.Context) (interface{}, error) {
var loginVals Login
var (
password string
loginVals Login
)
if err := c.ShouldBind(&loginVals); err != nil {
return "", jwt.ErrMissingLoginValues
}
Expand All @@ -53,15 +56,25 @@ func init() {
if err != nil {
return nil, err
}
if value, err := core.GetValue(userID + "_pass"); err != nil {
return nil, err
} else if value == pass {
if userID != "admin" {
mysql := core.GetMysql()
user := mysql.GetUserByName(userID)
if user == nil {
return nil, jwt.ErrFailedAuthentication
}
password = user.Password
} else {
if password, err = core.GetValue(userID + "_pass"); err != nil {
return nil, err
}
}
if password == pass {
return &loginVals, nil
}
return nil, jwt.ErrFailedAuthentication
},
Authorizator: func(data interface{}, c *gin.Context) bool {
if v, ok := data.(*Login); ok && v.Username == "admin" {
if _, ok := data.(*Login); ok {
return true
}
return false
Expand All @@ -85,7 +98,7 @@ func init() {
func updateUser(c *gin.Context) {
responseBody := controller.ResponseBody{Msg: "success"}
defer controller.TimeCost(time.Now(), &responseBody)
username := c.DefaultPostForm("username", "admin")
username := c.PostForm("username")
pass := c.PostForm("password")
err := core.SetValue(fmt.Sprintf("%s_pass", username), pass)
if err != nil {
Expand All @@ -94,6 +107,12 @@ func updateUser(c *gin.Context) {
c.JSON(200, responseBody)
}

// RequestUsername 获取请求接口的用户名
func RequestUsername(c *gin.Context) string {
claims := jwt.ExtractClaims(c)
return claims[identityKey].(string)
}

// Auth 权限router
func Auth(r *gin.Engine) *jwt.GinJWTMiddleware {
r.NoRoute(authMiddleware.MiddlewareFunc(), func(c *gin.Context) {
Expand Down
14 changes: 13 additions & 1 deletion web/controller/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,19 @@ import (
)

// UserList 获取用户列表
func UserList() *ResponseBody {
func UserList(findUser string) *ResponseBody {
responseBody := ResponseBody{Msg: "success"}
defer TimeCost(time.Now(), &responseBody)
mysql := core.GetMysql()
userList := mysql.GetData()
if findUser != "" {
for _, user := range userList {
if user.Username == findUser {
userList = []*core.User{user}
break
}
}
}
if userList == nil {
responseBody.Msg = "连接mysql失败!"
return &responseBody
Expand All @@ -35,6 +43,10 @@ func CreateUser(username string, password string) *ResponseBody {
responseBody.Msg = "不能创建用户名为admin的用户!"
return &responseBody
}
if _, err := core.GetValue(username + "_pass"); err == nil {
responseBody.Msg = "已存在用户名为: " + username + " 的用户!"
return &responseBody
}
mysql := core.GetMysql()
pass, err := base64.StdEncoding.DecodeString(password)
if err != nil {
Expand Down
7 changes: 6 additions & 1 deletion web/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ func userRouter(router *gin.Engine) {
user := router.Group("/trojan/user")
{
user.GET("", func(c *gin.Context) {
c.JSON(200, controller.UserList())
requestUser := RequestUsername(c)
if requestUser == "admin" {
c.JSON(200, controller.UserList(""))
} else {
c.JSON(200, controller.UserList(requestUser))
}
})
user.POST("", func(c *gin.Context) {
username := c.PostForm("username")
Expand Down

0 comments on commit 37522c2

Please sign in to comment.