148 lines
3.6 KiB
Go
148 lines
3.6 KiB
Go
package iam
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type tokenManager struct {
|
|
secret []byte
|
|
ttl time.Duration
|
|
kind tokenKind
|
|
}
|
|
|
|
func newTokenManager(secret string, ttl time.Duration, kind tokenKind) *tokenManager {
|
|
return &tokenManager{secret: []byte(secret), ttl: ttl, kind: kind}
|
|
}
|
|
|
|
func (t *tokenManager) generateAccess(userID int64, sessionID, jti string) (string, time.Time, error) {
|
|
now := time.Now().UTC()
|
|
expiresAt := now.Add(t.ttl)
|
|
payload := accessClaims{
|
|
UserID: userID,
|
|
SessionID: sessionID,
|
|
JTI: jti,
|
|
IssuedAt: now.Unix(),
|
|
ExpiresAt: expiresAt.Unix(),
|
|
Type: t.kind,
|
|
}
|
|
token, err := t.sign(payload)
|
|
if err != nil {
|
|
return "", time.Time{}, err
|
|
}
|
|
return token, expiresAt, nil
|
|
}
|
|
|
|
func (t *tokenManager) parseAccess(token string) (accessClaims, error) {
|
|
var claims accessClaims
|
|
if err := t.verifyAndDecode(token, &claims); err != nil {
|
|
return claims, err
|
|
}
|
|
if claims.Type != tokenKindAccess {
|
|
return claims, errInvalidToken
|
|
}
|
|
if claims.UserID == 0 || strings.TrimSpace(claims.SessionID) == "" || strings.TrimSpace(claims.JTI) == "" {
|
|
return claims, errInvalidToken
|
|
}
|
|
if time.Now().UTC().Unix() > claims.ExpiresAt {
|
|
return claims, errTokenExpired
|
|
}
|
|
return claims, nil
|
|
}
|
|
|
|
func (t *tokenManager) generateRefresh(userID int64, sessionID, refreshID string) (string, time.Time, error) {
|
|
now := time.Now().UTC()
|
|
expiresAt := now.Add(t.ttl)
|
|
payload := refreshClaims{
|
|
UserID: userID,
|
|
SessionID: sessionID,
|
|
RefreshID: refreshID,
|
|
IssuedAt: now.Unix(),
|
|
ExpiresAt: expiresAt.Unix(),
|
|
Type: t.kind,
|
|
}
|
|
token, err := t.sign(payload)
|
|
if err != nil {
|
|
return "", time.Time{}, err
|
|
}
|
|
return token, expiresAt, nil
|
|
}
|
|
|
|
func (t *tokenManager) parseRefresh(token string) (refreshClaims, error) {
|
|
var claims refreshClaims
|
|
if err := t.verifyAndDecode(token, &claims); err != nil {
|
|
return claims, err
|
|
}
|
|
if claims.Type != tokenKindRefresh {
|
|
return claims, errInvalidToken
|
|
}
|
|
if claims.UserID == 0 || strings.TrimSpace(claims.SessionID) == "" || strings.TrimSpace(claims.RefreshID) == "" {
|
|
return claims, errInvalidToken
|
|
}
|
|
if time.Now().UTC().Unix() > claims.ExpiresAt {
|
|
return claims, errTokenExpired
|
|
}
|
|
return claims, nil
|
|
}
|
|
|
|
func (t *tokenManager) sign(v any) (string, error) {
|
|
raw, err := json.Marshal(v)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
payload := base64.RawURLEncoding.EncodeToString(raw)
|
|
sig := t.signPayload(payload)
|
|
return payload + "." + sig, nil
|
|
}
|
|
|
|
func (t *tokenManager) verifyAndDecode(token string, out any) error {
|
|
parts := strings.Split(token, ".")
|
|
if len(parts) != 2 {
|
|
return errInvalidToken
|
|
}
|
|
payload, sig := parts[0], parts[1]
|
|
if !t.verifyPayload(payload, sig) {
|
|
return errInvalidToken
|
|
}
|
|
raw, err := base64.RawURLEncoding.DecodeString(payload)
|
|
if err != nil {
|
|
return errInvalidToken
|
|
}
|
|
if err := json.Unmarshal(raw, out); err != nil {
|
|
return errInvalidToken
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (t *tokenManager) signPayload(payload string) string {
|
|
mac := hmac.New(sha256.New, t.secret)
|
|
mac.Write([]byte(payload))
|
|
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
|
}
|
|
|
|
func (t *tokenManager) verifyPayload(payload, signature string) bool {
|
|
expected := t.signPayload(payload)
|
|
return hmac.Equal([]byte(signature), []byte(expected))
|
|
}
|
|
|
|
func randomID(prefix string) (string, error) {
|
|
b := make([]byte, 16)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return fmt.Sprintf("%s_%s", prefix, hex.EncodeToString(b)), nil
|
|
}
|
|
|
|
func hashRefreshToken(token, pepper string) string {
|
|
mac := hmac.New(sha256.New, []byte(pepper))
|
|
mac.Write([]byte(token))
|
|
return hex.EncodeToString(mac.Sum(nil))
|
|
}
|