refactor(auth): split IAM module and add access/refresh session flow
This commit is contained in:
@@ -2,9 +2,6 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
@@ -14,13 +11,13 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"wolves.top/todo/internal/iam"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/jackc/pgconn"
|
"github.com/jackc/pgconn"
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
"github.com/segmentio/kafka-go"
|
"github.com/segmentio/kafka-go"
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Task struct {
|
type Task struct {
|
||||||
@@ -35,90 +32,18 @@ type Task struct {
|
|||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type User struct {
|
|
||||||
ID int64 `json:"id"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type tokenManager struct {
|
|
||||||
secret []byte
|
|
||||||
ttl time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
type postgresStore struct {
|
type postgresStore struct {
|
||||||
pool *pgxpool.Pool
|
pool *pgxpool.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
type tokenCache struct {
|
|
||||||
client *redis.Client
|
|
||||||
prefix string
|
|
||||||
}
|
|
||||||
|
|
||||||
type taskEmitter struct {
|
type taskEmitter struct {
|
||||||
writer *kafka.Writer
|
writer *kafka.Writer
|
||||||
topic string
|
topic string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTokenManager(secret string, ttl time.Duration) *tokenManager {
|
var (
|
||||||
return &tokenManager{
|
errAlreadyExists = errors.New("already exists")
|
||||||
secret: []byte(secret),
|
)
|
||||||
ttl: ttl,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tokenManager) Generate(user User) (string, error) {
|
|
||||||
payload := struct {
|
|
||||||
UserID int64 `json:"uid"`
|
|
||||||
Exp int64 `json:"exp"`
|
|
||||||
}{
|
|
||||||
UserID: user.ID,
|
|
||||||
Exp: time.Now().Add(t.ttl).Unix(),
|
|
||||||
}
|
|
||||||
raw, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
encoded := base64.RawURLEncoding.EncodeToString(raw)
|
|
||||||
sig := t.sign(encoded)
|
|
||||||
return encoded + "." + sig, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tokenManager) Validate(token string) (int64, bool) {
|
|
||||||
parts := strings.Split(token, ".")
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
payload, sig := parts[0], parts[1]
|
|
||||||
if !t.verify(payload, sig) {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
raw, err := base64.RawURLEncoding.DecodeString(payload)
|
|
||||||
if err != nil {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
var data struct {
|
|
||||||
UserID int64 `json:"uid"`
|
|
||||||
Exp int64 `json:"exp"`
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal(raw, &data); err != nil {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
if data.UserID == 0 || time.Now().Unix() > data.Exp {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
return data.UserID, true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tokenManager) sign(payload string) string {
|
|
||||||
mac := hmac.New(sha256.New, t.secret)
|
|
||||||
mac.Write([]byte(payload))
|
|
||||||
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *tokenManager) verify(payload, signature string) bool {
|
|
||||||
expected := t.sign(payload)
|
|
||||||
return hmac.Equal([]byte(signature), []byte(expected))
|
|
||||||
}
|
|
||||||
|
|
||||||
func newPostgresStore(ctx context.Context, url string) (*postgresStore, error) {
|
func newPostgresStore(ctx context.Context, url string) (*postgresStore, error) {
|
||||||
pool, err := pgxpool.New(ctx, url)
|
pool, err := pgxpool.New(ctx, url)
|
||||||
@@ -166,38 +91,6 @@ func (s *postgresStore) initSchema(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *postgresStore) Register(ctx context.Context, email, password string) (User, error) {
|
|
||||||
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
||||||
if err != nil {
|
|
||||||
return User{}, err
|
|
||||||
}
|
|
||||||
var id int64
|
|
||||||
err = s.pool.QueryRow(ctx, `INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING id`, email, string(hashed)).Scan(&id)
|
|
||||||
if err != nil {
|
|
||||||
if isUniqueViolation(err) {
|
|
||||||
return User{}, errAlreadyExists
|
|
||||||
}
|
|
||||||
return User{}, err
|
|
||||||
}
|
|
||||||
return User{ID: id, Email: email}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *postgresStore) Login(ctx context.Context, email, password string) (User, error) {
|
|
||||||
var user User
|
|
||||||
var hash string
|
|
||||||
err := s.pool.QueryRow(ctx, `SELECT id, email, password_hash FROM users WHERE email = $1`, email).Scan(&user.ID, &user.Email, &hash)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
|
||||||
return User{}, errInvalidCredentials
|
|
||||||
}
|
|
||||||
return User{}, err
|
|
||||||
}
|
|
||||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil {
|
|
||||||
return User{}, errInvalidCredentials
|
|
||||||
}
|
|
||||||
return user, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *postgresStore) List(ctx context.Context, userID int64) ([]Task, error) {
|
func (s *postgresStore) List(ctx context.Context, userID int64) ([]Task, error) {
|
||||||
rows, err := s.pool.Query(ctx, `SELECT id, title, description, status, due_at, priority, tags, created_at, updated_at FROM tasks WHERE user_id = $1 ORDER BY id DESC`, userID)
|
rows, err := s.pool.Query(ctx, `SELECT id, title, description, status, due_at, priority, tags, created_at, updated_at FROM tasks WHERE user_id = $1 ORDER BY id DESC`, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -304,37 +197,8 @@ func (s *postgresStore) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTokenCache(addr, password string, db int) (*tokenCache, error) {
|
func (s *postgresStore) Pool() *pgxpool.Pool {
|
||||||
if strings.TrimSpace(addr) == "" {
|
return s.pool
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
client := redis.NewClient(&redis.Options{
|
|
||||||
Addr: addr,
|
|
||||||
Password: password,
|
|
||||||
DB: db,
|
|
||||||
})
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := client.Ping(ctx).Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &tokenCache{client: client, prefix: "auth:token:"}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *tokenCache) Save(ctx context.Context, token string, ttl time.Duration) error {
|
|
||||||
return c.client.Set(ctx, c.prefix+token, "1", ttl).Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *tokenCache) Delete(ctx context.Context, token string) error {
|
|
||||||
return c.client.Del(ctx, c.prefix+token).Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *tokenCache) Exists(ctx context.Context, token string) (bool, error) {
|
|
||||||
count, err := c.client.Exists(ctx, c.prefix+token).Result()
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
return count == 1, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTaskEmitter(brokers []string, topic string) *taskEmitter {
|
func newTaskEmitter(brokers []string, topic string) *taskEmitter {
|
||||||
@@ -367,10 +231,7 @@ func (e *taskEmitter) Emit(ctx context.Context, eventType string, task Task, use
|
|||||||
}
|
}
|
||||||
writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
writeCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := e.writer.WriteMessages(writeCtx, kafka.Message{
|
if err := e.writer.WriteMessages(writeCtx, kafka.Message{Key: []byte(strconv.FormatInt(task.ID, 10)), Value: data}); err != nil {
|
||||||
Key: []byte(strconv.FormatInt(task.ID, 10)),
|
|
||||||
Value: data,
|
|
||||||
}); err != nil {
|
|
||||||
log.Printf("kafka write failed: %v", err)
|
log.Printf("kafka write failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -384,11 +245,6 @@ func (e *taskEmitter) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
|
||||||
errAlreadyExists = errors.New("already exists")
|
|
||||||
errInvalidCredentials = errors.New("invalid credentials")
|
|
||||||
)
|
|
||||||
|
|
||||||
func isUniqueViolation(err error) bool {
|
func isUniqueViolation(err error) bool {
|
||||||
var pgErr *pgconn.PgError
|
var pgErr *pgconn.PgError
|
||||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||||
@@ -398,11 +254,9 @@ func isUniqueViolation(err error) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
store, tokens, cache, emitter := buildDependencies()
|
store, iamSvc, emitter := buildDependencies()
|
||||||
defer store.Close()
|
defer store.Close()
|
||||||
if cache != nil {
|
defer iamSvc.Close()
|
||||||
defer cache.client.Close()
|
|
||||||
}
|
|
||||||
defer emitter.Close()
|
defer emitter.Close()
|
||||||
|
|
||||||
gin.SetMode(gin.DebugMode)
|
gin.SetMode(gin.DebugMode)
|
||||||
@@ -410,135 +264,21 @@ func main() {
|
|||||||
router.RedirectTrailingSlash = false
|
router.RedirectTrailingSlash = false
|
||||||
router.RedirectFixedPath = false
|
router.RedirectFixedPath = false
|
||||||
|
|
||||||
router.Use(func(c *gin.Context) {
|
router.Use(corsMiddleware())
|
||||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
|
||||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
|
||||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
|
||||||
if c.Request.Method == http.MethodOptions {
|
|
||||||
c.AbortWithStatus(http.StatusNoContent)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Next()
|
|
||||||
})
|
|
||||||
|
|
||||||
router.GET("/api/health", func(c *gin.Context) {
|
router.GET("/api/health", func(c *gin.Context) {
|
||||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||||
})
|
})
|
||||||
|
|
||||||
auth := router.Group("/api/v1/auth")
|
cfg := iam.LoadConfig()
|
||||||
{
|
iam.NewHandler(iamSvc, cfg).RegisterRoutes(router)
|
||||||
auth.POST("/register", func(c *gin.Context) {
|
|
||||||
var input struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
input.Email = strings.TrimSpace(input.Email)
|
|
||||||
if input.Email == "" || input.Password == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "email and password required"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
user, err := store.Register(c.Request.Context(), input.Email, input.Password)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, errAlreadyExists) {
|
|
||||||
c.JSON(http.StatusConflict, gin.H{"error": "user already exists"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "registration failed"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusCreated, gin.H{"id": user.ID, "email": user.Email})
|
|
||||||
})
|
|
||||||
auth.POST("/login", func(c *gin.Context) {
|
|
||||||
var input struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
input.Email = strings.TrimSpace(input.Email)
|
|
||||||
if input.Email == "" || input.Password == "" {
|
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "email and password required"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
user, err := store.Login(c.Request.Context(), input.Email, input.Password)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, errInvalidCredentials) {
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid credentials"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "login failed"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
token, err := tokens.Generate(user)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "token generation failed"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if cache != nil {
|
|
||||||
if err := cache.Save(c.Request.Context(), token, tokens.ttl); err != nil {
|
|
||||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "token cache unavailable"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.JSON(http.StatusOK, gin.H{"token": token})
|
|
||||||
})
|
|
||||||
auth.POST("/logout", func(c *gin.Context) {
|
|
||||||
token := extractBearerToken(c)
|
|
||||||
if token == "" {
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if _, ok := tokens.Validate(token); !ok {
|
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if cache != nil {
|
|
||||||
if err := cache.Delete(c.Request.Context(), token); err != nil {
|
|
||||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "token cache unavailable"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.Status(http.StatusNoContent)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
api := router.Group("/api/v1")
|
api := router.Group("/api/v1")
|
||||||
api.Use(func(c *gin.Context) {
|
api.Use(iamSvc.RequireAccess())
|
||||||
token := extractBearerToken(c)
|
|
||||||
if token == "" {
|
|
||||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
userID, ok := tokens.Validate(token)
|
|
||||||
if !ok {
|
|
||||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if cache != nil {
|
|
||||||
exists, err := cache.Exists(c.Request.Context(), token)
|
|
||||||
if err != nil {
|
|
||||||
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "token cache unavailable"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.Set("user_id", userID)
|
|
||||||
c.Next()
|
|
||||||
})
|
|
||||||
|
|
||||||
tasks := api.Group("/tasks")
|
tasks := api.Group("/tasks")
|
||||||
{
|
{
|
||||||
tasks.GET("", func(c *gin.Context) {
|
tasks.GET("", func(c *gin.Context) {
|
||||||
userID := c.GetInt64("user_id")
|
userID := c.GetInt64(iam.ContextUserIDKey)
|
||||||
items, err := store.List(c.Request.Context(), userID)
|
items, err := store.List(c.Request.Context(), userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load tasks"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load tasks"})
|
||||||
@@ -552,10 +292,10 @@ func main() {
|
|||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userID := c.GetInt64("user_id")
|
userID := c.GetInt64(iam.ContextUserIDKey)
|
||||||
created, err := store.Create(c.Request.Context(), userID, input)
|
created, err := store.Create(c.Request.Context(), userID, input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create task"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save task"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
emitter.Emit(c.Request.Context(), "task.created", created, userID)
|
emitter.Emit(c.Request.Context(), "task.created", created, userID)
|
||||||
@@ -567,8 +307,8 @@ func main() {
|
|||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userID := c.GetInt64("user_id")
|
userID := c.GetInt64(iam.ContextUserIDKey)
|
||||||
task, err := store.Get(c.Request.Context(), userID, id)
|
item, err := store.Get(c.Request.Context(), userID, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
||||||
@@ -577,7 +317,7 @@ func main() {
|
|||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load task"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to load task"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, task)
|
c.JSON(http.StatusOK, item)
|
||||||
})
|
})
|
||||||
tasks.PUT(":id", func(c *gin.Context) {
|
tasks.PUT(":id", func(c *gin.Context) {
|
||||||
id, err := parseID(c.Param("id"))
|
id, err := parseID(c.Param("id"))
|
||||||
@@ -590,7 +330,7 @@ func main() {
|
|||||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userID := c.GetInt64("user_id")
|
userID := c.GetInt64(iam.ContextUserIDKey)
|
||||||
updated, err := store.Update(c.Request.Context(), userID, id, input)
|
updated, err := store.Update(c.Request.Context(), userID, id, input)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
@@ -609,7 +349,7 @@ func main() {
|
|||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid id"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userID := c.GetInt64("user_id")
|
userID := c.GetInt64(iam.ContextUserIDKey)
|
||||||
if err := store.Delete(c.Request.Context(), userID, id); err != nil {
|
if err := store.Delete(c.Request.Context(), userID, id); err != nil {
|
||||||
if errors.Is(err, pgx.ErrNoRows) {
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
c.JSON(http.StatusNotFound, gin.H{"error": "not found"})
|
||||||
@@ -628,7 +368,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildDependencies() (*postgresStore, *tokenManager, *tokenCache, *taskEmitter) {
|
func buildDependencies() (*postgresStore, *iam.Service, *taskEmitter) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -640,19 +380,10 @@ func buildDependencies() (*postgresStore, *tokenManager, *tokenCache, *taskEmitt
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("postgres connection failed: %v", err)
|
log.Fatalf("postgres connection failed: %v", err)
|
||||||
}
|
}
|
||||||
|
iamCfg := iam.LoadConfig()
|
||||||
secret := strings.TrimSpace(os.Getenv("AUTH_SECRET"))
|
iamSvc, err := iam.NewService(ctx, store.Pool(), iamCfg)
|
||||||
if secret == "" {
|
|
||||||
secret = "dev-secret-change-me"
|
|
||||||
}
|
|
||||||
tokens := newTokenManager(secret, 24*time.Hour)
|
|
||||||
|
|
||||||
redisAddr := strings.TrimSpace(os.Getenv("REDIS_ADDR"))
|
|
||||||
redisPassword := os.Getenv("REDIS_PASSWORD")
|
|
||||||
redisDB := parseEnvInt("REDIS_DB", 0)
|
|
||||||
cache, err := newTokenCache(redisAddr, redisPassword, redisDB)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("redis connection failed: %v", err)
|
log.Fatalf("iam initialization failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
brokers := splitCSV(os.Getenv("KAFKA_BROKERS"))
|
brokers := splitCSV(os.Getenv("KAFKA_BROKERS"))
|
||||||
@@ -661,34 +392,13 @@ func buildDependencies() (*postgresStore, *tokenManager, *tokenCache, *taskEmitt
|
|||||||
topic = "todo.tasks"
|
topic = "todo.tasks"
|
||||||
}
|
}
|
||||||
emitter := newTaskEmitter(brokers, topic)
|
emitter := newTaskEmitter(brokers, topic)
|
||||||
|
return store, iamSvc, emitter
|
||||||
return store, tokens, cache, emitter
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseID(value string) (int64, error) {
|
func parseID(value string) (int64, error) {
|
||||||
return strconv.ParseInt(value, 10, 64)
|
return strconv.ParseInt(value, 10, 64)
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractBearerToken(c *gin.Context) string {
|
|
||||||
authHeader := strings.TrimSpace(c.GetHeader("Authorization"))
|
|
||||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
|
||||||
return strings.TrimSpace(authHeader[7:])
|
|
||||||
}
|
|
||||||
return authHeader
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseEnvInt(key string, fallback int) int {
|
|
||||||
value := strings.TrimSpace(os.Getenv(key))
|
|
||||||
if value == "" {
|
|
||||||
return fallback
|
|
||||||
}
|
|
||||||
parsed, err := strconv.Atoi(value)
|
|
||||||
if err != nil {
|
|
||||||
return fallback
|
|
||||||
}
|
|
||||||
return parsed
|
|
||||||
}
|
|
||||||
|
|
||||||
func splitCSV(value string) []string {
|
func splitCSV(value string) []string {
|
||||||
if strings.TrimSpace(value) == "" {
|
if strings.TrimSpace(value) == "" {
|
||||||
return nil
|
return nil
|
||||||
@@ -703,3 +413,23 @@ func splitCSV(value string) []string {
|
|||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func corsMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
origin := strings.TrimSpace(c.GetHeader("Origin"))
|
||||||
|
if origin != "" {
|
||||||
|
c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
|
||||||
|
c.Writer.Header().Set("Vary", "Origin")
|
||||||
|
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||||
|
} else {
|
||||||
|
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||||
|
c.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
||||||
|
if c.Request.Method == http.MethodOptions {
|
||||||
|
c.AbortWithStatus(http.StatusNoContent)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -56,3 +56,23 @@ services:
|
|||||||
KAFKA_CLUSTERS_0_BOOTSTRAPSERVERS: kafka:9092
|
KAFKA_CLUSTERS_0_BOOTSTRAPSERVERS: kafka:9092
|
||||||
depends_on:
|
depends_on:
|
||||||
- kafka
|
- kafka
|
||||||
|
|
||||||
|
api:
|
||||||
|
image: golang:1.26.0
|
||||||
|
container_name: todo-api
|
||||||
|
working_dir: /app
|
||||||
|
command: go run ./cmd/server
|
||||||
|
volumes:
|
||||||
|
- .:/app
|
||||||
|
ports:
|
||||||
|
- "8080:8080"
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgres://todo:todo@postgres:5432/todo?sslmode=disable
|
||||||
|
REDIS_ADDR: redis:6379
|
||||||
|
KAFKA_BROKERS: kafka:9092
|
||||||
|
KAFKA_TOPIC: todo.tasks
|
||||||
|
AUTH_SECRET: dev-secret-change-me
|
||||||
|
depends_on:
|
||||||
|
- postgres
|
||||||
|
- redis
|
||||||
|
- kafka
|
||||||
|
|||||||
104
internal/iam/cache_redis.go
Normal file
104
internal/iam/cache_redis.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
package iam
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tokenCache struct {
|
||||||
|
client *redis.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTokenCache(addr, password string, db int) (*tokenCache, error) {
|
||||||
|
if addr == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
client := redis.NewClient(&redis.Options{Addr: addr, Password: password, DB: db})
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := client.Ping(ctx).Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &tokenCache{client: client}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tokenCache) Close() error {
|
||||||
|
if c == nil || c.client == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.client.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tokenCache) SetSessionActive(ctx context.Context, sid string, ttl time.Duration) error {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.client.Set(ctx, "auth:session:"+sid, "active", ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tokenCache) IsSessionActive(ctx context.Context, sid string) (bool, error) {
|
||||||
|
if c == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
count, err := c.client.Exists(ctx, "auth:session:"+sid).Result()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return count == 1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tokenCache) DeleteSession(ctx context.Context, sid string) error {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.client.Del(ctx, "auth:session:"+sid).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tokenCache) SetRefreshHash(ctx context.Context, rid, tokenHash string, ttl time.Duration) error {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.client.Set(ctx, "auth:refresh:"+rid, tokenHash, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tokenCache) GetRefreshHash(ctx context.Context, rid string) (string, error) {
|
||||||
|
if c == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
v, err := c.client.Get(ctx, "auth:refresh:"+rid).Result()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, redis.Nil) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return v, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tokenCache) DeleteRefresh(ctx context.Context, rid string) error {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.client.Del(ctx, "auth:refresh:"+rid).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tokenCache) DenyAccessJTI(ctx context.Context, jti string, ttl time.Duration) error {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.client.Set(ctx, "auth:deny:access:"+jti, "1", ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *tokenCache) IsAccessJTIDenied(ctx context.Context, jti string) (bool, error) {
|
||||||
|
if c == nil {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
count, err := c.client.Exists(ctx, "auth:deny:access:"+jti).Result()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return count == 1, nil
|
||||||
|
}
|
||||||
95
internal/iam/config.go
Normal file
95
internal/iam/config.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package iam
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
AccessSecret string
|
||||||
|
RefreshSecret string
|
||||||
|
RefreshPepper string
|
||||||
|
AccessTTL time.Duration
|
||||||
|
RefreshTTL time.Duration
|
||||||
|
SessionTTL time.Duration
|
||||||
|
RefreshCookieName string
|
||||||
|
CookieSecure bool
|
||||||
|
CookieDomain string
|
||||||
|
CookiePath string
|
||||||
|
CookieSameSite string
|
||||||
|
RedisAddr string
|
||||||
|
RedisPassword string
|
||||||
|
RedisDB int
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadConfig() Config {
|
||||||
|
cfg := Config{
|
||||||
|
AccessSecret: firstNonEmpty(strings.TrimSpace(os.Getenv("ACCESS_SECRET")), strings.TrimSpace(os.Getenv("AUTH_SECRET")), "dev-access-secret-change-me"),
|
||||||
|
RefreshSecret: firstNonEmpty(strings.TrimSpace(os.Getenv("REFRESH_SECRET")), strings.TrimSpace(os.Getenv("AUTH_SECRET")), "dev-refresh-secret-change-me"),
|
||||||
|
RefreshPepper: firstNonEmpty(strings.TrimSpace(os.Getenv("REFRESH_PEPPER")), "dev-refresh-pepper-change-me"),
|
||||||
|
AccessTTL: parseDuration("ACCESS_TTL", 60*time.Minute),
|
||||||
|
RefreshTTL: parseDuration("REFRESH_TTL", 7*24*time.Hour),
|
||||||
|
RefreshCookieName: firstNonEmpty(strings.TrimSpace(os.Getenv("REFRESH_COOKIE_NAME")), "rt"),
|
||||||
|
CookieSecure: parseBool("COOKIE_SECURE", isProduction()),
|
||||||
|
CookieDomain: strings.TrimSpace(os.Getenv("COOKIE_DOMAIN")),
|
||||||
|
CookiePath: firstNonEmpty(strings.TrimSpace(os.Getenv("COOKIE_PATH")), "/"),
|
||||||
|
CookieSameSite: firstNonEmpty(strings.TrimSpace(os.Getenv("COOKIE_SAMESITE")), "Lax"),
|
||||||
|
RedisAddr: strings.TrimSpace(os.Getenv("REDIS_ADDR")),
|
||||||
|
RedisPassword: os.Getenv("REDIS_PASSWORD"),
|
||||||
|
RedisDB: parseInt("REDIS_DB", 0),
|
||||||
|
}
|
||||||
|
cfg.SessionTTL = parseDuration("SESSION_TTL", cfg.RefreshTTL)
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseDuration(key string, fallback time.Duration) time.Duration {
|
||||||
|
value := strings.TrimSpace(os.Getenv(key))
|
||||||
|
if value == "" {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
d, err := time.ParseDuration(value)
|
||||||
|
if err != nil || d <= 0 {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseBool(key string, fallback bool) bool {
|
||||||
|
value := strings.TrimSpace(os.Getenv(key))
|
||||||
|
if value == "" {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
parsed, err := strconv.ParseBool(value)
|
||||||
|
if err != nil {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseInt(key string, fallback int) int {
|
||||||
|
value := strings.TrimSpace(os.Getenv(key))
|
||||||
|
if value == "" {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
parsed, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
func isProduction() bool {
|
||||||
|
env := strings.ToLower(strings.TrimSpace(os.Getenv("APP_ENV")))
|
||||||
|
return env == "prod" || env == "production"
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstNonEmpty(values ...string) string {
|
||||||
|
for _, v := range values {
|
||||||
|
if strings.TrimSpace(v) != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
192
internal/iam/http_handler.go
Normal file
192
internal/iam/http_handler.go
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
package iam
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Handler struct {
|
||||||
|
service *Service
|
||||||
|
cfg Config
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewHandler(service *Service, cfg Config) *Handler {
|
||||||
|
return &Handler{service: service, cfg: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) RegisterRoutes(router *gin.Engine) {
|
||||||
|
auth := router.Group("/api/v1/auth")
|
||||||
|
auth.POST("/register", h.register)
|
||||||
|
auth.POST("/login", h.login)
|
||||||
|
auth.POST("/refresh", h.refresh)
|
||||||
|
|
||||||
|
protected := auth.Group("")
|
||||||
|
protected.Use(h.service.RequireAccess())
|
||||||
|
protected.POST("/logout", h.logout)
|
||||||
|
protected.POST("/logout-all", h.logoutAll)
|
||||||
|
protected.GET("/sessions", h.listSessions)
|
||||||
|
protected.DELETE("/sessions/:sid", h.revokeSession)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) register(c *gin.Context) {
|
||||||
|
var input struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
AutoLogin *bool `json:"auto_login"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
autoLogin := true
|
||||||
|
if input.AutoLogin != nil {
|
||||||
|
autoLogin = *input.AutoLogin
|
||||||
|
}
|
||||||
|
result, refreshToken, err := h.service.Register(c.Request.Context(), input.Email, input.Password, autoLogin, requestMetaFromContext(c))
|
||||||
|
if err != nil {
|
||||||
|
h.writeAuthError(c, err, "registration failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !autoLogin {
|
||||||
|
c.JSON(http.StatusCreated, gin.H{"email": strings.TrimSpace(strings.ToLower(input.Email))})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.setRefreshCookie(c, refreshToken)
|
||||||
|
c.JSON(http.StatusCreated, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) login(c *gin.Context) {
|
||||||
|
var input struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
DeviceInfo string `json:"device_info"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindJSON(&input); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
meta := requestMetaFromContext(c)
|
||||||
|
meta.DeviceInfo = input.DeviceInfo
|
||||||
|
result, refreshToken, err := h.service.Login(c.Request.Context(), input.Email, input.Password, meta)
|
||||||
|
if err != nil {
|
||||||
|
h.writeAuthError(c, err, "login failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.setRefreshCookie(c, refreshToken)
|
||||||
|
c.JSON(http.StatusOK, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) refresh(c *gin.Context) {
|
||||||
|
refreshToken, err := c.Cookie(h.cfg.RefreshCookieName)
|
||||||
|
if err != nil || strings.TrimSpace(refreshToken) == "" {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing refresh token"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
result, newRefreshToken, err := h.service.Refresh(c.Request.Context(), refreshToken)
|
||||||
|
if err != nil {
|
||||||
|
h.writeAuthError(c, err, "refresh failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.setRefreshCookie(c, newRefreshToken)
|
||||||
|
c.JSON(http.StatusOK, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) logout(c *gin.Context) {
|
||||||
|
token := extractBearerToken(c.GetHeader("Authorization"))
|
||||||
|
if token == "" {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.service.Logout(c.Request.Context(), token); err != nil {
|
||||||
|
h.writeAuthError(c, err, "logout failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.clearRefreshCookie(c)
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) logoutAll(c *gin.Context) {
|
||||||
|
token := extractBearerToken(c.GetHeader("Authorization"))
|
||||||
|
if token == "" {
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "missing authorization"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.service.LogoutAll(c.Request.Context(), token); err != nil {
|
||||||
|
h.writeAuthError(c, err, "logout-all failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.clearRefreshCookie(c)
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) listSessions(c *gin.Context) {
|
||||||
|
uid := c.GetInt64(ContextUserIDKey)
|
||||||
|
sessions, err := h.service.ListSessions(c.Request.Context(), uid)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to list sessions"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.JSON(http.StatusOK, sessions)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) revokeSession(c *gin.Context) {
|
||||||
|
uid := c.GetInt64(ContextUserIDKey)
|
||||||
|
sid := strings.TrimSpace(c.Param("sid"))
|
||||||
|
if sid == "" {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid sid"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := h.service.RevokeSession(c.Request.Context(), uid, sid); err != nil {
|
||||||
|
h.writeAuthError(c, err, "revoke session failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestMetaFromContext(c *gin.Context) requestMeta {
|
||||||
|
return requestMeta{
|
||||||
|
IP: c.ClientIP(),
|
||||||
|
UserAgent: c.Request.UserAgent(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) writeAuthError(c *gin.Context, err error, fallback string) {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, errInvalidCredentials):
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid credentials"})
|
||||||
|
case errors.Is(err, errAlreadyExists):
|
||||||
|
c.JSON(http.StatusConflict, gin.H{"error": "user already exists"})
|
||||||
|
case errors.Is(err, errInvalidToken), errors.Is(err, errTokenExpired), errors.Is(err, errUnauthorized), errors.Is(err, errSessionRevoked):
|
||||||
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||||
|
case errors.Is(err, errForbidden):
|
||||||
|
c.JSON(http.StatusForbidden, gin.H{"error": "forbidden"})
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": fallback})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) setRefreshCookie(c *gin.Context, token string) {
|
||||||
|
h.applySameSite(c)
|
||||||
|
maxAge := int(h.cfg.RefreshTTL / time.Second)
|
||||||
|
c.SetCookie(h.cfg.RefreshCookieName, token, maxAge, h.cfg.CookiePath, h.cfg.CookieDomain, h.cfg.CookieSecure, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) clearRefreshCookie(c *gin.Context) {
|
||||||
|
h.applySameSite(c)
|
||||||
|
c.SetCookie(h.cfg.RefreshCookieName, "", -1, h.cfg.CookiePath, h.cfg.CookieDomain, h.cfg.CookieSecure, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) applySameSite(c *gin.Context) {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(h.cfg.CookieSameSite)) {
|
||||||
|
case "strict":
|
||||||
|
c.SetSameSite(http.SameSiteStrictMode)
|
||||||
|
case "none":
|
||||||
|
c.SetSameSite(http.SameSiteNoneMode)
|
||||||
|
default:
|
||||||
|
c.SetSameSite(http.SameSiteLaxMode)
|
||||||
|
}
|
||||||
|
}
|
||||||
40
internal/iam/middleware.go
Normal file
40
internal/iam/middleware.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package iam
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Service) RequireAccess() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
token := extractBearerToken(c.GetHeader("Authorization"))
|
||||||
|
if token == "" {
|
||||||
|
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
claims, err := s.ValidateAccessToken(c.Request.Context(), token)
|
||||||
|
if err != nil {
|
||||||
|
status := http.StatusUnauthorized
|
||||||
|
if errors.Is(err, errSessionRevoked) {
|
||||||
|
status = http.StatusUnauthorized
|
||||||
|
}
|
||||||
|
c.AbortWithStatusJSON(status, gin.H{"error": "invalid token"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set(ContextUserIDKey, claims.UserID)
|
||||||
|
c.Set(ContextSessionIDKey, claims.SessionID)
|
||||||
|
c.Set(ContextJTIKey, claims.JTI)
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractBearerToken(header string) string {
|
||||||
|
authHeader := strings.TrimSpace(header)
|
||||||
|
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||||
|
return strings.TrimSpace(authHeader[7:])
|
||||||
|
}
|
||||||
|
return authHeader
|
||||||
|
}
|
||||||
302
internal/iam/service.go
Normal file
302
internal/iam/service.go
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
package iam
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Service struct {
|
||||||
|
cfg Config
|
||||||
|
store *postgresStore
|
||||||
|
cache *tokenCache
|
||||||
|
accessToken *tokenManager
|
||||||
|
refreshToken *tokenManager
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewService(ctx context.Context, pool *pgxpool.Pool, cfg Config) (*Service, error) {
|
||||||
|
store := newPostgresStore(pool)
|
||||||
|
if err := store.initSchema(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cache, err := newTokenCache(cfg.RedisAddr, cfg.RedisPassword, cfg.RedisDB)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &Service{
|
||||||
|
cfg: cfg,
|
||||||
|
store: store,
|
||||||
|
cache: cache,
|
||||||
|
accessToken: newTokenManager(cfg.AccessSecret, cfg.AccessTTL, tokenKindAccess),
|
||||||
|
refreshToken: newTokenManager(cfg.RefreshSecret, cfg.RefreshTTL, tokenKindRefresh),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Close() error {
|
||||||
|
if s.cache != nil {
|
||||||
|
return s.cache.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Register(ctx context.Context, email, password string, autoLogin bool, meta requestMeta) (AuthResult, string, error) {
|
||||||
|
email = normalizeEmail(email)
|
||||||
|
if email == "" || strings.TrimSpace(password) == "" {
|
||||||
|
return AuthResult{}, "", errInvalidCredentials
|
||||||
|
}
|
||||||
|
uid, err := s.store.registerUser(ctx, email, password)
|
||||||
|
if err != nil {
|
||||||
|
return AuthResult{}, "", err
|
||||||
|
}
|
||||||
|
if !autoLogin {
|
||||||
|
return AuthResult{}, "", nil
|
||||||
|
}
|
||||||
|
pair, err := s.issueTokenPair(ctx, uid, meta)
|
||||||
|
if err != nil {
|
||||||
|
return AuthResult{}, "", err
|
||||||
|
}
|
||||||
|
return AuthResult{
|
||||||
|
AccessToken: pair.AccessToken,
|
||||||
|
ExpiresIn: int64(s.cfg.AccessTTL.Seconds()),
|
||||||
|
SessionID: pair.SessionID,
|
||||||
|
}, pair.RefreshToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Login(ctx context.Context, email, password string, meta requestMeta) (AuthResult, string, error) {
|
||||||
|
email = normalizeEmail(email)
|
||||||
|
if email == "" || strings.TrimSpace(password) == "" {
|
||||||
|
return AuthResult{}, "", errInvalidCredentials
|
||||||
|
}
|
||||||
|
uid, err := s.store.verifyUser(ctx, email, password)
|
||||||
|
if err != nil {
|
||||||
|
return AuthResult{}, "", err
|
||||||
|
}
|
||||||
|
pair, err := s.issueTokenPair(ctx, uid, meta)
|
||||||
|
if err != nil {
|
||||||
|
return AuthResult{}, "", err
|
||||||
|
}
|
||||||
|
return AuthResult{
|
||||||
|
AccessToken: pair.AccessToken,
|
||||||
|
ExpiresIn: int64(s.cfg.AccessTTL.Seconds()),
|
||||||
|
SessionID: pair.SessionID,
|
||||||
|
}, pair.RefreshToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) issueTokenPair(ctx context.Context, userID int64, meta requestMeta) (tokenPair, error) {
|
||||||
|
sid, err := randomID("sid")
|
||||||
|
if err != nil {
|
||||||
|
return tokenPair{}, err
|
||||||
|
}
|
||||||
|
rid, err := randomID("rid")
|
||||||
|
if err != nil {
|
||||||
|
return tokenPair{}, err
|
||||||
|
}
|
||||||
|
jti, err := randomID("jti")
|
||||||
|
if err != nil {
|
||||||
|
return tokenPair{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
sess := Session{
|
||||||
|
ID: sid,
|
||||||
|
UserID: userID,
|
||||||
|
DeviceInfo: strings.TrimSpace(meta.DeviceInfo),
|
||||||
|
IP: strings.TrimSpace(meta.IP),
|
||||||
|
UserAgent: strings.TrimSpace(meta.UserAgent),
|
||||||
|
CreatedAt: now,
|
||||||
|
ExpiresAt: now.Add(s.cfg.SessionTTL),
|
||||||
|
}
|
||||||
|
if err := s.store.createSession(ctx, sess); err != nil {
|
||||||
|
return tokenPair{}, err
|
||||||
|
}
|
||||||
|
refreshRaw, refreshExp, err := s.refreshToken.generateRefresh(userID, sid, rid)
|
||||||
|
if err != nil {
|
||||||
|
return tokenPair{}, err
|
||||||
|
}
|
||||||
|
refreshHash := hashRefreshToken(refreshRaw, s.cfg.RefreshPepper)
|
||||||
|
if err := s.store.createRefreshToken(ctx, rid, sid, refreshHash, refreshExp, nil); err != nil {
|
||||||
|
return tokenPair{}, err
|
||||||
|
}
|
||||||
|
accessRaw, accessExp, err := s.accessToken.generateAccess(userID, sid, jti)
|
||||||
|
if err != nil {
|
||||||
|
return tokenPair{}, err
|
||||||
|
}
|
||||||
|
if s.cache != nil {
|
||||||
|
_ = s.cache.SetSessionActive(ctx, sid, time.Until(sess.ExpiresAt))
|
||||||
|
_ = s.cache.SetRefreshHash(ctx, rid, refreshHash, time.Until(refreshExp))
|
||||||
|
_ = accessExp
|
||||||
|
}
|
||||||
|
return tokenPair{
|
||||||
|
AccessToken: accessRaw,
|
||||||
|
AccessTokenExpires: accessExp,
|
||||||
|
RefreshToken: refreshRaw,
|
||||||
|
SessionID: sid,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Refresh(ctx context.Context, refreshRaw string) (AuthResult, string, error) {
|
||||||
|
claims, err := s.refreshToken.parseRefresh(refreshRaw)
|
||||||
|
if err != nil {
|
||||||
|
return AuthResult{}, "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok, err := s.isSessionActive(ctx, claims.SessionID); err != nil {
|
||||||
|
return AuthResult{}, "", err
|
||||||
|
} else if !ok {
|
||||||
|
return AuthResult{}, "", errSessionRevoked
|
||||||
|
}
|
||||||
|
|
||||||
|
providedHash := hashRefreshToken(refreshRaw, s.cfg.RefreshPepper)
|
||||||
|
if s.cache != nil {
|
||||||
|
cached, err := s.cache.GetRefreshHash(ctx, claims.RefreshID)
|
||||||
|
if err != nil {
|
||||||
|
return AuthResult{}, "", err
|
||||||
|
}
|
||||||
|
if cached != "" && cached != providedHash {
|
||||||
|
_ = s.store.revokeSession(ctx, claims.SessionID)
|
||||||
|
_ = s.cache.DeleteSession(ctx, claims.SessionID)
|
||||||
|
return AuthResult{}, "", errSessionRevoked
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newRID, err := randomID("rid")
|
||||||
|
if err != nil {
|
||||||
|
return AuthResult{}, "", err
|
||||||
|
}
|
||||||
|
newJTI, err := randomID("jti")
|
||||||
|
if err != nil {
|
||||||
|
return AuthResult{}, "", err
|
||||||
|
}
|
||||||
|
newRefreshRaw, refreshExp, err := s.refreshToken.generateRefresh(claims.UserID, claims.SessionID, newRID)
|
||||||
|
if err != nil {
|
||||||
|
return AuthResult{}, "", err
|
||||||
|
}
|
||||||
|
newRefreshHash := hashRefreshToken(newRefreshRaw, s.cfg.RefreshPepper)
|
||||||
|
if err := s.store.rotateRefreshToken(ctx, claims, providedHash, newRID, newRefreshHash, refreshExp); err != nil {
|
||||||
|
if errors.Is(err, errSessionRevoked) {
|
||||||
|
_ = s.cache.DeleteSession(ctx, claims.SessionID)
|
||||||
|
}
|
||||||
|
return AuthResult{}, "", err
|
||||||
|
}
|
||||||
|
if s.cache != nil {
|
||||||
|
_ = s.cache.DeleteRefresh(ctx, claims.RefreshID)
|
||||||
|
_ = s.cache.SetRefreshHash(ctx, newRID, newRefreshHash, time.Until(refreshExp))
|
||||||
|
}
|
||||||
|
accessRaw, _, err := s.accessToken.generateAccess(claims.UserID, claims.SessionID, newJTI)
|
||||||
|
if err != nil {
|
||||||
|
return AuthResult{}, "", err
|
||||||
|
}
|
||||||
|
return AuthResult{
|
||||||
|
AccessToken: accessRaw,
|
||||||
|
ExpiresIn: int64(s.cfg.AccessTTL.Seconds()),
|
||||||
|
SessionID: claims.SessionID,
|
||||||
|
}, newRefreshRaw, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) Logout(ctx context.Context, accessRaw string) error {
|
||||||
|
claims, err := s.accessToken.parseAccess(accessRaw)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := s.store.revokeSession(ctx, claims.SessionID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if s.cache != nil {
|
||||||
|
_ = s.cache.DeleteSession(ctx, claims.SessionID)
|
||||||
|
_ = s.cache.DenyAccessJTI(ctx, claims.JTI, s.cfg.AccessTTL)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) LogoutAll(ctx context.Context, accessRaw string) error {
|
||||||
|
claims, err := s.accessToken.parseAccess(accessRaw)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sids, err := s.store.revokeAllUserSessions(ctx, claims.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if s.cache != nil {
|
||||||
|
for _, sid := range sids {
|
||||||
|
_ = s.cache.DeleteSession(ctx, sid)
|
||||||
|
}
|
||||||
|
_ = s.cache.DenyAccessJTI(ctx, claims.JTI, s.cfg.AccessTTL)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) ListSessions(ctx context.Context, userID int64) ([]Session, error) {
|
||||||
|
return s.store.listSessions(ctx, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) RevokeSession(ctx context.Context, userID int64, sid string) error {
|
||||||
|
sess, err := s.store.getSession(ctx, sid)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return errInvalidToken
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if sess.UserID != userID {
|
||||||
|
return errForbidden
|
||||||
|
}
|
||||||
|
if err := s.store.revokeSession(ctx, sid); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if s.cache != nil {
|
||||||
|
_ = s.cache.DeleteSession(ctx, sid)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) ValidateAccessToken(ctx context.Context, token string) (accessClaims, error) {
|
||||||
|
claims, err := s.accessToken.parseAccess(token)
|
||||||
|
if err != nil {
|
||||||
|
return accessClaims{}, err
|
||||||
|
}
|
||||||
|
if s.cache != nil {
|
||||||
|
denied, err := s.cache.IsAccessJTIDenied(ctx, claims.JTI)
|
||||||
|
if err != nil {
|
||||||
|
return accessClaims{}, err
|
||||||
|
}
|
||||||
|
if denied {
|
||||||
|
return accessClaims{}, errUnauthorized
|
||||||
|
}
|
||||||
|
}
|
||||||
|
active, err := s.isSessionActive(ctx, claims.SessionID)
|
||||||
|
if err != nil {
|
||||||
|
return accessClaims{}, err
|
||||||
|
}
|
||||||
|
if !active {
|
||||||
|
return accessClaims{}, errSessionRevoked
|
||||||
|
}
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) isSessionActive(ctx context.Context, sid string) (bool, error) {
|
||||||
|
if s.cache != nil {
|
||||||
|
ok, err := s.cache.IsSessionActive(ctx, sid)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ok, err := s.store.isSessionActive(ctx, sid)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if ok && s.cache != nil {
|
||||||
|
sess, err := s.store.getSession(ctx, sid)
|
||||||
|
if err == nil {
|
||||||
|
_ = s.cache.SetSessionActive(ctx, sid, time.Until(sess.ExpiresAt))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
279
internal/iam/store_pg.go
Normal file
279
internal/iam/store_pg.go
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
package iam
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgconn"
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type postgresStore struct {
|
||||||
|
pool *pgxpool.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
type refreshTokenRow struct {
|
||||||
|
ID string
|
||||||
|
SessionID string
|
||||||
|
TokenHash string
|
||||||
|
ExpiresAt time.Time
|
||||||
|
RotatedFrom sql.NullString
|
||||||
|
RevokedAt sql.NullTime
|
||||||
|
UsedAt sql.NullTime
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPostgresStore(pool *pgxpool.Pool) *postgresStore {
|
||||||
|
return &postgresStore{pool: pool}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *postgresStore) initSchema(ctx context.Context) error {
|
||||||
|
stmts := []string{
|
||||||
|
`CREATE TABLE IF NOT EXISTS auth_sessions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
device_info TEXT,
|
||||||
|
ip TEXT,
|
||||||
|
user_agent TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
|
revoked_at TIMESTAMPTZ
|
||||||
|
)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_auth_sessions_user_revoked ON auth_sessions(user_id, revoked_at)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_auth_sessions_expires_at ON auth_sessions(expires_at)`,
|
||||||
|
`CREATE TABLE IF NOT EXISTS auth_refresh_tokens (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
session_id TEXT NOT NULL REFERENCES auth_sessions(id) ON DELETE CASCADE,
|
||||||
|
token_hash TEXT NOT NULL,
|
||||||
|
expires_at TIMESTAMPTZ NOT NULL,
|
||||||
|
rotated_from TEXT,
|
||||||
|
revoked_at TIMESTAMPTZ,
|
||||||
|
used_at TIMESTAMPTZ,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||||
|
)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_auth_refresh_tokens_session_revoked ON auth_refresh_tokens(session_id, revoked_at)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_auth_refresh_tokens_expires_at ON auth_refresh_tokens(expires_at)`,
|
||||||
|
`CREATE INDEX IF NOT EXISTS idx_auth_refresh_tokens_rotated_from ON auth_refresh_tokens(rotated_from)`,
|
||||||
|
}
|
||||||
|
for _, stmt := range stmts {
|
||||||
|
if _, err := s.pool.Exec(ctx, stmt); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *postgresStore) registerUser(ctx context.Context, email, password string) (int64, error) {
|
||||||
|
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
var id int64
|
||||||
|
err = s.pool.QueryRow(ctx, `INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING id`, email, string(hashed)).Scan(&id)
|
||||||
|
if err != nil {
|
||||||
|
if isUniqueViolation(err) {
|
||||||
|
return 0, errAlreadyExists
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *postgresStore) verifyUser(ctx context.Context, email, password string) (int64, error) {
|
||||||
|
var id int64
|
||||||
|
var hash string
|
||||||
|
err := s.pool.QueryRow(ctx, `SELECT id, password_hash FROM users WHERE email = $1`, email).Scan(&id, &hash)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return 0, errInvalidCredentials
|
||||||
|
}
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil {
|
||||||
|
return 0, errInvalidCredentials
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *postgresStore) createSession(ctx context.Context, sess Session) error {
|
||||||
|
_, err := s.pool.Exec(ctx, `
|
||||||
|
INSERT INTO auth_sessions (id, user_id, device_info, ip, user_agent, expires_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
|
`, sess.ID, sess.UserID, sess.DeviceInfo, sess.IP, sess.UserAgent, sess.ExpiresAt)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *postgresStore) createRefreshToken(ctx context.Context, rid, sid, hash string, expiresAt time.Time, rotatedFrom *string) error {
|
||||||
|
_, err := s.pool.Exec(ctx, `
|
||||||
|
INSERT INTO auth_refresh_tokens (id, session_id, token_hash, expires_at, rotated_from)
|
||||||
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
|
`, rid, sid, hash, expiresAt, rotatedFrom)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *postgresStore) getSession(ctx context.Context, sid string) (Session, error) {
|
||||||
|
var sess Session
|
||||||
|
var revokedAt sql.NullTime
|
||||||
|
err := s.pool.QueryRow(ctx, `
|
||||||
|
SELECT id, user_id, COALESCE(device_info,''), COALESCE(ip,''), COALESCE(user_agent,''), created_at, expires_at, revoked_at
|
||||||
|
FROM auth_sessions WHERE id = $1
|
||||||
|
`, sid).Scan(&sess.ID, &sess.UserID, &sess.DeviceInfo, &sess.IP, &sess.UserAgent, &sess.CreatedAt, &sess.ExpiresAt, &revokedAt)
|
||||||
|
if err != nil {
|
||||||
|
return Session{}, err
|
||||||
|
}
|
||||||
|
if revokedAt.Valid {
|
||||||
|
t := revokedAt.Time
|
||||||
|
sess.RevokedAt = &t
|
||||||
|
}
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *postgresStore) isSessionActive(ctx context.Context, sid string) (bool, error) {
|
||||||
|
var ok bool
|
||||||
|
err := s.pool.QueryRow(ctx, `
|
||||||
|
SELECT EXISTS(
|
||||||
|
SELECT 1 FROM auth_sessions
|
||||||
|
WHERE id = $1 AND revoked_at IS NULL AND expires_at > now()
|
||||||
|
)
|
||||||
|
`, sid).Scan(&ok)
|
||||||
|
return ok, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *postgresStore) revokeSession(ctx context.Context, sid string) error {
|
||||||
|
_, err := s.pool.Exec(ctx, `UPDATE auth_sessions SET revoked_at = now() WHERE id = $1 AND revoked_at IS NULL`, sid)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = s.pool.Exec(ctx, `UPDATE auth_refresh_tokens SET revoked_at = now() WHERE session_id = $1 AND revoked_at IS NULL`, sid)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *postgresStore) revokeAllUserSessions(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
rows, err := s.pool.Query(ctx, `SELECT id FROM auth_sessions WHERE user_id = $1 AND revoked_at IS NULL`, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
ids := make([]string, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var sid string
|
||||||
|
if err := rows.Scan(&sid); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ids = append(ids, sid)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err = s.pool.Exec(ctx, `UPDATE auth_sessions SET revoked_at = now() WHERE user_id = $1 AND revoked_at IS NULL`, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err = s.pool.Exec(ctx, `
|
||||||
|
UPDATE auth_refresh_tokens rt
|
||||||
|
SET revoked_at = now()
|
||||||
|
FROM auth_sessions s
|
||||||
|
WHERE s.user_id = $1 AND rt.session_id = s.id AND rt.revoked_at IS NULL
|
||||||
|
`, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *postgresStore) listSessions(ctx context.Context, userID int64) ([]Session, error) {
|
||||||
|
rows, err := s.pool.Query(ctx, `
|
||||||
|
SELECT id, user_id, COALESCE(device_info,''), COALESCE(ip,''), COALESCE(user_agent,''), created_at, expires_at, revoked_at
|
||||||
|
FROM auth_sessions WHERE user_id = $1 ORDER BY created_at DESC
|
||||||
|
`, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
result := make([]Session, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
var sess Session
|
||||||
|
var revokedAt sql.NullTime
|
||||||
|
if err := rows.Scan(&sess.ID, &sess.UserID, &sess.DeviceInfo, &sess.IP, &sess.UserAgent, &sess.CreatedAt, &sess.ExpiresAt, &revokedAt); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if revokedAt.Valid {
|
||||||
|
t := revokedAt.Time
|
||||||
|
sess.RevokedAt = &t
|
||||||
|
}
|
||||||
|
result = append(result, sess)
|
||||||
|
}
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *postgresStore) getRefreshTokenForUpdate(ctx context.Context, tx pgx.Tx, rid string) (refreshTokenRow, error) {
|
||||||
|
var row refreshTokenRow
|
||||||
|
err := tx.QueryRow(ctx, `
|
||||||
|
SELECT id, session_id, token_hash, expires_at, rotated_from, revoked_at, used_at, created_at
|
||||||
|
FROM auth_refresh_tokens
|
||||||
|
WHERE id = $1
|
||||||
|
FOR UPDATE
|
||||||
|
`, rid).Scan(&row.ID, &row.SessionID, &row.TokenHash, &row.ExpiresAt, &row.RotatedFrom, &row.RevokedAt, &row.UsedAt, &row.CreatedAt)
|
||||||
|
return row, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *postgresStore) rotateRefreshToken(ctx context.Context, claims refreshClaims, providedHash, newRID, newHash string, newExp time.Time) error {
|
||||||
|
tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback(ctx)
|
||||||
|
|
||||||
|
row, err := s.getRefreshTokenForUpdate(ctx, tx, claims.RefreshID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, pgx.ErrNoRows) {
|
||||||
|
return errInvalidToken
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if row.SessionID != claims.SessionID || row.ExpiresAt.Before(time.Now().UTC()) {
|
||||||
|
return errInvalidToken
|
||||||
|
}
|
||||||
|
|
||||||
|
if row.RevokedAt.Valid || row.UsedAt.Valid || row.TokenHash != providedHash {
|
||||||
|
if _, err := tx.Exec(ctx, `UPDATE auth_sessions SET revoked_at = now() WHERE id = $1 AND revoked_at IS NULL`, claims.SessionID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(ctx, `UPDATE auth_refresh_tokens SET revoked_at = now() WHERE session_id = $1 AND revoked_at IS NULL`, claims.SessionID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := tx.Commit(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return errSessionRevoked
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tx.Exec(ctx, `UPDATE auth_refresh_tokens SET used_at = now(), revoked_at = now() WHERE id = $1`, claims.RefreshID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec(ctx, `
|
||||||
|
INSERT INTO auth_refresh_tokens (id, session_id, token_hash, expires_at, rotated_from)
|
||||||
|
VALUES ($1, $2, $3, $4, $5)
|
||||||
|
`, newRID, claims.SessionID, newHash, newExp, claims.RefreshID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isUniqueViolation(err error) bool {
|
||||||
|
var pgErr *pgconn.PgError
|
||||||
|
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeEmail(v string) string {
|
||||||
|
return strings.TrimSpace(strings.ToLower(v))
|
||||||
|
}
|
||||||
147
internal/iam/token.go
Normal file
147
internal/iam/token.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package iam
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tokenManager struct {
|
||||||
|
secret []byte
|
||||||
|
ttl time.Duration
|
||||||
|
kind tokenKind
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTokenManager(secret string, ttl time.Duration, kind tokenKind) *tokenManager {
|
||||||
|
return &tokenManager{secret: []byte(secret), ttl: ttl, kind: kind}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenManager) generateAccess(userID int64, sessionID, jti string) (string, time.Time, error) {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
expiresAt := now.Add(t.ttl)
|
||||||
|
payload := accessClaims{
|
||||||
|
UserID: userID,
|
||||||
|
SessionID: sessionID,
|
||||||
|
JTI: jti,
|
||||||
|
IssuedAt: now.Unix(),
|
||||||
|
ExpiresAt: expiresAt.Unix(),
|
||||||
|
Type: t.kind,
|
||||||
|
}
|
||||||
|
token, err := t.sign(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", time.Time{}, err
|
||||||
|
}
|
||||||
|
return token, expiresAt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenManager) parseAccess(token string) (accessClaims, error) {
|
||||||
|
var claims accessClaims
|
||||||
|
if err := t.verifyAndDecode(token, &claims); err != nil {
|
||||||
|
return claims, err
|
||||||
|
}
|
||||||
|
if claims.Type != tokenKindAccess {
|
||||||
|
return claims, errInvalidToken
|
||||||
|
}
|
||||||
|
if claims.UserID == 0 || strings.TrimSpace(claims.SessionID) == "" || strings.TrimSpace(claims.JTI) == "" {
|
||||||
|
return claims, errInvalidToken
|
||||||
|
}
|
||||||
|
if time.Now().UTC().Unix() > claims.ExpiresAt {
|
||||||
|
return claims, errTokenExpired
|
||||||
|
}
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenManager) generateRefresh(userID int64, sessionID, refreshID string) (string, time.Time, error) {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
expiresAt := now.Add(t.ttl)
|
||||||
|
payload := refreshClaims{
|
||||||
|
UserID: userID,
|
||||||
|
SessionID: sessionID,
|
||||||
|
RefreshID: refreshID,
|
||||||
|
IssuedAt: now.Unix(),
|
||||||
|
ExpiresAt: expiresAt.Unix(),
|
||||||
|
Type: t.kind,
|
||||||
|
}
|
||||||
|
token, err := t.sign(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", time.Time{}, err
|
||||||
|
}
|
||||||
|
return token, expiresAt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenManager) parseRefresh(token string) (refreshClaims, error) {
|
||||||
|
var claims refreshClaims
|
||||||
|
if err := t.verifyAndDecode(token, &claims); err != nil {
|
||||||
|
return claims, err
|
||||||
|
}
|
||||||
|
if claims.Type != tokenKindRefresh {
|
||||||
|
return claims, errInvalidToken
|
||||||
|
}
|
||||||
|
if claims.UserID == 0 || strings.TrimSpace(claims.SessionID) == "" || strings.TrimSpace(claims.RefreshID) == "" {
|
||||||
|
return claims, errInvalidToken
|
||||||
|
}
|
||||||
|
if time.Now().UTC().Unix() > claims.ExpiresAt {
|
||||||
|
return claims, errTokenExpired
|
||||||
|
}
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenManager) sign(v any) (string, error) {
|
||||||
|
raw, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
payload := base64.RawURLEncoding.EncodeToString(raw)
|
||||||
|
sig := t.signPayload(payload)
|
||||||
|
return payload + "." + sig, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenManager) verifyAndDecode(token string, out any) error {
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return errInvalidToken
|
||||||
|
}
|
||||||
|
payload, sig := parts[0], parts[1]
|
||||||
|
if !t.verifyPayload(payload, sig) {
|
||||||
|
return errInvalidToken
|
||||||
|
}
|
||||||
|
raw, err := base64.RawURLEncoding.DecodeString(payload)
|
||||||
|
if err != nil {
|
||||||
|
return errInvalidToken
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(raw, out); err != nil {
|
||||||
|
return errInvalidToken
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenManager) signPayload(payload string) string {
|
||||||
|
mac := hmac.New(sha256.New, t.secret)
|
||||||
|
mac.Write([]byte(payload))
|
||||||
|
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokenManager) verifyPayload(payload, signature string) bool {
|
||||||
|
expected := t.signPayload(payload)
|
||||||
|
return hmac.Equal([]byte(signature), []byte(expected))
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomID(prefix string) (string, error) {
|
||||||
|
b := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s_%s", prefix, hex.EncodeToString(b)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func hashRefreshToken(token, pepper string) string {
|
||||||
|
mac := hmac.New(sha256.New, []byte(pepper))
|
||||||
|
mac.Write([]byte(token))
|
||||||
|
return hex.EncodeToString(mac.Sum(nil))
|
||||||
|
}
|
||||||
77
internal/iam/types.go
Normal file
77
internal/iam/types.go
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
package iam
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ContextUserIDKey = "user_id"
|
||||||
|
ContextSessionIDKey = "session_id"
|
||||||
|
ContextJTIKey = "jti"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tokenKind string
|
||||||
|
|
||||||
|
const (
|
||||||
|
tokenKindAccess tokenKind = "access"
|
||||||
|
tokenKindRefresh tokenKind = "refresh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type accessClaims struct {
|
||||||
|
UserID int64 `json:"uid"`
|
||||||
|
SessionID string `json:"sid"`
|
||||||
|
JTI string `json:"jti"`
|
||||||
|
IssuedAt int64 `json:"iat"`
|
||||||
|
ExpiresAt int64 `json:"exp"`
|
||||||
|
Type tokenKind `json:"typ"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type refreshClaims struct {
|
||||||
|
UserID int64 `json:"uid"`
|
||||||
|
SessionID string `json:"sid"`
|
||||||
|
RefreshID string `json:"rid"`
|
||||||
|
IssuedAt int64 `json:"iat"`
|
||||||
|
ExpiresAt int64 `json:"exp"`
|
||||||
|
Type tokenKind `json:"typ"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
UserID int64 `json:"user_id"`
|
||||||
|
DeviceInfo string `json:"device_info"`
|
||||||
|
IP string `json:"ip"`
|
||||||
|
UserAgent string `json:"user_agent"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
RevokedAt *time.Time `json:"revoked_at,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type tokenPair struct {
|
||||||
|
AccessToken string
|
||||||
|
AccessTokenExpires time.Time
|
||||||
|
RefreshToken string
|
||||||
|
SessionID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type requestMeta struct {
|
||||||
|
IP string
|
||||||
|
UserAgent string
|
||||||
|
DeviceInfo string
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthResult struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
SessionID string `json:"session_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
errInvalidToken = errors.New("invalid token")
|
||||||
|
errTokenExpired = errors.New("token expired")
|
||||||
|
errInvalidCredentials = errors.New("invalid credentials")
|
||||||
|
errAlreadyExists = errors.New("already exists")
|
||||||
|
errSessionRevoked = errors.New("session revoked")
|
||||||
|
errUnauthorized = errors.New("unauthorized")
|
||||||
|
errForbidden = errors.New("forbidden")
|
||||||
|
)
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
import type { Credentials, Task } from '../types'
|
import type { Credentials, Task } from '../types'
|
||||||
import { session } from '../stores/session'
|
import { clearSession, session, setSessionToken } from '../stores/session'
|
||||||
|
|
||||||
const API_BASE = 'http://localhost:8080/api/v1'
|
const API_BASE = 'http://localhost:8080/api/v1'
|
||||||
|
|
||||||
function headers() {
|
function headers(extra?: Record<string, string>) {
|
||||||
const h: Record<string, string> = { 'Content-Type': 'application/json' }
|
const h: Record<string, string> = { 'Content-Type': 'application/json', ...(extra ?? {}) }
|
||||||
if (session.token) h.Authorization = `Bearer ${session.token}`
|
if (session.token) h.Authorization = `Bearer ${session.token}`
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
@@ -18,59 +18,100 @@ async function handle<T>(response: Response): Promise<T> {
|
|||||||
return (await response.json()) as T
|
return (await response.json()) as T
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function register(credentials: Credentials) {
|
async function refreshAccessToken(): Promise<boolean> {
|
||||||
const response = await fetch(`${API_BASE}/auth/register`, {
|
try {
|
||||||
|
const response = await fetch(`${API_BASE}/auth/refresh`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
|
credentials: 'include',
|
||||||
headers: headers(),
|
headers: headers(),
|
||||||
body: JSON.stringify(credentials),
|
|
||||||
})
|
})
|
||||||
return handle<{ id: number; email: string }>(response)
|
if (!response.ok) return false
|
||||||
|
const data = (await response.json()) as { access_token: string }
|
||||||
|
if (!data?.access_token) return false
|
||||||
|
setSessionToken(data.access_token)
|
||||||
|
return true
|
||||||
|
} catch {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function request<T>(path: string, init: RequestInit = {}, allowRefresh = true): Promise<T> {
|
||||||
|
const merged: RequestInit = {
|
||||||
|
credentials: 'include',
|
||||||
|
...init,
|
||||||
|
headers: headers((init.headers as Record<string, string> | undefined) ?? {}),
|
||||||
|
}
|
||||||
|
const response = await fetch(`${API_BASE}${path}`, merged)
|
||||||
|
if (response.status === 401 && allowRefresh && !path.startsWith('/auth/')) {
|
||||||
|
const refreshed = await refreshAccessToken()
|
||||||
|
if (refreshed) {
|
||||||
|
return request<T>(path, init, false)
|
||||||
|
}
|
||||||
|
clearSession()
|
||||||
|
}
|
||||||
|
return handle<T>(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function register(credentials: Credentials, autoLogin = true) {
|
||||||
|
return request<{ email: string; access_token?: string; expires_in?: number; session_id?: string }>('/auth/register', {
|
||||||
|
method: 'POST',
|
||||||
|
body: JSON.stringify({ ...credentials, auto_login: autoLogin }),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function login(credentials: Credentials) {
|
export async function login(credentials: Credentials) {
|
||||||
const response = await fetch(`${API_BASE}/auth/login`, {
|
return request<{ access_token: string; expires_in: number; session_id: string }>('/auth/login', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: headers(),
|
|
||||||
body: JSON.stringify(credentials),
|
body: JSON.stringify(credentials),
|
||||||
})
|
})
|
||||||
return handle<{ token: string }>(response)
|
}
|
||||||
|
|
||||||
|
export async function refresh() {
|
||||||
|
return request<{ access_token: string; expires_in: number; session_id: string }>('/auth/refresh', {
|
||||||
|
method: 'POST',
|
||||||
|
}, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function logout() {
|
export async function logout() {
|
||||||
const response = await fetch(`${API_BASE}/auth/logout`, {
|
return request<void>('/auth/logout', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: headers(),
|
}, false)
|
||||||
})
|
}
|
||||||
return handle<void>(response)
|
|
||||||
|
export async function logoutAll() {
|
||||||
|
return request<void>('/auth/logout-all', {
|
||||||
|
method: 'POST',
|
||||||
|
}, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function listSessions() {
|
||||||
|
return request<Array<{ id: string; device_info: string; ip: string; user_agent: string; created_at: string; expires_at: string; revoked_at?: string }>>('/auth/sessions')
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function revokeSession(id: string) {
|
||||||
|
return request<void>(`/auth/sessions/${id}`, { method: 'DELETE' }, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function listTasks() {
|
export async function listTasks() {
|
||||||
const response = await fetch(`${API_BASE}/tasks`, { headers: headers() })
|
return request<Task[]>('/tasks')
|
||||||
return handle<Task[]>(response)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function createTask(payload: Partial<Task>) {
|
export async function createTask(payload: Partial<Task>) {
|
||||||
const response = await fetch(`${API_BASE}/tasks`, {
|
return request<Task>('/tasks', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: headers(),
|
|
||||||
body: JSON.stringify(payload),
|
body: JSON.stringify(payload),
|
||||||
})
|
})
|
||||||
return handle<Task>(response)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function updateTask(id: number, payload: Partial<Task>) {
|
export async function updateTask(id: number, payload: Partial<Task>) {
|
||||||
const response = await fetch(`${API_BASE}/tasks/${id}`, {
|
return request<Task>(`/tasks/${id}`, {
|
||||||
method: 'PUT',
|
method: 'PUT',
|
||||||
headers: headers(),
|
|
||||||
body: JSON.stringify(payload),
|
body: JSON.stringify(payload),
|
||||||
})
|
})
|
||||||
return handle<Task>(response)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function deleteTask(id: number) {
|
export async function deleteTask(id: number) {
|
||||||
const response = await fetch(`${API_BASE}/tasks/${id}`, {
|
return request<void>(`/tasks/${id}`, {
|
||||||
method: 'DELETE',
|
method: 'DELETE',
|
||||||
headers: headers(),
|
|
||||||
})
|
})
|
||||||
return handle<void>(response)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { reactive } from 'vue'
|
import { reactive } from 'vue'
|
||||||
|
|
||||||
const tokenKey = 'supertodo_token'
|
const tokenKey = 'supertodo_access_token'
|
||||||
|
|
||||||
export const session = reactive({
|
export const session = reactive({
|
||||||
token: localStorage.getItem(tokenKey) ?? '',
|
token: localStorage.getItem(tokenKey) ?? '',
|
||||||
@@ -14,6 +14,11 @@ export function setSession(token: string, email: string) {
|
|||||||
localStorage.setItem('supertodo_email', email)
|
localStorage.setItem('supertodo_email', email)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function setSessionToken(token: string) {
|
||||||
|
session.token = token
|
||||||
|
localStorage.setItem(tokenKey, token)
|
||||||
|
}
|
||||||
|
|
||||||
export function clearSession() {
|
export function clearSession() {
|
||||||
session.token = ''
|
session.token = ''
|
||||||
session.email = ''
|
session.email = ''
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ async function submit() {
|
|||||||
await api.register(form)
|
await api.register(form)
|
||||||
}
|
}
|
||||||
const data = await api.login(form)
|
const data = await api.login(form)
|
||||||
setSession(data.token, form.email)
|
setSession(data.access_token, form.email)
|
||||||
router.push('/todos')
|
router.push('/todos')
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
error.value = t('auth_failed')
|
error.value = t('auth_failed')
|
||||||
|
|||||||
Reference in New Issue
Block a user