package iam import ( "crypto/hmac" "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/hex" "encoding/json" "fmt" "strings" "time" ) type tokenManager struct { secret []byte ttl time.Duration kind tokenKind } func newTokenManager(secret string, ttl time.Duration, kind tokenKind) *tokenManager { return &tokenManager{secret: []byte(secret), ttl: ttl, kind: kind} } func (t *tokenManager) generateAccess(userID int64, sessionID, jti string) (string, time.Time, error) { now := time.Now().UTC() expiresAt := now.Add(t.ttl) payload := accessClaims{ UserID: userID, SessionID: sessionID, JTI: jti, IssuedAt: now.Unix(), ExpiresAt: expiresAt.Unix(), Type: t.kind, } token, err := t.sign(payload) if err != nil { return "", time.Time{}, err } return token, expiresAt, nil } func (t *tokenManager) parseAccess(token string) (accessClaims, error) { var claims accessClaims if err := t.verifyAndDecode(token, &claims); err != nil { return claims, err } if claims.Type != tokenKindAccess { return claims, errInvalidToken } if claims.UserID == 0 || strings.TrimSpace(claims.SessionID) == "" || strings.TrimSpace(claims.JTI) == "" { return claims, errInvalidToken } if time.Now().UTC().Unix() > claims.ExpiresAt { return claims, errTokenExpired } return claims, nil } func (t *tokenManager) generateRefresh(userID int64, sessionID, refreshID string) (string, time.Time, error) { now := time.Now().UTC() expiresAt := now.Add(t.ttl) payload := refreshClaims{ UserID: userID, SessionID: sessionID, RefreshID: refreshID, IssuedAt: now.Unix(), ExpiresAt: expiresAt.Unix(), Type: t.kind, } token, err := t.sign(payload) if err != nil { return "", time.Time{}, err } return token, expiresAt, nil } func (t *tokenManager) parseRefresh(token string) (refreshClaims, error) { var claims refreshClaims if err := t.verifyAndDecode(token, &claims); err != nil { return claims, err } if claims.Type != tokenKindRefresh { return claims, errInvalidToken } if claims.UserID == 0 || strings.TrimSpace(claims.SessionID) == "" || strings.TrimSpace(claims.RefreshID) == "" { return claims, errInvalidToken } if time.Now().UTC().Unix() > claims.ExpiresAt { return claims, errTokenExpired } return claims, nil } func (t *tokenManager) sign(v any) (string, error) { raw, err := json.Marshal(v) if err != nil { return "", err } payload := base64.RawURLEncoding.EncodeToString(raw) sig := t.signPayload(payload) return payload + "." + sig, nil } func (t *tokenManager) verifyAndDecode(token string, out any) error { parts := strings.Split(token, ".") if len(parts) != 2 { return errInvalidToken } payload, sig := parts[0], parts[1] if !t.verifyPayload(payload, sig) { return errInvalidToken } raw, err := base64.RawURLEncoding.DecodeString(payload) if err != nil { return errInvalidToken } if err := json.Unmarshal(raw, out); err != nil { return errInvalidToken } return nil } func (t *tokenManager) signPayload(payload string) string { mac := hmac.New(sha256.New, t.secret) mac.Write([]byte(payload)) return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) } func (t *tokenManager) verifyPayload(payload, signature string) bool { expected := t.signPayload(payload) return hmac.Equal([]byte(signature), []byte(expected)) } func randomID(prefix string) (string, error) { b := make([]byte, 16) if _, err := rand.Read(b); err != nil { return "", err } return fmt.Sprintf("%s_%s", prefix, hex.EncodeToString(b)), nil } func hashRefreshToken(token, pepper string) string { mac := hmac.New(sha256.New, []byte(pepper)) mac.Write([]byte(token)) return hex.EncodeToString(mac.Sum(nil)) }