refactor(auth): split IAM module and add access/refresh session flow

This commit is contained in:
2026-03-01 21:26:37 +08:00
parent 6a2d2c9724
commit 57c27e9102
13 changed files with 1377 additions and 345 deletions

104
internal/iam/cache_redis.go Normal file
View File

@@ -0,0 +1,104 @@
package iam
import (
"context"
"errors"
"time"
"github.com/redis/go-redis/v9"
)
type tokenCache struct {
client *redis.Client
}
func newTokenCache(addr, password string, db int) (*tokenCache, error) {
if addr == "" {
return nil, nil
}
client := redis.NewClient(&redis.Options{Addr: addr, Password: password, DB: db})
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, err
}
return &tokenCache{client: client}, nil
}
func (c *tokenCache) Close() error {
if c == nil || c.client == nil {
return nil
}
return c.client.Close()
}
func (c *tokenCache) SetSessionActive(ctx context.Context, sid string, ttl time.Duration) error {
if c == nil {
return nil
}
return c.client.Set(ctx, "auth:session:"+sid, "active", ttl).Err()
}
func (c *tokenCache) IsSessionActive(ctx context.Context, sid string) (bool, error) {
if c == nil {
return false, nil
}
count, err := c.client.Exists(ctx, "auth:session:"+sid).Result()
if err != nil {
return false, err
}
return count == 1, nil
}
func (c *tokenCache) DeleteSession(ctx context.Context, sid string) error {
if c == nil {
return nil
}
return c.client.Del(ctx, "auth:session:"+sid).Err()
}
func (c *tokenCache) SetRefreshHash(ctx context.Context, rid, tokenHash string, ttl time.Duration) error {
if c == nil {
return nil
}
return c.client.Set(ctx, "auth:refresh:"+rid, tokenHash, ttl).Err()
}
func (c *tokenCache) GetRefreshHash(ctx context.Context, rid string) (string, error) {
if c == nil {
return "", nil
}
v, err := c.client.Get(ctx, "auth:refresh:"+rid).Result()
if err != nil {
if errors.Is(err, redis.Nil) {
return "", nil
}
return "", err
}
return v, nil
}
func (c *tokenCache) DeleteRefresh(ctx context.Context, rid string) error {
if c == nil {
return nil
}
return c.client.Del(ctx, "auth:refresh:"+rid).Err()
}
func (c *tokenCache) DenyAccessJTI(ctx context.Context, jti string, ttl time.Duration) error {
if c == nil {
return nil
}
return c.client.Set(ctx, "auth:deny:access:"+jti, "1", ttl).Err()
}
func (c *tokenCache) IsAccessJTIDenied(ctx context.Context, jti string) (bool, error) {
if c == nil {
return false, nil
}
count, err := c.client.Exists(ctx, "auth:deny:access:"+jti).Result()
if err != nil {
return false, err
}
return count == 1, nil
}

95
internal/iam/config.go Normal file
View File

@@ -0,0 +1,95 @@
package iam
import (
"os"
"strconv"
"strings"
"time"
)
type Config struct {
AccessSecret string
RefreshSecret string
RefreshPepper string
AccessTTL time.Duration
RefreshTTL time.Duration
SessionTTL time.Duration
RefreshCookieName string
CookieSecure bool
CookieDomain string
CookiePath string
CookieSameSite string
RedisAddr string
RedisPassword string
RedisDB int
}
func LoadConfig() Config {
cfg := Config{
AccessSecret: firstNonEmpty(strings.TrimSpace(os.Getenv("ACCESS_SECRET")), strings.TrimSpace(os.Getenv("AUTH_SECRET")), "dev-access-secret-change-me"),
RefreshSecret: firstNonEmpty(strings.TrimSpace(os.Getenv("REFRESH_SECRET")), strings.TrimSpace(os.Getenv("AUTH_SECRET")), "dev-refresh-secret-change-me"),
RefreshPepper: firstNonEmpty(strings.TrimSpace(os.Getenv("REFRESH_PEPPER")), "dev-refresh-pepper-change-me"),
AccessTTL: parseDuration("ACCESS_TTL", 60*time.Minute),
RefreshTTL: parseDuration("REFRESH_TTL", 7*24*time.Hour),
RefreshCookieName: firstNonEmpty(strings.TrimSpace(os.Getenv("REFRESH_COOKIE_NAME")), "rt"),
CookieSecure: parseBool("COOKIE_SECURE", isProduction()),
CookieDomain: strings.TrimSpace(os.Getenv("COOKIE_DOMAIN")),
CookiePath: firstNonEmpty(strings.TrimSpace(os.Getenv("COOKIE_PATH")), "/"),
CookieSameSite: firstNonEmpty(strings.TrimSpace(os.Getenv("COOKIE_SAMESITE")), "Lax"),
RedisAddr: strings.TrimSpace(os.Getenv("REDIS_ADDR")),
RedisPassword: os.Getenv("REDIS_PASSWORD"),
RedisDB: parseInt("REDIS_DB", 0),
}
cfg.SessionTTL = parseDuration("SESSION_TTL", cfg.RefreshTTL)
return cfg
}
func parseDuration(key string, fallback time.Duration) time.Duration {
value := strings.TrimSpace(os.Getenv(key))
if value == "" {
return fallback
}
d, err := time.ParseDuration(value)
if err != nil || d <= 0 {
return fallback
}
return d
}
func parseBool(key string, fallback bool) bool {
value := strings.TrimSpace(os.Getenv(key))
if value == "" {
return fallback
}
parsed, err := strconv.ParseBool(value)
if err != nil {
return fallback
}
return parsed
}
func parseInt(key string, fallback int) int {
value := strings.TrimSpace(os.Getenv(key))
if value == "" {
return fallback
}
parsed, err := strconv.Atoi(value)
if err != nil {
return fallback
}
return parsed
}
func isProduction() bool {
env := strings.ToLower(strings.TrimSpace(os.Getenv("APP_ENV")))
return env == "prod" || env == "production"
}
func firstNonEmpty(values ...string) string {
for _, v := range values {
if strings.TrimSpace(v) != "" {
return v
}
}
return ""
}

View File

@@ -0,0 +1,192 @@
package iam
import (
"errors"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
)
type Handler struct {
service *Service
cfg Config
}
func NewHandler(service *Service, cfg Config) *Handler {
return &Handler{service: service, cfg: cfg}
}
func (h *Handler) RegisterRoutes(router *gin.Engine) {
auth := router.Group("/api/v1/auth")
auth.POST("/register", h.register)
auth.POST("/login", h.login)
auth.POST("/refresh", h.refresh)
protected := auth.Group("")
protected.Use(h.service.RequireAccess())
protected.POST("/logout", h.logout)
protected.POST("/logout-all", h.logoutAll)
protected.GET("/sessions", h.listSessions)
protected.DELETE("/sessions/:sid", h.revokeSession)
}
func (h *Handler) register(c *gin.Context) {
var input struct {
Email string `json:"email"`
Password string `json:"password"`
AutoLogin *bool `json:"auto_login"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
autoLogin := true
if input.AutoLogin != nil {
autoLogin = *input.AutoLogin
}
result, refreshToken, err := h.service.Register(c.Request.Context(), input.Email, input.Password, autoLogin, requestMetaFromContext(c))
if err != nil {
h.writeAuthError(c, err, "registration failed")
return
}
if !autoLogin {
c.JSON(http.StatusCreated, gin.H{"email": strings.TrimSpace(strings.ToLower(input.Email))})
return
}
h.setRefreshCookie(c, refreshToken)
c.JSON(http.StatusCreated, result)
}
func (h *Handler) login(c *gin.Context) {
var input struct {
Email string `json:"email"`
Password string `json:"password"`
DeviceInfo string `json:"device_info"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
meta := requestMetaFromContext(c)
meta.DeviceInfo = input.DeviceInfo
result, refreshToken, err := h.service.Login(c.Request.Context(), input.Email, input.Password, meta)
if err != nil {
h.writeAuthError(c, err, "login failed")
return
}
h.setRefreshCookie(c, refreshToken)
c.JSON(http.StatusOK, result)
}
func (h *Handler) refresh(c *gin.Context) {
refreshToken, err := c.Cookie(h.cfg.RefreshCookieName)
if err != nil || strings.TrimSpace(refreshToken) == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing refresh token"})
return
}
result, newRefreshToken, err := h.service.Refresh(c.Request.Context(), refreshToken)
if err != nil {
h.writeAuthError(c, err, "refresh failed")
return
}
h.setRefreshCookie(c, newRefreshToken)
c.JSON(http.StatusOK, result)
}
func (h *Handler) logout(c *gin.Context) {
token := extractBearerToken(c.GetHeader("Authorization"))
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization"})
return
}
if err := h.service.Logout(c.Request.Context(), token); err != nil {
h.writeAuthError(c, err, "logout failed")
return
}
h.clearRefreshCookie(c)
c.Status(http.StatusNoContent)
}
func (h *Handler) logoutAll(c *gin.Context) {
token := extractBearerToken(c.GetHeader("Authorization"))
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization"})
return
}
if err := h.service.LogoutAll(c.Request.Context(), token); err != nil {
h.writeAuthError(c, err, "logout-all failed")
return
}
h.clearRefreshCookie(c)
c.Status(http.StatusNoContent)
}
func (h *Handler) listSessions(c *gin.Context) {
uid := c.GetInt64(ContextUserIDKey)
sessions, err := h.service.ListSessions(c.Request.Context(), uid)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list sessions"})
return
}
c.JSON(http.StatusOK, sessions)
}
func (h *Handler) revokeSession(c *gin.Context) {
uid := c.GetInt64(ContextUserIDKey)
sid := strings.TrimSpace(c.Param("sid"))
if sid == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid sid"})
return
}
if err := h.service.RevokeSession(c.Request.Context(), uid, sid); err != nil {
h.writeAuthError(c, err, "revoke session failed")
return
}
c.Status(http.StatusNoContent)
}
func requestMetaFromContext(c *gin.Context) requestMeta {
return requestMeta{
IP: c.ClientIP(),
UserAgent: c.Request.UserAgent(),
}
}
func (h *Handler) writeAuthError(c *gin.Context, err error, fallback string) {
switch {
case errors.Is(err, errInvalidCredentials):
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid credentials"})
case errors.Is(err, errAlreadyExists):
c.JSON(http.StatusConflict, gin.H{"error": "user already exists"})
case errors.Is(err, errInvalidToken), errors.Is(err, errTokenExpired), errors.Is(err, errUnauthorized), errors.Is(err, errSessionRevoked):
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
case errors.Is(err, errForbidden):
c.JSON(http.StatusForbidden, gin.H{"error": "forbidden"})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": fallback})
}
}
func (h *Handler) setRefreshCookie(c *gin.Context, token string) {
h.applySameSite(c)
maxAge := int(h.cfg.RefreshTTL / time.Second)
c.SetCookie(h.cfg.RefreshCookieName, token, maxAge, h.cfg.CookiePath, h.cfg.CookieDomain, h.cfg.CookieSecure, true)
}
func (h *Handler) clearRefreshCookie(c *gin.Context) {
h.applySameSite(c)
c.SetCookie(h.cfg.RefreshCookieName, "", -1, h.cfg.CookiePath, h.cfg.CookieDomain, h.cfg.CookieSecure, true)
}
func (h *Handler) applySameSite(c *gin.Context) {
switch strings.ToLower(strings.TrimSpace(h.cfg.CookieSameSite)) {
case "strict":
c.SetSameSite(http.SameSiteStrictMode)
case "none":
c.SetSameSite(http.SameSiteNoneMode)
default:
c.SetSameSite(http.SameSiteLaxMode)
}
}

View File

@@ -0,0 +1,40 @@
package iam
import (
"errors"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
func (s *Service) RequireAccess() gin.HandlerFunc {
return func(c *gin.Context) {
token := extractBearerToken(c.GetHeader("Authorization"))
if token == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization"})
return
}
claims, err := s.ValidateAccessToken(c.Request.Context(), token)
if err != nil {
status := http.StatusUnauthorized
if errors.Is(err, errSessionRevoked) {
status = http.StatusUnauthorized
}
c.AbortWithStatusJSON(status, gin.H{"error": "invalid token"})
return
}
c.Set(ContextUserIDKey, claims.UserID)
c.Set(ContextSessionIDKey, claims.SessionID)
c.Set(ContextJTIKey, claims.JTI)
c.Next()
}
}
func extractBearerToken(header string) string {
authHeader := strings.TrimSpace(header)
if strings.HasPrefix(authHeader, "Bearer ") {
return strings.TrimSpace(authHeader[7:])
}
return authHeader
}

302
internal/iam/service.go Normal file
View File

@@ -0,0 +1,302 @@
package iam
import (
"context"
"errors"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
type Service struct {
cfg Config
store *postgresStore
cache *tokenCache
accessToken *tokenManager
refreshToken *tokenManager
}
func NewService(ctx context.Context, pool *pgxpool.Pool, cfg Config) (*Service, error) {
store := newPostgresStore(pool)
if err := store.initSchema(ctx); err != nil {
return nil, err
}
cache, err := newTokenCache(cfg.RedisAddr, cfg.RedisPassword, cfg.RedisDB)
if err != nil {
return nil, err
}
return &Service{
cfg: cfg,
store: store,
cache: cache,
accessToken: newTokenManager(cfg.AccessSecret, cfg.AccessTTL, tokenKindAccess),
refreshToken: newTokenManager(cfg.RefreshSecret, cfg.RefreshTTL, tokenKindRefresh),
}, nil
}
func (s *Service) Close() error {
if s.cache != nil {
return s.cache.Close()
}
return nil
}
func (s *Service) Register(ctx context.Context, email, password string, autoLogin bool, meta requestMeta) (AuthResult, string, error) {
email = normalizeEmail(email)
if email == "" || strings.TrimSpace(password) == "" {
return AuthResult{}, "", errInvalidCredentials
}
uid, err := s.store.registerUser(ctx, email, password)
if err != nil {
return AuthResult{}, "", err
}
if !autoLogin {
return AuthResult{}, "", nil
}
pair, err := s.issueTokenPair(ctx, uid, meta)
if err != nil {
return AuthResult{}, "", err
}
return AuthResult{
AccessToken: pair.AccessToken,
ExpiresIn: int64(s.cfg.AccessTTL.Seconds()),
SessionID: pair.SessionID,
}, pair.RefreshToken, nil
}
func (s *Service) Login(ctx context.Context, email, password string, meta requestMeta) (AuthResult, string, error) {
email = normalizeEmail(email)
if email == "" || strings.TrimSpace(password) == "" {
return AuthResult{}, "", errInvalidCredentials
}
uid, err := s.store.verifyUser(ctx, email, password)
if err != nil {
return AuthResult{}, "", err
}
pair, err := s.issueTokenPair(ctx, uid, meta)
if err != nil {
return AuthResult{}, "", err
}
return AuthResult{
AccessToken: pair.AccessToken,
ExpiresIn: int64(s.cfg.AccessTTL.Seconds()),
SessionID: pair.SessionID,
}, pair.RefreshToken, nil
}
func (s *Service) issueTokenPair(ctx context.Context, userID int64, meta requestMeta) (tokenPair, error) {
sid, err := randomID("sid")
if err != nil {
return tokenPair{}, err
}
rid, err := randomID("rid")
if err != nil {
return tokenPair{}, err
}
jti, err := randomID("jti")
if err != nil {
return tokenPair{}, err
}
now := time.Now().UTC()
sess := Session{
ID: sid,
UserID: userID,
DeviceInfo: strings.TrimSpace(meta.DeviceInfo),
IP: strings.TrimSpace(meta.IP),
UserAgent: strings.TrimSpace(meta.UserAgent),
CreatedAt: now,
ExpiresAt: now.Add(s.cfg.SessionTTL),
}
if err := s.store.createSession(ctx, sess); err != nil {
return tokenPair{}, err
}
refreshRaw, refreshExp, err := s.refreshToken.generateRefresh(userID, sid, rid)
if err != nil {
return tokenPair{}, err
}
refreshHash := hashRefreshToken(refreshRaw, s.cfg.RefreshPepper)
if err := s.store.createRefreshToken(ctx, rid, sid, refreshHash, refreshExp, nil); err != nil {
return tokenPair{}, err
}
accessRaw, accessExp, err := s.accessToken.generateAccess(userID, sid, jti)
if err != nil {
return tokenPair{}, err
}
if s.cache != nil {
_ = s.cache.SetSessionActive(ctx, sid, time.Until(sess.ExpiresAt))
_ = s.cache.SetRefreshHash(ctx, rid, refreshHash, time.Until(refreshExp))
_ = accessExp
}
return tokenPair{
AccessToken: accessRaw,
AccessTokenExpires: accessExp,
RefreshToken: refreshRaw,
SessionID: sid,
}, nil
}
func (s *Service) Refresh(ctx context.Context, refreshRaw string) (AuthResult, string, error) {
claims, err := s.refreshToken.parseRefresh(refreshRaw)
if err != nil {
return AuthResult{}, "", err
}
if ok, err := s.isSessionActive(ctx, claims.SessionID); err != nil {
return AuthResult{}, "", err
} else if !ok {
return AuthResult{}, "", errSessionRevoked
}
providedHash := hashRefreshToken(refreshRaw, s.cfg.RefreshPepper)
if s.cache != nil {
cached, err := s.cache.GetRefreshHash(ctx, claims.RefreshID)
if err != nil {
return AuthResult{}, "", err
}
if cached != "" && cached != providedHash {
_ = s.store.revokeSession(ctx, claims.SessionID)
_ = s.cache.DeleteSession(ctx, claims.SessionID)
return AuthResult{}, "", errSessionRevoked
}
}
newRID, err := randomID("rid")
if err != nil {
return AuthResult{}, "", err
}
newJTI, err := randomID("jti")
if err != nil {
return AuthResult{}, "", err
}
newRefreshRaw, refreshExp, err := s.refreshToken.generateRefresh(claims.UserID, claims.SessionID, newRID)
if err != nil {
return AuthResult{}, "", err
}
newRefreshHash := hashRefreshToken(newRefreshRaw, s.cfg.RefreshPepper)
if err := s.store.rotateRefreshToken(ctx, claims, providedHash, newRID, newRefreshHash, refreshExp); err != nil {
if errors.Is(err, errSessionRevoked) {
_ = s.cache.DeleteSession(ctx, claims.SessionID)
}
return AuthResult{}, "", err
}
if s.cache != nil {
_ = s.cache.DeleteRefresh(ctx, claims.RefreshID)
_ = s.cache.SetRefreshHash(ctx, newRID, newRefreshHash, time.Until(refreshExp))
}
accessRaw, _, err := s.accessToken.generateAccess(claims.UserID, claims.SessionID, newJTI)
if err != nil {
return AuthResult{}, "", err
}
return AuthResult{
AccessToken: accessRaw,
ExpiresIn: int64(s.cfg.AccessTTL.Seconds()),
SessionID: claims.SessionID,
}, newRefreshRaw, nil
}
func (s *Service) Logout(ctx context.Context, accessRaw string) error {
claims, err := s.accessToken.parseAccess(accessRaw)
if err != nil {
return err
}
if err := s.store.revokeSession(ctx, claims.SessionID); err != nil {
return err
}
if s.cache != nil {
_ = s.cache.DeleteSession(ctx, claims.SessionID)
_ = s.cache.DenyAccessJTI(ctx, claims.JTI, s.cfg.AccessTTL)
}
return nil
}
func (s *Service) LogoutAll(ctx context.Context, accessRaw string) error {
claims, err := s.accessToken.parseAccess(accessRaw)
if err != nil {
return err
}
sids, err := s.store.revokeAllUserSessions(ctx, claims.UserID)
if err != nil {
return err
}
if s.cache != nil {
for _, sid := range sids {
_ = s.cache.DeleteSession(ctx, sid)
}
_ = s.cache.DenyAccessJTI(ctx, claims.JTI, s.cfg.AccessTTL)
}
return nil
}
func (s *Service) ListSessions(ctx context.Context, userID int64) ([]Session, error) {
return s.store.listSessions(ctx, userID)
}
func (s *Service) RevokeSession(ctx context.Context, userID int64, sid string) error {
sess, err := s.store.getSession(ctx, sid)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return errInvalidToken
}
return err
}
if sess.UserID != userID {
return errForbidden
}
if err := s.store.revokeSession(ctx, sid); err != nil {
return err
}
if s.cache != nil {
_ = s.cache.DeleteSession(ctx, sid)
}
return nil
}
func (s *Service) ValidateAccessToken(ctx context.Context, token string) (accessClaims, error) {
claims, err := s.accessToken.parseAccess(token)
if err != nil {
return accessClaims{}, err
}
if s.cache != nil {
denied, err := s.cache.IsAccessJTIDenied(ctx, claims.JTI)
if err != nil {
return accessClaims{}, err
}
if denied {
return accessClaims{}, errUnauthorized
}
}
active, err := s.isSessionActive(ctx, claims.SessionID)
if err != nil {
return accessClaims{}, err
}
if !active {
return accessClaims{}, errSessionRevoked
}
return claims, nil
}
func (s *Service) isSessionActive(ctx context.Context, sid string) (bool, error) {
if s.cache != nil {
ok, err := s.cache.IsSessionActive(ctx, sid)
if err != nil {
return false, err
}
if ok {
return true, nil
}
}
ok, err := s.store.isSessionActive(ctx, sid)
if err != nil {
return false, err
}
if ok && s.cache != nil {
sess, err := s.store.getSession(ctx, sid)
if err == nil {
_ = s.cache.SetSessionActive(ctx, sid, time.Until(sess.ExpiresAt))
}
}
return ok, nil
}

279
internal/iam/store_pg.go Normal file
View File

@@ -0,0 +1,279 @@
package iam
import (
"context"
"database/sql"
"errors"
"strings"
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"golang.org/x/crypto/bcrypt"
)
type postgresStore struct {
pool *pgxpool.Pool
}
type refreshTokenRow struct {
ID string
SessionID string
TokenHash string
ExpiresAt time.Time
RotatedFrom sql.NullString
RevokedAt sql.NullTime
UsedAt sql.NullTime
CreatedAt time.Time
}
func newPostgresStore(pool *pgxpool.Pool) *postgresStore {
return &postgresStore{pool: pool}
}
func (s *postgresStore) initSchema(ctx context.Context) error {
stmts := []string{
`CREATE TABLE IF NOT EXISTS auth_sessions (
id TEXT PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
device_info TEXT,
ip TEXT,
user_agent TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
expires_at TIMESTAMPTZ NOT NULL,
revoked_at TIMESTAMPTZ
)`,
`CREATE INDEX IF NOT EXISTS idx_auth_sessions_user_revoked ON auth_sessions(user_id, revoked_at)`,
`CREATE INDEX IF NOT EXISTS idx_auth_sessions_expires_at ON auth_sessions(expires_at)`,
`CREATE TABLE IF NOT EXISTS auth_refresh_tokens (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL REFERENCES auth_sessions(id) ON DELETE CASCADE,
token_hash TEXT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
rotated_from TEXT,
revoked_at TIMESTAMPTZ,
used_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
)`,
`CREATE INDEX IF NOT EXISTS idx_auth_refresh_tokens_session_revoked ON auth_refresh_tokens(session_id, revoked_at)`,
`CREATE INDEX IF NOT EXISTS idx_auth_refresh_tokens_expires_at ON auth_refresh_tokens(expires_at)`,
`CREATE INDEX IF NOT EXISTS idx_auth_refresh_tokens_rotated_from ON auth_refresh_tokens(rotated_from)`,
}
for _, stmt := range stmts {
if _, err := s.pool.Exec(ctx, stmt); err != nil {
return err
}
}
return nil
}
func (s *postgresStore) registerUser(ctx context.Context, email, password string) (int64, error) {
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return 0, err
}
var id int64
err = s.pool.QueryRow(ctx, `INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING id`, email, string(hashed)).Scan(&id)
if err != nil {
if isUniqueViolation(err) {
return 0, errAlreadyExists
}
return 0, err
}
return id, nil
}
func (s *postgresStore) verifyUser(ctx context.Context, email, password string) (int64, error) {
var id int64
var hash string
err := s.pool.QueryRow(ctx, `SELECT id, password_hash FROM users WHERE email = $1`, email).Scan(&id, &hash)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return 0, errInvalidCredentials
}
return 0, err
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil {
return 0, errInvalidCredentials
}
return id, nil
}
func (s *postgresStore) createSession(ctx context.Context, sess Session) error {
_, err := s.pool.Exec(ctx, `
INSERT INTO auth_sessions (id, user_id, device_info, ip, user_agent, expires_at)
VALUES ($1, $2, $3, $4, $5, $6)
`, sess.ID, sess.UserID, sess.DeviceInfo, sess.IP, sess.UserAgent, sess.ExpiresAt)
return err
}
func (s *postgresStore) createRefreshToken(ctx context.Context, rid, sid, hash string, expiresAt time.Time, rotatedFrom *string) error {
_, err := s.pool.Exec(ctx, `
INSERT INTO auth_refresh_tokens (id, session_id, token_hash, expires_at, rotated_from)
VALUES ($1, $2, $3, $4, $5)
`, rid, sid, hash, expiresAt, rotatedFrom)
return err
}
func (s *postgresStore) getSession(ctx context.Context, sid string) (Session, error) {
var sess Session
var revokedAt sql.NullTime
err := s.pool.QueryRow(ctx, `
SELECT id, user_id, COALESCE(device_info,''), COALESCE(ip,''), COALESCE(user_agent,''), created_at, expires_at, revoked_at
FROM auth_sessions WHERE id = $1
`, sid).Scan(&sess.ID, &sess.UserID, &sess.DeviceInfo, &sess.IP, &sess.UserAgent, &sess.CreatedAt, &sess.ExpiresAt, &revokedAt)
if err != nil {
return Session{}, err
}
if revokedAt.Valid {
t := revokedAt.Time
sess.RevokedAt = &t
}
return sess, nil
}
func (s *postgresStore) isSessionActive(ctx context.Context, sid string) (bool, error) {
var ok bool
err := s.pool.QueryRow(ctx, `
SELECT EXISTS(
SELECT 1 FROM auth_sessions
WHERE id = $1 AND revoked_at IS NULL AND expires_at > now()
)
`, sid).Scan(&ok)
return ok, err
}
func (s *postgresStore) revokeSession(ctx context.Context, sid string) error {
_, err := s.pool.Exec(ctx, `UPDATE auth_sessions SET revoked_at = now() WHERE id = $1 AND revoked_at IS NULL`, sid)
if err != nil {
return err
}
_, err = s.pool.Exec(ctx, `UPDATE auth_refresh_tokens SET revoked_at = now() WHERE session_id = $1 AND revoked_at IS NULL`, sid)
return err
}
func (s *postgresStore) revokeAllUserSessions(ctx context.Context, userID int64) ([]string, error) {
rows, err := s.pool.Query(ctx, `SELECT id FROM auth_sessions WHERE user_id = $1 AND revoked_at IS NULL`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var sid string
if err := rows.Scan(&sid); err != nil {
return nil, err
}
ids = append(ids, sid)
}
if err := rows.Err(); err != nil {
return nil, err
}
_, err = s.pool.Exec(ctx, `UPDATE auth_sessions SET revoked_at = now() WHERE user_id = $1 AND revoked_at IS NULL`, userID)
if err != nil {
return nil, err
}
_, err = s.pool.Exec(ctx, `
UPDATE auth_refresh_tokens rt
SET revoked_at = now()
FROM auth_sessions s
WHERE s.user_id = $1 AND rt.session_id = s.id AND rt.revoked_at IS NULL
`, userID)
if err != nil {
return nil, err
}
return ids, nil
}
func (s *postgresStore) listSessions(ctx context.Context, userID int64) ([]Session, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, user_id, COALESCE(device_info,''), COALESCE(ip,''), COALESCE(user_agent,''), created_at, expires_at, revoked_at
FROM auth_sessions WHERE user_id = $1 ORDER BY created_at DESC
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
result := make([]Session, 0)
for rows.Next() {
var sess Session
var revokedAt sql.NullTime
if err := rows.Scan(&sess.ID, &sess.UserID, &sess.DeviceInfo, &sess.IP, &sess.UserAgent, &sess.CreatedAt, &sess.ExpiresAt, &revokedAt); err != nil {
return nil, err
}
if revokedAt.Valid {
t := revokedAt.Time
sess.RevokedAt = &t
}
result = append(result, sess)
}
return result, rows.Err()
}
func (s *postgresStore) getRefreshTokenForUpdate(ctx context.Context, tx pgx.Tx, rid string) (refreshTokenRow, error) {
var row refreshTokenRow
err := tx.QueryRow(ctx, `
SELECT id, session_id, token_hash, expires_at, rotated_from, revoked_at, used_at, created_at
FROM auth_refresh_tokens
WHERE id = $1
FOR UPDATE
`, rid).Scan(&row.ID, &row.SessionID, &row.TokenHash, &row.ExpiresAt, &row.RotatedFrom, &row.RevokedAt, &row.UsedAt, &row.CreatedAt)
return row, err
}
func (s *postgresStore) rotateRefreshToken(ctx context.Context, claims refreshClaims, providedHash, newRID, newHash string, newExp time.Time) error {
tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{})
if err != nil {
return err
}
defer tx.Rollback(ctx)
row, err := s.getRefreshTokenForUpdate(ctx, tx, claims.RefreshID)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return errInvalidToken
}
return err
}
if row.SessionID != claims.SessionID || row.ExpiresAt.Before(time.Now().UTC()) {
return errInvalidToken
}
if row.RevokedAt.Valid || row.UsedAt.Valid || row.TokenHash != providedHash {
if _, err := tx.Exec(ctx, `UPDATE auth_sessions SET revoked_at = now() WHERE id = $1 AND revoked_at IS NULL`, claims.SessionID); err != nil {
return err
}
if _, err := tx.Exec(ctx, `UPDATE auth_refresh_tokens SET revoked_at = now() WHERE session_id = $1 AND revoked_at IS NULL`, claims.SessionID); err != nil {
return err
}
if err := tx.Commit(ctx); err != nil {
return err
}
return errSessionRevoked
}
if _, err := tx.Exec(ctx, `UPDATE auth_refresh_tokens SET used_at = now(), revoked_at = now() WHERE id = $1`, claims.RefreshID); err != nil {
return err
}
if _, err := tx.Exec(ctx, `
INSERT INTO auth_refresh_tokens (id, session_id, token_hash, expires_at, rotated_from)
VALUES ($1, $2, $3, $4, $5)
`, newRID, claims.SessionID, newHash, newExp, claims.RefreshID); err != nil {
return err
}
return tx.Commit(ctx)
}
func isUniqueViolation(err error) bool {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
return true
}
return false
}
func normalizeEmail(v string) string {
return strings.TrimSpace(strings.ToLower(v))
}

147
internal/iam/token.go Normal file
View File

@@ -0,0 +1,147 @@
package iam
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"strings"
"time"
)
type tokenManager struct {
secret []byte
ttl time.Duration
kind tokenKind
}
func newTokenManager(secret string, ttl time.Duration, kind tokenKind) *tokenManager {
return &tokenManager{secret: []byte(secret), ttl: ttl, kind: kind}
}
func (t *tokenManager) generateAccess(userID int64, sessionID, jti string) (string, time.Time, error) {
now := time.Now().UTC()
expiresAt := now.Add(t.ttl)
payload := accessClaims{
UserID: userID,
SessionID: sessionID,
JTI: jti,
IssuedAt: now.Unix(),
ExpiresAt: expiresAt.Unix(),
Type: t.kind,
}
token, err := t.sign(payload)
if err != nil {
return "", time.Time{}, err
}
return token, expiresAt, nil
}
func (t *tokenManager) parseAccess(token string) (accessClaims, error) {
var claims accessClaims
if err := t.verifyAndDecode(token, &claims); err != nil {
return claims, err
}
if claims.Type != tokenKindAccess {
return claims, errInvalidToken
}
if claims.UserID == 0 || strings.TrimSpace(claims.SessionID) == "" || strings.TrimSpace(claims.JTI) == "" {
return claims, errInvalidToken
}
if time.Now().UTC().Unix() > claims.ExpiresAt {
return claims, errTokenExpired
}
return claims, nil
}
func (t *tokenManager) generateRefresh(userID int64, sessionID, refreshID string) (string, time.Time, error) {
now := time.Now().UTC()
expiresAt := now.Add(t.ttl)
payload := refreshClaims{
UserID: userID,
SessionID: sessionID,
RefreshID: refreshID,
IssuedAt: now.Unix(),
ExpiresAt: expiresAt.Unix(),
Type: t.kind,
}
token, err := t.sign(payload)
if err != nil {
return "", time.Time{}, err
}
return token, expiresAt, nil
}
func (t *tokenManager) parseRefresh(token string) (refreshClaims, error) {
var claims refreshClaims
if err := t.verifyAndDecode(token, &claims); err != nil {
return claims, err
}
if claims.Type != tokenKindRefresh {
return claims, errInvalidToken
}
if claims.UserID == 0 || strings.TrimSpace(claims.SessionID) == "" || strings.TrimSpace(claims.RefreshID) == "" {
return claims, errInvalidToken
}
if time.Now().UTC().Unix() > claims.ExpiresAt {
return claims, errTokenExpired
}
return claims, nil
}
func (t *tokenManager) sign(v any) (string, error) {
raw, err := json.Marshal(v)
if err != nil {
return "", err
}
payload := base64.RawURLEncoding.EncodeToString(raw)
sig := t.signPayload(payload)
return payload + "." + sig, nil
}
func (t *tokenManager) verifyAndDecode(token string, out any) error {
parts := strings.Split(token, ".")
if len(parts) != 2 {
return errInvalidToken
}
payload, sig := parts[0], parts[1]
if !t.verifyPayload(payload, sig) {
return errInvalidToken
}
raw, err := base64.RawURLEncoding.DecodeString(payload)
if err != nil {
return errInvalidToken
}
if err := json.Unmarshal(raw, out); err != nil {
return errInvalidToken
}
return nil
}
func (t *tokenManager) signPayload(payload string) string {
mac := hmac.New(sha256.New, t.secret)
mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}
func (t *tokenManager) verifyPayload(payload, signature string) bool {
expected := t.signPayload(payload)
return hmac.Equal([]byte(signature), []byte(expected))
}
func randomID(prefix string) (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return fmt.Sprintf("%s_%s", prefix, hex.EncodeToString(b)), nil
}
func hashRefreshToken(token, pepper string) string {
mac := hmac.New(sha256.New, []byte(pepper))
mac.Write([]byte(token))
return hex.EncodeToString(mac.Sum(nil))
}

77
internal/iam/types.go Normal file
View File

@@ -0,0 +1,77 @@
package iam
import (
"errors"
"time"
)
const (
ContextUserIDKey = "user_id"
ContextSessionIDKey = "session_id"
ContextJTIKey = "jti"
)
type tokenKind string
const (
tokenKindAccess tokenKind = "access"
tokenKindRefresh tokenKind = "refresh"
)
type accessClaims struct {
UserID int64 `json:"uid"`
SessionID string `json:"sid"`
JTI string `json:"jti"`
IssuedAt int64 `json:"iat"`
ExpiresAt int64 `json:"exp"`
Type tokenKind `json:"typ"`
}
type refreshClaims struct {
UserID int64 `json:"uid"`
SessionID string `json:"sid"`
RefreshID string `json:"rid"`
IssuedAt int64 `json:"iat"`
ExpiresAt int64 `json:"exp"`
Type tokenKind `json:"typ"`
}
type Session struct {
ID string `json:"id"`
UserID int64 `json:"user_id"`
DeviceInfo string `json:"device_info"`
IP string `json:"ip"`
UserAgent string `json:"user_agent"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
RevokedAt *time.Time `json:"revoked_at,omitempty"`
}
type tokenPair struct {
AccessToken string
AccessTokenExpires time.Time
RefreshToken string
SessionID string
}
type requestMeta struct {
IP string
UserAgent string
DeviceInfo string
}
type AuthResult struct {
AccessToken string `json:"access_token"`
ExpiresIn int64 `json:"expires_in"`
SessionID string `json:"session_id"`
}
var (
errInvalidToken = errors.New("invalid token")
errTokenExpired = errors.New("token expired")
errInvalidCredentials = errors.New("invalid credentials")
errAlreadyExists = errors.New("already exists")
errSessionRevoked = errors.New("session revoked")
errUnauthorized = errors.New("unauthorized")
errForbidden = errors.New("forbidden")
)