Extract ExecTx

This commit is contained in:
binwiederhier 2026-03-02 19:45:35 -05:00
parent 31f0234098
commit ea4739f79b
9 changed files with 222 additions and 260 deletions

View file

@ -91,3 +91,36 @@ func extractDurationParam(q url.Values, key string, defaultValue time.Duration)
} }
return d, nil return d, nil
} }
// ExecTx executes a function within a database transaction. If the function returns an error,
// the transaction is rolled back. Otherwise, the transaction is committed.
func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := f(tx); err != nil {
return err
}
return tx.Commit()
}
// QueryTx executes a function within a database transaction and returns the result. If the function
// returns an error, the transaction is rolled back. Otherwise, the transaction is committed.
func QueryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
tx, err := db.Begin()
if err != nil {
var zero T
return zero, err
}
defer tx.Rollback()
t, err := f(tx)
if err != nil {
return t, err
}
if err := tx.Commit(); err != nil {
return t, err
}
return t, nil
}

View file

@ -9,6 +9,7 @@ import (
"sync" "sync"
"time" "time"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model" "heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util" "heckel.io/ntfy/v2/util"
@ -334,17 +335,14 @@ func (c *Cache) Topics() ([]string, error) {
func (c *Cache) DeleteMessages(ids ...string) error { func (c *Cache) DeleteMessages(ids ...string) error {
c.maybeLock() c.maybeLock()
defer c.maybeUnlock() defer c.maybeUnlock()
tx, err := c.db.Begin() return db.ExecTx(c.db, func(tx *sql.Tx) error {
if err != nil { for _, id := range ids {
return err if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
} return err
defer tx.Rollback() }
for _, id := range ids {
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
return err
} }
} return nil
return tx.Commit() })
} }
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID. // DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
@ -352,54 +350,43 @@ func (c *Cache) DeleteMessages(ids ...string) error {
func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) { func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
c.maybeLock() c.maybeLock()
defer c.maybeUnlock() defer c.maybeUnlock()
tx, err := c.db.Begin() return db.QueryTx(c.db, func(tx *sql.Tx) ([]string, error) {
if err != nil { rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
return nil, err if err != nil {
}
defer tx.Rollback()
// First, get the message IDs of scheduled messages to be deleted
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
if err != nil {
return nil, err
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
return nil, err return nil, err
} }
ids = append(ids, id) defer rows.Close()
} ids := make([]string, 0)
if err := rows.Err(); err != nil { for rows.Next() {
return nil, err var id string
} if err := rows.Scan(&id); err != nil {
rows.Close() // Close rows before executing delete in same transaction return nil, err
// Then delete the messages }
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil { ids = append(ids, id)
return nil, err }
} if err := rows.Err(); err != nil {
if err := tx.Commit(); err != nil { return nil, err
return nil, err }
} rows.Close() // Close rows before executing delete in same transaction
return ids, nil if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
return nil, err
}
return ids, nil
})
} }
// ExpireMessages marks messages in the given topics as expired // ExpireMessages marks messages in the given topics as expired
func (c *Cache) ExpireMessages(topics ...string) error { func (c *Cache) ExpireMessages(topics ...string) error {
c.maybeLock() c.maybeLock()
defer c.maybeUnlock() defer c.maybeUnlock()
tx, err := c.db.Begin() return db.ExecTx(c.db, func(tx *sql.Tx) error {
if err != nil { for _, t := range topics {
return err if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
} return err
defer tx.Rollback() }
for _, t := range topics {
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
return err
} }
} return nil
return tx.Commit() })
} }
// AttachmentsExpired returns message IDs with expired attachments that have not been deleted // AttachmentsExpired returns message IDs with expired attachments that have not been deleted
@ -427,17 +414,14 @@ func (c *Cache) AttachmentsExpired() ([]string, error) {
func (c *Cache) MarkAttachmentsDeleted(ids ...string) error { func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
c.maybeLock() c.maybeLock()
defer c.maybeUnlock() defer c.maybeUnlock()
tx, err := c.db.Begin() return db.ExecTx(c.db, func(tx *sql.Tx) error {
if err != nil { for _, id := range ids {
return err if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
} return err
defer tx.Rollback() }
for _, id := range ids {
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
return err
} }
} return nil
return tx.Commit() })
} }
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender // AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender

View file

@ -3,6 +3,8 @@ package message
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"heckel.io/ntfy/v2/db"
) )
// Initial PostgreSQL schema // Initial PostgreSQL schema
@ -55,34 +57,29 @@ const (
// PostgreSQL schema management queries // PostgreSQL schema management queries
const ( const (
pgCurrentSchemaVersion = 14 postgresCurrentSchemaVersion = 14
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('message', $1)` postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'` postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
) )
func setupPostgres(db *sql.DB) error { func setupPostgres(db *sql.DB) error {
var schemaVersion int var schemaVersion int
err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion) if err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion); err != nil {
if err != nil {
return setupNewPostgresDB(db) return setupNewPostgresDB(db)
} } else if schemaVersion > postgresCurrentSchemaVersion {
if schemaVersion > pgCurrentSchemaVersion { return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion)
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
} }
return nil return nil
} }
func setupNewPostgresDB(db *sql.DB) error { func setupNewPostgresDB(sqlDB *sql.DB) error {
tx, err := db.Begin() return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if err != nil { if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err return err
} }
defer tx.Rollback() if _, err := tx.Exec(postgresInsertSchemaVersionQuery, postgresCurrentSchemaVersion); err != nil {
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil { return err
return err }
} return nil
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil { })
return err
}
return tx.Commit()
} }

View file

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"time" "time"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/log"
) )
@ -382,85 +383,70 @@ func sqliteMigrateFrom8(db *sql.DB, _ time.Duration) error {
return nil return nil
} }
func sqliteMigrateFrom9(db *sql.DB, cacheDuration time.Duration) error { func sqliteMigrateFrom9(sqlDB *sql.DB, cacheDuration time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10") log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
tx, err := db.Begin() return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if err != nil { if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
return err return err
} }
defer tx.Rollback() if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil { return err
return err }
} if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil {
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil { return err
return err }
} return nil
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil { })
return err
}
return tx.Commit()
} }
func sqliteMigrateFrom10(db *sql.DB, _ time.Duration) error { func sqliteMigrateFrom10(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11") log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
tx, err := db.Begin() return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if err != nil { if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
return err return err
} }
defer tx.Rollback() if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil {
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil { return err
return err }
} return nil
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil { })
return err
}
return tx.Commit()
} }
func sqliteMigrateFrom11(db *sql.DB, _ time.Duration) error { func sqliteMigrateFrom11(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12") log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
tx, err := db.Begin() return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if err != nil { if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
return err return err
} }
defer tx.Rollback() if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil {
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil { return err
return err }
} return nil
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil { })
return err
}
return tx.Commit()
} }
func sqliteMigrateFrom12(db *sql.DB, _ time.Duration) error { func sqliteMigrateFrom12(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13") log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
tx, err := db.Begin() return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if err != nil { if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
return err return err
} }
defer tx.Rollback() if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil {
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil { return err
return err }
} return nil
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil { })
return err
}
return tx.Commit()
} }
func sqliteMigrateFrom13(db *sql.DB, _ time.Duration) error { func sqliteMigrateFrom13(sqlDB *sql.DB, _ time.Duration) error {
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14") log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14")
tx, err := db.Begin() return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if err != nil { if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
return err return err
} }
defer tx.Rollback() if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil {
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil { return err
return err }
} return nil
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil { })
return err
}
return tx.Commit()
} }

View file

@ -13,6 +13,7 @@ import (
"time" "time"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log" "heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/payments" "heckel.io/ntfy/v2/payments"
"heckel.io/ntfy/v2/util" "heckel.io/ntfy/v2/util"
@ -122,7 +123,7 @@ func (a *Manager) AddUser(username, password string, role Role, hashed bool) err
if err != nil { if err != nil {
return err return err
} }
return execTx(a.db, func(tx *sql.Tx) error { return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.addUserTx(tx, username, hash, role, false) return a.addUserTx(tx, username, hash, role, false)
}) })
} }
@ -150,7 +151,7 @@ func (a *Manager) RemoveUser(username string) error {
if err := a.CanChangeUser(username); err != nil { if err := a.CanChangeUser(username); err != nil {
return err return err
} }
return execTx(a.db, func(tx *sql.Tx) error { return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.removeUserTx(tx, username) return a.removeUserTx(tx, username)
}) })
} }
@ -173,7 +174,7 @@ func (a *Manager) MarkUserRemoved(user *User) error {
if !AllowedUsername(user.Name) { if !AllowedUsername(user.Name) {
return ErrInvalidArgument return ErrInvalidArgument
} }
return execTx(a.db, func(tx *sql.Tx) error { return db.ExecTx(a.db, func(tx *sql.Tx) error {
if err := a.resetUserAccessTx(tx, user.Name); err != nil { if err := a.resetUserAccessTx(tx, user.Name); err != nil {
return err return err
} }
@ -205,7 +206,7 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error {
if err != nil { if err != nil {
return err return err
} }
return execTx(a.db, func(tx *sql.Tx) error { return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.changePasswordHashTx(tx, username, hash) return a.changePasswordHashTx(tx, username, hash)
}) })
} }
@ -224,7 +225,7 @@ func (a *Manager) ChangeRole(username string, role Role) error {
if err := a.CanChangeUser(username); err != nil { if err := a.CanChangeUser(username); err != nil {
return err return err
} }
return execTx(a.db, func(tx *sql.Tx) error { return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.changeRoleTx(tx, username, role) return a.changeRoleTx(tx, username, role)
}) })
} }
@ -365,7 +366,7 @@ func (a *Manager) writeUserStatsQueue() error {
a.statsQueue = make(map[string]*Stats) a.statsQueue = make(map[string]*Stats)
a.mu.Unlock() a.mu.Unlock()
return execTx(a.db, func(tx *sql.Tx) error { return db.ExecTx(a.db, func(tx *sql.Tx) error {
log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue)) log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
for userID, update := range statsQueue { for userID, update := range statsQueue {
log. log.
@ -573,7 +574,7 @@ func (a *Manager) resolvePerms(base, perm Permission) error {
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry // read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
// owner may either be a user (username), or the system (empty). // owner may either be a user (username), or the system (empty).
func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error { func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
return execTx(a.db, func(tx *sql.Tx) error { return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.allowAccessTx(tx, username, topicPattern, permission, false) return a.allowAccessTx(tx, username, topicPattern, permission, false)
}) })
} }
@ -591,7 +592,7 @@ func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is // ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
// empty) for an entire user. The parameter topicPattern may include wildcards (*). // empty) for an entire user. The parameter topicPattern may include wildcards (*).
func (a *Manager) ResetAccess(username string, topicPattern string) error { func (a *Manager) ResetAccess(username string, topicPattern string) error {
return execTx(a.db, func(tx *sql.Tx) error { return db.ExecTx(a.db, func(tx *sql.Tx) error {
return a.resetAccessTx(tx, username, topicPattern) return a.resetAccessTx(tx, username, topicPattern)
}) })
} }
@ -715,7 +716,7 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) { if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
return ErrInvalidArgument return ErrInvalidArgument
} }
return execTx(a.db, func(tx *sql.Tx) error { return db.ExecTx(a.db, func(tx *sql.Tx) error {
if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil { if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil {
return err return err
} }
@ -735,7 +736,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
return ErrInvalidArgument return ErrInvalidArgument
} }
} }
return execTx(a.db, func(tx *sql.Tx) error { return db.ExecTx(a.db, func(tx *sql.Tx) error {
for _, topic := range topics { for _, topic := range topics {
if err := a.resetTopicAccessTx(tx, username, topic); err != nil { if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
return err return err
@ -874,7 +875,7 @@ func (a *Manager) resetTopicAccessTx(tx *sql.Tx, username, topicPattern string)
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
// given user, if there are too many of them. // given user, if there are too many of them.
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) { func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
return queryTx(a.db, func(tx *sql.Tx) (*Token, error) { return db.QueryTx(a.db, func(tx *sql.Tx) (*Token, error) {
return a.createTokenTx(tx, userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned) return a.createTokenTx(tx, userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned)
}) })
} }
@ -1033,7 +1034,7 @@ func (a *Manager) writeTokenUpdateQueue() error {
a.tokenQueue = make(map[string]*TokenUpdate) a.tokenQueue = make(map[string]*TokenUpdate)
a.mu.Unlock() a.mu.Unlock()
return execTx(a.db, func(tx *sql.Tx) error { return db.ExecTx(a.db, func(tx *sql.Tx) error {
log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue)) log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
for tokenID, update := range tokenQueue { for tokenID, update := range tokenQueue {
log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix()) log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
@ -1254,7 +1255,7 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
if err != nil { if err != nil {
return err return err
} }
return execTx(a.db, func(tx *sql.Tx) error { return db.ExecTx(a.db, func(tx *sql.Tx) error {
if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil { if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil {
return fmt.Errorf("failed to provision users: %v", err) return fmt.Errorf("failed to provision users: %v", err)
} }

View file

@ -113,35 +113,3 @@ func escapeUnderscore(s string) string {
func unescapeUnderscore(s string) string { func unescapeUnderscore(s string) string {
return strings.ReplaceAll(s, "\\_", "_") return strings.ReplaceAll(s, "\\_", "_")
} }
// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := f(tx); err != nil {
return err
}
return tx.Commit()
}
// queryTx executes a function in a transaction and returns the result. If the function
// returns an error, the transaction is rolled back.
func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
tx, err := db.Begin()
if err != nil {
var zero T
return zero, err
}
defer tx.Rollback()
t, err := f(tx)
if err != nil {
return t, err
}
if err := tx.Commit(); err != nil {
return t, err
}
return t, nil
}

View file

@ -6,6 +6,7 @@ import (
"net/netip" "net/netip"
"time" "time"
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/util" "heckel.io/ntfy/v2/util"
) )
@ -46,41 +47,38 @@ type queries struct {
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID. // UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID.
func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error { func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
tx, err := s.db.Begin() return db.ExecTx(s.db, func(tx *sql.Tx) error {
if err != nil { // Read number of subscriptions for subscriber IP address
return err var subscriptionCount int
} if err := tx.QueryRow(s.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil {
defer tx.Rollback()
// Read number of subscriptions for subscriber IP address
var subscriptionCount int
if err := tx.QueryRow(s.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil {
return err
}
// Read existing subscription ID for endpoint (or create new ID)
var subscriptionID string
if err := tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID); errors.Is(err, sql.ErrNoRows) {
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
return ErrWebPushTooManySubscriptions
}
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
} else if err != nil {
return err
}
// Insert or update subscription
updatedAt, warnedAt := time.Now().Unix(), 0
if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
return err
}
// Replace all subscription topics
if _, err := tx.Exec(s.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil {
return err
}
for _, topic := range topics {
if _, err = tx.Exec(s.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil {
return err return err
} }
} // Read existing subscription ID for endpoint (or create new ID)
return tx.Commit() var subscriptionID string
if err := tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID); errors.Is(err, sql.ErrNoRows) {
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
return ErrWebPushTooManySubscriptions
}
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
} else if err != nil {
return err
}
// Insert or update subscription
updatedAt, warnedAt := time.Now().Unix(), 0
if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
return err
}
// Replace all subscription topics
if _, err := tx.Exec(s.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil {
return err
}
for _, topic := range topics {
if _, err := tx.Exec(s.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil {
return err
}
}
return nil
})
} }
// SubscriptionsForTopic returns all subscriptions for the given topic. // SubscriptionsForTopic returns all subscriptions for the given topic.
@ -105,17 +103,14 @@ func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription,
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon. // MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon.
func (s *Store) MarkExpiryWarningSent(subscriptions []*Subscription) error { func (s *Store) MarkExpiryWarningSent(subscriptions []*Subscription) error {
tx, err := s.db.Begin() return db.ExecTx(s.db, func(tx *sql.Tx) error {
if err != nil { for _, subscription := range subscriptions {
return err if _, err := tx.Exec(s.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil {
} return err
defer tx.Rollback() }
for _, subscription := range subscriptions {
if _, err := tx.Exec(s.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil {
return err
} }
} return nil
return tx.Commit() })
} }
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint. // RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint.

View file

@ -3,6 +3,8 @@ package webpush
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"heckel.io/ntfy/v2/db"
) )
const ( const (
@ -107,17 +109,14 @@ func setupPostgres(db *sql.DB) error {
return nil return nil
} }
func setupNewPostgres(db *sql.DB) error { func setupNewPostgres(sqlDB *sql.DB) error {
tx, err := db.Begin() return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if err != nil { if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
return err return err
} }
defer tx.Rollback() if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil { return err
return err }
} return nil
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil { })
return err
}
return tx.Commit()
} }

View file

@ -5,6 +5,8 @@ import (
"fmt" "fmt"
_ "github.com/mattn/go-sqlite3" // SQLite driver _ "github.com/mattn/go-sqlite3" // SQLite driver
"heckel.io/ntfy/v2/db"
) )
const ( const (
@ -119,19 +121,16 @@ func setupSQLite(db *sql.DB) error {
return nil return nil
} }
func setupNewSQLite(db *sql.DB) error { func setupNewSQLite(sqlDB *sql.DB) error {
tx, err := db.Begin() return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
if err != nil { if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil {
return err return err
} }
defer tx.Rollback() if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil { return err
return err }
} return nil
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil { })
return err
}
return tx.Commit()
} }
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error { func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {