package iam import ( "context" "database/sql" "errors" "strings" "time" "github.com/jackc/pgconn" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "golang.org/x/crypto/bcrypt" ) type postgresStore struct { pool *pgxpool.Pool } type refreshTokenRow struct { ID string SessionID string TokenHash string ExpiresAt time.Time RotatedFrom sql.NullString RevokedAt sql.NullTime UsedAt sql.NullTime CreatedAt time.Time } func newPostgresStore(pool *pgxpool.Pool) *postgresStore { return &postgresStore{pool: pool} } func (s *postgresStore) initSchema(ctx context.Context) error { stmts := []string{ `CREATE TABLE IF NOT EXISTS auth_sessions ( id TEXT PRIMARY KEY, user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, device_info TEXT, ip TEXT, user_agent TEXT, created_at TIMESTAMPTZ NOT NULL DEFAULT now(), expires_at TIMESTAMPTZ NOT NULL, revoked_at TIMESTAMPTZ )`, `CREATE INDEX IF NOT EXISTS idx_auth_sessions_user_revoked ON auth_sessions(user_id, revoked_at)`, `CREATE INDEX IF NOT EXISTS idx_auth_sessions_expires_at ON auth_sessions(expires_at)`, `CREATE TABLE IF NOT EXISTS auth_refresh_tokens ( id TEXT PRIMARY KEY, session_id TEXT NOT NULL REFERENCES auth_sessions(id) ON DELETE CASCADE, token_hash TEXT NOT NULL, expires_at TIMESTAMPTZ NOT NULL, rotated_from TEXT, revoked_at TIMESTAMPTZ, used_at TIMESTAMPTZ, created_at TIMESTAMPTZ NOT NULL DEFAULT now() )`, `CREATE INDEX IF NOT EXISTS idx_auth_refresh_tokens_session_revoked ON auth_refresh_tokens(session_id, revoked_at)`, `CREATE INDEX IF NOT EXISTS idx_auth_refresh_tokens_expires_at ON auth_refresh_tokens(expires_at)`, `CREATE INDEX IF NOT EXISTS idx_auth_refresh_tokens_rotated_from ON auth_refresh_tokens(rotated_from)`, } for _, stmt := range stmts { if _, err := s.pool.Exec(ctx, stmt); err != nil { return err } } return nil } func (s *postgresStore) registerUser(ctx context.Context, email, password string) (int64, error) { hashed, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return 0, err } var id int64 err = s.pool.QueryRow(ctx, `INSERT INTO users (email, password_hash) VALUES ($1, $2) RETURNING id`, email, string(hashed)).Scan(&id) if err != nil { if isUniqueViolation(err) { return 0, errAlreadyExists } return 0, err } return id, nil } func (s *postgresStore) verifyUser(ctx context.Context, email, password string) (int64, error) { var id int64 var hash string err := s.pool.QueryRow(ctx, `SELECT id, password_hash FROM users WHERE email = $1`, email).Scan(&id, &hash) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return 0, errInvalidCredentials } return 0, err } if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)); err != nil { return 0, errInvalidCredentials } return id, nil } func (s *postgresStore) createSession(ctx context.Context, sess Session) error { _, err := s.pool.Exec(ctx, ` INSERT INTO auth_sessions (id, user_id, device_info, ip, user_agent, expires_at) VALUES ($1, $2, $3, $4, $5, $6) `, sess.ID, sess.UserID, sess.DeviceInfo, sess.IP, sess.UserAgent, sess.ExpiresAt) return err } func (s *postgresStore) createRefreshToken(ctx context.Context, rid, sid, hash string, expiresAt time.Time, rotatedFrom *string) error { _, err := s.pool.Exec(ctx, ` INSERT INTO auth_refresh_tokens (id, session_id, token_hash, expires_at, rotated_from) VALUES ($1, $2, $3, $4, $5) `, rid, sid, hash, expiresAt, rotatedFrom) return err } func (s *postgresStore) getSession(ctx context.Context, sid string) (Session, error) { var sess Session var revokedAt sql.NullTime err := s.pool.QueryRow(ctx, ` SELECT id, user_id, COALESCE(device_info,''), COALESCE(ip,''), COALESCE(user_agent,''), created_at, expires_at, revoked_at FROM auth_sessions WHERE id = $1 `, sid).Scan(&sess.ID, &sess.UserID, &sess.DeviceInfo, &sess.IP, &sess.UserAgent, &sess.CreatedAt, &sess.ExpiresAt, &revokedAt) if err != nil { return Session{}, err } if revokedAt.Valid { t := revokedAt.Time sess.RevokedAt = &t } return sess, nil } func (s *postgresStore) isSessionActive(ctx context.Context, sid string) (bool, error) { var ok bool err := s.pool.QueryRow(ctx, ` SELECT EXISTS( SELECT 1 FROM auth_sessions WHERE id = $1 AND revoked_at IS NULL AND expires_at > now() ) `, sid).Scan(&ok) return ok, err } func (s *postgresStore) revokeSession(ctx context.Context, sid string) error { _, err := s.pool.Exec(ctx, `UPDATE auth_sessions SET revoked_at = now() WHERE id = $1 AND revoked_at IS NULL`, sid) if err != nil { return err } _, err = s.pool.Exec(ctx, `UPDATE auth_refresh_tokens SET revoked_at = now() WHERE session_id = $1 AND revoked_at IS NULL`, sid) return err } func (s *postgresStore) revokeAllUserSessions(ctx context.Context, userID int64) ([]string, error) { rows, err := s.pool.Query(ctx, `SELECT id FROM auth_sessions WHERE user_id = $1 AND revoked_at IS NULL`, userID) if err != nil { return nil, err } defer rows.Close() ids := make([]string, 0) for rows.Next() { var sid string if err := rows.Scan(&sid); err != nil { return nil, err } ids = append(ids, sid) } if err := rows.Err(); err != nil { return nil, err } _, err = s.pool.Exec(ctx, `UPDATE auth_sessions SET revoked_at = now() WHERE user_id = $1 AND revoked_at IS NULL`, userID) if err != nil { return nil, err } _, err = s.pool.Exec(ctx, ` UPDATE auth_refresh_tokens rt SET revoked_at = now() FROM auth_sessions s WHERE s.user_id = $1 AND rt.session_id = s.id AND rt.revoked_at IS NULL `, userID) if err != nil { return nil, err } return ids, nil } func (s *postgresStore) listSessions(ctx context.Context, userID int64) ([]Session, error) { rows, err := s.pool.Query(ctx, ` SELECT id, user_id, COALESCE(device_info,''), COALESCE(ip,''), COALESCE(user_agent,''), created_at, expires_at, revoked_at FROM auth_sessions WHERE user_id = $1 ORDER BY created_at DESC `, userID) if err != nil { return nil, err } defer rows.Close() result := make([]Session, 0) for rows.Next() { var sess Session var revokedAt sql.NullTime if err := rows.Scan(&sess.ID, &sess.UserID, &sess.DeviceInfo, &sess.IP, &sess.UserAgent, &sess.CreatedAt, &sess.ExpiresAt, &revokedAt); err != nil { return nil, err } if revokedAt.Valid { t := revokedAt.Time sess.RevokedAt = &t } result = append(result, sess) } return result, rows.Err() } func (s *postgresStore) getRefreshTokenForUpdate(ctx context.Context, tx pgx.Tx, rid string) (refreshTokenRow, error) { var row refreshTokenRow err := tx.QueryRow(ctx, ` SELECT id, session_id, token_hash, expires_at, rotated_from, revoked_at, used_at, created_at FROM auth_refresh_tokens WHERE id = $1 FOR UPDATE `, rid).Scan(&row.ID, &row.SessionID, &row.TokenHash, &row.ExpiresAt, &row.RotatedFrom, &row.RevokedAt, &row.UsedAt, &row.CreatedAt) return row, err } func (s *postgresStore) rotateRefreshToken(ctx context.Context, claims refreshClaims, providedHash, newRID, newHash string, newExp time.Time) error { tx, err := s.pool.BeginTx(ctx, pgx.TxOptions{}) if err != nil { return err } defer tx.Rollback(ctx) row, err := s.getRefreshTokenForUpdate(ctx, tx, claims.RefreshID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return errInvalidToken } return err } if row.SessionID != claims.SessionID || row.ExpiresAt.Before(time.Now().UTC()) { return errInvalidToken } if row.RevokedAt.Valid || row.UsedAt.Valid || row.TokenHash != providedHash { if _, err := tx.Exec(ctx, `UPDATE auth_sessions SET revoked_at = now() WHERE id = $1 AND revoked_at IS NULL`, claims.SessionID); err != nil { return err } if _, err := tx.Exec(ctx, `UPDATE auth_refresh_tokens SET revoked_at = now() WHERE session_id = $1 AND revoked_at IS NULL`, claims.SessionID); err != nil { return err } if err := tx.Commit(ctx); err != nil { return err } return errSessionRevoked } if _, err := tx.Exec(ctx, `UPDATE auth_refresh_tokens SET used_at = now(), revoked_at = now() WHERE id = $1`, claims.RefreshID); err != nil { return err } if _, err := tx.Exec(ctx, ` INSERT INTO auth_refresh_tokens (id, session_id, token_hash, expires_at, rotated_from) VALUES ($1, $2, $3, $4, $5) `, newRID, claims.SessionID, newHash, newExp, claims.RefreshID); err != nil { return err } return tx.Commit(ctx) } func isUniqueViolation(err error) bool { var pgErr *pgconn.PgError if errors.As(err, &pgErr) && pgErr.Code == "23505" { return true } return false } func normalizeEmail(v string) string { return strings.TrimSpace(strings.ToLower(v)) }