mirror of
https://github.com/andatoshiki/shikigrid.git
synced 2026-06-06 04:04:15 +00:00
90 lines
2.3 KiB
Go
90 lines
2.3 KiB
Go
package api
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"github.com/evilsocket/islazy/log"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/andatoshiki/shikigrid/models"
|
|
"net/http"
|
|
"os"
|
|
"time"
|
|
)
|
|
|
|
var (
|
|
ErrTokenClaims = errors.New("can't extract claims from jwt token")
|
|
ErrTokenInvalid = errors.New("jwt token not valid")
|
|
ErrTokenExpired = errors.New("jwt token expired")
|
|
ErrTokenIncomplete = errors.New("jwt token is missing required fields")
|
|
ErrTokenUnauthorized = errors.New("jwt token authorized field is false (?!)")
|
|
)
|
|
|
|
func validateToken(header string) (jwt.MapClaims, error) {
|
|
token, err := jwt.Parse(header, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return []byte(os.Getenv("API_SECRET")), nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return nil, ErrTokenClaims
|
|
} else if !token.Valid {
|
|
return nil, ErrTokenInvalid
|
|
}
|
|
|
|
required := []string{
|
|
"expires_at",
|
|
"authorized",
|
|
"unit_id",
|
|
"unit_ident",
|
|
}
|
|
for _, req := range required {
|
|
if _, found := claims[req]; !found {
|
|
return nil, ErrTokenIncomplete
|
|
}
|
|
}
|
|
|
|
log.Debug("%+v", claims)
|
|
|
|
if expiresAt, err := time.Parse(time.RFC3339, claims["expires_at"].(string)); err != nil {
|
|
return nil, ErrTokenExpired
|
|
} else if expiresAt.Before(time.Now()) {
|
|
return nil, ErrTokenExpired
|
|
} else if claims["authorized"].(bool) != true {
|
|
return nil, ErrTokenUnauthorized
|
|
}
|
|
return claims, err
|
|
}
|
|
|
|
func Authenticate(w http.ResponseWriter, r *http.Request) *models.Unit {
|
|
client := clientIP(r)
|
|
tokenHeader := reqToken(r)
|
|
if tokenHeader == "" {
|
|
log.Debug("unauthenticated request from %s", client)
|
|
ERROR(w, http.StatusUnauthorized, ErrUnauthorized)
|
|
return nil
|
|
}
|
|
|
|
claims, err := validateToken(tokenHeader)
|
|
if err != nil {
|
|
log.Debug("token error for %s: %v", client, err)
|
|
ERROR(w, http.StatusUnauthorized, ErrUnauthorized)
|
|
return nil
|
|
}
|
|
|
|
log.Debug("claims[unit_id] = %+v", claims["unit_id"])
|
|
unit := models.FindUnit(uint(claims["unit_id"].(float64)))
|
|
if unit == nil {
|
|
log.Warning("client %s authenticated with unknown claims '%v'", client, claims)
|
|
ERROR(w, http.StatusUnauthorized, ErrUnauthorized)
|
|
return nil
|
|
}
|
|
|
|
return unit
|
|
}
|