From 57c27e91025759b3fc67fa5140e2c53c47b14203 Mon Sep 17 00:00:00 2001 From: wolves Date: Sun, 1 Mar 2026 21:26:37 +0800 Subject: [PATCH] refactor(auth): split IAM module and add access/refresh session flow --- cmd/server/main.go | 364 +++++------------------------------ docker-compose.yml | 20 ++ internal/iam/cache_redis.go | 104 ++++++++++ internal/iam/config.go | 95 +++++++++ internal/iam/http_handler.go | 192 ++++++++++++++++++ internal/iam/middleware.go | 40 ++++ internal/iam/service.go | 302 +++++++++++++++++++++++++++++ internal/iam/store_pg.go | 279 +++++++++++++++++++++++++++ internal/iam/token.go | 147 ++++++++++++++ internal/iam/types.go | 77 ++++++++ web/src/services/api.ts | 93 ++++++--- web/src/stores/session.ts | 7 +- web/src/views/LoginView.vue | 2 +- 13 files changed, 1377 insertions(+), 345 deletions(-) create mode 100644 internal/iam/cache_redis.go create mode 100644 internal/iam/config.go create mode 100644 internal/iam/http_handler.go create mode 100644 internal/iam/middleware.go create mode 100644 internal/iam/service.go create mode 100644 internal/iam/store_pg.go create mode 100644 internal/iam/token.go create mode 100644 internal/iam/types.go diff --git a/cmd/server/main.go b/cmd/server/main.go index 6562532..c104bfb 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -2,9 +2,6 @@ package main import ( "context" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" "encoding/json" "errors" "log" @@ -14,13 +11,13 @@ import ( "strings" "time" + "wolves.top/todo/internal/iam" + "github.com/gin-gonic/gin" "github.com/jackc/pgconn" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" - "github.com/redis/go-redis/v9" "github.com/segmentio/kafka-go" - "golang.org/x/crypto/bcrypt" ) type Task struct { @@ -35,90 +32,18 @@ type Task struct { UpdatedAt time.Time `json:"updated_at"` } -type User struct { - ID int64 `json:"id"` - Email string `json:"email"` -} - -type tokenManager struct { - secret []byte - ttl time.Duration -} - type postgresStore struct { pool *pgxpool.Pool } -type tokenCache struct { - client *redis.Client - prefix string -} - type taskEmitter struct { writer *kafka.Writer topic string } -func newTokenManager(secret string, ttl time.Duration) *tokenManager { - return &tokenManager{ - secret: []byte(secret), - ttl: ttl, - } -} - -func (t *tokenManager) Generate(user User) (string, error) { - payload := struct { - UserID int64 `json:"uid"` - Exp int64 `json:"exp"` - }{ - UserID: user.ID, - Exp: time.Now().Add(t.ttl).Unix(), - } - raw, err := json.Marshal(payload) - if err != nil { - return "", err - } - encoded := base64.RawURLEncoding.EncodeToString(raw) - sig := t.sign(encoded) - return encoded + "." + sig, nil -} - -func (t *tokenManager) Validate(token string) (int64, bool) { - parts := strings.Split(token, ".") - if len(parts) != 2 { - return 0, false - } - payload, sig := parts[0], parts[1] - if !t.verify(payload, sig) { - return 0, false - } - raw, err := base64.RawURLEncoding.DecodeString(payload) - if err != nil { - return 0, false - } - var data struct { - UserID int64 `json:"uid"` - Exp int64 `json:"exp"` - } - if err := json.Unmarshal(raw, &data); err != nil { - return 0, false - } - if data.UserID == 0 || time.Now().Unix() > data.Exp { - return 0, false - } - return data.UserID, true -} - -func (t *tokenManager) sign(payload string) string { - mac := hmac.New(sha256.New, t.secret) - mac.Write([]byte(payload)) - return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) -} - -func (t *tokenManager) verify(payload, signature string) bool { - expected := t.sign(payload) - return hmac.Equal([]byte(signature), []byte(expected)) -} +var ( + errAlreadyExists = errors.New("already exists") +) func newPostgresStore(ctx context.Context, url string) (*postgresStore, error) { pool, err := pgxpool.New(ctx, url) @@ -166,38 +91,6 @@ func (s *postgresStore) initSchema(ctx context.Context) error { return nil } -func (s *postgresStore) Register(ctx context.Context, email, password string) (User, error) { - hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) - if err != nil { - return User{}, err - } - var id int64 - err = s.pool.QueryRow(ctx, `INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING id`, email, string(hashed)).Scan(&id) - if err != nil { - if isUniqueViolation(err) { - return User{}, errAlreadyExists - } - return User{}, err - } - return User{ID: id, Email: email}, nil -} - -func (s *postgresStore) Login(ctx context.Context, email, password string) (User, error) { - var user User - var hash string - err := s.pool.QueryRow(ctx, `SELECT id, email, password_hash FROM users WHERE email = $1`, email).Scan(&user.ID, &user.Email, &hash) - if err != nil { - if errors.Is(err, pgx.ErrNoRows) { - return User{}, errInvalidCredentials - } - return User{}, err - } - if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil { - return User{}, errInvalidCredentials - } - return user, nil -} - func (s *postgresStore) List(ctx context.Context, userID int64) ([]Task, error) { rows, err := s.pool.Query(ctx, `SELECT id, title, description, status, due_at, priority, tags, created_at, updated_at FROM tasks WHERE user_id = $1 ORDER BY id DESC`, userID) if err != nil { @@ -304,37 +197,8 @@ func (s *postgresStore) Close() { } } -func newTokenCache(addr, password string, db int) (*tokenCache, error) { - if strings.TrimSpace(addr) == "" { - return nil, nil - } - client := redis.NewClient(&redis.Options{ - Addr: addr, - Password: password, - DB: db, - }) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if err := client.Ping(ctx).Err(); err != nil { - return nil, err - } - return &tokenCache{client: client, prefix: "auth:token:"}, nil -} - -func (c *tokenCache) Save(ctx context.Context, token string, ttl time.Duration) error { - return c.client.Set(ctx, c.prefix+token, "1", ttl).Err() -} - -func (c *tokenCache) Delete(ctx context.Context, token string) error { - return c.client.Del(ctx, c.prefix+token).Err() -} - -func (c *tokenCache) Exists(ctx context.Context, token string) (bool, error) { - count, err := c.client.Exists(ctx, c.prefix+token).Result() - if err != nil { - return false, err - } - return count == 1, nil +func (s *postgresStore) Pool() *pgxpool.Pool { + return s.pool } func newTaskEmitter(brokers []string, topic string) *taskEmitter { @@ -367,10 +231,7 @@ func (e *taskEmitter) Emit(ctx context.Context, eventType string, task Task, use } writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second) defer cancel() - if err := e.writer.WriteMessages(writeCtx, kafka.Message{ - Key: []byte(strconv.FormatInt(task.ID, 10)), - Value: data, - }); err != nil { + if err := e.writer.WriteMessages(writeCtx, kafka.Message{Key: []byte(strconv.FormatInt(task.ID, 10)), Value: data}); err != nil { log.Printf("kafka write failed: %v", err) } } @@ -384,11 +245,6 @@ func (e *taskEmitter) Close() { } } -var ( - errAlreadyExists = errors.New("already exists") - errInvalidCredentials = errors.New("invalid credentials") -) - func isUniqueViolation(err error) bool { var pgErr *pgconn.PgError if errors.As(err, &pgErr) && pgErr.Code == "23505" { @@ -398,11 +254,9 @@ func isUniqueViolation(err error) bool { } func main() { - store, tokens, cache, emitter := buildDependencies() + store, iamSvc, emitter := buildDependencies() defer store.Close() - if cache != nil { - defer cache.client.Close() - } + defer iamSvc.Close() defer emitter.Close() gin.SetMode(gin.DebugMode) @@ -410,135 +264,21 @@ func main() { router.RedirectTrailingSlash = false router.RedirectFixedPath = false - router.Use(func(c *gin.Context) { - c.Writer.Header().Set("Access-Control-Allow-Origin", "*") - c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") - c.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") - if c.Request.Method == http.MethodOptions { - c.AbortWithStatus(http.StatusNoContent) - return - } - c.Next() - }) + router.Use(corsMiddleware()) router.GET("/api/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) }) - auth := router.Group("/api/v1/auth") - { - auth.POST("/register", func(c *gin.Context) { - var input struct { - Email string `json:"email"` - Password string `json:"password"` - } - if err := c.ShouldBindJSON(&input); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - input.Email = strings.TrimSpace(input.Email) - if input.Email == "" || input.Password == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "email and password required"}) - return - } - user, err := store.Register(c.Request.Context(), input.Email, input.Password) - if err != nil { - if errors.Is(err, errAlreadyExists) { - c.JSON(http.StatusConflict, gin.H{"error": "user already exists"}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": "registration failed"}) - return - } - c.JSON(http.StatusCreated, gin.H{"id": user.ID, "email": user.Email}) - }) - auth.POST("/login", func(c *gin.Context) { - var input struct { - Email string `json:"email"` - Password string `json:"password"` - } - if err := c.ShouldBindJSON(&input); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - input.Email = strings.TrimSpace(input.Email) - if input.Email == "" || input.Password == "" { - c.JSON(http.StatusBadRequest, gin.H{"error": "email and password required"}) - return - } - user, err := store.Login(c.Request.Context(), input.Email, input.Password) - if err != nil { - if errors.Is(err, errInvalidCredentials) { - c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid credentials"}) - return - } - c.JSON(http.StatusInternalServerError, gin.H{"error": "login failed"}) - return - } - token, err := tokens.Generate(user) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "token generation failed"}) - return - } - if cache != nil { - if err := cache.Save(c.Request.Context(), token, tokens.ttl); err != nil { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "token cache unavailable"}) - return - } - } - c.JSON(http.StatusOK, gin.H{"token": token}) - }) - auth.POST("/logout", func(c *gin.Context) { - token := extractBearerToken(c) - if token == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization"}) - return - } - if _, ok := tokens.Validate(token); !ok { - c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"}) - return - } - if cache != nil { - if err := cache.Delete(c.Request.Context(), token); err != nil { - c.JSON(http.StatusServiceUnavailable, gin.H{"error": "token cache unavailable"}) - return - } - } - c.Status(http.StatusNoContent) - }) - } + cfg := iam.LoadConfig() + iam.NewHandler(iamSvc, cfg).RegisterRoutes(router) api := router.Group("/api/v1") - api.Use(func(c *gin.Context) { - token := extractBearerToken(c) - if token == "" { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization"}) - return - } - userID, ok := tokens.Validate(token) - if !ok { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"}) - return - } - if cache != nil { - exists, err := cache.Exists(c.Request.Context(), token) - if err != nil { - c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "token cache unavailable"}) - return - } - if !exists { - c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"}) - return - } - } - c.Set("user_id", userID) - c.Next() - }) - + api.Use(iamSvc.RequireAccess()) tasks := api.Group("/tasks") { tasks.GET("", func(c *gin.Context) { - userID := c.GetInt64("user_id") + userID := c.GetInt64(iam.ContextUserIDKey) items, err := store.List(c.Request.Context(), userID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load tasks"}) @@ -552,10 +292,10 @@ func main() { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - userID := c.GetInt64("user_id") + userID := c.GetInt64(iam.ContextUserIDKey) created, err := store.Create(c.Request.Context(), userID, input) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create task"}) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save task"}) return } emitter.Emit(c.Request.Context(), "task.created", created, userID) @@ -567,8 +307,8 @@ func main() { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"}) return } - userID := c.GetInt64("user_id") - task, err := store.Get(c.Request.Context(), userID, id) + userID := c.GetInt64(iam.ContextUserIDKey) + item, err := store.Get(c.Request.Context(), userID, id) if err != nil { if errors.Is(err, pgx.ErrNoRows) { c.JSON(http.StatusNotFound, gin.H{"error": "not found"}) @@ -577,7 +317,7 @@ func main() { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load task"}) return } - c.JSON(http.StatusOK, task) + c.JSON(http.StatusOK, item) }) tasks.PUT(":id", func(c *gin.Context) { id, err := parseID(c.Param("id")) @@ -590,7 +330,7 @@ func main() { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } - userID := c.GetInt64("user_id") + userID := c.GetInt64(iam.ContextUserIDKey) updated, err := store.Update(c.Request.Context(), userID, id, input) if err != nil { if errors.Is(err, pgx.ErrNoRows) { @@ -609,7 +349,7 @@ func main() { c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"}) return } - userID := c.GetInt64("user_id") + userID := c.GetInt64(iam.ContextUserIDKey) if err := store.Delete(c.Request.Context(), userID, id); err != nil { if errors.Is(err, pgx.ErrNoRows) { c.JSON(http.StatusNotFound, gin.H{"error": "not found"}) @@ -628,7 +368,7 @@ func main() { } } -func buildDependencies() (*postgresStore, *tokenManager, *tokenCache, *taskEmitter) { +func buildDependencies() (*postgresStore, *iam.Service, *taskEmitter) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -640,19 +380,10 @@ func buildDependencies() (*postgresStore, *tokenManager, *tokenCache, *taskEmitt if err != nil { log.Fatalf("postgres connection failed: %v", err) } - - secret := strings.TrimSpace(os.Getenv("AUTH_SECRET")) - if secret == "" { - secret = "dev-secret-change-me" - } - tokens := newTokenManager(secret, 24*time.Hour) - - redisAddr := strings.TrimSpace(os.Getenv("REDIS_ADDR")) - redisPassword := os.Getenv("REDIS_PASSWORD") - redisDB := parseEnvInt("REDIS_DB", 0) - cache, err := newTokenCache(redisAddr, redisPassword, redisDB) + iamCfg := iam.LoadConfig() + iamSvc, err := iam.NewService(ctx, store.Pool(), iamCfg) if err != nil { - log.Fatalf("redis connection failed: %v", err) + log.Fatalf("iam initialization failed: %v", err) } brokers := splitCSV(os.Getenv("KAFKA_BROKERS")) @@ -661,34 +392,13 @@ func buildDependencies() (*postgresStore, *tokenManager, *tokenCache, *taskEmitt topic = "todo.tasks" } emitter := newTaskEmitter(brokers, topic) - - return store, tokens, cache, emitter + return store, iamSvc, emitter } func parseID(value string) (int64, error) { return strconv.ParseInt(value, 10, 64) } -func extractBearerToken(c *gin.Context) string { - authHeader := strings.TrimSpace(c.GetHeader("Authorization")) - if strings.HasPrefix(authHeader, "Bearer ") { - return strings.TrimSpace(authHeader[7:]) - } - return authHeader -} - -func parseEnvInt(key string, fallback int) int { - value := strings.TrimSpace(os.Getenv(key)) - if value == "" { - return fallback - } - parsed, err := strconv.Atoi(value) - if err != nil { - return fallback - } - return parsed -} - func splitCSV(value string) []string { if strings.TrimSpace(value) == "" { return nil @@ -703,3 +413,23 @@ func splitCSV(value string) []string { } return result } + +func corsMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + origin := strings.TrimSpace(c.GetHeader("Origin")) + if origin != "" { + c.Writer.Header().Set("Access-Control-Allow-Origin", origin) + c.Writer.Header().Set("Vary", "Origin") + c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") + } else { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + } + c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") + c.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") + if c.Request.Method == http.MethodOptions { + c.AbortWithStatus(http.StatusNoContent) + return + } + c.Next() + } +} diff --git a/docker-compose.yml b/docker-compose.yml index aea930a..f901d15 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -56,3 +56,23 @@ services: KAFKA_CLUSTERS_0_BOOTSTRAPSERVERS: kafka:9092 depends_on: - kafka + + api: + image: golang:1.26.0 + container_name: todo-api + working_dir: /app + command: go run ./cmd/server + volumes: + - .:/app + ports: + - "8080:8080" + environment: + DATABASE_URL: postgres://todo:todo@postgres:5432/todo?sslmode=disable + REDIS_ADDR: redis:6379 + KAFKA_BROKERS: kafka:9092 + KAFKA_TOPIC: todo.tasks + AUTH_SECRET: dev-secret-change-me + depends_on: + - postgres + - redis + - kafka diff --git a/internal/iam/cache_redis.go b/internal/iam/cache_redis.go new file mode 100644 index 0000000..576411b --- /dev/null +++ b/internal/iam/cache_redis.go @@ -0,0 +1,104 @@ +package iam + +import ( + "context" + "errors" + "time" + + "github.com/redis/go-redis/v9" +) + +type tokenCache struct { + client *redis.Client +} + +func newTokenCache(addr, password string, db int) (*tokenCache, error) { + if addr == "" { + return nil, nil + } + client := redis.NewClient(&redis.Options{Addr: addr, Password: password, DB: db}) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := client.Ping(ctx).Err(); err != nil { + return nil, err + } + return &tokenCache{client: client}, nil +} + +func (c *tokenCache) Close() error { + if c == nil || c.client == nil { + return nil + } + return c.client.Close() +} + +func (c *tokenCache) SetSessionActive(ctx context.Context, sid string, ttl time.Duration) error { + if c == nil { + return nil + } + return c.client.Set(ctx, "auth:session:"+sid, "active", ttl).Err() +} + +func (c *tokenCache) IsSessionActive(ctx context.Context, sid string) (bool, error) { + if c == nil { + return false, nil + } + count, err := c.client.Exists(ctx, "auth:session:"+sid).Result() + if err != nil { + return false, err + } + return count == 1, nil +} + +func (c *tokenCache) DeleteSession(ctx context.Context, sid string) error { + if c == nil { + return nil + } + return c.client.Del(ctx, "auth:session:"+sid).Err() +} + +func (c *tokenCache) SetRefreshHash(ctx context.Context, rid, tokenHash string, ttl time.Duration) error { + if c == nil { + return nil + } + return c.client.Set(ctx, "auth:refresh:"+rid, tokenHash, ttl).Err() +} + +func (c *tokenCache) GetRefreshHash(ctx context.Context, rid string) (string, error) { + if c == nil { + return "", nil + } + v, err := c.client.Get(ctx, "auth:refresh:"+rid).Result() + if err != nil { + if errors.Is(err, redis.Nil) { + return "", nil + } + return "", err + } + return v, nil +} + +func (c *tokenCache) DeleteRefresh(ctx context.Context, rid string) error { + if c == nil { + return nil + } + return c.client.Del(ctx, "auth:refresh:"+rid).Err() +} + +func (c *tokenCache) DenyAccessJTI(ctx context.Context, jti string, ttl time.Duration) error { + if c == nil { + return nil + } + return c.client.Set(ctx, "auth:deny:access:"+jti, "1", ttl).Err() +} + +func (c *tokenCache) IsAccessJTIDenied(ctx context.Context, jti string) (bool, error) { + if c == nil { + return false, nil + } + count, err := c.client.Exists(ctx, "auth:deny:access:"+jti).Result() + if err != nil { + return false, err + } + return count == 1, nil +} diff --git a/internal/iam/config.go b/internal/iam/config.go new file mode 100644 index 0000000..5d33a6e --- /dev/null +++ b/internal/iam/config.go @@ -0,0 +1,95 @@ +package iam + +import ( + "os" + "strconv" + "strings" + "time" +) + +type Config struct { + AccessSecret string + RefreshSecret string + RefreshPepper string + AccessTTL time.Duration + RefreshTTL time.Duration + SessionTTL time.Duration + RefreshCookieName string + CookieSecure bool + CookieDomain string + CookiePath string + CookieSameSite string + RedisAddr string + RedisPassword string + RedisDB int +} + +func LoadConfig() Config { + cfg := Config{ + AccessSecret: firstNonEmpty(strings.TrimSpace(os.Getenv("ACCESS_SECRET")), strings.TrimSpace(os.Getenv("AUTH_SECRET")), "dev-access-secret-change-me"), + RefreshSecret: firstNonEmpty(strings.TrimSpace(os.Getenv("REFRESH_SECRET")), strings.TrimSpace(os.Getenv("AUTH_SECRET")), "dev-refresh-secret-change-me"), + RefreshPepper: firstNonEmpty(strings.TrimSpace(os.Getenv("REFRESH_PEPPER")), "dev-refresh-pepper-change-me"), + AccessTTL: parseDuration("ACCESS_TTL", 60*time.Minute), + RefreshTTL: parseDuration("REFRESH_TTL", 7*24*time.Hour), + RefreshCookieName: firstNonEmpty(strings.TrimSpace(os.Getenv("REFRESH_COOKIE_NAME")), "rt"), + CookieSecure: parseBool("COOKIE_SECURE", isProduction()), + CookieDomain: strings.TrimSpace(os.Getenv("COOKIE_DOMAIN")), + CookiePath: firstNonEmpty(strings.TrimSpace(os.Getenv("COOKIE_PATH")), "/"), + CookieSameSite: firstNonEmpty(strings.TrimSpace(os.Getenv("COOKIE_SAMESITE")), "Lax"), + RedisAddr: strings.TrimSpace(os.Getenv("REDIS_ADDR")), + RedisPassword: os.Getenv("REDIS_PASSWORD"), + RedisDB: parseInt("REDIS_DB", 0), + } + cfg.SessionTTL = parseDuration("SESSION_TTL", cfg.RefreshTTL) + return cfg +} + +func parseDuration(key string, fallback time.Duration) time.Duration { + value := strings.TrimSpace(os.Getenv(key)) + if value == "" { + return fallback + } + d, err := time.ParseDuration(value) + if err != nil || d <= 0 { + return fallback + } + return d +} + +func parseBool(key string, fallback bool) bool { + value := strings.TrimSpace(os.Getenv(key)) + if value == "" { + return fallback + } + parsed, err := strconv.ParseBool(value) + if err != nil { + return fallback + } + return parsed +} + +func parseInt(key string, fallback int) int { + value := strings.TrimSpace(os.Getenv(key)) + if value == "" { + return fallback + } + parsed, err := strconv.Atoi(value) + if err != nil { + return fallback + } + return parsed +} + +func isProduction() bool { + env := strings.ToLower(strings.TrimSpace(os.Getenv("APP_ENV"))) + return env == "prod" || env == "production" +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + if strings.TrimSpace(v) != "" { + return v + } + } + return "" +} diff --git a/internal/iam/http_handler.go b/internal/iam/http_handler.go new file mode 100644 index 0000000..807f8f0 --- /dev/null +++ b/internal/iam/http_handler.go @@ -0,0 +1,192 @@ +package iam + +import ( + "errors" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +type Handler struct { + service *Service + cfg Config +} + +func NewHandler(service *Service, cfg Config) *Handler { + return &Handler{service: service, cfg: cfg} +} + +func (h *Handler) RegisterRoutes(router *gin.Engine) { + auth := router.Group("/api/v1/auth") + auth.POST("/register", h.register) + auth.POST("/login", h.login) + auth.POST("/refresh", h.refresh) + + protected := auth.Group("") + protected.Use(h.service.RequireAccess()) + protected.POST("/logout", h.logout) + protected.POST("/logout-all", h.logoutAll) + protected.GET("/sessions", h.listSessions) + protected.DELETE("/sessions/:sid", h.revokeSession) +} + +func (h *Handler) register(c *gin.Context) { + var input struct { + Email string `json:"email"` + Password string `json:"password"` + AutoLogin *bool `json:"auto_login"` + } + if err := c.ShouldBindJSON(&input); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + autoLogin := true + if input.AutoLogin != nil { + autoLogin = *input.AutoLogin + } + result, refreshToken, err := h.service.Register(c.Request.Context(), input.Email, input.Password, autoLogin, requestMetaFromContext(c)) + if err != nil { + h.writeAuthError(c, err, "registration failed") + return + } + if !autoLogin { + c.JSON(http.StatusCreated, gin.H{"email": strings.TrimSpace(strings.ToLower(input.Email))}) + return + } + h.setRefreshCookie(c, refreshToken) + c.JSON(http.StatusCreated, result) +} + +func (h *Handler) login(c *gin.Context) { + var input struct { + Email string `json:"email"` + Password string `json:"password"` + DeviceInfo string `json:"device_info"` + } + if err := c.ShouldBindJSON(&input); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + meta := requestMetaFromContext(c) + meta.DeviceInfo = input.DeviceInfo + result, refreshToken, err := h.service.Login(c.Request.Context(), input.Email, input.Password, meta) + if err != nil { + h.writeAuthError(c, err, "login failed") + return + } + h.setRefreshCookie(c, refreshToken) + c.JSON(http.StatusOK, result) +} + +func (h *Handler) refresh(c *gin.Context) { + refreshToken, err := c.Cookie(h.cfg.RefreshCookieName) + if err != nil || strings.TrimSpace(refreshToken) == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "missing refresh token"}) + return + } + result, newRefreshToken, err := h.service.Refresh(c.Request.Context(), refreshToken) + if err != nil { + h.writeAuthError(c, err, "refresh failed") + return + } + h.setRefreshCookie(c, newRefreshToken) + c.JSON(http.StatusOK, result) +} + +func (h *Handler) logout(c *gin.Context) { + token := extractBearerToken(c.GetHeader("Authorization")) + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization"}) + return + } + if err := h.service.Logout(c.Request.Context(), token); err != nil { + h.writeAuthError(c, err, "logout failed") + return + } + h.clearRefreshCookie(c) + c.Status(http.StatusNoContent) +} + +func (h *Handler) logoutAll(c *gin.Context) { + token := extractBearerToken(c.GetHeader("Authorization")) + if token == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization"}) + return + } + if err := h.service.LogoutAll(c.Request.Context(), token); err != nil { + h.writeAuthError(c, err, "logout-all failed") + return + } + h.clearRefreshCookie(c) + c.Status(http.StatusNoContent) +} + +func (h *Handler) listSessions(c *gin.Context) { + uid := c.GetInt64(ContextUserIDKey) + sessions, err := h.service.ListSessions(c.Request.Context(), uid) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list sessions"}) + return + } + c.JSON(http.StatusOK, sessions) +} + +func (h *Handler) revokeSession(c *gin.Context) { + uid := c.GetInt64(ContextUserIDKey) + sid := strings.TrimSpace(c.Param("sid")) + if sid == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid sid"}) + return + } + if err := h.service.RevokeSession(c.Request.Context(), uid, sid); err != nil { + h.writeAuthError(c, err, "revoke session failed") + return + } + c.Status(http.StatusNoContent) +} + +func requestMetaFromContext(c *gin.Context) requestMeta { + return requestMeta{ + IP: c.ClientIP(), + UserAgent: c.Request.UserAgent(), + } +} + +func (h *Handler) writeAuthError(c *gin.Context, err error, fallback string) { + switch { + case errors.Is(err, errInvalidCredentials): + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid credentials"}) + case errors.Is(err, errAlreadyExists): + c.JSON(http.StatusConflict, gin.H{"error": "user already exists"}) + case errors.Is(err, errInvalidToken), errors.Is(err, errTokenExpired), errors.Is(err, errUnauthorized), errors.Is(err, errSessionRevoked): + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"}) + case errors.Is(err, errForbidden): + c.JSON(http.StatusForbidden, gin.H{"error": "forbidden"}) + default: + c.JSON(http.StatusInternalServerError, gin.H{"error": fallback}) + } +} + +func (h *Handler) setRefreshCookie(c *gin.Context, token string) { + h.applySameSite(c) + maxAge := int(h.cfg.RefreshTTL / time.Second) + c.SetCookie(h.cfg.RefreshCookieName, token, maxAge, h.cfg.CookiePath, h.cfg.CookieDomain, h.cfg.CookieSecure, true) +} + +func (h *Handler) clearRefreshCookie(c *gin.Context) { + h.applySameSite(c) + c.SetCookie(h.cfg.RefreshCookieName, "", -1, h.cfg.CookiePath, h.cfg.CookieDomain, h.cfg.CookieSecure, true) +} + +func (h *Handler) applySameSite(c *gin.Context) { + switch strings.ToLower(strings.TrimSpace(h.cfg.CookieSameSite)) { + case "strict": + c.SetSameSite(http.SameSiteStrictMode) + case "none": + c.SetSameSite(http.SameSiteNoneMode) + default: + c.SetSameSite(http.SameSiteLaxMode) + } +} diff --git a/internal/iam/middleware.go b/internal/iam/middleware.go new file mode 100644 index 0000000..68cebfb --- /dev/null +++ b/internal/iam/middleware.go @@ -0,0 +1,40 @@ +package iam + +import ( + "errors" + "net/http" + "strings" + + "github.com/gin-gonic/gin" +) + +func (s *Service) RequireAccess() gin.HandlerFunc { + return func(c *gin.Context) { + token := extractBearerToken(c.GetHeader("Authorization")) + if token == "" { + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization"}) + return + } + claims, err := s.ValidateAccessToken(c.Request.Context(), token) + if err != nil { + status := http.StatusUnauthorized + if errors.Is(err, errSessionRevoked) { + status = http.StatusUnauthorized + } + c.AbortWithStatusJSON(status, gin.H{"error": "invalid token"}) + return + } + c.Set(ContextUserIDKey, claims.UserID) + c.Set(ContextSessionIDKey, claims.SessionID) + c.Set(ContextJTIKey, claims.JTI) + c.Next() + } +} + +func extractBearerToken(header string) string { + authHeader := strings.TrimSpace(header) + if strings.HasPrefix(authHeader, "Bearer ") { + return strings.TrimSpace(authHeader[7:]) + } + return authHeader +} diff --git a/internal/iam/service.go b/internal/iam/service.go new file mode 100644 index 0000000..0d52528 --- /dev/null +++ b/internal/iam/service.go @@ -0,0 +1,302 @@ +package iam + +import ( + "context" + "errors" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +type Service struct { + cfg Config + store *postgresStore + cache *tokenCache + accessToken *tokenManager + refreshToken *tokenManager +} + +func NewService(ctx context.Context, pool *pgxpool.Pool, cfg Config) (*Service, error) { + store := newPostgresStore(pool) + if err := store.initSchema(ctx); err != nil { + return nil, err + } + cache, err := newTokenCache(cfg.RedisAddr, cfg.RedisPassword, cfg.RedisDB) + if err != nil { + return nil, err + } + return &Service{ + cfg: cfg, + store: store, + cache: cache, + accessToken: newTokenManager(cfg.AccessSecret, cfg.AccessTTL, tokenKindAccess), + refreshToken: newTokenManager(cfg.RefreshSecret, cfg.RefreshTTL, tokenKindRefresh), + }, nil +} + +func (s *Service) Close() error { + if s.cache != nil { + return s.cache.Close() + } + return nil +} + +func (s *Service) Register(ctx context.Context, email, password string, autoLogin bool, meta requestMeta) (AuthResult, string, error) { + email = normalizeEmail(email) + if email == "" || strings.TrimSpace(password) == "" { + return AuthResult{}, "", errInvalidCredentials + } + uid, err := s.store.registerUser(ctx, email, password) + if err != nil { + return AuthResult{}, "", err + } + if !autoLogin { + return AuthResult{}, "", nil + } + pair, err := s.issueTokenPair(ctx, uid, meta) + if err != nil { + return AuthResult{}, "", err + } + return AuthResult{ + AccessToken: pair.AccessToken, + ExpiresIn: int64(s.cfg.AccessTTL.Seconds()), + SessionID: pair.SessionID, + }, pair.RefreshToken, nil +} + +func (s *Service) Login(ctx context.Context, email, password string, meta requestMeta) (AuthResult, string, error) { + email = normalizeEmail(email) + if email == "" || strings.TrimSpace(password) == "" { + return AuthResult{}, "", errInvalidCredentials + } + uid, err := s.store.verifyUser(ctx, email, password) + if err != nil { + return AuthResult{}, "", err + } + pair, err := s.issueTokenPair(ctx, uid, meta) + if err != nil { + return AuthResult{}, "", err + } + return AuthResult{ + AccessToken: pair.AccessToken, + ExpiresIn: int64(s.cfg.AccessTTL.Seconds()), + SessionID: pair.SessionID, + }, pair.RefreshToken, nil +} + +func (s *Service) issueTokenPair(ctx context.Context, userID int64, meta requestMeta) (tokenPair, error) { + sid, err := randomID("sid") + if err != nil { + return tokenPair{}, err + } + rid, err := randomID("rid") + if err != nil { + return tokenPair{}, err + } + jti, err := randomID("jti") + if err != nil { + return tokenPair{}, err + } + + now := time.Now().UTC() + sess := Session{ + ID: sid, + UserID: userID, + DeviceInfo: strings.TrimSpace(meta.DeviceInfo), + IP: strings.TrimSpace(meta.IP), + UserAgent: strings.TrimSpace(meta.UserAgent), + CreatedAt: now, + ExpiresAt: now.Add(s.cfg.SessionTTL), + } + if err := s.store.createSession(ctx, sess); err != nil { + return tokenPair{}, err + } + refreshRaw, refreshExp, err := s.refreshToken.generateRefresh(userID, sid, rid) + if err != nil { + return tokenPair{}, err + } + refreshHash := hashRefreshToken(refreshRaw, s.cfg.RefreshPepper) + if err := s.store.createRefreshToken(ctx, rid, sid, refreshHash, refreshExp, nil); err != nil { + return tokenPair{}, err + } + accessRaw, accessExp, err := s.accessToken.generateAccess(userID, sid, jti) + if err != nil { + return tokenPair{}, err + } + if s.cache != nil { + _ = s.cache.SetSessionActive(ctx, sid, time.Until(sess.ExpiresAt)) + _ = s.cache.SetRefreshHash(ctx, rid, refreshHash, time.Until(refreshExp)) + _ = accessExp + } + return tokenPair{ + AccessToken: accessRaw, + AccessTokenExpires: accessExp, + RefreshToken: refreshRaw, + SessionID: sid, + }, nil +} + +func (s *Service) Refresh(ctx context.Context, refreshRaw string) (AuthResult, string, error) { + claims, err := s.refreshToken.parseRefresh(refreshRaw) + if err != nil { + return AuthResult{}, "", err + } + + if ok, err := s.isSessionActive(ctx, claims.SessionID); err != nil { + return AuthResult{}, "", err + } else if !ok { + return AuthResult{}, "", errSessionRevoked + } + + providedHash := hashRefreshToken(refreshRaw, s.cfg.RefreshPepper) + if s.cache != nil { + cached, err := s.cache.GetRefreshHash(ctx, claims.RefreshID) + if err != nil { + return AuthResult{}, "", err + } + if cached != "" && cached != providedHash { + _ = s.store.revokeSession(ctx, claims.SessionID) + _ = s.cache.DeleteSession(ctx, claims.SessionID) + return AuthResult{}, "", errSessionRevoked + } + } + + newRID, err := randomID("rid") + if err != nil { + return AuthResult{}, "", err + } + newJTI, err := randomID("jti") + if err != nil { + return AuthResult{}, "", err + } + newRefreshRaw, refreshExp, err := s.refreshToken.generateRefresh(claims.UserID, claims.SessionID, newRID) + if err != nil { + return AuthResult{}, "", err + } + newRefreshHash := hashRefreshToken(newRefreshRaw, s.cfg.RefreshPepper) + if err := s.store.rotateRefreshToken(ctx, claims, providedHash, newRID, newRefreshHash, refreshExp); err != nil { + if errors.Is(err, errSessionRevoked) { + _ = s.cache.DeleteSession(ctx, claims.SessionID) + } + return AuthResult{}, "", err + } + if s.cache != nil { + _ = s.cache.DeleteRefresh(ctx, claims.RefreshID) + _ = s.cache.SetRefreshHash(ctx, newRID, newRefreshHash, time.Until(refreshExp)) + } + accessRaw, _, err := s.accessToken.generateAccess(claims.UserID, claims.SessionID, newJTI) + if err != nil { + return AuthResult{}, "", err + } + return AuthResult{ + AccessToken: accessRaw, + ExpiresIn: int64(s.cfg.AccessTTL.Seconds()), + SessionID: claims.SessionID, + }, newRefreshRaw, nil +} + +func (s *Service) Logout(ctx context.Context, accessRaw string) error { + claims, err := s.accessToken.parseAccess(accessRaw) + if err != nil { + return err + } + if err := s.store.revokeSession(ctx, claims.SessionID); err != nil { + return err + } + if s.cache != nil { + _ = s.cache.DeleteSession(ctx, claims.SessionID) + _ = s.cache.DenyAccessJTI(ctx, claims.JTI, s.cfg.AccessTTL) + } + return nil +} + +func (s *Service) LogoutAll(ctx context.Context, accessRaw string) error { + claims, err := s.accessToken.parseAccess(accessRaw) + if err != nil { + return err + } + sids, err := s.store.revokeAllUserSessions(ctx, claims.UserID) + if err != nil { + return err + } + if s.cache != nil { + for _, sid := range sids { + _ = s.cache.DeleteSession(ctx, sid) + } + _ = s.cache.DenyAccessJTI(ctx, claims.JTI, s.cfg.AccessTTL) + } + return nil +} + +func (s *Service) ListSessions(ctx context.Context, userID int64) ([]Session, error) { + return s.store.listSessions(ctx, userID) +} + +func (s *Service) RevokeSession(ctx context.Context, userID int64, sid string) error { + sess, err := s.store.getSession(ctx, sid) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return errInvalidToken + } + return err + } + if sess.UserID != userID { + return errForbidden + } + if err := s.store.revokeSession(ctx, sid); err != nil { + return err + } + if s.cache != nil { + _ = s.cache.DeleteSession(ctx, sid) + } + return nil +} + +func (s *Service) ValidateAccessToken(ctx context.Context, token string) (accessClaims, error) { + claims, err := s.accessToken.parseAccess(token) + if err != nil { + return accessClaims{}, err + } + if s.cache != nil { + denied, err := s.cache.IsAccessJTIDenied(ctx, claims.JTI) + if err != nil { + return accessClaims{}, err + } + if denied { + return accessClaims{}, errUnauthorized + } + } + active, err := s.isSessionActive(ctx, claims.SessionID) + if err != nil { + return accessClaims{}, err + } + if !active { + return accessClaims{}, errSessionRevoked + } + return claims, nil +} + +func (s *Service) isSessionActive(ctx context.Context, sid string) (bool, error) { + if s.cache != nil { + ok, err := s.cache.IsSessionActive(ctx, sid) + if err != nil { + return false, err + } + if ok { + return true, nil + } + } + ok, err := s.store.isSessionActive(ctx, sid) + if err != nil { + return false, err + } + if ok && s.cache != nil { + sess, err := s.store.getSession(ctx, sid) + if err == nil { + _ = s.cache.SetSessionActive(ctx, sid, time.Until(sess.ExpiresAt)) + } + } + return ok, nil +} diff --git a/internal/iam/store_pg.go b/internal/iam/store_pg.go new file mode 100644 index 0000000..d0d1bf1 --- /dev/null +++ b/internal/iam/store_pg.go @@ -0,0 +1,279 @@ +package iam + +import ( + "context" + "database/sql" + "errors" + "strings" + "time" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "golang.org/x/crypto/bcrypt" +) + +type postgresStore struct { + pool *pgxpool.Pool +} + +type refreshTokenRow struct { + ID string + SessionID string + TokenHash string + ExpiresAt time.Time + RotatedFrom sql.NullString + RevokedAt sql.NullTime + UsedAt sql.NullTime + CreatedAt time.Time +} + +func newPostgresStore(pool *pgxpool.Pool) *postgresStore { + return &postgresStore{pool: pool} +} + +func (s *postgresStore) initSchema(ctx context.Context) error { + stmts := []string{ + `CREATE TABLE IF NOT EXISTS auth_sessions ( + id TEXT PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + device_info TEXT, + ip TEXT, + user_agent TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + expires_at TIMESTAMPTZ NOT NULL, + revoked_at TIMESTAMPTZ + )`, + `CREATE INDEX IF NOT EXISTS idx_auth_sessions_user_revoked ON auth_sessions(user_id, revoked_at)`, + `CREATE INDEX IF NOT EXISTS idx_auth_sessions_expires_at ON auth_sessions(expires_at)`, + `CREATE TABLE IF NOT EXISTS auth_refresh_tokens ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL REFERENCES auth_sessions(id) ON DELETE CASCADE, + token_hash TEXT NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + rotated_from TEXT, + revoked_at TIMESTAMPTZ, + used_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT now() + )`, + `CREATE INDEX IF NOT EXISTS idx_auth_refresh_tokens_session_revoked ON auth_refresh_tokens(session_id, revoked_at)`, + `CREATE INDEX IF NOT EXISTS idx_auth_refresh_tokens_expires_at ON auth_refresh_tokens(expires_at)`, + `CREATE INDEX IF NOT EXISTS idx_auth_refresh_tokens_rotated_from ON auth_refresh_tokens(rotated_from)`, + } + for _, stmt := range stmts { + if _, err := s.pool.Exec(ctx, stmt); err != nil { + return err + } + } + return nil +} + +func (s *postgresStore) registerUser(ctx context.Context, email, password string) (int64, error) { + hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return 0, err + } + var id int64 + err = s.pool.QueryRow(ctx, `INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING id`, email, string(hashed)).Scan(&id) + if err != nil { + if isUniqueViolation(err) { + return 0, errAlreadyExists + } + return 0, err + } + return id, nil +} + +func (s *postgresStore) verifyUser(ctx context.Context, email, password string) (int64, error) { + var id int64 + var hash string + err := s.pool.QueryRow(ctx, `SELECT id, password_hash FROM users WHERE email = $1`, email).Scan(&id, &hash) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return 0, errInvalidCredentials + } + return 0, err + } + if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil { + return 0, errInvalidCredentials + } + return id, nil +} + +func (s *postgresStore) createSession(ctx context.Context, sess Session) error { + _, err := s.pool.Exec(ctx, ` + INSERT INTO auth_sessions (id, user_id, device_info, ip, user_agent, expires_at) + VALUES ($1, $2, $3, $4, $5, $6) + `, sess.ID, sess.UserID, sess.DeviceInfo, sess.IP, sess.UserAgent, sess.ExpiresAt) + return err +} + +func (s *postgresStore) createRefreshToken(ctx context.Context, rid, sid, hash string, expiresAt time.Time, rotatedFrom *string) error { + _, err := s.pool.Exec(ctx, ` + INSERT INTO auth_refresh_tokens (id, session_id, token_hash, expires_at, rotated_from) + VALUES ($1, $2, $3, $4, $5) + `, rid, sid, hash, expiresAt, rotatedFrom) + return err +} + +func (s *postgresStore) getSession(ctx context.Context, sid string) (Session, error) { + var sess Session + var revokedAt sql.NullTime + err := s.pool.QueryRow(ctx, ` + SELECT id, user_id, COALESCE(device_info,''), COALESCE(ip,''), COALESCE(user_agent,''), created_at, expires_at, revoked_at + FROM auth_sessions WHERE id = $1 + `, sid).Scan(&sess.ID, &sess.UserID, &sess.DeviceInfo, &sess.IP, &sess.UserAgent, &sess.CreatedAt, &sess.ExpiresAt, &revokedAt) + if err != nil { + return Session{}, err + } + if revokedAt.Valid { + t := revokedAt.Time + sess.RevokedAt = &t + } + return sess, nil +} + +func (s *postgresStore) isSessionActive(ctx context.Context, sid string) (bool, error) { + var ok bool + err := s.pool.QueryRow(ctx, ` + SELECT EXISTS( + SELECT 1 FROM auth_sessions + WHERE id = $1 AND revoked_at IS NULL AND expires_at > now() + ) + `, sid).Scan(&ok) + return ok, err +} + +func (s *postgresStore) revokeSession(ctx context.Context, sid string) error { + _, err := s.pool.Exec(ctx, `UPDATE auth_sessions SET revoked_at = now() WHERE id = $1 AND revoked_at IS NULL`, sid) + if err != nil { + return err + } + _, err = s.pool.Exec(ctx, `UPDATE auth_refresh_tokens SET revoked_at = now() WHERE session_id = $1 AND revoked_at IS NULL`, sid) + return err +} + +func (s *postgresStore) revokeAllUserSessions(ctx context.Context, userID int64) ([]string, error) { + rows, err := s.pool.Query(ctx, `SELECT id FROM auth_sessions WHERE user_id = $1 AND revoked_at IS NULL`, userID) + if err != nil { + return nil, err + } + defer rows.Close() + ids := make([]string, 0) + for rows.Next() { + var sid string + if err := rows.Scan(&sid); err != nil { + return nil, err + } + ids = append(ids, sid) + } + if err := rows.Err(); err != nil { + return nil, err + } + _, err = s.pool.Exec(ctx, `UPDATE auth_sessions SET revoked_at = now() WHERE user_id = $1 AND revoked_at IS NULL`, userID) + if err != nil { + return nil, err + } + _, err = s.pool.Exec(ctx, ` + UPDATE auth_refresh_tokens rt + SET revoked_at = now() + FROM auth_sessions s + WHERE s.user_id = $1 AND rt.session_id = s.id AND rt.revoked_at IS NULL + `, userID) + if err != nil { + return nil, err + } + return ids, nil +} + +func (s *postgresStore) listSessions(ctx context.Context, userID int64) ([]Session, error) { + rows, err := s.pool.Query(ctx, ` + SELECT id, user_id, COALESCE(device_info,''), COALESCE(ip,''), COALESCE(user_agent,''), created_at, expires_at, revoked_at + FROM auth_sessions WHERE user_id = $1 ORDER BY created_at DESC + `, userID) + if err != nil { + return nil, err + } + defer rows.Close() + result := make([]Session, 0) + for rows.Next() { + var sess Session + var revokedAt sql.NullTime + if err := rows.Scan(&sess.ID, &sess.UserID, &sess.DeviceInfo, &sess.IP, &sess.UserAgent, &sess.CreatedAt, &sess.ExpiresAt, &revokedAt); err != nil { + return nil, err + } + if revokedAt.Valid { + t := revokedAt.Time + sess.RevokedAt = &t + } + result = append(result, sess) + } + return result, rows.Err() +} + +func (s *postgresStore) getRefreshTokenForUpdate(ctx context.Context, tx pgx.Tx, rid string) (refreshTokenRow, error) { + var row refreshTokenRow + err := tx.QueryRow(ctx, ` + SELECT id, session_id, token_hash, expires_at, rotated_from, revoked_at, used_at, created_at + FROM auth_refresh_tokens + WHERE id = $1 + FOR UPDATE + `, rid).Scan(&row.ID, &row.SessionID, &row.TokenHash, &row.ExpiresAt, &row.RotatedFrom, &row.RevokedAt, &row.UsedAt, &row.CreatedAt) + return row, err +} + +func (s *postgresStore) rotateRefreshToken(ctx context.Context, claims refreshClaims, providedHash, newRID, newHash string, newExp time.Time) error { + tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{}) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + row, err := s.getRefreshTokenForUpdate(ctx, tx, claims.RefreshID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return errInvalidToken + } + return err + } + + if row.SessionID != claims.SessionID || row.ExpiresAt.Before(time.Now().UTC()) { + return errInvalidToken + } + + if row.RevokedAt.Valid || row.UsedAt.Valid || row.TokenHash != providedHash { + if _, err := tx.Exec(ctx, `UPDATE auth_sessions SET revoked_at = now() WHERE id = $1 AND revoked_at IS NULL`, claims.SessionID); err != nil { + return err + } + if _, err := tx.Exec(ctx, `UPDATE auth_refresh_tokens SET revoked_at = now() WHERE session_id = $1 AND revoked_at IS NULL`, claims.SessionID); err != nil { + return err + } + if err := tx.Commit(ctx); err != nil { + return err + } + return errSessionRevoked + } + + if _, err := tx.Exec(ctx, `UPDATE auth_refresh_tokens SET used_at = now(), revoked_at = now() WHERE id = $1`, claims.RefreshID); err != nil { + return err + } + if _, err := tx.Exec(ctx, ` + INSERT INTO auth_refresh_tokens (id, session_id, token_hash, expires_at, rotated_from) + VALUES ($1, $2, $3, $4, $5) + `, newRID, claims.SessionID, newHash, newExp, claims.RefreshID); err != nil { + return err + } + return tx.Commit(ctx) +} + +func isUniqueViolation(err error) bool { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) && pgErr.Code == "23505" { + return true + } + return false +} + +func normalizeEmail(v string) string { + return strings.TrimSpace(strings.ToLower(v)) +} diff --git a/internal/iam/token.go b/internal/iam/token.go new file mode 100644 index 0000000..1c50671 --- /dev/null +++ b/internal/iam/token.go @@ -0,0 +1,147 @@ +package iam + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "strings" + "time" +) + +type tokenManager struct { + secret []byte + ttl time.Duration + kind tokenKind +} + +func newTokenManager(secret string, ttl time.Duration, kind tokenKind) *tokenManager { + return &tokenManager{secret: []byte(secret), ttl: ttl, kind: kind} +} + +func (t *tokenManager) generateAccess(userID int64, sessionID, jti string) (string, time.Time, error) { + now := time.Now().UTC() + expiresAt := now.Add(t.ttl) + payload := accessClaims{ + UserID: userID, + SessionID: sessionID, + JTI: jti, + IssuedAt: now.Unix(), + ExpiresAt: expiresAt.Unix(), + Type: t.kind, + } + token, err := t.sign(payload) + if err != nil { + return "", time.Time{}, err + } + return token, expiresAt, nil +} + +func (t *tokenManager) parseAccess(token string) (accessClaims, error) { + var claims accessClaims + if err := t.verifyAndDecode(token, &claims); err != nil { + return claims, err + } + if claims.Type != tokenKindAccess { + return claims, errInvalidToken + } + if claims.UserID == 0 || strings.TrimSpace(claims.SessionID) == "" || strings.TrimSpace(claims.JTI) == "" { + return claims, errInvalidToken + } + if time.Now().UTC().Unix() > claims.ExpiresAt { + return claims, errTokenExpired + } + return claims, nil +} + +func (t *tokenManager) generateRefresh(userID int64, sessionID, refreshID string) (string, time.Time, error) { + now := time.Now().UTC() + expiresAt := now.Add(t.ttl) + payload := refreshClaims{ + UserID: userID, + SessionID: sessionID, + RefreshID: refreshID, + IssuedAt: now.Unix(), + ExpiresAt: expiresAt.Unix(), + Type: t.kind, + } + token, err := t.sign(payload) + if err != nil { + return "", time.Time{}, err + } + return token, expiresAt, nil +} + +func (t *tokenManager) parseRefresh(token string) (refreshClaims, error) { + var claims refreshClaims + if err := t.verifyAndDecode(token, &claims); err != nil { + return claims, err + } + if claims.Type != tokenKindRefresh { + return claims, errInvalidToken + } + if claims.UserID == 0 || strings.TrimSpace(claims.SessionID) == "" || strings.TrimSpace(claims.RefreshID) == "" { + return claims, errInvalidToken + } + if time.Now().UTC().Unix() > claims.ExpiresAt { + return claims, errTokenExpired + } + return claims, nil +} + +func (t *tokenManager) sign(v any) (string, error) { + raw, err := json.Marshal(v) + if err != nil { + return "", err + } + payload := base64.RawURLEncoding.EncodeToString(raw) + sig := t.signPayload(payload) + return payload + "." + sig, nil +} + +func (t *tokenManager) verifyAndDecode(token string, out any) error { + parts := strings.Split(token, ".") + if len(parts) != 2 { + return errInvalidToken + } + payload, sig := parts[0], parts[1] + if !t.verifyPayload(payload, sig) { + return errInvalidToken + } + raw, err := base64.RawURLEncoding.DecodeString(payload) + if err != nil { + return errInvalidToken + } + if err := json.Unmarshal(raw, out); err != nil { + return errInvalidToken + } + return nil +} + +func (t *tokenManager) signPayload(payload string) string { + mac := hmac.New(sha256.New, t.secret) + mac.Write([]byte(payload)) + return base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) +} + +func (t *tokenManager) verifyPayload(payload, signature string) bool { + expected := t.signPayload(payload) + return hmac.Equal([]byte(signature), []byte(expected)) +} + +func randomID(prefix string) (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", err + } + return fmt.Sprintf("%s_%s", prefix, hex.EncodeToString(b)), nil +} + +func hashRefreshToken(token, pepper string) string { + mac := hmac.New(sha256.New, []byte(pepper)) + mac.Write([]byte(token)) + return hex.EncodeToString(mac.Sum(nil)) +} diff --git a/internal/iam/types.go b/internal/iam/types.go new file mode 100644 index 0000000..15759e2 --- /dev/null +++ b/internal/iam/types.go @@ -0,0 +1,77 @@ +package iam + +import ( + "errors" + "time" +) + +const ( + ContextUserIDKey = "user_id" + ContextSessionIDKey = "session_id" + ContextJTIKey = "jti" +) + +type tokenKind string + +const ( + tokenKindAccess tokenKind = "access" + tokenKindRefresh tokenKind = "refresh" +) + +type accessClaims struct { + UserID int64 `json:"uid"` + SessionID string `json:"sid"` + JTI string `json:"jti"` + IssuedAt int64 `json:"iat"` + ExpiresAt int64 `json:"exp"` + Type tokenKind `json:"typ"` +} + +type refreshClaims struct { + UserID int64 `json:"uid"` + SessionID string `json:"sid"` + RefreshID string `json:"rid"` + IssuedAt int64 `json:"iat"` + ExpiresAt int64 `json:"exp"` + Type tokenKind `json:"typ"` +} + +type Session struct { + ID string `json:"id"` + UserID int64 `json:"user_id"` + DeviceInfo string `json:"device_info"` + IP string `json:"ip"` + UserAgent string `json:"user_agent"` + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` + RevokedAt *time.Time `json:"revoked_at,omitempty"` +} + +type tokenPair struct { + AccessToken string + AccessTokenExpires time.Time + RefreshToken string + SessionID string +} + +type requestMeta struct { + IP string + UserAgent string + DeviceInfo string +} + +type AuthResult struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + SessionID string `json:"session_id"` +} + +var ( + errInvalidToken = errors.New("invalid token") + errTokenExpired = errors.New("token expired") + errInvalidCredentials = errors.New("invalid credentials") + errAlreadyExists = errors.New("already exists") + errSessionRevoked = errors.New("session revoked") + errUnauthorized = errors.New("unauthorized") + errForbidden = errors.New("forbidden") +) diff --git a/web/src/services/api.ts b/web/src/services/api.ts index 29c5f0f..687b2e8 100644 --- a/web/src/services/api.ts +++ b/web/src/services/api.ts @@ -1,10 +1,10 @@ import type { Credentials, Task } from '../types' -import { session } from '../stores/session' +import { clearSession, session, setSessionToken } from '../stores/session' const API_BASE = 'http://localhost:8080/api/v1' -function headers() { - const h: Record = { 'Content-Type': 'application/json' } +function headers(extra?: Record) { + const h: Record = { 'Content-Type': 'application/json', ...(extra ?? {}) } if (session.token) h.Authorization = `Bearer ${session.token}` return h } @@ -18,59 +18,100 @@ async function handle(response: Response): Promise { return (await response.json()) as T } -export async function register(credentials: Credentials) { - const response = await fetch(`${API_BASE}/auth/register`, { +async function refreshAccessToken(): Promise { + try { + const response = await fetch(`${API_BASE}/auth/refresh`, { + method: 'POST', + credentials: 'include', + headers: headers(), + }) + if (!response.ok) return false + const data = (await response.json()) as { access_token: string } + if (!data?.access_token) return false + setSessionToken(data.access_token) + return true + } catch { + return false + } +} + +async function request(path: string, init: RequestInit = {}, allowRefresh = true): Promise { + const merged: RequestInit = { + credentials: 'include', + ...init, + headers: headers((init.headers as Record | undefined) ?? {}), + } + const response = await fetch(`${API_BASE}${path}`, merged) + if (response.status === 401 && allowRefresh && !path.startsWith('/auth/')) { + const refreshed = await refreshAccessToken() + if (refreshed) { + return request(path, init, false) + } + clearSession() + } + return handle(response) +} + +export async function register(credentials: Credentials, autoLogin = true) { + return request<{ email: string; access_token?: string; expires_in?: number; session_id?: string }>('/auth/register', { method: 'POST', - headers: headers(), - body: JSON.stringify(credentials), + body: JSON.stringify({ ...credentials, auto_login: autoLogin }), }) - return handle<{ id: number; email: string }>(response) } export async function login(credentials: Credentials) { - const response = await fetch(`${API_BASE}/auth/login`, { + return request<{ access_token: string; expires_in: number; session_id: string }>('/auth/login', { method: 'POST', - headers: headers(), body: JSON.stringify(credentials), }) - return handle<{ token: string }>(response) +} + +export async function refresh() { + return request<{ access_token: string; expires_in: number; session_id: string }>('/auth/refresh', { + method: 'POST', + }, false) } export async function logout() { - const response = await fetch(`${API_BASE}/auth/logout`, { + return request('/auth/logout', { method: 'POST', - headers: headers(), - }) - return handle(response) + }, false) +} + +export async function logoutAll() { + return request('/auth/logout-all', { + method: 'POST', + }, false) +} + +export async function listSessions() { + return request>('/auth/sessions') +} + +export async function revokeSession(id: string) { + return request(`/auth/sessions/${id}`, { method: 'DELETE' }, false) } export async function listTasks() { - const response = await fetch(`${API_BASE}/tasks`, { headers: headers() }) - return handle(response) + return request('/tasks') } export async function createTask(payload: Partial) { - const response = await fetch(`${API_BASE}/tasks`, { + return request('/tasks', { method: 'POST', - headers: headers(), body: JSON.stringify(payload), }) - return handle(response) } export async function updateTask(id: number, payload: Partial) { - const response = await fetch(`${API_BASE}/tasks/${id}`, { + return request(`/tasks/${id}`, { method: 'PUT', - headers: headers(), body: JSON.stringify(payload), }) - return handle(response) } export async function deleteTask(id: number) { - const response = await fetch(`${API_BASE}/tasks/${id}`, { + return request(`/tasks/${id}`, { method: 'DELETE', - headers: headers(), }) - return handle(response) } diff --git a/web/src/stores/session.ts b/web/src/stores/session.ts index 1b81490..5157474 100644 --- a/web/src/stores/session.ts +++ b/web/src/stores/session.ts @@ -1,6 +1,6 @@ import { reactive } from 'vue' -const tokenKey = 'supertodo_token' +const tokenKey = 'supertodo_access_token' export const session = reactive({ token: localStorage.getItem(tokenKey) ?? '', @@ -14,6 +14,11 @@ export function setSession(token: string, email: string) { localStorage.setItem('supertodo_email', email) } +export function setSessionToken(token: string) { + session.token = token + localStorage.setItem(tokenKey, token) +} + export function clearSession() { session.token = '' session.email = '' diff --git a/web/src/views/LoginView.vue b/web/src/views/LoginView.vue index aea49b2..9519f4c 100644 --- a/web/src/views/LoginView.vue +++ b/web/src/views/LoginView.vue @@ -19,7 +19,7 @@ async function submit() { await api.register(form) } const data = await api.login(form) - setSession(data.token, form.email) + setSession(data.access_token, form.email) router.push('/todos') } catch (e) { error.value = t('auth_failed')