Refine, log unhealthy replica

This commit is contained in:
binwiederhier 2026-03-11 21:07:58 -04:00
parent ac65df1e83
commit 85bdfc61ce
10 changed files with 86 additions and 60 deletions

View file

@ -380,11 +380,11 @@ func createUserManager(c *cli.Context) (*user.Manager, error) {
QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval, QueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
} }
if databaseURL != "" { if databaseURL != "" {
pool, dbErr := pg.Open(databaseURL) host, dbErr := pg.Open(databaseURL)
if dbErr != nil { if dbErr != nil {
return nil, dbErr return nil, dbErr
} }
return user.NewPostgresManager(db.NewDB(pool, nil), authConfig) return user.NewPostgresManager(db.New(host, nil), authConfig)
} else if authFile != "" { } else if authFile != "" {
if !util.FileExists(authFile) { if !util.FileExists(authFile) {
return nil, errors.New("auth-file does not exist; please start the server at least once to create it") return nil, errors.New("auth-file does not exist; please start the server at least once to create it")

View file

@ -10,8 +10,9 @@ import (
) )
const ( const (
replicaHealthCheckInterval = 30 * time.Second replicaHealthCheckInitialDelay = 5 * time.Second
replicaHealthCheckTimeout = 2 * time.Second replicaHealthCheckInterval = 30 * time.Second
replicaHealthCheckTimeout = 10 * time.Second
) )
// Beginner is an interface for types that can begin a database transaction. // Beginner is an interface for types that can begin a database transaction.
@ -24,30 +25,29 @@ type Beginner interface {
// delegate to the primary. The ReadOnly() method returns a *sql.DB from a healthy replica // delegate to the primary. The ReadOnly() method returns a *sql.DB from a healthy replica
// (round-robin), falling back to the primary if no replicas are configured or all are unhealthy. // (round-robin), falling back to the primary if no replicas are configured or all are unhealthy.
type DB struct { type DB struct {
primary *sql.DB primary *Host
replicas []*replica replicas []*Host
counter atomic.Uint64 counter atomic.Uint64
cancel context.CancelFunc cancel context.CancelFunc
} }
type replica struct { // Host pairs a *sql.DB with the host:port it was opened against.
db *sql.DB type Host struct {
Addr string // "host:port"
DB *sql.DB
healthy atomic.Bool healthy atomic.Bool
} }
// NewDB creates a new DB that wraps the given primary and optional replica connections. // New creates a new DB that wraps the given primary and optional replica connections.
// If replicas is nil or empty, ReadOnly() simply returns the primary. // If replicas is nil or empty, ReadOnly() simply returns the primary.
// Replicas start unhealthy and are checked immediately by a background goroutine. // Replicas start unhealthy and are checked immediately by a background goroutine.
func NewDB(primary *sql.DB, replicas []*sql.DB) *DB { func New(primary *Host, replicas []*Host) *DB {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
d := &DB{ d := &DB{
primary: primary, primary: primary,
replicas: make([]*replica, len(replicas)), replicas: replicas,
cancel: cancel, cancel: cancel,
} }
for i, r := range replicas {
d.replicas[i] = &replica{db: r} // healthy defaults to false
}
if len(d.replicas) > 0 { if len(d.replicas) > 0 {
go d.healthCheckLoop(ctx) go d.healthCheckLoop(ctx)
} }
@ -57,63 +57,68 @@ func NewDB(primary *sql.DB, replicas []*sql.DB) *DB {
// Primary returns the underlying primary *sql.DB. This is only intended for // Primary returns the underlying primary *sql.DB. This is only intended for
// one-time schema setup during store initialization, not for regular queries. // one-time schema setup during store initialization, not for regular queries.
func (d *DB) Primary() *sql.DB { func (d *DB) Primary() *sql.DB {
return d.primary return d.primary.DB
} }
// Query delegates to the primary database. // Query delegates to the primary database.
func (d *DB) Query(query string, args ...any) (*sql.Rows, error) { func (d *DB) Query(query string, args ...any) (*sql.Rows, error) {
return d.primary.Query(query, args...) return d.primary.DB.Query(query, args...)
} }
// QueryRow delegates to the primary database. // QueryRow delegates to the primary database.
func (d *DB) QueryRow(query string, args ...any) *sql.Row { func (d *DB) QueryRow(query string, args ...any) *sql.Row {
return d.primary.QueryRow(query, args...) return d.primary.DB.QueryRow(query, args...)
} }
// Exec delegates to the primary database. // Exec delegates to the primary database.
func (d *DB) Exec(query string, args ...any) (sql.Result, error) { func (d *DB) Exec(query string, args ...any) (sql.Result, error) {
return d.primary.Exec(query, args...) return d.primary.DB.Exec(query, args...)
} }
// Begin delegates to the primary database. // Begin delegates to the primary database.
func (d *DB) Begin() (*sql.Tx, error) { func (d *DB) Begin() (*sql.Tx, error) {
return d.primary.Begin() return d.primary.DB.Begin()
} }
// Ping delegates to the primary database. // Ping delegates to the primary database.
func (d *DB) Ping() error { func (d *DB) Ping() error {
return d.primary.Ping() return d.primary.DB.Ping()
} }
// Close closes the primary database and all replicas, and stops the health-check goroutine. // Close closes the primary database and all replicas, and stops the health-check goroutine.
func (d *DB) Close() error { func (d *DB) Close() error {
d.cancel() d.cancel()
for _, r := range d.replicas { for _, r := range d.replicas {
r.db.Close() r.DB.Close()
} }
return d.primary.Close() return d.primary.DB.Close()
} }
// ReadOnly returns a *sql.DB suitable for read-only queries. It round-robins across healthy // ReadOnly returns a *sql.DB suitable for read-only queries. It round-robins across healthy
// replicas. If all replicas are unhealthy or none are configured, the primary is returned. // replicas. If all replicas are unhealthy or none are configured, the primary is returned.
func (d *DB) ReadOnly() *sql.DB { func (d *DB) ReadOnly() *sql.DB {
if len(d.replicas) == 0 { if len(d.replicas) == 0 {
return d.primary return d.primary.DB
} }
n := len(d.replicas) n := len(d.replicas)
start := int(d.counter.Add(1) - 1) start := int(d.counter.Add(1) - 1)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
r := d.replicas[(start+i)%n] r := d.replicas[(start+i)%n]
if r.healthy.Load() { if r.healthy.Load() {
return r.db return r.DB
} }
} }
return d.primary return d.primary.DB
} }
// healthCheckLoop checks replicas immediately, then periodically on a ticker. // healthCheckLoop checks replicas immediately, then periodically on a ticker.
func (d *DB) healthCheckLoop(ctx context.Context) { func (d *DB) healthCheckLoop(ctx context.Context) {
d.checkReplicas(ctx) select {
case <-ctx.Done():
return
case <-time.After(replicaHealthCheckInitialDelay):
d.checkReplicas(ctx)
}
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@ -129,17 +134,17 @@ func (d *DB) checkReplicas(ctx context.Context) {
for _, r := range d.replicas { for _, r := range d.replicas {
wasHealthy := r.healthy.Load() wasHealthy := r.healthy.Load()
pingCtx, cancel := context.WithTimeout(ctx, replicaHealthCheckTimeout) pingCtx, cancel := context.WithTimeout(ctx, replicaHealthCheckTimeout)
err := r.db.PingContext(pingCtx) err := r.DB.PingContext(pingCtx)
cancel() cancel()
if err != nil { if err != nil {
r.healthy.Store(false) r.healthy.Store(false)
if wasHealthy { if wasHealthy {
log.Error("Database replica is now unhealthy: %s", err) log.Error("Database replica %s is unhealthy: %s", r.Addr, err)
} }
} else { } else {
r.healthy.Store(true) r.healthy.Store(true)
if !wasHealthy { if !wasHealthy {
log.Info("Database replica is now healthy again") log.Info("Database replica %s is healthy", r.Addr)
} }
} }
} }

View file

@ -9,6 +9,8 @@ import (
"time" "time"
_ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver _ "github.com/jackc/pgx/v5/stdlib" // PostgreSQL driver
"heckel.io/ntfy/v2/db"
) )
const ( const (
@ -20,11 +22,30 @@ const (
defaultMaxOpenConns = 10 defaultMaxOpenConns = 10
) )
// Open opens a PostgreSQL database connection pool from a DSN string. It supports custom // Open opens a PostgreSQL connection pool for a primary database. It pings the database
// to verify connectivity before returning.
func Open(dsn string) (*db.Host, error) {
d, err := open(dsn)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
if err := d.DB.Ping(); err != nil {
return nil, fmt.Errorf("database ping failed on %v: %w", d.Addr, err)
}
return d, nil
}
// OpenReplica opens a PostgreSQL connection pool for a read replica. Unlike Open, it does
// not ping the database, since replicas are health-checked in the background by db.DB.
func OpenReplica(dsn string) (*db.Host, error) {
return open(dsn)
}
// open opens a PostgreSQL database connection pool from a DSN string. It supports custom
// query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns, // query parameters for pool configuration: pool_max_conns (default 10), pool_max_idle_conns,
// pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from // pool_conn_max_lifetime, and pool_conn_max_idle_time. These parameters are stripped from
// the DSN before passing it to the driver. // the DSN before passing it to the driver.
func Open(dsn string) (*sql.DB, error) { func open(dsn string) (*db.Host, error) {
u, err := url.Parse(dsn) u, err := url.Parse(dsn)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid database URL: %w", err) return nil, fmt.Errorf("invalid database URL: %w", err)
@ -53,24 +74,24 @@ func Open(dsn string) (*sql.DB, error) {
return nil, err return nil, err
} }
u.RawQuery = q.Encode() u.RawQuery = q.Encode()
db, err := sql.Open("pgx", u.String()) d, err := sql.Open("pgx", u.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }
db.SetMaxOpenConns(maxOpenConns) d.SetMaxOpenConns(maxOpenConns)
if maxIdleConns > 0 { if maxIdleConns > 0 {
db.SetMaxIdleConns(maxIdleConns) d.SetMaxIdleConns(maxIdleConns)
} }
if connMaxLifetime > 0 { if connMaxLifetime > 0 {
db.SetConnMaxLifetime(connMaxLifetime) d.SetConnMaxLifetime(connMaxLifetime)
} }
if connMaxIdleTime > 0 { if connMaxIdleTime > 0 {
db.SetConnMaxIdleTime(connMaxIdleTime) d.SetConnMaxIdleTime(connMaxIdleTime)
} }
if err := db.Ping(); err != nil { return &db.Host{
return nil, fmt.Errorf("database ping failed (URL: %s): %w", censorPassword(u), err) Addr: u.Host,
} DB: d,
return db, nil }, nil
} }
func extractIntParam(q url.Values, key string, defaultValue int) (int, error) { func extractIntParam(q url.Values, key string, defaultValue int) (int, error) {

View file

@ -30,19 +30,19 @@ func CreateTestPostgresSchema(t *testing.T) string {
q.Set("pool_max_conns", testPoolMaxConns) q.Set("pool_max_conns", testPoolMaxConns)
u.RawQuery = q.Encode() u.RawQuery = q.Encode()
dsn = u.String() dsn = u.String()
setupDB, err := pg.Open(dsn) setupHost, err := pg.Open(dsn)
require.Nil(t, err) require.Nil(t, err)
_, err = setupDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema)) _, err = setupHost.DB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schema))
require.Nil(t, err) require.Nil(t, err)
require.Nil(t, setupDB.Close()) require.Nil(t, setupHost.DB.Close())
q.Set("search_path", schema) q.Set("search_path", schema)
u.RawQuery = q.Encode() u.RawQuery = q.Encode()
schemaDSN := u.String() schemaDSN := u.String()
t.Cleanup(func() { t.Cleanup(func() {
cleanDB, err := pg.Open(dsn) cleanHost, err := pg.Open(dsn)
if err == nil { if err == nil {
cleanDB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema)) cleanHost.DB.Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", schema))
cleanDB.Close() cleanHost.DB.Close()
} }
}) })
return schemaDSN return schemaDSN
@ -54,9 +54,9 @@ func CreateTestPostgresSchema(t *testing.T) string {
func CreateTestPostgres(t *testing.T) *db.DB { func CreateTestPostgres(t *testing.T) *db.DB {
t.Helper() t.Helper()
schemaDSN := CreateTestPostgresSchema(t) schemaDSN := CreateTestPostgresSchema(t)
testDB, err := pg.Open(schemaDSN) testHost, err := pg.Open(schemaDSN)
require.Nil(t, err) require.Nil(t, err)
d := db.NewDB(testDB, nil) d := db.New(testHost, nil)
t.Cleanup(func() { t.Cleanup(func() {
d.Close() d.Close()
}) })

View file

@ -118,7 +118,7 @@ func NewSQLiteStore(filename, startupQueries string, cacheDuration time.Duration
if err := setupSQLite(sqlDB, startupQueries, cacheDuration); err != nil { if err := setupSQLite(sqlDB, startupQueries, cacheDuration); err != nil {
return nil, err return nil, err
} }
return newCache(db.NewDB(sqlDB, nil), sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil return newCache(db.New(&db.Host{DB: sqlDB}, nil), sqliteQueries, &sync.Mutex{}, batchSize, batchTimeout, nop), nil
} }
// NewMemStore creates an in-memory cache // NewMemStore creates an in-memory cache

View file

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/sha256" "crypto/sha256"
"database/sql"
"embed" "embed"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
@ -186,20 +185,20 @@ func New(conf *Config) (*Server, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
var replicas []*sql.DB var replicas []*db.Host
for _, replicaURL := range conf.DatabaseReplicaURLs { for _, replicaURL := range conf.DatabaseReplicaURLs {
r, err := pg.Open(replicaURL) r, err := pg.OpenReplica(replicaURL)
if err != nil { if err != nil {
// Close already-opened replicas before returning // Close already-opened replicas before returning
for _, opened := range replicas { for _, opened := range replicas {
opened.Close() opened.DB.Close()
} }
primary.Close() primary.DB.Close()
return nil, fmt.Errorf("failed to open database replica: %w", err) return nil, fmt.Errorf("failed to open database replica: %w", err)
} }
replicas = append(replicas, r) replicas = append(replicas, r)
} }
pool = db.NewDB(primary, replicas) pool = db.New(primary, replicas)
} }
messageCache, err := createMessageCache(conf, pool) messageCache, err := createMessageCache(conf, pool)
if err != nil { if err != nil {

View file

@ -236,10 +236,11 @@ func execImport(c *cli.Context) error {
} }
fmt.Println() fmt.Println()
pgDB, err := pg.Open(databaseURL) pgHost, err := pg.Open(databaseURL)
if err != nil { if err != nil {
return fmt.Errorf("cannot connect to PostgreSQL: %w", err) return fmt.Errorf("cannot connect to PostgreSQL: %w", err)
} }
pgDB := pgHost.DB
defer pgDB.Close() defer pgDB.Close()
if c.Bool("create-schema") { if c.Bool("create-schema") {

View file

@ -291,5 +291,5 @@ func NewSQLiteManager(filename, startupQueries string, config *Config) (*Manager
if err := runSQLiteStartupQueries(sqlDB, startupQueries); err != nil { if err := runSQLiteStartupQueries(sqlDB, startupQueries); err != nil {
return nil, err return nil, err
} }
return newManager(db.NewDB(sqlDB, nil), sqliteQueries, config) return newManager(db.New(&db.Host{DB: sqlDB}, nil), sqliteQueries, config)
} }

View file

@ -37,9 +37,9 @@ func forEachBackend(t *testing.T, f func(t *testing.T, newManager newManagerFunc
t.Run("postgres", func(t *testing.T) { t.Run("postgres", func(t *testing.T) {
schemaDSN := dbtest.CreateTestPostgresSchema(t) schemaDSN := dbtest.CreateTestPostgresSchema(t)
f(t, func(config *Config) *Manager { f(t, func(config *Config) *Manager {
pool, err := pg.Open(schemaDSN) host, err := pg.Open(schemaDSN)
require.Nil(t, err) require.Nil(t, err)
a, err := NewPostgresManager(db.NewDB(pool, nil), config) a, err := NewPostgresManager(db.New(host, nil), config)
require.Nil(t, err) require.Nil(t, err)
return a return a
}) })

View file

@ -90,7 +90,7 @@ func NewSQLiteStore(filename, startupQueries string) (*Store, error) {
return nil, err return nil, err
} }
return &Store{ return &Store{
db: db.NewDB(sqlDB, nil), db: db.New(&db.Host{DB: sqlDB}, nil),
queries: queries{ queries: queries{
selectSubscriptionIDByEndpoint: sqliteSelectSubscriptionIDByEndpointQuery, selectSubscriptionIDByEndpoint: sqliteSelectSubscriptionIDByEndpointQuery,
selectSubscriptionCountBySubscriberIP: sqliteSelectSubscriptionCountBySubscriberIPQuery, selectSubscriptionCountBySubscriberIP: sqliteSelectSubscriptionCountBySubscriberIPQuery,