refactor(auth): split IAM module and add access/refresh session flow
This commit is contained in:
279
internal/iam/store_pg.go
Normal file
279
internal/iam/store_pg.go
Normal file
@@ -0,0 +1,279 @@
|
||||
package iam
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type postgresStore struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
type refreshTokenRow struct {
|
||||
ID string
|
||||
SessionID string
|
||||
TokenHash string
|
||||
ExpiresAt time.Time
|
||||
RotatedFrom sql.NullString
|
||||
RevokedAt sql.NullTime
|
||||
UsedAt sql.NullTime
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
func newPostgresStore(pool *pgxpool.Pool) *postgresStore {
|
||||
return &postgresStore{pool: pool}
|
||||
}
|
||||
|
||||
func (s *postgresStore) initSchema(ctx context.Context) error {
|
||||
stmts := []string{
|
||||
`CREATE TABLE IF NOT EXISTS auth_sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
device_info TEXT,
|
||||
ip TEXT,
|
||||
user_agent TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
revoked_at TIMESTAMPTZ
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_auth_sessions_user_revoked ON auth_sessions(user_id, revoked_at)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_auth_sessions_expires_at ON auth_sessions(expires_at)`,
|
||||
`CREATE TABLE IF NOT EXISTS auth_refresh_tokens (
|
||||
id TEXT PRIMARY KEY,
|
||||
session_id TEXT NOT NULL REFERENCES auth_sessions(id) ON DELETE CASCADE,
|
||||
token_hash TEXT NOT NULL,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
rotated_from TEXT,
|
||||
revoked_at TIMESTAMPTZ,
|
||||
used_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_auth_refresh_tokens_session_revoked ON auth_refresh_tokens(session_id, revoked_at)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_auth_refresh_tokens_expires_at ON auth_refresh_tokens(expires_at)`,
|
||||
`CREATE INDEX IF NOT EXISTS idx_auth_refresh_tokens_rotated_from ON auth_refresh_tokens(rotated_from)`,
|
||||
}
|
||||
for _, stmt := range stmts {
|
||||
if _, err := s.pool.Exec(ctx, stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *postgresStore) registerUser(ctx context.Context, email, password string) (int64, error) {
|
||||
hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var id int64
|
||||
err = s.pool.QueryRow(ctx, `INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING id`, email, string(hashed)).Scan(&id)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
return 0, errAlreadyExists
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (s *postgresStore) verifyUser(ctx context.Context, email, password string) (int64, error) {
|
||||
var id int64
|
||||
var hash string
|
||||
err := s.pool.QueryRow(ctx, `SELECT id, password_hash FROM users WHERE email = $1`, email).Scan(&id, &hash)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return 0, errInvalidCredentials
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil {
|
||||
return 0, errInvalidCredentials
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (s *postgresStore) createSession(ctx context.Context, sess Session) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
INSERT INTO auth_sessions (id, user_id, device_info, ip, user_agent, expires_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
`, sess.ID, sess.UserID, sess.DeviceInfo, sess.IP, sess.UserAgent, sess.ExpiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *postgresStore) createRefreshToken(ctx context.Context, rid, sid, hash string, expiresAt time.Time, rotatedFrom *string) error {
|
||||
_, err := s.pool.Exec(ctx, `
|
||||
INSERT INTO auth_refresh_tokens (id, session_id, token_hash, expires_at, rotated_from)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
`, rid, sid, hash, expiresAt, rotatedFrom)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *postgresStore) getSession(ctx context.Context, sid string) (Session, error) {
|
||||
var sess Session
|
||||
var revokedAt sql.NullTime
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
SELECT id, user_id, COALESCE(device_info,''), COALESCE(ip,''), COALESCE(user_agent,''), created_at, expires_at, revoked_at
|
||||
FROM auth_sessions WHERE id = $1
|
||||
`, sid).Scan(&sess.ID, &sess.UserID, &sess.DeviceInfo, &sess.IP, &sess.UserAgent, &sess.CreatedAt, &sess.ExpiresAt, &revokedAt)
|
||||
if err != nil {
|
||||
return Session{}, err
|
||||
}
|
||||
if revokedAt.Valid {
|
||||
t := revokedAt.Time
|
||||
sess.RevokedAt = &t
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (s *postgresStore) isSessionActive(ctx context.Context, sid string) (bool, error) {
|
||||
var ok bool
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM auth_sessions
|
||||
WHERE id = $1 AND revoked_at IS NULL AND expires_at > now()
|
||||
)
|
||||
`, sid).Scan(&ok)
|
||||
return ok, err
|
||||
}
|
||||
|
||||
func (s *postgresStore) revokeSession(ctx context.Context, sid string) error {
|
||||
_, err := s.pool.Exec(ctx, `UPDATE auth_sessions SET revoked_at = now() WHERE id = $1 AND revoked_at IS NULL`, sid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.pool.Exec(ctx, `UPDATE auth_refresh_tokens SET revoked_at = now() WHERE session_id = $1 AND revoked_at IS NULL`, sid)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *postgresStore) revokeAllUserSessions(ctx context.Context, userID int64) ([]string, error) {
|
||||
rows, err := s.pool.Query(ctx, `SELECT id FROM auth_sessions WHERE user_id = $1 AND revoked_at IS NULL`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
ids := make([]string, 0)
|
||||
for rows.Next() {
|
||||
var sid string
|
||||
if err := rows.Scan(&sid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids = append(ids, sid)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = s.pool.Exec(ctx, `UPDATE auth_sessions SET revoked_at = now() WHERE user_id = $1 AND revoked_at IS NULL`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = s.pool.Exec(ctx, `
|
||||
UPDATE auth_refresh_tokens rt
|
||||
SET revoked_at = now()
|
||||
FROM auth_sessions s
|
||||
WHERE s.user_id = $1 AND rt.session_id = s.id AND rt.revoked_at IS NULL
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (s *postgresStore) listSessions(ctx context.Context, userID int64) ([]Session, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT id, user_id, COALESCE(device_info,''), COALESCE(ip,''), COALESCE(user_agent,''), created_at, expires_at, revoked_at
|
||||
FROM auth_sessions WHERE user_id = $1 ORDER BY created_at DESC
|
||||
`, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
result := make([]Session, 0)
|
||||
for rows.Next() {
|
||||
var sess Session
|
||||
var revokedAt sql.NullTime
|
||||
if err := rows.Scan(&sess.ID, &sess.UserID, &sess.DeviceInfo, &sess.IP, &sess.UserAgent, &sess.CreatedAt, &sess.ExpiresAt, &revokedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if revokedAt.Valid {
|
||||
t := revokedAt.Time
|
||||
sess.RevokedAt = &t
|
||||
}
|
||||
result = append(result, sess)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *postgresStore) getRefreshTokenForUpdate(ctx context.Context, tx pgx.Tx, rid string) (refreshTokenRow, error) {
|
||||
var row refreshTokenRow
|
||||
err := tx.QueryRow(ctx, `
|
||||
SELECT id, session_id, token_hash, expires_at, rotated_from, revoked_at, used_at, created_at
|
||||
FROM auth_refresh_tokens
|
||||
WHERE id = $1
|
||||
FOR UPDATE
|
||||
`, rid).Scan(&row.ID, &row.SessionID, &row.TokenHash, &row.ExpiresAt, &row.RotatedFrom, &row.RevokedAt, &row.UsedAt, &row.CreatedAt)
|
||||
return row, err
|
||||
}
|
||||
|
||||
func (s *postgresStore) rotateRefreshToken(ctx context.Context, claims refreshClaims, providedHash, newRID, newHash string, newExp time.Time) error {
|
||||
tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
row, err := s.getRefreshTokenForUpdate(ctx, tx, claims.RefreshID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return errInvalidToken
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if row.SessionID != claims.SessionID || row.ExpiresAt.Before(time.Now().UTC()) {
|
||||
return errInvalidToken
|
||||
}
|
||||
|
||||
if row.RevokedAt.Valid || row.UsedAt.Valid || row.TokenHash != providedHash {
|
||||
if _, err := tx.Exec(ctx, `UPDATE auth_sessions SET revoked_at = now() WHERE id = $1 AND revoked_at IS NULL`, claims.SessionID); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `UPDATE auth_refresh_tokens SET revoked_at = now() WHERE session_id = $1 AND revoked_at IS NULL`, claims.SessionID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return errSessionRevoked
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(ctx, `UPDATE auth_refresh_tokens SET used_at = now(), revoked_at = now() WHERE id = $1`, claims.RefreshID); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO auth_refresh_tokens (id, session_id, token_hash, expires_at, rotated_from)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
`, newRID, claims.SessionID, newHash, newExp, claims.RefreshID); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func isUniqueViolation(err error) bool {
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "23505" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeEmail(v string) string {
|
||||
return strings.TrimSpace(strings.ToLower(v))
|
||||
}
|
||||
Reference in New Issue
Block a user