Files

148 lines
3.6 KiB
Go

package iam
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"strings"
"time"
)
type tokenManager struct {
secret []byte
ttl time.Duration
kind tokenKind
}
func newTokenManager(secret string, ttl time.Duration, kind tokenKind) *tokenManager {
return &tokenManager{secret: []byte(secret), ttl: ttl, kind: kind}
}
func (t *tokenManager) generateAccess(userID int64, sessionID, jti string) (string, time.Time, error) {
now := time.Now().UTC()
expiresAt := now.Add(t.ttl)
payload := accessClaims{
UserID: userID,
SessionID: sessionID,
JTI: jti,
IssuedAt: now.Unix(),
ExpiresAt: expiresAt.Unix(),
Type: t.kind,
}
token, err := t.sign(payload)
if err != nil {
return "", time.Time{}, err
}
return token, expiresAt, nil
}
func (t *tokenManager) parseAccess(token string) (accessClaims, error) {
var claims accessClaims
if err := t.verifyAndDecode(token, &claims); err != nil {
return claims, err
}
if claims.Type != tokenKindAccess {
return claims, errInvalidToken
}
if claims.UserID == 0 || strings.TrimSpace(claims.SessionID) == "" || strings.TrimSpace(claims.JTI) == "" {
return claims, errInvalidToken
}
if time.Now().UTC().Unix() > claims.ExpiresAt {
return claims, errTokenExpired
}
return claims, nil
}
func (t *tokenManager) generateRefresh(userID int64, sessionID, refreshID string) (string, time.Time, error) {
now := time.Now().UTC()
expiresAt := now.Add(t.ttl)
payload := refreshClaims{
UserID: userID,
SessionID: sessionID,
RefreshID: refreshID,
IssuedAt: now.Unix(),
ExpiresAt: expiresAt.Unix(),
Type: t.kind,
}
token, err := t.sign(payload)
if err != nil {
return "", time.Time{}, err
}
return token, expiresAt, nil
}
func (t *tokenManager) parseRefresh(token string) (refreshClaims, error) {
var claims refreshClaims
if err := t.verifyAndDecode(token, &claims); err != nil {
return claims, err
}
if claims.Type != tokenKindRefresh {
return claims, errInvalidToken
}
if claims.UserID == 0 || strings.TrimSpace(claims.SessionID) == "" || strings.TrimSpace(claims.RefreshID) == "" {
return claims, errInvalidToken
}
if time.Now().UTC().Unix() > claims.ExpiresAt {
return claims, errTokenExpired
}
return claims, nil
}
func (t *tokenManager) sign(v any) (string, error) {
raw, err := json.Marshal(v)
if err != nil {
return "", err
}
payload := base64.RawURLEncoding.EncodeToString(raw)
sig := t.signPayload(payload)
return payload + "." + sig, nil
}
func (t *tokenManager) verifyAndDecode(token string, out any) error {
parts := strings.Split(token, ".")
if len(parts) != 2 {
return errInvalidToken
}
payload, sig := parts[0], parts[1]
if !t.verifyPayload(payload, sig) {
return errInvalidToken
}
raw, err := base64.RawURLEncoding.DecodeString(payload)
if err != nil {
return errInvalidToken
}
if err := json.Unmarshal(raw, out); err != nil {
return errInvalidToken
}
return nil
}
func (t *tokenManager) signPayload(payload string) string {
mac := hmac.New(sha256.New, t.secret)
mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}
func (t *tokenManager) verifyPayload(payload, signature string) bool {
expected := t.signPayload(payload)
return hmac.Equal([]byte(signature), []byte(expected))
}
func randomID(prefix string) (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return fmt.Sprintf("%s_%s", prefix, hex.EncodeToString(b)), nil
}
func hashRefreshToken(token, pepper string) string {
mac := hmac.New(sha256.New, []byte(pepper))
mac.Write([]byte(token))
return hex.EncodeToString(mac.Sum(nil))
}