package iam import ( "context" "errors" "strings" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) type Service struct { cfg Config store *postgresStore cache *tokenCache accessToken *tokenManager refreshToken *tokenManager } func NewService(ctx context.Context, pool *pgxpool.Pool, cfg Config) (*Service, error) { store := newPostgresStore(pool) if err := store.initSchema(ctx); err != nil { return nil, err } cache, err := newTokenCache(cfg.RedisAddr, cfg.RedisPassword, cfg.RedisDB) if err != nil { return nil, err } return &Service{ cfg: cfg, store: store, cache: cache, accessToken: newTokenManager(cfg.AccessSecret, cfg.AccessTTL, tokenKindAccess), refreshToken: newTokenManager(cfg.RefreshSecret, cfg.RefreshTTL, tokenKindRefresh), }, nil } func (s *Service) Close() error { if s.cache != nil { return s.cache.Close() } return nil } func (s *Service) Register(ctx context.Context, email, password string, autoLogin bool, meta requestMeta) (AuthResult, string, error) { email = normalizeEmail(email) if email == "" || strings.TrimSpace(password) == "" { return AuthResult{}, "", errInvalidCredentials } uid, err := s.store.registerUser(ctx, email, password) if err != nil { return AuthResult{}, "", err } if !autoLogin { return AuthResult{}, "", nil } pair, err := s.issueTokenPair(ctx, uid, meta) if err != nil { return AuthResult{}, "", err } return AuthResult{ AccessToken: pair.AccessToken, ExpiresIn: int64(s.cfg.AccessTTL.Seconds()), SessionID: pair.SessionID, }, pair.RefreshToken, nil } func (s *Service) Login(ctx context.Context, email, password string, meta requestMeta) (AuthResult, string, error) { email = normalizeEmail(email) if email == "" || strings.TrimSpace(password) == "" { return AuthResult{}, "", errInvalidCredentials } uid, err := s.store.verifyUser(ctx, email, password) if err != nil { return AuthResult{}, "", err } pair, err := s.issueTokenPair(ctx, uid, meta) if err != nil { return AuthResult{}, "", err } return AuthResult{ AccessToken: pair.AccessToken, ExpiresIn: int64(s.cfg.AccessTTL.Seconds()), SessionID: pair.SessionID, }, pair.RefreshToken, nil } func (s *Service) issueTokenPair(ctx context.Context, userID int64, meta requestMeta) (tokenPair, error) { sid, err := randomID("sid") if err != nil { return tokenPair{}, err } rid, err := randomID("rid") if err != nil { return tokenPair{}, err } jti, err := randomID("jti") if err != nil { return tokenPair{}, err } now := time.Now().UTC() sess := Session{ ID: sid, UserID: userID, DeviceInfo: strings.TrimSpace(meta.DeviceInfo), IP: strings.TrimSpace(meta.IP), UserAgent: strings.TrimSpace(meta.UserAgent), CreatedAt: now, ExpiresAt: now.Add(s.cfg.SessionTTL), } if err := s.store.createSession(ctx, sess); err != nil { return tokenPair{}, err } refreshRaw, refreshExp, err := s.refreshToken.generateRefresh(userID, sid, rid) if err != nil { return tokenPair{}, err } refreshHash := hashRefreshToken(refreshRaw, s.cfg.RefreshPepper) if err := s.store.createRefreshToken(ctx, rid, sid, refreshHash, refreshExp, nil); err != nil { return tokenPair{}, err } accessRaw, accessExp, err := s.accessToken.generateAccess(userID, sid, jti) if err != nil { return tokenPair{}, err } if s.cache != nil { _ = s.cache.SetSessionActive(ctx, sid, time.Until(sess.ExpiresAt)) _ = s.cache.SetRefreshHash(ctx, rid, refreshHash, time.Until(refreshExp)) _ = accessExp } return tokenPair{ AccessToken: accessRaw, AccessTokenExpires: accessExp, RefreshToken: refreshRaw, SessionID: sid, }, nil } func (s *Service) Refresh(ctx context.Context, refreshRaw string) (AuthResult, string, error) { claims, err := s.refreshToken.parseRefresh(refreshRaw) if err != nil { return AuthResult{}, "", err } if ok, err := s.isSessionActive(ctx, claims.SessionID); err != nil { return AuthResult{}, "", err } else if !ok { return AuthResult{}, "", errSessionRevoked } providedHash := hashRefreshToken(refreshRaw, s.cfg.RefreshPepper) if s.cache != nil { cached, err := s.cache.GetRefreshHash(ctx, claims.RefreshID) if err != nil { return AuthResult{}, "", err } if cached != "" && cached != providedHash { _ = s.store.revokeSession(ctx, claims.SessionID) _ = s.cache.DeleteSession(ctx, claims.SessionID) return AuthResult{}, "", errSessionRevoked } } newRID, err := randomID("rid") if err != nil { return AuthResult{}, "", err } newJTI, err := randomID("jti") if err != nil { return AuthResult{}, "", err } newRefreshRaw, refreshExp, err := s.refreshToken.generateRefresh(claims.UserID, claims.SessionID, newRID) if err != nil { return AuthResult{}, "", err } newRefreshHash := hashRefreshToken(newRefreshRaw, s.cfg.RefreshPepper) if err := s.store.rotateRefreshToken(ctx, claims, providedHash, newRID, newRefreshHash, refreshExp); err != nil { if errors.Is(err, errSessionRevoked) { _ = s.cache.DeleteSession(ctx, claims.SessionID) } return AuthResult{}, "", err } if s.cache != nil { _ = s.cache.DeleteRefresh(ctx, claims.RefreshID) _ = s.cache.SetRefreshHash(ctx, newRID, newRefreshHash, time.Until(refreshExp)) } accessRaw, _, err := s.accessToken.generateAccess(claims.UserID, claims.SessionID, newJTI) if err != nil { return AuthResult{}, "", err } return AuthResult{ AccessToken: accessRaw, ExpiresIn: int64(s.cfg.AccessTTL.Seconds()), SessionID: claims.SessionID, }, newRefreshRaw, nil } func (s *Service) Logout(ctx context.Context, accessRaw string) error { claims, err := s.accessToken.parseAccess(accessRaw) if err != nil { return err } if err := s.store.revokeSession(ctx, claims.SessionID); err != nil { return err } if s.cache != nil { _ = s.cache.DeleteSession(ctx, claims.SessionID) _ = s.cache.DenyAccessJTI(ctx, claims.JTI, s.cfg.AccessTTL) } return nil } func (s *Service) LogoutAll(ctx context.Context, accessRaw string) error { claims, err := s.accessToken.parseAccess(accessRaw) if err != nil { return err } sids, err := s.store.revokeAllUserSessions(ctx, claims.UserID) if err != nil { return err } if s.cache != nil { for _, sid := range sids { _ = s.cache.DeleteSession(ctx, sid) } _ = s.cache.DenyAccessJTI(ctx, claims.JTI, s.cfg.AccessTTL) } return nil } func (s *Service) ListSessions(ctx context.Context, userID int64) ([]Session, error) { return s.store.listSessions(ctx, userID) } func (s *Service) RevokeSession(ctx context.Context, userID int64, sid string) error { sess, err := s.store.getSession(ctx, sid) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return errInvalidToken } return err } if sess.UserID != userID { return errForbidden } if err := s.store.revokeSession(ctx, sid); err != nil { return err } if s.cache != nil { _ = s.cache.DeleteSession(ctx, sid) } return nil } func (s *Service) ValidateAccessToken(ctx context.Context, token string) (accessClaims, error) { claims, err := s.accessToken.parseAccess(token) if err != nil { return accessClaims{}, err } if s.cache != nil { denied, err := s.cache.IsAccessJTIDenied(ctx, claims.JTI) if err != nil { return accessClaims{}, err } if denied { return accessClaims{}, errUnauthorized } } active, err := s.isSessionActive(ctx, claims.SessionID) if err != nil { return accessClaims{}, err } if !active { return accessClaims{}, errSessionRevoked } return claims, nil } func (s *Service) isSessionActive(ctx context.Context, sid string) (bool, error) { if s.cache != nil { ok, err := s.cache.IsSessionActive(ctx, sid) if err != nil { return false, err } if ok { return true, nil } } ok, err := s.store.isSessionActive(ctx, sid) if err != nil { return false, err } if ok && s.cache != nil { sess, err := s.store.getSession(ctx, sid) if err == nil { _ = s.cache.SetSessionActive(ctx, sid, time.Until(sess.ExpiresAt)) } } return ok, nil }