refactor(auth): split IAM module and add access/refresh session flow

This commit is contained in:
2026-03-01 21:26:37 +08:00
parent 6a2d2c9724
commit 57c27e9102
13 changed files with 1377 additions and 345 deletions

302
internal/iam/service.go Normal file
View File

@@ -0,0 +1,302 @@
package iam
import (
"context"
"errors"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
type Service struct {
cfg Config
store *postgresStore
cache *tokenCache
accessToken *tokenManager
refreshToken *tokenManager
}
func NewService(ctx context.Context, pool *pgxpool.Pool, cfg Config) (*Service, error) {
store := newPostgresStore(pool)
if err := store.initSchema(ctx); err != nil {
return nil, err
}
cache, err := newTokenCache(cfg.RedisAddr, cfg.RedisPassword, cfg.RedisDB)
if err != nil {
return nil, err
}
return &Service{
cfg: cfg,
store: store,
cache: cache,
accessToken: newTokenManager(cfg.AccessSecret, cfg.AccessTTL, tokenKindAccess),
refreshToken: newTokenManager(cfg.RefreshSecret, cfg.RefreshTTL, tokenKindRefresh),
}, nil
}
func (s *Service) Close() error {
if s.cache != nil {
return s.cache.Close()
}
return nil
}
func (s *Service) Register(ctx context.Context, email, password string, autoLogin bool, meta requestMeta) (AuthResult, string, error) {
email = normalizeEmail(email)
if email == "" || strings.TrimSpace(password) == "" {
return AuthResult{}, "", errInvalidCredentials
}
uid, err := s.store.registerUser(ctx, email, password)
if err != nil {
return AuthResult{}, "", err
}
if !autoLogin {
return AuthResult{}, "", nil
}
pair, err := s.issueTokenPair(ctx, uid, meta)
if err != nil {
return AuthResult{}, "", err
}
return AuthResult{
AccessToken: pair.AccessToken,
ExpiresIn: int64(s.cfg.AccessTTL.Seconds()),
SessionID: pair.SessionID,
}, pair.RefreshToken, nil
}
func (s *Service) Login(ctx context.Context, email, password string, meta requestMeta) (AuthResult, string, error) {
email = normalizeEmail(email)
if email == "" || strings.TrimSpace(password) == "" {
return AuthResult{}, "", errInvalidCredentials
}
uid, err := s.store.verifyUser(ctx, email, password)
if err != nil {
return AuthResult{}, "", err
}
pair, err := s.issueTokenPair(ctx, uid, meta)
if err != nil {
return AuthResult{}, "", err
}
return AuthResult{
AccessToken: pair.AccessToken,
ExpiresIn: int64(s.cfg.AccessTTL.Seconds()),
SessionID: pair.SessionID,
}, pair.RefreshToken, nil
}
func (s *Service) issueTokenPair(ctx context.Context, userID int64, meta requestMeta) (tokenPair, error) {
sid, err := randomID("sid")
if err != nil {
return tokenPair{}, err
}
rid, err := randomID("rid")
if err != nil {
return tokenPair{}, err
}
jti, err := randomID("jti")
if err != nil {
return tokenPair{}, err
}
now := time.Now().UTC()
sess := Session{
ID: sid,
UserID: userID,
DeviceInfo: strings.TrimSpace(meta.DeviceInfo),
IP: strings.TrimSpace(meta.IP),
UserAgent: strings.TrimSpace(meta.UserAgent),
CreatedAt: now,
ExpiresAt: now.Add(s.cfg.SessionTTL),
}
if err := s.store.createSession(ctx, sess); err != nil {
return tokenPair{}, err
}
refreshRaw, refreshExp, err := s.refreshToken.generateRefresh(userID, sid, rid)
if err != nil {
return tokenPair{}, err
}
refreshHash := hashRefreshToken(refreshRaw, s.cfg.RefreshPepper)
if err := s.store.createRefreshToken(ctx, rid, sid, refreshHash, refreshExp, nil); err != nil {
return tokenPair{}, err
}
accessRaw, accessExp, err := s.accessToken.generateAccess(userID, sid, jti)
if err != nil {
return tokenPair{}, err
}
if s.cache != nil {
_ = s.cache.SetSessionActive(ctx, sid, time.Until(sess.ExpiresAt))
_ = s.cache.SetRefreshHash(ctx, rid, refreshHash, time.Until(refreshExp))
_ = accessExp
}
return tokenPair{
AccessToken: accessRaw,
AccessTokenExpires: accessExp,
RefreshToken: refreshRaw,
SessionID: sid,
}, nil
}
func (s *Service) Refresh(ctx context.Context, refreshRaw string) (AuthResult, string, error) {
claims, err := s.refreshToken.parseRefresh(refreshRaw)
if err != nil {
return AuthResult{}, "", err
}
if ok, err := s.isSessionActive(ctx, claims.SessionID); err != nil {
return AuthResult{}, "", err
} else if !ok {
return AuthResult{}, "", errSessionRevoked
}
providedHash := hashRefreshToken(refreshRaw, s.cfg.RefreshPepper)
if s.cache != nil {
cached, err := s.cache.GetRefreshHash(ctx, claims.RefreshID)
if err != nil {
return AuthResult{}, "", err
}
if cached != "" && cached != providedHash {
_ = s.store.revokeSession(ctx, claims.SessionID)
_ = s.cache.DeleteSession(ctx, claims.SessionID)
return AuthResult{}, "", errSessionRevoked
}
}
newRID, err := randomID("rid")
if err != nil {
return AuthResult{}, "", err
}
newJTI, err := randomID("jti")
if err != nil {
return AuthResult{}, "", err
}
newRefreshRaw, refreshExp, err := s.refreshToken.generateRefresh(claims.UserID, claims.SessionID, newRID)
if err != nil {
return AuthResult{}, "", err
}
newRefreshHash := hashRefreshToken(newRefreshRaw, s.cfg.RefreshPepper)
if err := s.store.rotateRefreshToken(ctx, claims, providedHash, newRID, newRefreshHash, refreshExp); err != nil {
if errors.Is(err, errSessionRevoked) {
_ = s.cache.DeleteSession(ctx, claims.SessionID)
}
return AuthResult{}, "", err
}
if s.cache != nil {
_ = s.cache.DeleteRefresh(ctx, claims.RefreshID)
_ = s.cache.SetRefreshHash(ctx, newRID, newRefreshHash, time.Until(refreshExp))
}
accessRaw, _, err := s.accessToken.generateAccess(claims.UserID, claims.SessionID, newJTI)
if err != nil {
return AuthResult{}, "", err
}
return AuthResult{
AccessToken: accessRaw,
ExpiresIn: int64(s.cfg.AccessTTL.Seconds()),
SessionID: claims.SessionID,
}, newRefreshRaw, nil
}
func (s *Service) Logout(ctx context.Context, accessRaw string) error {
claims, err := s.accessToken.parseAccess(accessRaw)
if err != nil {
return err
}
if err := s.store.revokeSession(ctx, claims.SessionID); err != nil {
return err
}
if s.cache != nil {
_ = s.cache.DeleteSession(ctx, claims.SessionID)
_ = s.cache.DenyAccessJTI(ctx, claims.JTI, s.cfg.AccessTTL)
}
return nil
}
func (s *Service) LogoutAll(ctx context.Context, accessRaw string) error {
claims, err := s.accessToken.parseAccess(accessRaw)
if err != nil {
return err
}
sids, err := s.store.revokeAllUserSessions(ctx, claims.UserID)
if err != nil {
return err
}
if s.cache != nil {
for _, sid := range sids {
_ = s.cache.DeleteSession(ctx, sid)
}
_ = s.cache.DenyAccessJTI(ctx, claims.JTI, s.cfg.AccessTTL)
}
return nil
}
func (s *Service) ListSessions(ctx context.Context, userID int64) ([]Session, error) {
return s.store.listSessions(ctx, userID)
}
func (s *Service) RevokeSession(ctx context.Context, userID int64, sid string) error {
sess, err := s.store.getSession(ctx, sid)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return errInvalidToken
}
return err
}
if sess.UserID != userID {
return errForbidden
}
if err := s.store.revokeSession(ctx, sid); err != nil {
return err
}
if s.cache != nil {
_ = s.cache.DeleteSession(ctx, sid)
}
return nil
}
func (s *Service) ValidateAccessToken(ctx context.Context, token string) (accessClaims, error) {
claims, err := s.accessToken.parseAccess(token)
if err != nil {
return accessClaims{}, err
}
if s.cache != nil {
denied, err := s.cache.IsAccessJTIDenied(ctx, claims.JTI)
if err != nil {
return accessClaims{}, err
}
if denied {
return accessClaims{}, errUnauthorized
}
}
active, err := s.isSessionActive(ctx, claims.SessionID)
if err != nil {
return accessClaims{}, err
}
if !active {
return accessClaims{}, errSessionRevoked
}
return claims, nil
}
func (s *Service) isSessionActive(ctx context.Context, sid string) (bool, error) {
if s.cache != nil {
ok, err := s.cache.IsSessionActive(ctx, sid)
if err != nil {
return false, err
}
if ok {
return true, nil
}
}
ok, err := s.store.isSessionActive(ctx, sid)
if err != nil {
return false, err
}
if ok && s.cache != nil {
sess, err := s.store.getSession(ctx, sid)
if err == nil {
_ = s.cache.SetSessionActive(ctx, sid, time.Until(sess.ExpiresAt))
}
}
return ok, nil
}