ai-css/middleware/xpink_auth/auth_jwt.go
2026-02-12 08:50:11 +00:00

296 lines
6.4 KiB
Go

package xpink_auth
import (
"ai-css/library/logger"
"crypto/rsa"
"fmt"
"net/url"
"strings"
"sync"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
)
var (
IdentityKey = "XPINK_USER"
TokenLookup = "header:Authorization,query:Authorization,referer:Authorization"
SigningAlgorithm = "RS256"
pubKey *rsa.PublicKey // runtime load
pubkeyContent = []byte(`-----BEGIN PUBLIC KEY-----
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAqD/o6TI7AZyNEbFQTy4g
K4Hd+aLAoLRwOe0iKqDWK4HRZABtLLvLFZLdwP4iUNAQOoy+WXz3CGqwzvs36531
6rOzeCKtYGSN64+Pnn6UWaricnCZ2Tqng2eNln9kHALbguGVtrOSQNZr97OCOOk3
ZDCnNwnz0hA9AhIRX1LNswPPC18q2Itdb5C//nxoEJPyY3u0r1YDL6sPD1eUDI0x
+4A8Dgqny4Z84XALn2ucR9bcUGSbtyTR1pg42MYyw6I7MV4P0YGXD3kcItd+9qlX
rULFZh5RLFl52PeA7bmXUpxKeg2lvv4CzNlk+eM7UyHctjYmM5rk+6QencjHk+qo
doVMzeX0e3sby72aq7g66QWThwGgVwwRFxsodtSwl6TAXH3TAVd3nyZ9tSqM/BT7
B8acMVzG/lzMVvrEtJHUcPlfHNDKmWuLWo6ywblc/MGj7z8Fe/pk+wJ1Nv4WCBMj
3kv4durqVNh4YhPvxt+wAZzsNxmliFEGXb+yC/8qpZv13EgNt4f1voKYML7StIj5
oYslqoYvzN3j5ROBRDlJaxqErEwDLwEeiqBuSME6H6hJFD3SRujmcdFtl4GYyZb9
F7VlEGHjQqKljkjB5DOno2tV5EzGNu21dAwBHSHfto7nqG781QmQrDAVs681pNpU
iWNoAGc0L/VR0YPuV2X+ml8CAwEAAQ==
-----END PUBLIC KEY-----`)
key = ""
TokenHeadName = "Bearer"
ErrEmptyFormToken = fmt.Errorf("empty form token")
ErrEmptyParamToken = fmt.Errorf("empty param token")
ErrEmptyCookieToken = fmt.Errorf("empty cookie token")
ErrEmptyQueryToken = fmt.Errorf("empty query token")
ErrEmptyAuthHeader = fmt.Errorf("empty auth header")
ErrInvalidAuthHeader = fmt.Errorf("invalid auth header")
loadPbk sync.Once
)
type UserSession struct {
Id uint64
Userno string
NickName string
Jti string
}
func Identity(c *gin.Context) interface{} {
loadPbk.Do(func() {
var err error
pubKey, err = jwt.ParseRSAPublicKeyFromPEM(pubkeyContent)
if err != nil {
logger.Error("parse rsa public key fail err:%v", err)
}
})
claims, err := GetClaimsFromJWT(c)
if err != nil {
logger.Error("parse claims failed", err)
}
c.Set("JWT_PAYLOAD", claims)
return JwtToUserSession(ExtractClaims(c))
}
func JwtToUserSession(payload jwt.MapClaims) UserSession {
// 安全地提取 Id
var id uint64
if idVal, ok := payload["Id"]; ok && idVal != nil {
switch v := idVal.(type) {
case float64:
id = uint64(v)
case int64:
id = uint64(v)
case int:
id = uint64(v)
case uint64:
id = v
}
}
// 安全地提取其他字段
userno := ""
if val, ok := payload["Userno"]; ok && val != nil {
userno = val.(string)
}
nickName := ""
if val, ok := payload["NickName"]; ok && val != nil {
nickName = val.(string)
}
jti := ""
if val, ok := payload["Jti"]; ok && val != nil {
jti = val.(string)
}
return UserSession{
Id: id,
Userno: userno,
NickName: nickName,
Jti: jti,
}
}
// ExtractClaims help to extract the JWT claims
func ExtractClaims(c *gin.Context) jwt.MapClaims {
claims, exists := c.Get("JWT_PAYLOAD")
if !exists {
return make(jwt.MapClaims)
}
return claims.(jwt.MapClaims)
}
// ParseToken parse jwt token from hertz context
func ParseToken(c *gin.Context) (*jwt.Token, error) {
var token string
var err error
methods := strings.Split(TokenLookup, ",")
for _, method := range methods {
if len(token) > 0 {
break
}
parts := strings.Split(strings.TrimSpace(method), ":")
k := strings.TrimSpace(parts[0])
v := strings.TrimSpace(parts[1])
switch k {
case "header":
token, err = jwtFromHeader(c, v)
case "query":
token, err = jwtFromQuery(c, v)
case "cookie":
token, err = jwtFromCookie(c, v)
case "param":
token, err = jwtFromParam(c, v)
case "form":
token, err = jwtFromForm(c, v)
case "referer":
token, err = jwtFromReferer(c, v)
}
if token != "" && err == nil {
logger.Infof("capture token:%s", token)
break
}
}
if err != nil {
return nil, err
}
// save token string if valid
c.Set("JWT_TOKEN", token)
return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
if jwt.GetSigningMethod(SigningAlgorithm) != t.Method {
return nil, fmt.Errorf("err invalid signingalgorithm")
}
var usingPublickkey bool
switch SigningAlgorithm {
case "RS256", "RS512", "RS384":
usingPublickkey = true
default:
usingPublickkey = false
}
if usingPublickkey {
return pubKey, nil
}
return key, nil
})
}
func GetClaimsFromJWT(c *gin.Context) (jwt.MapClaims, error) {
token, err := ParseToken(c)
if err != nil {
return nil, err
}
claims := jwt.MapClaims{}
for key, value := range token.Claims.(jwt.MapClaims) {
claims[key] = value
}
return claims, nil
}
func MiddlewareSetIdentity(c *gin.Context) {
identity := Identity(c)
if identity != nil {
c.Set(IdentityKey, identity)
}
}
func GetXPINKUser(c *gin.Context) UserSession {
data, ok := c.Get(IdentityKey)
if data != nil && ok {
if us, isUs := data.(UserSession); isUs {
return us
}
}
return UserSession{}
}
func jwtFromHeader(c *gin.Context, key string) (string, error) {
authHeader := c.Request.Header.Get(key)
if authHeader == "" {
return "", ErrEmptyAuthHeader
}
parts := strings.SplitN(authHeader, " ", 2)
if !(len(parts) == 2 && parts[0] == TokenHeadName) {
return "", ErrInvalidAuthHeader
}
return parts[len(parts)-1], nil
}
func jwtFromQuery(c *gin.Context, key string) (string, error) {
token := c.Query(key)
if token == "" {
return "", ErrEmptyQueryToken
}
return token, nil
}
func jwtFromCookie(c *gin.Context, key string) (string, error) {
cookie, _ := c.Cookie(key)
if cookie == "" {
return "", ErrEmptyCookieToken
}
return cookie, nil
}
func jwtFromParam(c *gin.Context, key string) (string, error) {
token := c.Param(key)
if token == "" {
return "", ErrEmptyParamToken
}
return token, nil
}
func jwtFromForm(c *gin.Context, key string) (string, error) {
token := c.PostForm(key)
if token == "" {
return "", ErrEmptyFormToken
}
return token, nil
}
func jwtFromReferer(c *gin.Context, key string) (string, error) {
refererPath := c.GetHeader("Referer")
if refererPath == "" {
return "", fmt.Errorf("err empty Referer")
}
rul, err := url.Parse(refererPath)
if err != nil {
return "", fmt.Errorf("err invalid Referer %s", refererPath)
}
token := rul.Query().Get(key)
if token == "" {
return "", ErrEmptyQueryToken
}
parts := strings.SplitN(token, " ", 2)
if !(len(parts) == 2 && parts[0] == TokenHeadName) {
return "", ErrInvalidAuthHeader
}
return parts[len(parts)-1], nil
}