303 lines
7.9 KiB
Go
303 lines
7.9 KiB
Go
package iam
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
type Service struct {
|
|
cfg Config
|
|
store *postgresStore
|
|
cache *tokenCache
|
|
accessToken *tokenManager
|
|
refreshToken *tokenManager
|
|
}
|
|
|
|
func NewService(ctx context.Context, pool *pgxpool.Pool, cfg Config) (*Service, error) {
|
|
store := newPostgresStore(pool)
|
|
if err := store.initSchema(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
cache, err := newTokenCache(cfg.RedisAddr, cfg.RedisPassword, cfg.RedisDB)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Service{
|
|
cfg: cfg,
|
|
store: store,
|
|
cache: cache,
|
|
accessToken: newTokenManager(cfg.AccessSecret, cfg.AccessTTL, tokenKindAccess),
|
|
refreshToken: newTokenManager(cfg.RefreshSecret, cfg.RefreshTTL, tokenKindRefresh),
|
|
}, nil
|
|
}
|
|
|
|
func (s *Service) Close() error {
|
|
if s.cache != nil {
|
|
return s.cache.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) Register(ctx context.Context, email, password string, autoLogin bool, meta requestMeta) (AuthResult, string, error) {
|
|
email = normalizeEmail(email)
|
|
if email == "" || strings.TrimSpace(password) == "" {
|
|
return AuthResult{}, "", errInvalidCredentials
|
|
}
|
|
uid, err := s.store.registerUser(ctx, email, password)
|
|
if err != nil {
|
|
return AuthResult{}, "", err
|
|
}
|
|
if !autoLogin {
|
|
return AuthResult{}, "", nil
|
|
}
|
|
pair, err := s.issueTokenPair(ctx, uid, meta)
|
|
if err != nil {
|
|
return AuthResult{}, "", err
|
|
}
|
|
return AuthResult{
|
|
AccessToken: pair.AccessToken,
|
|
ExpiresIn: int64(s.cfg.AccessTTL.Seconds()),
|
|
SessionID: pair.SessionID,
|
|
}, pair.RefreshToken, nil
|
|
}
|
|
|
|
func (s *Service) Login(ctx context.Context, email, password string, meta requestMeta) (AuthResult, string, error) {
|
|
email = normalizeEmail(email)
|
|
if email == "" || strings.TrimSpace(password) == "" {
|
|
return AuthResult{}, "", errInvalidCredentials
|
|
}
|
|
uid, err := s.store.verifyUser(ctx, email, password)
|
|
if err != nil {
|
|
return AuthResult{}, "", err
|
|
}
|
|
pair, err := s.issueTokenPair(ctx, uid, meta)
|
|
if err != nil {
|
|
return AuthResult{}, "", err
|
|
}
|
|
return AuthResult{
|
|
AccessToken: pair.AccessToken,
|
|
ExpiresIn: int64(s.cfg.AccessTTL.Seconds()),
|
|
SessionID: pair.SessionID,
|
|
}, pair.RefreshToken, nil
|
|
}
|
|
|
|
func (s *Service) issueTokenPair(ctx context.Context, userID int64, meta requestMeta) (tokenPair, error) {
|
|
sid, err := randomID("sid")
|
|
if err != nil {
|
|
return tokenPair{}, err
|
|
}
|
|
rid, err := randomID("rid")
|
|
if err != nil {
|
|
return tokenPair{}, err
|
|
}
|
|
jti, err := randomID("jti")
|
|
if err != nil {
|
|
return tokenPair{}, err
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
sess := Session{
|
|
ID: sid,
|
|
UserID: userID,
|
|
DeviceInfo: strings.TrimSpace(meta.DeviceInfo),
|
|
IP: strings.TrimSpace(meta.IP),
|
|
UserAgent: strings.TrimSpace(meta.UserAgent),
|
|
CreatedAt: now,
|
|
ExpiresAt: now.Add(s.cfg.SessionTTL),
|
|
}
|
|
if err := s.store.createSession(ctx, sess); err != nil {
|
|
return tokenPair{}, err
|
|
}
|
|
refreshRaw, refreshExp, err := s.refreshToken.generateRefresh(userID, sid, rid)
|
|
if err != nil {
|
|
return tokenPair{}, err
|
|
}
|
|
refreshHash := hashRefreshToken(refreshRaw, s.cfg.RefreshPepper)
|
|
if err := s.store.createRefreshToken(ctx, rid, sid, refreshHash, refreshExp, nil); err != nil {
|
|
return tokenPair{}, err
|
|
}
|
|
accessRaw, accessExp, err := s.accessToken.generateAccess(userID, sid, jti)
|
|
if err != nil {
|
|
return tokenPair{}, err
|
|
}
|
|
if s.cache != nil {
|
|
_ = s.cache.SetSessionActive(ctx, sid, time.Until(sess.ExpiresAt))
|
|
_ = s.cache.SetRefreshHash(ctx, rid, refreshHash, time.Until(refreshExp))
|
|
_ = accessExp
|
|
}
|
|
return tokenPair{
|
|
AccessToken: accessRaw,
|
|
AccessTokenExpires: accessExp,
|
|
RefreshToken: refreshRaw,
|
|
SessionID: sid,
|
|
}, nil
|
|
}
|
|
|
|
func (s *Service) Refresh(ctx context.Context, refreshRaw string) (AuthResult, string, error) {
|
|
claims, err := s.refreshToken.parseRefresh(refreshRaw)
|
|
if err != nil {
|
|
return AuthResult{}, "", err
|
|
}
|
|
|
|
if ok, err := s.isSessionActive(ctx, claims.SessionID); err != nil {
|
|
return AuthResult{}, "", err
|
|
} else if !ok {
|
|
return AuthResult{}, "", errSessionRevoked
|
|
}
|
|
|
|
providedHash := hashRefreshToken(refreshRaw, s.cfg.RefreshPepper)
|
|
if s.cache != nil {
|
|
cached, err := s.cache.GetRefreshHash(ctx, claims.RefreshID)
|
|
if err != nil {
|
|
return AuthResult{}, "", err
|
|
}
|
|
if cached != "" && cached != providedHash {
|
|
_ = s.store.revokeSession(ctx, claims.SessionID)
|
|
_ = s.cache.DeleteSession(ctx, claims.SessionID)
|
|
return AuthResult{}, "", errSessionRevoked
|
|
}
|
|
}
|
|
|
|
newRID, err := randomID("rid")
|
|
if err != nil {
|
|
return AuthResult{}, "", err
|
|
}
|
|
newJTI, err := randomID("jti")
|
|
if err != nil {
|
|
return AuthResult{}, "", err
|
|
}
|
|
newRefreshRaw, refreshExp, err := s.refreshToken.generateRefresh(claims.UserID, claims.SessionID, newRID)
|
|
if err != nil {
|
|
return AuthResult{}, "", err
|
|
}
|
|
newRefreshHash := hashRefreshToken(newRefreshRaw, s.cfg.RefreshPepper)
|
|
if err := s.store.rotateRefreshToken(ctx, claims, providedHash, newRID, newRefreshHash, refreshExp); err != nil {
|
|
if errors.Is(err, errSessionRevoked) {
|
|
_ = s.cache.DeleteSession(ctx, claims.SessionID)
|
|
}
|
|
return AuthResult{}, "", err
|
|
}
|
|
if s.cache != nil {
|
|
_ = s.cache.DeleteRefresh(ctx, claims.RefreshID)
|
|
_ = s.cache.SetRefreshHash(ctx, newRID, newRefreshHash, time.Until(refreshExp))
|
|
}
|
|
accessRaw, _, err := s.accessToken.generateAccess(claims.UserID, claims.SessionID, newJTI)
|
|
if err != nil {
|
|
return AuthResult{}, "", err
|
|
}
|
|
return AuthResult{
|
|
AccessToken: accessRaw,
|
|
ExpiresIn: int64(s.cfg.AccessTTL.Seconds()),
|
|
SessionID: claims.SessionID,
|
|
}, newRefreshRaw, nil
|
|
}
|
|
|
|
func (s *Service) Logout(ctx context.Context, accessRaw string) error {
|
|
claims, err := s.accessToken.parseAccess(accessRaw)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := s.store.revokeSession(ctx, claims.SessionID); err != nil {
|
|
return err
|
|
}
|
|
if s.cache != nil {
|
|
_ = s.cache.DeleteSession(ctx, claims.SessionID)
|
|
_ = s.cache.DenyAccessJTI(ctx, claims.JTI, s.cfg.AccessTTL)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) LogoutAll(ctx context.Context, accessRaw string) error {
|
|
claims, err := s.accessToken.parseAccess(accessRaw)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sids, err := s.store.revokeAllUserSessions(ctx, claims.UserID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if s.cache != nil {
|
|
for _, sid := range sids {
|
|
_ = s.cache.DeleteSession(ctx, sid)
|
|
}
|
|
_ = s.cache.DenyAccessJTI(ctx, claims.JTI, s.cfg.AccessTTL)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) ListSessions(ctx context.Context, userID int64) ([]Session, error) {
|
|
return s.store.listSessions(ctx, userID)
|
|
}
|
|
|
|
func (s *Service) RevokeSession(ctx context.Context, userID int64, sid string) error {
|
|
sess, err := s.store.getSession(ctx, sid)
|
|
if err != nil {
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return errInvalidToken
|
|
}
|
|
return err
|
|
}
|
|
if sess.UserID != userID {
|
|
return errForbidden
|
|
}
|
|
if err := s.store.revokeSession(ctx, sid); err != nil {
|
|
return err
|
|
}
|
|
if s.cache != nil {
|
|
_ = s.cache.DeleteSession(ctx, sid)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) ValidateAccessToken(ctx context.Context, token string) (accessClaims, error) {
|
|
claims, err := s.accessToken.parseAccess(token)
|
|
if err != nil {
|
|
return accessClaims{}, err
|
|
}
|
|
if s.cache != nil {
|
|
denied, err := s.cache.IsAccessJTIDenied(ctx, claims.JTI)
|
|
if err != nil {
|
|
return accessClaims{}, err
|
|
}
|
|
if denied {
|
|
return accessClaims{}, errUnauthorized
|
|
}
|
|
}
|
|
active, err := s.isSessionActive(ctx, claims.SessionID)
|
|
if err != nil {
|
|
return accessClaims{}, err
|
|
}
|
|
if !active {
|
|
return accessClaims{}, errSessionRevoked
|
|
}
|
|
return claims, nil
|
|
}
|
|
|
|
func (s *Service) isSessionActive(ctx context.Context, sid string) (bool, error) {
|
|
if s.cache != nil {
|
|
ok, err := s.cache.IsSessionActive(ctx, sid)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if ok {
|
|
return true, nil
|
|
}
|
|
}
|
|
ok, err := s.store.isSessionActive(ctx, sid)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if ok && s.cache != nil {
|
|
sess, err := s.store.getSession(ctx, sid)
|
|
if err == nil {
|
|
_ = s.cache.SetSessionActive(ctx, sid, time.Until(sess.ExpiresAt))
|
|
}
|
|
}
|
|
return ok, nil
|
|
}
|