refactor(auth): split IAM module and add access/refresh session flow
This commit is contained in:
302
internal/iam/service.go
Normal file
302
internal/iam/service.go
Normal file
@@ -0,0 +1,302 @@
|
||||
package iam
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
cfg Config
|
||||
store *postgresStore
|
||||
cache *tokenCache
|
||||
accessToken *tokenManager
|
||||
refreshToken *tokenManager
|
||||
}
|
||||
|
||||
func NewService(ctx context.Context, pool *pgxpool.Pool, cfg Config) (*Service, error) {
|
||||
store := newPostgresStore(pool)
|
||||
if err := store.initSchema(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cache, err := newTokenCache(cfg.RedisAddr, cfg.RedisPassword, cfg.RedisDB)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Service{
|
||||
cfg: cfg,
|
||||
store: store,
|
||||
cache: cache,
|
||||
accessToken: newTokenManager(cfg.AccessSecret, cfg.AccessTTL, tokenKindAccess),
|
||||
refreshToken: newTokenManager(cfg.RefreshSecret, cfg.RefreshTTL, tokenKindRefresh),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
if s.cache != nil {
|
||||
return s.cache.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Register(ctx context.Context, email, password string, autoLogin bool, meta requestMeta) (AuthResult, string, error) {
|
||||
email = normalizeEmail(email)
|
||||
if email == "" || strings.TrimSpace(password) == "" {
|
||||
return AuthResult{}, "", errInvalidCredentials
|
||||
}
|
||||
uid, err := s.store.registerUser(ctx, email, password)
|
||||
if err != nil {
|
||||
return AuthResult{}, "", err
|
||||
}
|
||||
if !autoLogin {
|
||||
return AuthResult{}, "", nil
|
||||
}
|
||||
pair, err := s.issueTokenPair(ctx, uid, meta)
|
||||
if err != nil {
|
||||
return AuthResult{}, "", err
|
||||
}
|
||||
return AuthResult{
|
||||
AccessToken: pair.AccessToken,
|
||||
ExpiresIn: int64(s.cfg.AccessTTL.Seconds()),
|
||||
SessionID: pair.SessionID,
|
||||
}, pair.RefreshToken, nil
|
||||
}
|
||||
|
||||
func (s *Service) Login(ctx context.Context, email, password string, meta requestMeta) (AuthResult, string, error) {
|
||||
email = normalizeEmail(email)
|
||||
if email == "" || strings.TrimSpace(password) == "" {
|
||||
return AuthResult{}, "", errInvalidCredentials
|
||||
}
|
||||
uid, err := s.store.verifyUser(ctx, email, password)
|
||||
if err != nil {
|
||||
return AuthResult{}, "", err
|
||||
}
|
||||
pair, err := s.issueTokenPair(ctx, uid, meta)
|
||||
if err != nil {
|
||||
return AuthResult{}, "", err
|
||||
}
|
||||
return AuthResult{
|
||||
AccessToken: pair.AccessToken,
|
||||
ExpiresIn: int64(s.cfg.AccessTTL.Seconds()),
|
||||
SessionID: pair.SessionID,
|
||||
}, pair.RefreshToken, nil
|
||||
}
|
||||
|
||||
func (s *Service) issueTokenPair(ctx context.Context, userID int64, meta requestMeta) (tokenPair, error) {
|
||||
sid, err := randomID("sid")
|
||||
if err != nil {
|
||||
return tokenPair{}, err
|
||||
}
|
||||
rid, err := randomID("rid")
|
||||
if err != nil {
|
||||
return tokenPair{}, err
|
||||
}
|
||||
jti, err := randomID("jti")
|
||||
if err != nil {
|
||||
return tokenPair{}, err
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
sess := Session{
|
||||
ID: sid,
|
||||
UserID: userID,
|
||||
DeviceInfo: strings.TrimSpace(meta.DeviceInfo),
|
||||
IP: strings.TrimSpace(meta.IP),
|
||||
UserAgent: strings.TrimSpace(meta.UserAgent),
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(s.cfg.SessionTTL),
|
||||
}
|
||||
if err := s.store.createSession(ctx, sess); err != nil {
|
||||
return tokenPair{}, err
|
||||
}
|
||||
refreshRaw, refreshExp, err := s.refreshToken.generateRefresh(userID, sid, rid)
|
||||
if err != nil {
|
||||
return tokenPair{}, err
|
||||
}
|
||||
refreshHash := hashRefreshToken(refreshRaw, s.cfg.RefreshPepper)
|
||||
if err := s.store.createRefreshToken(ctx, rid, sid, refreshHash, refreshExp, nil); err != nil {
|
||||
return tokenPair{}, err
|
||||
}
|
||||
accessRaw, accessExp, err := s.accessToken.generateAccess(userID, sid, jti)
|
||||
if err != nil {
|
||||
return tokenPair{}, err
|
||||
}
|
||||
if s.cache != nil {
|
||||
_ = s.cache.SetSessionActive(ctx, sid, time.Until(sess.ExpiresAt))
|
||||
_ = s.cache.SetRefreshHash(ctx, rid, refreshHash, time.Until(refreshExp))
|
||||
_ = accessExp
|
||||
}
|
||||
return tokenPair{
|
||||
AccessToken: accessRaw,
|
||||
AccessTokenExpires: accessExp,
|
||||
RefreshToken: refreshRaw,
|
||||
SessionID: sid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Refresh(ctx context.Context, refreshRaw string) (AuthResult, string, error) {
|
||||
claims, err := s.refreshToken.parseRefresh(refreshRaw)
|
||||
if err != nil {
|
||||
return AuthResult{}, "", err
|
||||
}
|
||||
|
||||
if ok, err := s.isSessionActive(ctx, claims.SessionID); err != nil {
|
||||
return AuthResult{}, "", err
|
||||
} else if !ok {
|
||||
return AuthResult{}, "", errSessionRevoked
|
||||
}
|
||||
|
||||
providedHash := hashRefreshToken(refreshRaw, s.cfg.RefreshPepper)
|
||||
if s.cache != nil {
|
||||
cached, err := s.cache.GetRefreshHash(ctx, claims.RefreshID)
|
||||
if err != nil {
|
||||
return AuthResult{}, "", err
|
||||
}
|
||||
if cached != "" && cached != providedHash {
|
||||
_ = s.store.revokeSession(ctx, claims.SessionID)
|
||||
_ = s.cache.DeleteSession(ctx, claims.SessionID)
|
||||
return AuthResult{}, "", errSessionRevoked
|
||||
}
|
||||
}
|
||||
|
||||
newRID, err := randomID("rid")
|
||||
if err != nil {
|
||||
return AuthResult{}, "", err
|
||||
}
|
||||
newJTI, err := randomID("jti")
|
||||
if err != nil {
|
||||
return AuthResult{}, "", err
|
||||
}
|
||||
newRefreshRaw, refreshExp, err := s.refreshToken.generateRefresh(claims.UserID, claims.SessionID, newRID)
|
||||
if err != nil {
|
||||
return AuthResult{}, "", err
|
||||
}
|
||||
newRefreshHash := hashRefreshToken(newRefreshRaw, s.cfg.RefreshPepper)
|
||||
if err := s.store.rotateRefreshToken(ctx, claims, providedHash, newRID, newRefreshHash, refreshExp); err != nil {
|
||||
if errors.Is(err, errSessionRevoked) {
|
||||
_ = s.cache.DeleteSession(ctx, claims.SessionID)
|
||||
}
|
||||
return AuthResult{}, "", err
|
||||
}
|
||||
if s.cache != nil {
|
||||
_ = s.cache.DeleteRefresh(ctx, claims.RefreshID)
|
||||
_ = s.cache.SetRefreshHash(ctx, newRID, newRefreshHash, time.Until(refreshExp))
|
||||
}
|
||||
accessRaw, _, err := s.accessToken.generateAccess(claims.UserID, claims.SessionID, newJTI)
|
||||
if err != nil {
|
||||
return AuthResult{}, "", err
|
||||
}
|
||||
return AuthResult{
|
||||
AccessToken: accessRaw,
|
||||
ExpiresIn: int64(s.cfg.AccessTTL.Seconds()),
|
||||
SessionID: claims.SessionID,
|
||||
}, newRefreshRaw, nil
|
||||
}
|
||||
|
||||
func (s *Service) Logout(ctx context.Context, accessRaw string) error {
|
||||
claims, err := s.accessToken.parseAccess(accessRaw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.store.revokeSession(ctx, claims.SessionID); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.cache != nil {
|
||||
_ = s.cache.DeleteSession(ctx, claims.SessionID)
|
||||
_ = s.cache.DenyAccessJTI(ctx, claims.JTI, s.cfg.AccessTTL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) LogoutAll(ctx context.Context, accessRaw string) error {
|
||||
claims, err := s.accessToken.parseAccess(accessRaw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sids, err := s.store.revokeAllUserSessions(ctx, claims.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.cache != nil {
|
||||
for _, sid := range sids {
|
||||
_ = s.cache.DeleteSession(ctx, sid)
|
||||
}
|
||||
_ = s.cache.DenyAccessJTI(ctx, claims.JTI, s.cfg.AccessTTL)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) ListSessions(ctx context.Context, userID int64) ([]Session, error) {
|
||||
return s.store.listSessions(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *Service) RevokeSession(ctx context.Context, userID int64, sid string) error {
|
||||
sess, err := s.store.getSession(ctx, sid)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return errInvalidToken
|
||||
}
|
||||
return err
|
||||
}
|
||||
if sess.UserID != userID {
|
||||
return errForbidden
|
||||
}
|
||||
if err := s.store.revokeSession(ctx, sid); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.cache != nil {
|
||||
_ = s.cache.DeleteSession(ctx, sid)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) ValidateAccessToken(ctx context.Context, token string) (accessClaims, error) {
|
||||
claims, err := s.accessToken.parseAccess(token)
|
||||
if err != nil {
|
||||
return accessClaims{}, err
|
||||
}
|
||||
if s.cache != nil {
|
||||
denied, err := s.cache.IsAccessJTIDenied(ctx, claims.JTI)
|
||||
if err != nil {
|
||||
return accessClaims{}, err
|
||||
}
|
||||
if denied {
|
||||
return accessClaims{}, errUnauthorized
|
||||
}
|
||||
}
|
||||
active, err := s.isSessionActive(ctx, claims.SessionID)
|
||||
if err != nil {
|
||||
return accessClaims{}, err
|
||||
}
|
||||
if !active {
|
||||
return accessClaims{}, errSessionRevoked
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *Service) isSessionActive(ctx context.Context, sid string) (bool, error) {
|
||||
if s.cache != nil {
|
||||
ok, err := s.cache.IsSessionActive(ctx, sid)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if ok {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
ok, err := s.store.isSessionActive(ctx, sid)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if ok && s.cache != nil {
|
||||
sess, err := s.store.getSession(ctx, sid)
|
||||
if err == nil {
|
||||
_ = s.cache.SetSessionActive(ctx, sid, time.Until(sess.ExpiresAt))
|
||||
}
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
Reference in New Issue
Block a user