refactor(auth): split IAM module and add access/refresh session flow

This commit is contained in:
2026-03-01 21:26:37 +08:00
parent 6a2d2c9724
commit 57c27e9102
13 changed files with 1377 additions and 345 deletions

View File

@@ -2,9 +2,6 @@ package main
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"log"
@@ -14,13 +11,13 @@ import (
"strings"
"time"
"wolves.top/todo/internal/iam"
"github.com/gin-gonic/gin"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
"github.com/segmentio/kafka-go"
"golang.org/x/crypto/bcrypt"
)
type Task struct {
@@ -35,90 +32,18 @@ type Task struct {
UpdatedAt time.Time `json:"updated_at"`
}
type User struct {
ID int64 `json:"id"`
Email string `json:"email"`
}
type tokenManager struct {
secret []byte
ttl time.Duration
}
type postgresStore struct {
pool *pgxpool.Pool
}
type tokenCache struct {
client *redis.Client
prefix string
}
type taskEmitter struct {
writer *kafka.Writer
topic string
}
func newTokenManager(secret string, ttl time.Duration) *tokenManager {
return &tokenManager{
secret: []byte(secret),
ttl: ttl,
}
}
func (t *tokenManager) Generate(user User) (string, error) {
payload := struct {
UserID int64 `json:"uid"`
Exp int64 `json:"exp"`
}{
UserID: user.ID,
Exp: time.Now().Add(t.ttl).Unix(),
}
raw, err := json.Marshal(payload)
if err != nil {
return "", err
}
encoded := base64.RawURLEncoding.EncodeToString(raw)
sig := t.sign(encoded)
return encoded + "." + sig, nil
}
func (t *tokenManager) Validate(token string) (int64, bool) {
parts := strings.Split(token, ".")
if len(parts) != 2 {
return 0, false
}
payload, sig := parts[0], parts[1]
if !t.verify(payload, sig) {
return 0, false
}
raw, err := base64.RawURLEncoding.DecodeString(payload)
if err != nil {
return 0, false
}
var data struct {
UserID int64 `json:"uid"`
Exp int64 `json:"exp"`
}
if err := json.Unmarshal(raw, &data); err != nil {
return 0, false
}
if data.UserID == 0 || time.Now().Unix() > data.Exp {
return 0, false
}
return data.UserID, true
}
func (t *tokenManager) sign(payload string) string {
mac := hmac.New(sha256.New, t.secret)
mac.Write([]byte(payload))
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}
func (t *tokenManager) verify(payload, signature string) bool {
expected := t.sign(payload)
return hmac.Equal([]byte(signature), []byte(expected))
}
var (
errAlreadyExists = errors.New("already exists")
)
func newPostgresStore(ctx context.Context, url string) (*postgresStore, error) {
pool, err := pgxpool.New(ctx, url)
@@ -166,38 +91,6 @@ func (s *postgresStore) initSchema(ctx context.Context) error {
return nil
}
func (s *postgresStore) Register(ctx context.Context, email, password string) (User, error) {
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return User{}, err
}
var id int64
err = s.pool.QueryRow(ctx, `INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING id`, email, string(hashed)).Scan(&id)
if err != nil {
if isUniqueViolation(err) {
return User{}, errAlreadyExists
}
return User{}, err
}
return User{ID: id, Email: email}, nil
}
func (s *postgresStore) Login(ctx context.Context, email, password string) (User, error) {
var user User
var hash string
err := s.pool.QueryRow(ctx, `SELECT id, email, password_hash FROM users WHERE email = $1`, email).Scan(&user.ID, &user.Email, &hash)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return User{}, errInvalidCredentials
}
return User{}, err
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil {
return User{}, errInvalidCredentials
}
return user, nil
}
func (s *postgresStore) List(ctx context.Context, userID int64) ([]Task, error) {
rows, err := s.pool.Query(ctx, `SELECT id, title, description, status, due_at, priority, tags, created_at, updated_at FROM tasks WHERE user_id = $1 ORDER BY id DESC`, userID)
if err != nil {
@@ -304,37 +197,8 @@ func (s *postgresStore) Close() {
}
}
func newTokenCache(addr, password string, db int) (*tokenCache, error) {
if strings.TrimSpace(addr) == "" {
return nil, nil
}
client := redis.NewClient(&redis.Options{
Addr: addr,
Password: password,
DB: db,
})
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, err
}
return &tokenCache{client: client, prefix: "auth:token:"}, nil
}
func (c *tokenCache) Save(ctx context.Context, token string, ttl time.Duration) error {
return c.client.Set(ctx, c.prefix+token, "1", ttl).Err()
}
func (c *tokenCache) Delete(ctx context.Context, token string) error {
return c.client.Del(ctx, c.prefix+token).Err()
}
func (c *tokenCache) Exists(ctx context.Context, token string) (bool, error) {
count, err := c.client.Exists(ctx, c.prefix+token).Result()
if err != nil {
return false, err
}
return count == 1, nil
func (s *postgresStore) Pool() *pgxpool.Pool {
return s.pool
}
func newTaskEmitter(brokers []string, topic string) *taskEmitter {
@@ -367,10 +231,7 @@ func (e *taskEmitter) Emit(ctx context.Context, eventType string, task Task, use
}
writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
if err := e.writer.WriteMessages(writeCtx, kafka.Message{
Key: []byte(strconv.FormatInt(task.ID, 10)),
Value: data,
}); err != nil {
if err := e.writer.WriteMessages(writeCtx, kafka.Message{Key: []byte(strconv.FormatInt(task.ID, 10)), Value: data}); err != nil {
log.Printf("kafka write failed: %v", err)
}
}
@@ -384,11 +245,6 @@ func (e *taskEmitter) Close() {
}
}
var (
errAlreadyExists = errors.New("already exists")
errInvalidCredentials = errors.New("invalid credentials")
)
func isUniqueViolation(err error) bool {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
@@ -398,11 +254,9 @@ func isUniqueViolation(err error) bool {
}
func main() {
store, tokens, cache, emitter := buildDependencies()
store, iamSvc, emitter := buildDependencies()
defer store.Close()
if cache != nil {
defer cache.client.Close()
}
defer iamSvc.Close()
defer emitter.Close()
gin.SetMode(gin.DebugMode)
@@ -410,135 +264,21 @@ func main() {
router.RedirectTrailingSlash = false
router.RedirectFixedPath = false
router.Use(func(c *gin.Context) {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
})
router.Use(corsMiddleware())
router.GET("/api/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
auth := router.Group("/api/v1/auth")
{
auth.POST("/register", func(c *gin.Context) {
var input struct {
Email string `json:"email"`
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
input.Email = strings.TrimSpace(input.Email)
if input.Email == "" || input.Password == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "email and password required"})
return
}
user, err := store.Register(c.Request.Context(), input.Email, input.Password)
if err != nil {
if errors.Is(err, errAlreadyExists) {
c.JSON(http.StatusConflict, gin.H{"error": "user already exists"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "registration failed"})
return
}
c.JSON(http.StatusCreated, gin.H{"id": user.ID, "email": user.Email})
})
auth.POST("/login", func(c *gin.Context) {
var input struct {
Email string `json:"email"`
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
input.Email = strings.TrimSpace(input.Email)
if input.Email == "" || input.Password == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "email and password required"})
return
}
user, err := store.Login(c.Request.Context(), input.Email, input.Password)
if err != nil {
if errors.Is(err, errInvalidCredentials) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid credentials"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "login failed"})
return
}
token, err := tokens.Generate(user)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "token generation failed"})
return
}
if cache != nil {
if err := cache.Save(c.Request.Context(), token, tokens.ttl); err != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "token cache unavailable"})
return
}
}
c.JSON(http.StatusOK, gin.H{"token": token})
})
auth.POST("/logout", func(c *gin.Context) {
token := extractBearerToken(c)
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization"})
return
}
if _, ok := tokens.Validate(token); !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if cache != nil {
if err := cache.Delete(c.Request.Context(), token); err != nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "token cache unavailable"})
return
}
}
c.Status(http.StatusNoContent)
})
}
cfg := iam.LoadConfig()
iam.NewHandler(iamSvc, cfg).RegisterRoutes(router)
api := router.Group("/api/v1")
api.Use(func(c *gin.Context) {
token := extractBearerToken(c)
if token == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization"})
return
}
userID, ok := tokens.Validate(token)
if !ok {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if cache != nil {
exists, err := cache.Exists(c.Request.Context(), token)
if err != nil {
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "token cache unavailable"})
return
}
if !exists {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
}
c.Set("user_id", userID)
c.Next()
})
api.Use(iamSvc.RequireAccess())
tasks := api.Group("/tasks")
{
tasks.GET("", func(c *gin.Context) {
userID := c.GetInt64("user_id")
userID := c.GetInt64(iam.ContextUserIDKey)
items, err := store.List(c.Request.Context(), userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load tasks"})
@@ -552,10 +292,10 @@ func main() {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
userID := c.GetInt64("user_id")
userID := c.GetInt64(iam.ContextUserIDKey)
created, err := store.Create(c.Request.Context(), userID, input)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create task"})
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save task"})
return
}
emitter.Emit(c.Request.Context(), "task.created", created, userID)
@@ -567,8 +307,8 @@ func main() {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
return
}
userID := c.GetInt64("user_id")
task, err := store.Get(c.Request.Context(), userID, id)
userID := c.GetInt64(iam.ContextUserIDKey)
item, err := store.Get(c.Request.Context(), userID, id)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
@@ -577,7 +317,7 @@ func main() {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load task"})
return
}
c.JSON(http.StatusOK, task)
c.JSON(http.StatusOK, item)
})
tasks.PUT(":id", func(c *gin.Context) {
id, err := parseID(c.Param("id"))
@@ -590,7 +330,7 @@ func main() {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
userID := c.GetInt64("user_id")
userID := c.GetInt64(iam.ContextUserIDKey)
updated, err := store.Update(c.Request.Context(), userID, id, input)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
@@ -609,7 +349,7 @@ func main() {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
return
}
userID := c.GetInt64("user_id")
userID := c.GetInt64(iam.ContextUserIDKey)
if err := store.Delete(c.Request.Context(), userID, id); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
@@ -628,7 +368,7 @@ func main() {
}
}
func buildDependencies() (*postgresStore, *tokenManager, *tokenCache, *taskEmitter) {
func buildDependencies() (*postgresStore, *iam.Service, *taskEmitter) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
@@ -640,19 +380,10 @@ func buildDependencies() (*postgresStore, *tokenManager, *tokenCache, *taskEmitt
if err != nil {
log.Fatalf("postgres connection failed: %v", err)
}
secret := strings.TrimSpace(os.Getenv("AUTH_SECRET"))
if secret == "" {
secret = "dev-secret-change-me"
}
tokens := newTokenManager(secret, 24*time.Hour)
redisAddr := strings.TrimSpace(os.Getenv("REDIS_ADDR"))
redisPassword := os.Getenv("REDIS_PASSWORD")
redisDB := parseEnvInt("REDIS_DB", 0)
cache, err := newTokenCache(redisAddr, redisPassword, redisDB)
iamCfg := iam.LoadConfig()
iamSvc, err := iam.NewService(ctx, store.Pool(), iamCfg)
if err != nil {
log.Fatalf("redis connection failed: %v", err)
log.Fatalf("iam initialization failed: %v", err)
}
brokers := splitCSV(os.Getenv("KAFKA_BROKERS"))
@@ -661,34 +392,13 @@ func buildDependencies() (*postgresStore, *tokenManager, *tokenCache, *taskEmitt
topic = "todo.tasks"
}
emitter := newTaskEmitter(brokers, topic)
return store, tokens, cache, emitter
return store, iamSvc, emitter
}
func parseID(value string) (int64, error) {
return strconv.ParseInt(value, 10, 64)
}
func extractBearerToken(c *gin.Context) string {
authHeader := strings.TrimSpace(c.GetHeader("Authorization"))
if strings.HasPrefix(authHeader, "Bearer ") {
return strings.TrimSpace(authHeader[7:])
}
return authHeader
}
func parseEnvInt(key string, fallback int) int {
value := strings.TrimSpace(os.Getenv(key))
if value == "" {
return fallback
}
parsed, err := strconv.Atoi(value)
if err != nil {
return fallback
}
return parsed
}
func splitCSV(value string) []string {
if strings.TrimSpace(value) == "" {
return nil
@@ -703,3 +413,23 @@ func splitCSV(value string) []string {
}
return result
}
func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
origin := strings.TrimSpace(c.GetHeader("Origin"))
if origin != "" {
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
c.Writer.Header().Set("Vary", "Origin")
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
} else {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
}
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}