Extract ExecTx
This commit is contained in:
parent
31f0234098
commit
ea4739f79b
9 changed files with 222 additions and 260 deletions
33
db/db.go
33
db/db.go
|
|
@ -91,3 +91,36 @@ func extractDurationParam(q url.Values, key string, defaultValue time.Duration)
|
||||||
}
|
}
|
||||||
return d, nil
|
return d, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExecTx executes a function within a database transaction. If the function returns an error,
|
||||||
|
// the transaction is rolled back. Otherwise, the transaction is committed.
|
||||||
|
func ExecTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
if err := f(tx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryTx executes a function within a database transaction and returns the result. If the function
|
||||||
|
// returns an error, the transaction is rolled back. Otherwise, the transaction is committed.
|
||||||
|
func QueryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
var zero T
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
t, err := f(tx)
|
||||||
|
if err != nil {
|
||||||
|
return t, err
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return t, err
|
||||||
|
}
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
|
||||||
102
message/cache.go
102
message/cache.go
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
"heckel.io/ntfy/v2/model"
|
"heckel.io/ntfy/v2/model"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
|
|
@ -334,17 +335,14 @@ func (c *Cache) Topics() ([]string, error) {
|
||||||
func (c *Cache) DeleteMessages(ids ...string) error {
|
func (c *Cache) DeleteMessages(ids ...string) error {
|
||||||
c.maybeLock()
|
c.maybeLock()
|
||||||
defer c.maybeUnlock()
|
defer c.maybeUnlock()
|
||||||
tx, err := c.db.Begin()
|
return db.ExecTx(c.db, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
for _, id := range ids {
|
||||||
return err
|
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
|
||||||
}
|
return err
|
||||||
defer tx.Rollback()
|
}
|
||||||
for _, id := range ids {
|
|
||||||
if _, err := tx.Exec(c.queries.deleteMessage, id); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
return nil
|
||||||
return tx.Commit()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
|
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
|
||||||
|
|
@ -352,54 +350,43 @@ func (c *Cache) DeleteMessages(ids ...string) error {
|
||||||
func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
|
func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
|
||||||
c.maybeLock()
|
c.maybeLock()
|
||||||
defer c.maybeUnlock()
|
defer c.maybeUnlock()
|
||||||
tx, err := c.db.Begin()
|
return db.QueryTx(c.db, func(tx *sql.Tx) ([]string, error) {
|
||||||
if err != nil {
|
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
|
||||||
return nil, err
|
if err != nil {
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
// First, get the message IDs of scheduled messages to be deleted
|
|
||||||
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
ids := make([]string, 0)
|
|
||||||
for rows.Next() {
|
|
||||||
var id string
|
|
||||||
if err := rows.Scan(&id); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ids = append(ids, id)
|
defer rows.Close()
|
||||||
}
|
ids := make([]string, 0)
|
||||||
if err := rows.Err(); err != nil {
|
for rows.Next() {
|
||||||
return nil, err
|
var id string
|
||||||
}
|
if err := rows.Scan(&id); err != nil {
|
||||||
rows.Close() // Close rows before executing delete in same transaction
|
return nil, err
|
||||||
// Then delete the messages
|
}
|
||||||
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
|
ids = append(ids, id)
|
||||||
return nil, err
|
}
|
||||||
}
|
if err := rows.Err(); err != nil {
|
||||||
if err := tx.Commit(); err != nil {
|
return nil, err
|
||||||
return nil, err
|
}
|
||||||
}
|
rows.Close() // Close rows before executing delete in same transaction
|
||||||
return ids, nil
|
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return ids, nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExpireMessages marks messages in the given topics as expired
|
// ExpireMessages marks messages in the given topics as expired
|
||||||
func (c *Cache) ExpireMessages(topics ...string) error {
|
func (c *Cache) ExpireMessages(topics ...string) error {
|
||||||
c.maybeLock()
|
c.maybeLock()
|
||||||
defer c.maybeUnlock()
|
defer c.maybeUnlock()
|
||||||
tx, err := c.db.Begin()
|
return db.ExecTx(c.db, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
for _, t := range topics {
|
||||||
return err
|
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
|
||||||
}
|
return err
|
||||||
defer tx.Rollback()
|
}
|
||||||
for _, t := range topics {
|
|
||||||
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
return nil
|
||||||
return tx.Commit()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// AttachmentsExpired returns message IDs with expired attachments that have not been deleted
|
// AttachmentsExpired returns message IDs with expired attachments that have not been deleted
|
||||||
|
|
@ -427,17 +414,14 @@ func (c *Cache) AttachmentsExpired() ([]string, error) {
|
||||||
func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
|
func (c *Cache) MarkAttachmentsDeleted(ids ...string) error {
|
||||||
c.maybeLock()
|
c.maybeLock()
|
||||||
defer c.maybeUnlock()
|
defer c.maybeUnlock()
|
||||||
tx, err := c.db.Begin()
|
return db.ExecTx(c.db, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
for _, id := range ids {
|
||||||
return err
|
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
|
||||||
}
|
return err
|
||||||
defer tx.Rollback()
|
}
|
||||||
for _, id := range ids {
|
|
||||||
if _, err := tx.Exec(c.queries.updateAttachmentDeleted, id); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
return nil
|
||||||
return tx.Commit()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender
|
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@ package message
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Initial PostgreSQL schema
|
// Initial PostgreSQL schema
|
||||||
|
|
@ -55,34 +57,29 @@ const (
|
||||||
|
|
||||||
// PostgreSQL schema management queries
|
// PostgreSQL schema management queries
|
||||||
const (
|
const (
|
||||||
pgCurrentSchemaVersion = 14
|
postgresCurrentSchemaVersion = 14
|
||||||
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
|
postgresInsertSchemaVersionQuery = `INSERT INTO schema_version (store, version) VALUES ('message', $1)`
|
||||||
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
|
postgresSelectSchemaVersionQuery = `SELECT version FROM schema_version WHERE store = 'message'`
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupPostgres(db *sql.DB) error {
|
func setupPostgres(db *sql.DB) error {
|
||||||
var schemaVersion int
|
var schemaVersion int
|
||||||
err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion)
|
if err := db.QueryRow(postgresSelectSchemaVersionQuery).Scan(&schemaVersion); err != nil {
|
||||||
if err != nil {
|
|
||||||
return setupNewPostgresDB(db)
|
return setupNewPostgresDB(db)
|
||||||
}
|
} else if schemaVersion > postgresCurrentSchemaVersion {
|
||||||
if schemaVersion > pgCurrentSchemaVersion {
|
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, postgresCurrentSchemaVersion)
|
||||||
return fmt.Errorf("unexpected schema version: version %d is higher than current version %d", schemaVersion, pgCurrentSchemaVersion)
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupNewPostgresDB(db *sql.DB) error {
|
func setupNewPostgresDB(sqlDB *sql.DB) error {
|
||||||
tx, err := db.Begin()
|
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, postgresCurrentSchemaVersion); err != nil {
|
||||||
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
return nil
|
||||||
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
|
})
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -382,85 +383,70 @@ func sqliteMigrateFrom8(db *sql.DB, _ time.Duration) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func sqliteMigrateFrom9(db *sql.DB, cacheDuration time.Duration) error {
|
func sqliteMigrateFrom9(sqlDB *sql.DB, cacheDuration time.Duration) error {
|
||||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 9 to 10")
|
||||||
tx, err := db.Begin()
|
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
|
||||||
if _, err := tx.Exec(sqliteMigrate9To10AlterMessagesTableQuery); err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil {
|
||||||
if _, err := tx.Exec(sqliteMigrate9To10UpdateMessageExpiryQuery, int64(cacheDuration.Seconds())); err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
return nil
|
||||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 10); err != nil {
|
})
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func sqliteMigrateFrom10(db *sql.DB, _ time.Duration) error {
|
func sqliteMigrateFrom10(sqlDB *sql.DB, _ time.Duration) error {
|
||||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 10 to 11")
|
||||||
tx, err := db.Begin()
|
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil {
|
||||||
if _, err := tx.Exec(sqliteMigrate10To11AlterMessagesTableQuery); err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
return nil
|
||||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 11); err != nil {
|
})
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func sqliteMigrateFrom11(db *sql.DB, _ time.Duration) error {
|
func sqliteMigrateFrom11(sqlDB *sql.DB, _ time.Duration) error {
|
||||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 11 to 12")
|
||||||
tx, err := db.Begin()
|
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil {
|
||||||
if _, err := tx.Exec(sqliteMigrate11To12AlterMessagesTableQuery); err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
return nil
|
||||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 12); err != nil {
|
})
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func sqliteMigrateFrom12(db *sql.DB, _ time.Duration) error {
|
func sqliteMigrateFrom12(sqlDB *sql.DB, _ time.Duration) error {
|
||||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 12 to 13")
|
||||||
tx, err := db.Begin()
|
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil {
|
||||||
if _, err := tx.Exec(sqliteMigrate12To13AlterMessagesTableQuery); err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
return nil
|
||||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 13); err != nil {
|
})
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func sqliteMigrateFrom13(db *sql.DB, _ time.Duration) error {
|
func sqliteMigrateFrom13(sqlDB *sql.DB, _ time.Duration) error {
|
||||||
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14")
|
log.Tag(tagMessageCache).Info("Migrating cache database schema: from 13 to 14")
|
||||||
tx, err := db.Begin()
|
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil {
|
||||||
if _, err := tx.Exec(sqliteMigrate13To14AlterMessagesTableQuery); err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
return nil
|
||||||
if _, err := tx.Exec(sqliteUpdateSchemaVersionQuery, 14); err != nil {
|
})
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
"heckel.io/ntfy/v2/log"
|
"heckel.io/ntfy/v2/log"
|
||||||
"heckel.io/ntfy/v2/payments"
|
"heckel.io/ntfy/v2/payments"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
|
|
@ -122,7 +123,7 @@ func (a *Manager) AddUser(username, password string, role Role, hashed bool) err
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return execTx(a.db, func(tx *sql.Tx) error {
|
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||||
return a.addUserTx(tx, username, hash, role, false)
|
return a.addUserTx(tx, username, hash, role, false)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -150,7 +151,7 @@ func (a *Manager) RemoveUser(username string) error {
|
||||||
if err := a.CanChangeUser(username); err != nil {
|
if err := a.CanChangeUser(username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return execTx(a.db, func(tx *sql.Tx) error {
|
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||||
return a.removeUserTx(tx, username)
|
return a.removeUserTx(tx, username)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -173,7 +174,7 @@ func (a *Manager) MarkUserRemoved(user *User) error {
|
||||||
if !AllowedUsername(user.Name) {
|
if !AllowedUsername(user.Name) {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
return execTx(a.db, func(tx *sql.Tx) error {
|
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||||
if err := a.resetUserAccessTx(tx, user.Name); err != nil {
|
if err := a.resetUserAccessTx(tx, user.Name); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -205,7 +206,7 @@ func (a *Manager) ChangePassword(username, password string, hashed bool) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return execTx(a.db, func(tx *sql.Tx) error {
|
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||||
return a.changePasswordHashTx(tx, username, hash)
|
return a.changePasswordHashTx(tx, username, hash)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -224,7 +225,7 @@ func (a *Manager) ChangeRole(username string, role Role) error {
|
||||||
if err := a.CanChangeUser(username); err != nil {
|
if err := a.CanChangeUser(username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return execTx(a.db, func(tx *sql.Tx) error {
|
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||||
return a.changeRoleTx(tx, username, role)
|
return a.changeRoleTx(tx, username, role)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -365,7 +366,7 @@ func (a *Manager) writeUserStatsQueue() error {
|
||||||
a.statsQueue = make(map[string]*Stats)
|
a.statsQueue = make(map[string]*Stats)
|
||||||
a.mu.Unlock()
|
a.mu.Unlock()
|
||||||
|
|
||||||
return execTx(a.db, func(tx *sql.Tx) error {
|
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||||
log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
|
log.Tag(tag).Debug("Writing user stats queue for %d user(s)", len(statsQueue))
|
||||||
for userID, update := range statsQueue {
|
for userID, update := range statsQueue {
|
||||||
log.
|
log.
|
||||||
|
|
@ -573,7 +574,7 @@ func (a *Manager) resolvePerms(base, perm Permission) error {
|
||||||
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
|
// read/write access to a topic. The parameter topicPattern may include wildcards (*). The ACL entry
|
||||||
// owner may either be a user (username), or the system (empty).
|
// owner may either be a user (username), or the system (empty).
|
||||||
func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
|
func (a *Manager) AllowAccess(username string, topicPattern string, permission Permission) error {
|
||||||
return execTx(a.db, func(tx *sql.Tx) error {
|
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||||
return a.allowAccessTx(tx, username, topicPattern, permission, false)
|
return a.allowAccessTx(tx, username, topicPattern, permission, false)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -591,7 +592,7 @@ func (a *Manager) allowAccessTx(tx *sql.Tx, username string, topicPattern string
|
||||||
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
|
// ResetAccess removes an access control list entry for a specific username/topic, or (if topic is
|
||||||
// empty) for an entire user. The parameter topicPattern may include wildcards (*).
|
// empty) for an entire user. The parameter topicPattern may include wildcards (*).
|
||||||
func (a *Manager) ResetAccess(username string, topicPattern string) error {
|
func (a *Manager) ResetAccess(username string, topicPattern string) error {
|
||||||
return execTx(a.db, func(tx *sql.Tx) error {
|
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||||
return a.resetAccessTx(tx, username, topicPattern)
|
return a.resetAccessTx(tx, username, topicPattern)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -715,7 +716,7 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss
|
||||||
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
|
if !AllowedUsername(username) || username == Everyone || !AllowedTopic(topic) {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
return execTx(a.db, func(tx *sql.Tx) error {
|
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||||
if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil {
|
if err := a.addReservationAccessTx(tx, username, topic, true, true, username); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -735,7 +736,7 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return execTx(a.db, func(tx *sql.Tx) error {
|
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||||
for _, topic := range topics {
|
for _, topic := range topics {
|
||||||
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
|
if err := a.resetTopicAccessTx(tx, username, topic); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
@ -874,7 +875,7 @@ func (a *Manager) resetTopicAccessTx(tx *sql.Tx, username, topicPattern string)
|
||||||
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
|
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
|
||||||
// given user, if there are too many of them.
|
// given user, if there are too many of them.
|
||||||
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
|
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr, provisioned bool) (*Token, error) {
|
||||||
return queryTx(a.db, func(tx *sql.Tx) (*Token, error) {
|
return db.QueryTx(a.db, func(tx *sql.Tx) (*Token, error) {
|
||||||
return a.createTokenTx(tx, userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned)
|
return a.createTokenTx(tx, userID, GenerateToken(), label, time.Now(), origin, expires, tokenMaxCount, provisioned)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
@ -1033,7 +1034,7 @@ func (a *Manager) writeTokenUpdateQueue() error {
|
||||||
a.tokenQueue = make(map[string]*TokenUpdate)
|
a.tokenQueue = make(map[string]*TokenUpdate)
|
||||||
a.mu.Unlock()
|
a.mu.Unlock()
|
||||||
|
|
||||||
return execTx(a.db, func(tx *sql.Tx) error {
|
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||||
log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
|
log.Tag(tag).Debug("Writing token update queue for %d token(s)", len(tokenQueue))
|
||||||
for tokenID, update := range tokenQueue {
|
for tokenID, update := range tokenQueue {
|
||||||
log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
|
log.Tag(tag).Trace("Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
|
||||||
|
|
@ -1254,7 +1255,7 @@ func (a *Manager) maybeProvisionUsersAccessAndTokens() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return execTx(a.db, func(tx *sql.Tx) error {
|
return db.ExecTx(a.db, func(tx *sql.Tx) error {
|
||||||
if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil {
|
if err := a.maybeProvisionUsers(tx, provisionUsernames, existingUsers); err != nil {
|
||||||
return fmt.Errorf("failed to provision users: %v", err)
|
return fmt.Errorf("failed to provision users: %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
32
user/util.go
32
user/util.go
|
|
@ -113,35 +113,3 @@ func escapeUnderscore(s string) string {
|
||||||
func unescapeUnderscore(s string) string {
|
func unescapeUnderscore(s string) string {
|
||||||
return strings.ReplaceAll(s, "\\_", "_")
|
return strings.ReplaceAll(s, "\\_", "_")
|
||||||
}
|
}
|
||||||
|
|
||||||
// execTx executes a function in a transaction. If the function returns an error, the transaction is rolled back.
|
|
||||||
func execTx(db *sql.DB, f func(tx *sql.Tx) error) error {
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
if err := f(tx); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
|
||||||
|
|
||||||
// queryTx executes a function in a transaction and returns the result. If the function
|
|
||||||
// returns an error, the transaction is rolled back.
|
|
||||||
func queryTx[T any](db *sql.DB, f func(tx *sql.Tx) (T, error)) (T, error) {
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
var zero T
|
|
||||||
return zero, err
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
t, err := f(tx)
|
|
||||||
if err != nil {
|
|
||||||
return t, err
|
|
||||||
}
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
return t, err
|
|
||||||
}
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
"heckel.io/ntfy/v2/util"
|
"heckel.io/ntfy/v2/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -46,41 +47,38 @@ type queries struct {
|
||||||
|
|
||||||
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID.
|
// UpsertSubscription adds or updates Web Push subscriptions for the given topics and user ID.
|
||||||
func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
|
func (s *Store) UpsertSubscription(endpoint string, auth, p256dh, userID string, subscriberIP netip.Addr, topics []string) error {
|
||||||
tx, err := s.db.Begin()
|
return db.ExecTx(s.db, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
// Read number of subscriptions for subscriber IP address
|
||||||
return err
|
var subscriptionCount int
|
||||||
}
|
if err := tx.QueryRow(s.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil {
|
||||||
defer tx.Rollback()
|
|
||||||
// Read number of subscriptions for subscriber IP address
|
|
||||||
var subscriptionCount int
|
|
||||||
if err := tx.QueryRow(s.queries.selectSubscriptionCountBySubscriberIP, subscriberIP.String()).Scan(&subscriptionCount); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Read existing subscription ID for endpoint (or create new ID)
|
|
||||||
var subscriptionID string
|
|
||||||
if err := tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID); errors.Is(err, sql.ErrNoRows) {
|
|
||||||
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
|
|
||||||
return ErrWebPushTooManySubscriptions
|
|
||||||
}
|
|
||||||
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Insert or update subscription
|
|
||||||
updatedAt, warnedAt := time.Now().Unix(), 0
|
|
||||||
if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// Replace all subscription topics
|
|
||||||
if _, err := tx.Exec(s.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for _, topic := range topics {
|
|
||||||
if _, err = tx.Exec(s.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
// Read existing subscription ID for endpoint (or create new ID)
|
||||||
return tx.Commit()
|
var subscriptionID string
|
||||||
|
if err := tx.QueryRow(s.queries.selectSubscriptionIDByEndpoint, endpoint).Scan(&subscriptionID); errors.Is(err, sql.ErrNoRows) {
|
||||||
|
if subscriptionCount >= subscriptionEndpointLimitPerSubscriberIP {
|
||||||
|
return ErrWebPushTooManySubscriptions
|
||||||
|
}
|
||||||
|
subscriptionID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
|
||||||
|
} else if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Insert or update subscription
|
||||||
|
updatedAt, warnedAt := time.Now().Unix(), 0
|
||||||
|
if _, err := tx.Exec(s.queries.upsertSubscription, subscriptionID, endpoint, auth, p256dh, userID, subscriberIP.String(), updatedAt, warnedAt); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Replace all subscription topics
|
||||||
|
if _, err := tx.Exec(s.queries.deleteSubscriptionTopicAll, subscriptionID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, topic := range topics {
|
||||||
|
if _, err := tx.Exec(s.queries.insertSubscriptionTopic, subscriptionID, topic); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SubscriptionsForTopic returns all subscriptions for the given topic.
|
// SubscriptionsForTopic returns all subscriptions for the given topic.
|
||||||
|
|
@ -105,17 +103,14 @@ func (s *Store) SubscriptionsExpiring(warnAfter time.Duration) ([]*Subscription,
|
||||||
|
|
||||||
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon.
|
// MarkExpiryWarningSent marks the given subscriptions as having received a warning about expiring soon.
|
||||||
func (s *Store) MarkExpiryWarningSent(subscriptions []*Subscription) error {
|
func (s *Store) MarkExpiryWarningSent(subscriptions []*Subscription) error {
|
||||||
tx, err := s.db.Begin()
|
return db.ExecTx(s.db, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
for _, subscription := range subscriptions {
|
||||||
return err
|
if _, err := tx.Exec(s.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil {
|
||||||
}
|
return err
|
||||||
defer tx.Rollback()
|
}
|
||||||
for _, subscription := range subscriptions {
|
|
||||||
if _, err := tx.Exec(s.queries.updateSubscriptionWarningSent, time.Now().Unix(), subscription.ID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
return nil
|
||||||
return tx.Commit()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint.
|
// RemoveSubscriptionsByEndpoint removes the subscription for the given endpoint.
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@ package webpush
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
@ -107,17 +109,14 @@ func setupPostgres(db *sql.DB) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupNewPostgres(db *sql.DB) error {
|
func setupNewPostgres(sqlDB *sql.DB) error {
|
||||||
tx, err := db.Begin()
|
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
|
||||||
if _, err := tx.Exec(postgresCreateTablesQuery); err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
return nil
|
||||||
if _, err := tx.Exec(postgresInsertSchemaVersionQuery, pgCurrentSchemaVersion); err != nil {
|
})
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||||
|
|
||||||
|
"heckel.io/ntfy/v2/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|
@ -119,19 +121,16 @@ func setupSQLite(db *sql.DB) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupNewSQLite(db *sql.DB) error {
|
func setupNewSQLite(sqlDB *sql.DB) error {
|
||||||
tx, err := db.Begin()
|
return db.ExecTx(sqlDB, func(tx *sql.Tx) error {
|
||||||
if err != nil {
|
if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
|
||||||
if _, err := tx.Exec(sqliteCreateTablesQuery); err != nil {
|
return err
|
||||||
return err
|
}
|
||||||
}
|
return nil
|
||||||
if _, err := tx.Exec(sqliteInsertSchemaVersionQuery, sqliteCurrentSchemaVersion); err != nil {
|
})
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
func runSQLiteStartupQueries(db *sql.DB, startupQueries string) error {
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue