-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: finish task #13
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,201 @@ | ||||||||||||||||||||||||
package AI | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
import ( | ||||||||||||||||||||||||
"crypto/hmac" | ||||||||||||||||||||||||
"crypto/sha256" | ||||||||||||||||||||||||
"encoding/base64" | ||||||||||||||||||||||||
"encoding/json" | ||||||||||||||||||||||||
"fmt" | ||||||||||||||||||||||||
"github.com/gorilla/websocket" | ||||||||||||||||||||||||
"io" | ||||||||||||||||||||||||
"net/http" | ||||||||||||||||||||||||
"net/url" | ||||||||||||||||||||||||
"strings" | ||||||||||||||||||||||||
"time" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
const ( | ||||||||||||||||||||||||
xunfeiAIAPIUrl = "wss://spark-api.xf-yun.com/v3.5/chat" | ||||||||||||||||||||||||
apiSecret = "OTM2NGMxOWJjY2FkOGYwZTEyOTVjZGY2" | ||||||||||||||||||||||||
apiKey = "ad54d6374685da80a5f420297ab6af00" | ||||||||||||||||||||||||
appId = "cee63188" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// GenerateSum 通过WebSocket与AI模型交互以生成答案 | ||||||||||||||||||||||||
func GenerateSum(question string, answers []string) (string, error) { | ||||||||||||||||||||||||
d := websocket.Dialer{ | ||||||||||||||||||||||||
HandshakeTimeout: 5 * time.Second, | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
// 握手并建立websocket连接 | ||||||||||||||||||||||||
conn, resp, err := d.Dial(assembleAuthUrl1(xunfeiAIAPIUrl, apiKey, apiSecret), nil) | ||||||||||||||||||||||||
if err != nil { | ||||||||||||||||||||||||
return "", fmt.Errorf("连接失败: %s, %v", readResp(resp), err) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
defer func(conn *websocket.Conn) { | ||||||||||||||||||||||||
err := conn.Close() | ||||||||||||||||||||||||
if err != nil { | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
} | ||||||||||||||||||||||||
Comment on lines
+43
to
+47
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle Errors in Deferred Function The deferred function closing the WebSocket connection does not handle potential errors from Apply this diff to handle any errors when closing the connection: defer func(conn *websocket.Conn) {
err := conn.Close()
if err != nil {
-
+ fmt.Printf("Error closing WebSocket connection: %v\n", err)
}
}(conn) // Ensure the connection is closed when the function ends This change will log any errors encountered when attempting to close the connection. 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||
}(conn) // 确保在函数结束时关闭连接 | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// 将所有的答案用 | 符号连接起来 | ||||||||||||||||||||||||
joinedAnswers := strings.Join(answers, "| ") | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// 构造最终的提示词 | ||||||||||||||||||||||||
prompt := fmt.Sprintf("我会给你一个问题和一组用 | 符号分隔的答案,帮我总结一个完整的回答,不要带有自己的评论和分析。 问题: %s\n答案: %s", question, joinedAnswers) | ||||||||||||||||||||||||
Comment on lines
+50
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Improve Prompt Clarity for the AI Model The prompt constructed for the AI model could be clearer to ensure it generates the desired response. Providing precise instructions can improve the quality of the AI's output. Consider rephrasing the prompt for better clarity: // 构造最终的提示词
-prompt := fmt.Sprintf("我会给你一个问题和一组用 | 符号分隔的答案,帮我总结一个完整的回答,不要带有自己的评论和分析。 问题: %s\n答案: %s", question, joinedAnswers)
+prompt := fmt.Sprintf("请根据以下问题和提供的多个答案,总结成一个完整的回答,不要添加任何评论或分析。\n问题:%s\n答案:%s", question, joinedAnswers) This rephrased prompt provides clear instructions, which can help the AI model generate the expected summary. 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||
data := genParams1(appId, prompt) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// 发送数据 | ||||||||||||||||||||||||
if err := conn.WriteJSON(data); err != nil { | ||||||||||||||||||||||||
return "", fmt.Errorf("发送数据失败: %v", err) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
var answer string | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// 获取返回的数据 | ||||||||||||||||||||||||
for { | ||||||||||||||||||||||||
_, msg, err := conn.ReadMessage() | ||||||||||||||||||||||||
if err != nil { | ||||||||||||||||||||||||
return "", fmt.Errorf("读取消息失败: %v", err) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
var data map[string]interface{} | ||||||||||||||||||||||||
if err := json.Unmarshal(msg, &data); err != nil { | ||||||||||||||||||||||||
return "", fmt.Errorf("解析JSON失败: %v", err) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
// 解析数据 | ||||||||||||||||||||||||
payload, ok := data["payload"].(map[string]interface{}) | ||||||||||||||||||||||||
if !ok { | ||||||||||||||||||||||||
return "", fmt.Errorf("无效的payload格式") | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
choices, ok := payload["choices"].(map[string]interface{}) | ||||||||||||||||||||||||
if !ok { | ||||||||||||||||||||||||
return "", fmt.Errorf("无效的choices格式") | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
header, ok := data["header"].(map[string]interface{}) | ||||||||||||||||||||||||
if !ok { | ||||||||||||||||||||||||
return "", fmt.Errorf("无效的header格式") | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
code, ok := header["code"].(float64) | ||||||||||||||||||||||||
if !ok || code != 0 { | ||||||||||||||||||||||||
return "", fmt.Errorf("错误的响应代码: %v", data["payload"]) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
status, ok := choices["status"].(float64) | ||||||||||||||||||||||||
if !ok { | ||||||||||||||||||||||||
return "", fmt.Errorf("无效的status格式") | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
text, ok := choices["text"].([]interface{}) | ||||||||||||||||||||||||
if !ok { | ||||||||||||||||||||||||
return "", fmt.Errorf("无效的text格式") | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
content, ok := text[0].(map[string]interface{})["content"].(string) | ||||||||||||||||||||||||
if !ok { | ||||||||||||||||||||||||
return "", fmt.Errorf("无效的content格式") | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if status != 2 { | ||||||||||||||||||||||||
answer += content | ||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||
answer += content | ||||||||||||||||||||||||
usage, ok := payload["usage"].(map[string]interface{}) | ||||||||||||||||||||||||
if ok { | ||||||||||||||||||||||||
temp, ok := usage["text"].(map[string]interface{}) | ||||||||||||||||||||||||
if ok { | ||||||||||||||||||||||||
totalTokens, ok := temp["total_tokens"].(float64) | ||||||||||||||||||||||||
if ok { | ||||||||||||||||||||||||
fmt.Println("total_tokens:", totalTokens) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
break | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
Comment on lines
+71
to
+121
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Simplify JSON Parsing with Structs Manually parsing JSON using nested maps and type assertions can be error-prone and hard to maintain. Defining Go structs that mirror the JSON response structure can simplify parsing and improve code readability. Define response structs and update the parsing logic: type AIResponse struct {
Header struct {
Code float64 `json:"code"`
} `json:"header"`
Payload struct {
Choices struct {
Status float64 `json:"status"`
Text []struct {
Content string `json:"content"`
} `json:"text"`
} `json:"choices"`
Usage struct {
Text struct {
TotalTokens float64 `json:"total_tokens"`
} `json:"text"`
} `json:"usage"`
} `json:"payload"`
} Update the parsing section: var answer string
// 获取返回的数据
for {
_, msg, err := conn.ReadMessage()
if err != nil {
return "", fmt.Errorf("读取消息失败: %v", err)
}
- var data map[string]interface{}
- if err := json.Unmarshal(msg, &data); err != nil {
+ var aiResp AIResponse
+ if err := json.Unmarshal(msg, &aiResp); err != nil {
return "", fmt.Errorf("解析JSON失败: %v", err)
}
- // Existing parsing logic with multiple type assertions...
+ // Check for errors in the response
+ if aiResp.Header.Code != 0 {
+ return "", fmt.Errorf("错误的响应代码: %v", aiResp.Header.Code)
+ }
+ // Append the content to the answer
+ for _, text := range aiResp.Payload.Choices.Text {
+ answer += text.Content
+ }
+ // Check if the response is complete
+ if aiResp.Payload.Choices.Status == 2 {
+ if aiResp.Payload.Usage.Text.TotalTokens > 0 {
+ fmt.Println("total_tokens:", aiResp.Payload.Usage.Text.TotalTokens)
+ }
+ break
+ }
} This refactoring enhances code clarity and reduces the risk of runtime errors due to incorrect type assertions. |
||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// 输出返回结果 | ||||||||||||||||||||||||
return answer, nil | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// 生成参数 | ||||||||||||||||||||||||
func genParams1(appid, question string) map[string]interface{} { // 根据实际情况修改返回的数据结构和字段名 | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
messages := []Message{ | ||||||||||||||||||||||||
{Role: "user", Content: question}, | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
data := map[string]interface{}{ // 根据实际情况修改返回的数据结构和字段名 | ||||||||||||||||||||||||
"header": map[string]interface{}{ // 根据实际情况修改返回的数据结构和字段名 | ||||||||||||||||||||||||
"app_id": appid, // 根据实际情况修改返回的数据结构和字段名 | ||||||||||||||||||||||||
}, | ||||||||||||||||||||||||
"parameter": map[string]interface{}{ // 根据实际情况修改返回的数据结构和字段名 | ||||||||||||||||||||||||
"chat": map[string]interface{}{ // 根据实际情况修改返回的数据结构和字段名 | ||||||||||||||||||||||||
"domain": "general", // 根据实际情况修改返回的数据结构和字段名 | ||||||||||||||||||||||||
"temperature": float64(0.8), // 根据实际情况修改返回的数据结构和字段名 | ||||||||||||||||||||||||
"top_k": int64(6), // 根据实际情况修改返回的数据结构和字段名 | ||||||||||||||||||||||||
"max_tokens": int64(2048), // 根据实际情况修改返回的数据结构和字段名 | ||||||||||||||||||||||||
"auditing": "default", // 根据实际情况修改返回的数据结构和字段名 | ||||||||||||||||||||||||
}, | ||||||||||||||||||||||||
}, | ||||||||||||||||||||||||
"payload": map[string]interface{}{ // 根据实际情况修改返回的数据结构和字段名 | ||||||||||||||||||||||||
"message": map[string]interface{}{ // 根据实际情况修改返回的数据结构和字段名 | ||||||||||||||||||||||||
"text": messages, // 根据实际情况修改返回的数据结构和字段名 | ||||||||||||||||||||||||
}, | ||||||||||||||||||||||||
}, | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
return data // 根据实际情况修改返回的数据结构和字段名 | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// 创建鉴权url apikey 即 hmac username | ||||||||||||||||||||||||
func assembleAuthUrl1(hosturl string, apiKey, apiSecret string) string { | ||||||||||||||||||||||||
ul, err := url.Parse(hosturl) | ||||||||||||||||||||||||
if err != nil { | ||||||||||||||||||||||||
fmt.Println(err) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
//签名时间 | ||||||||||||||||||||||||
date := time.Now().UTC().Format(time.RFC1123) | ||||||||||||||||||||||||
//date = "Tue, 28 May 2019 09:10:42 MST" | ||||||||||||||||||||||||
//参与签名的字段 host ,date, request-line | ||||||||||||||||||||||||
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} | ||||||||||||||||||||||||
//拼接签名字符串 | ||||||||||||||||||||||||
sgin := strings.Join(signString, "\n") | ||||||||||||||||||||||||
// fmt.Println(sgin) | ||||||||||||||||||||||||
//签名结果 | ||||||||||||||||||||||||
sha := HmacWithShaTobase64("hmac-sha256", sgin, apiSecret) | ||||||||||||||||||||||||
// fmt.Println(sha) | ||||||||||||||||||||||||
//构建请求参数 此时不需要urlencoding | ||||||||||||||||||||||||
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, | ||||||||||||||||||||||||
"hmac-sha256", "host date request-line", sha) | ||||||||||||||||||||||||
//将请求参数使用base64编码 | ||||||||||||||||||||||||
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
v := url.Values{} | ||||||||||||||||||||||||
v.Add("host", ul.Host) | ||||||||||||||||||||||||
v.Add("date", date) | ||||||||||||||||||||||||
v.Add("authorization", authorization) | ||||||||||||||||||||||||
//将编码后的字符串url encode后添加到url后面 | ||||||||||||||||||||||||
callurl := hosturl + "?" + v.Encode() | ||||||||||||||||||||||||
return callurl | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
func HmacWithShaTobase64(algorithm, data, key string) string { | ||||||||||||||||||||||||
mac := hmac.New(sha256.New, []byte(key)) | ||||||||||||||||||||||||
mac.Write([]byte(data)) | ||||||||||||||||||||||||
encodeData := mac.Sum(nil) | ||||||||||||||||||||||||
return base64.StdEncoding.EncodeToString(encodeData) | ||||||||||||||||||||||||
Comment on lines
+190
to
+193
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Handle Potential Errors in HMAC Calculation While it's unlikely, the Modify the function to handle the error: func HmacWithShaTobase64(algorithm, data, key string) string {
mac := hmac.New(sha256.New, []byte(key))
- mac.Write([]byte(data))
+ if _, err := mac.Write([]byte(data)); err != nil {
+ fmt.Printf("Error writing data to HMAC: %v\n", err)
+ return ""
+ }
encodeData := mac.Sum(nil)
return base64.StdEncoding.EncodeToString(encodeData)
} This addition will log any errors during HMAC calculation and prevent unexpected crashes.
|
||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
func readResp(resp *http.Response) string { | ||||||||||||||||||||||||
if resp == nil { | ||||||||||||||||||||||||
return "" | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
b, err := io.ReadAll(resp.Body) | ||||||||||||||||||||||||
if err != nil { | ||||||||||||||||||||||||
panic(err) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
Comment on lines
+202
to
+203
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid Using panic for Error Handling Using Apply this diff to return an empty response and log the error: func readResp(resp *http.Response) string {
if resp == nil {
return ""
}
b, err := io.ReadAll(resp.Body)
if err != nil {
- panic(err)
+ fmt.Printf("Error reading response body: %v\n", err)
+ return ""
}
return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
} This change ensures that the application can handle the error without crashing. 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||
return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b)) | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
type Message struct { | ||||||||||||||||||||||||
Role string `json:"role"` | ||||||||||||||||||||||||
Content string `json:"content"` | ||||||||||||||||||||||||
} |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,92 @@ | ||||||||||||||||||||||
package auth | ||||||||||||||||||||||
|
||||||||||||||||||||||
import ( | ||||||||||||||||||||||
"errors" | ||||||||||||||||||||||
"github.com/gin-gonic/gin" | ||||||||||||||||||||||
"github.com/golang-jwt/jwt/v4" | ||||||||||||||||||||||
"net/http" | ||||||||||||||||||||||
"strings" | ||||||||||||||||||||||
"time" | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
// 秘钥 (确保在生产环境中安全存储此秘钥) | ||||||||||||||||||||||
var jwtSecret = []byte("yourSecretKey") | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid hardcoding the JWT secret key; use environment variables or secure configuration instead. Storing the secret key directly in the code is a security risk. It's best practice to load secrets from environment variables or a secure configuration service to prevent unauthorized access. Apply this diff to address the issue: +import "os"
// Secret key (ensure secure storage in production)
-var jwtSecret = []byte("yourSecretKey")
+var jwtSecret = []byte(os.Getenv("JWT_SECRET")) Ensure that the environment variable 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||
|
||||||||||||||||||||||
// 定义 JWT 的声明结构 | ||||||||||||||||||||||
type Claims struct { | ||||||||||||||||||||||
UserID uint `json:"user_id"` | ||||||||||||||||||||||
UserName string `json:"name"` | ||||||||||||||||||||||
jwt.RegisteredClaims | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
// 生成JWT Token | ||||||||||||||||||||||
func GenerateToken(userID uint, userName string) (string, error) { | ||||||||||||||||||||||
// 定义 Token 的声明,包含用户信息和到期时间 | ||||||||||||||||||||||
claims := &Claims{ | ||||||||||||||||||||||
UserID: userID, | ||||||||||||||||||||||
UserName: userName, | ||||||||||||||||||||||
RegisteredClaims: jwt.RegisteredClaims{ | ||||||||||||||||||||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 72)), // Token 有效期 72 小时 | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Consider making the token expiration duration configurable. Currently, the token expiration is hardcoded to 72 hours. Making it configurable enhances flexibility and allows you to adjust expiration without changing code. Apply this diff to make the expiration time configurable via an environment variable: +import (
+ "os"
+ "strconv"
+)
// Generate JWT Token
func GenerateToken(userID uint, userName string) (string, error) {
// Define token claims, including user information and expiration time
+ expirationHours := 72 // default expiration
+ if envExpiry := os.Getenv("TOKEN_EXPIRATION_HOURS"); envExpiry != "" {
+ if hours, err := strconv.Atoi(envExpiry); err == nil {
+ expirationHours = hours
+ }
+ }
claims := &Claims{
UserID: userID,
UserName: userName,
RegisteredClaims: jwt.RegisteredClaims{
- ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Hour * 72)), // Token valid for 72 hours
+ ExpiresAt: jwt.NewNumericDate(time.Now().Add(time.Duration(expirationHours) * time.Hour)),
},
} Don't forget to handle any potential errors and ensure that the environment variable
|
||||||||||||||||||||||
}, | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
// 创建带有声明的 Token | ||||||||||||||||||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) | ||||||||||||||||||||||
|
||||||||||||||||||||||
// 签署 Token 并返回 | ||||||||||||||||||||||
return token.SignedString(jwtSecret) | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
// 验证JWT的中间件 | ||||||||||||||||||||||
func AuthMiddleware() gin.HandlerFunc { | ||||||||||||||||||||||
return func(c *gin.Context) { | ||||||||||||||||||||||
// 获取请求头中的 Authorization 字段 | ||||||||||||||||||||||
tokenString := strings.TrimSpace(c.GetHeader("Authorization")) | ||||||||||||||||||||||
if tokenString == "" || !strings.HasPrefix(tokenString, "Bearer ") { | ||||||||||||||||||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权,请登录"}) | ||||||||||||||||||||||
c.Abort() | ||||||||||||||||||||||
return | ||||||||||||||||||||||
} | ||||||||||||||||||||||
Comment on lines
+47
to
+50
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Refactor repetitive error handling into a helper function. Multiple blocks in your middleware handle errors in a similar way. Refactoring them into a helper function reduces code duplication and enhances readability. Create a helper function for unauthorized responses: func AuthMiddleware() gin.HandlerFunc {
+ // Helper function for unauthorized responses
+ unauthorized := func(c *gin.Context, message string) {
+ c.JSON(http.StatusUnauthorized, gin.H{"error": message})
+ c.Abort()
+ }
return func(c *gin.Context) {
// Get the Authorization header
tokenString := strings.TrimSpace(c.GetHeader("Authorization"))
if tokenString == "" || !strings.HasPrefix(tokenString, "Bearer ") {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "未授权,请登录"})
- c.Abort()
- return
+ unauthorized(c, "未授权,请登录")
+ return
}
// ... existing code ...
// Check for token parsing errors or invalid token
if err != nil {
if ve, ok := err.(*jwt.ValidationError); ok {
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "Token格式错误"})
- c.Abort()
- return
+ unauthorized(c, "Token格式错误")
+ return
} else if ve.Errors&jwt.ValidationErrorExpired != 0 {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "Token已过期"})
- c.Abort()
- return
+ unauthorized(c, "Token已过期")
+ return
} else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "Token尚未生效"})
- c.Abort()
- return
+ unauthorized(c, "Token尚未生效")
+ return
} else {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的Token"})
- c.Abort()
- return
+ unauthorized(c, "无效的Token")
+ return
}
- c.Abort()
- return
}
}
// ... existing code ...
} else {
- c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的Token"})
- c.Abort()
+ unauthorized(c, "无效的Token")
return
} This simplifies your middleware and makes it easier to manage error responses. Also applies to: 64-77, 85-87 |
||||||||||||||||||||||
|
||||||||||||||||||||||
// 移除 Bearer 前缀 | ||||||||||||||||||||||
tokenString = strings.TrimSpace(strings.TrimPrefix(tokenString, "Bearer ")) | ||||||||||||||||||||||
|
||||||||||||||||||||||
// 解析 Token | ||||||||||||||||||||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle unexpected errors during token parsing. Currently, if an error occurs that's not a Modify the error handling to catch unexpected errors: // Check for token parsing errors or invalid token
if err != nil {
if ve, ok := err.(*jwt.ValidationError); ok {
// ... existing validation error handling ...
+ } else {
+ unauthorized(c, "Token解析错误")
+ return
}
+ } else {
+ unauthorized(c, "无效的Token")
+ return
} This ensures that any unexpected parsing errors are also communicated to the client.
|
||||||||||||||||||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { | ||||||||||||||||||||||
return nil, errors.New("无效的签名方法") | ||||||||||||||||||||||
} | ||||||||||||||||||||||
return jwtSecret, nil | ||||||||||||||||||||||
}) | ||||||||||||||||||||||
Comment on lines
+57
to
+61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explicitly check the signing algorithm to prevent security risks. Using type assertion to check the signing method may not be sufficient. It's more secure to compare the algorithm explicitly to avoid algorithm substitution attacks. Apply this diff to enhance the security of signing method validation: // Parse Token
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
- if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
+ if token.Method.Alg() != jwt.SigningMethodHS256.Alg() {
return nil, errors.New("无效的签名方法")
}
return jwtSecret, nil
}) This ensures that only tokens signed with 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||
|
||||||||||||||||||||||
// 检查 Token 解析是否出错或者无效 | ||||||||||||||||||||||
if err != nil { | ||||||||||||||||||||||
if ve, ok := err.(*jwt.ValidationError); ok { | ||||||||||||||||||||||
if ve.Errors&jwt.ValidationErrorMalformed != 0 { | ||||||||||||||||||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token格式错误"}) | ||||||||||||||||||||||
} else if ve.Errors&jwt.ValidationErrorExpired != 0 { | ||||||||||||||||||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token已过期"}) | ||||||||||||||||||||||
} else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 { | ||||||||||||||||||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Token尚未生效"}) | ||||||||||||||||||||||
} else { | ||||||||||||||||||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的Token"}) | ||||||||||||||||||||||
} | ||||||||||||||||||||||
c.Abort() | ||||||||||||||||||||||
return | ||||||||||||||||||||||
} | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
// 检查 Token 是否有效 | ||||||||||||||||||||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid { | ||||||||||||||||||||||
// 将用户信息保存在上下文中 | ||||||||||||||||||||||
c.Set("user_id", claims.UserID) | ||||||||||||||||||||||
c.Set("user_name", claims.UserName) | ||||||||||||||||||||||
} else { | ||||||||||||||||||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的Token"}) | ||||||||||||||||||||||
c.Abort() | ||||||||||||||||||||||
return | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
c.Next() // 继续执行下一个处理器 | ||||||||||||||||||||||
} | ||||||||||||||||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package database1 | ||
|
||
import ( | ||
"gorm.io/driver/mysql" | ||
"gorm.io/gorm" | ||
"log" | ||
) | ||
|
||
var DB *gorm.DB | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Consider encapsulating the DB variable. While it's common to have a package-level database variable, exposing it directly can lead to issues with encapsulation and make it harder to manage database access across the application. Consider making |
||
|
||
// InitDB 初始化数据库连接 | ||
func InitDB() { | ||
dsn := "root:Wu12345678@tcp(127.0.0.1:3306)/user_db?charset=utf8mb4&parseTime=True&loc=Local" | ||
var err error | ||
DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}) | ||
if err != nil { | ||
log.Fatal("无法连接到数据库:", err) | ||
} | ||
log.Println("数据库连接成功") | ||
|
||
// 测试数据库连接 | ||
sqlDB, err := DB.DB() | ||
if err != nil { | ||
log.Fatal("获取数据库实例失败:", err) | ||
} | ||
|
||
// Ping 数据库 | ||
if err := sqlDB.Ping(); err != nil { | ||
log.Fatal("数据库连接失败:", err) | ||
} | ||
log.Println("数据库连接测试成功") | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Improve database initialization practices.
|
||
|
||
func AutoMigrate(models ...interface{}) { | ||
err := DB.AutoMigrate(models...) | ||
if err != nil { | ||
log.Fatal("自动迁移失败:", err) | ||
} | ||
} | ||
Comment on lines
+47
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Enhance the AutoMigrate function.
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package database2 | ||
|
||
import ( | ||
"gorm.io/driver/mysql" | ||
"gorm.io/gorm" | ||
"log" | ||
) | ||
|
||
var DB *gorm.DB | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Consider using dependency injection instead of a global variable. While using a global variable for the database connection is common, it can make unit testing and managing dependencies more challenging. Consider using dependency injection by passing the |
||
|
||
// InitDB 初始化数据库连接 | ||
func InitDB() { | ||
dsn := "root:Wu12345678@tcp(127.0.0.1:3306)/problem_db?charset=utf8mb4&parseTime=True&loc=Local" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid hardcoding sensitive information in the source code. Hardcoding database credentials in the source code is a security risk. Consider using environment variables or a configuration file to store sensitive information. Here's an example of how you could use environment variables: import "os"
// ...
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
os.Getenv("DB_USER"),
os.Getenv("DB_PASSWORD"),
os.Getenv("DB_HOST"),
os.Getenv("DB_NAME")) |
||
var err error | ||
DB, err = gorm.Open(mysql.Open(dsn), &gorm.Config{}) | ||
if err != nil { | ||
log.Fatal("无法连接到数据库:", err) | ||
} | ||
log.Println("数据库连接成功") | ||
|
||
// 测试数据库连接 | ||
sqlDB, err := DB.DB() | ||
if err != nil { | ||
log.Fatal("获取数据库实例失败:", err) | ||
} | ||
|
||
// Ping 数据库 | ||
if err := sqlDB.Ping(); err != nil { | ||
log.Fatal("数据库连接失败:", err) | ||
} | ||
log.Println("数据库连接测试成功") | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Consider adding configuration options for the database connection. The Here's an example of how you could modify the function: type DBConfig struct {
Host string
Port string
User string
Password string
DBName string
}
func InitDB(config DBConfig) {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
config.User, config.Password, config.Host, config.Port, config.DBName)
// ... rest of the function
} |
||
|
||
func AutoMigrate(models ...interface{}) { | ||
err := DB.AutoMigrate(models...) | ||
if err != nil { | ||
log.Fatal("自动迁移失败:", err) | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid Hardcoding API Keys and Secrets in Code
Storing API keys, secrets, and other credentials directly in the code is a security risk and can lead to potential breaches if the code is exposed. It's recommended to use environment variables or a secure configuration management system to handle sensitive information.
Apply this diff to remove hardcoded credentials and retrieve them from environment variables:
Ensure that the necessary environment variables are set in your deployment environment.
📝 Committable suggestion
🧰 Tools
🪛 Gitleaks