ntfy-server/message/cache.go

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

567 lines
15 KiB
Go
Raw Normal View History

package message
import (
"database/sql"
"encoding/json"
"errors"
"net/netip"
"strings"
"sync"
"time"
2026-03-02 19:45:35 -05:00
"heckel.io/ntfy/v2/db"
"heckel.io/ntfy/v2/log"
"heckel.io/ntfy/v2/model"
"heckel.io/ntfy/v2/util"
)
const (
tagMessageCache = "message_cache"
)
var errNoRows = errors.New("no rows found")
2026-03-01 13:19:53 -05:00
// queries holds the database-specific SQL queries
type queries struct {
insertMessage string
selectScheduledMessageIDsBySeqID string
deleteScheduledBySequenceID string
updateMessagesForTopicExpiry string
selectMessagesByID string
selectMessagesSinceTime string
selectMessagesSinceTimeScheduled string
selectMessagesSinceID string
selectMessagesSinceIDScheduled string
selectMessagesLatest string
selectMessagesDue string
2026-03-25 15:28:23 -04:00
deleteExpiredMessages string
updateMessagePublished string
selectMessagesCount string
selectTopics string
2026-03-25 15:28:23 -04:00
markExpiredAttachmentsDeleted string
selectAttachmentsSizeBySender string
selectAttachmentsSizeByUserID string
2026-03-23 12:44:40 -04:00
selectAttachmentsWithSizes string
selectStats string
updateStats string
updateMessageTime string
}
2026-03-01 13:19:53 -05:00
// Cache stores published messages
type Cache struct {
2026-03-10 22:17:40 -04:00
db *db.DB
queue *util.BatchingQueue[*model.Message]
nop bool
2026-02-22 16:21:27 -05:00
mu *sync.Mutex // nil for PostgreSQL (concurrent writes supported), set for SQLite (single writer)
2026-03-01 13:19:53 -05:00
queries queries
}
2026-03-10 22:17:40 -04:00
func newCache(db *db.DB, queries queries, mu *sync.Mutex, batchSize int, batchTimeout time.Duration, nop bool) *Cache {
var queue *util.BatchingQueue[*model.Message]
if batchSize > 0 || batchTimeout > 0 {
queue = util.NewBatchingQueue[*model.Message](batchSize, batchTimeout)
}
2026-03-01 13:19:53 -05:00
c := &Cache{
db: db,
queue: queue,
nop: nop,
2026-02-22 16:21:27 -05:00
mu: mu,
queries: queries,
}
go c.processMessageBatches()
return c
}
2026-03-01 13:19:53 -05:00
func (c *Cache) maybeLock() {
2026-02-22 16:21:27 -05:00
if c.mu != nil {
c.mu.Lock()
}
}
2026-03-01 13:19:53 -05:00
func (c *Cache) maybeUnlock() {
2026-02-22 16:21:27 -05:00
if c.mu != nil {
c.mu.Unlock()
}
}
// AddMessage stores a message to the message cache synchronously, or queues it to be stored at a later date asynchronously.
2026-02-20 16:15:07 -05:00
// The message is queued only if "batchSize" or "batchTimeout" are passed to the constructor.
2026-03-01 13:19:53 -05:00
func (c *Cache) AddMessage(m *model.Message) error {
if c.queue != nil {
c.queue.Enqueue(m)
return nil
}
return c.addMessages([]*model.Message{m})
}
// AddMessages synchronously stores a batch of messages to the message cache
2026-03-01 13:19:53 -05:00
func (c *Cache) AddMessages(ms []*model.Message) error {
return c.addMessages(ms)
}
2026-03-01 13:19:53 -05:00
func (c *Cache) addMessages(ms []*model.Message) error {
2026-02-22 16:21:27 -05:00
c.maybeLock()
defer c.maybeUnlock()
if c.nop {
return nil
}
if len(ms) == 0 {
return nil
}
start := time.Now()
tx, err := c.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
stmt, err := tx.Prepare(c.queries.insertMessage)
if err != nil {
return err
}
defer stmt.Close()
for _, m := range ms {
if m.Event != model.MessageEvent && m.Event != model.MessageDeleteEvent && m.Event != model.MessageClearEvent {
return model.ErrUnexpectedMessageType
}
published := m.Time <= time.Now().Unix()
2026-03-15 21:03:18 -04:00
tags := util.SanitizeUTF8(strings.Join(m.Tags, ","))
var attachmentName, attachmentType, attachmentURL string
var attachmentSize, attachmentExpires int64
var attachmentDeleted bool
if m.Attachment != nil {
2026-03-15 21:03:18 -04:00
attachmentName = util.SanitizeUTF8(m.Attachment.Name)
attachmentType = util.SanitizeUTF8(m.Attachment.Type)
attachmentSize = m.Attachment.Size
attachmentExpires = m.Attachment.Expires
2026-03-15 21:03:18 -04:00
attachmentURL = util.SanitizeUTF8(m.Attachment.URL)
}
var actionsStr string
if len(m.Actions) > 0 {
actionsBytes, err := json.Marshal(m.Actions)
if err != nil {
return err
}
actionsStr = string(actionsBytes)
}
var sender string
if m.Sender.IsValid() {
sender = m.Sender.String()
}
_, err := stmt.Exec(
m.ID,
m.SequenceID,
m.Time,
m.Event,
m.Expires,
2026-03-15 21:03:18 -04:00
util.SanitizeUTF8(m.Topic),
util.SanitizeUTF8(m.Message),
util.SanitizeUTF8(m.Title),
m.Priority,
tags,
2026-03-15 21:03:18 -04:00
util.SanitizeUTF8(m.Click),
util.SanitizeUTF8(m.Icon),
actionsStr,
attachmentName,
attachmentType,
attachmentSize,
attachmentExpires,
attachmentURL,
attachmentDeleted, // Always zero
sender,
m.User,
2026-03-15 21:03:18 -04:00
util.SanitizeUTF8(m.ContentType),
m.Encoding,
published,
)
if err != nil {
return err
}
}
if err := tx.Commit(); err != nil {
log.Tag(tagMessageCache).Err(err).Error("Writing %d message(s) failed (took %v)", len(ms), time.Since(start))
return err
}
log.Tag(tagMessageCache).Debug("Wrote %d message(s) in %v", len(ms), time.Since(start))
return nil
}
2026-03-01 13:19:53 -05:00
// Messages returns messages for a topic since the given marker, optionally including scheduled messages
func (c *Cache) Messages(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
if since.IsNone() {
return make([]*model.Message, 0), nil
} else if since.IsLatest() {
return c.messagesLatest(topic)
} else if since.IsID() {
return c.messagesSinceID(topic, since, scheduled)
}
return c.messagesSinceTime(topic, since, scheduled)
}
2026-03-01 13:19:53 -05:00
func (c *Cache) messagesSinceTime(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
var rows *sql.Rows
var err error
2026-03-10 22:17:40 -04:00
rdb := c.db.ReadOnly()
if scheduled {
2026-03-10 22:17:40 -04:00
rows, err = rdb.Query(c.queries.selectMessagesSinceTimeScheduled, topic, since.Time().Unix())
} else {
2026-03-10 22:17:40 -04:00
rows, err = rdb.Query(c.queries.selectMessagesSinceTime, topic, since.Time().Unix())
}
if err != nil {
return nil, err
}
return readMessages(rows)
}
2026-03-01 13:19:53 -05:00
func (c *Cache) messagesSinceID(topic string, since model.SinceMarker, scheduled bool) ([]*model.Message, error) {
var rows *sql.Rows
2026-02-22 16:21:27 -05:00
var err error
2026-03-10 22:17:40 -04:00
rdb := c.db.ReadOnly()
if scheduled {
2026-03-10 22:17:40 -04:00
rows, err = rdb.Query(c.queries.selectMessagesSinceIDScheduled, topic, since.ID())
} else {
2026-03-10 22:17:40 -04:00
rows, err = rdb.Query(c.queries.selectMessagesSinceID, topic, since.ID())
}
if err != nil {
return nil, err
}
return readMessages(rows)
}
2026-03-01 13:19:53 -05:00
func (c *Cache) messagesLatest(topic string) ([]*model.Message, error) {
2026-03-10 22:17:40 -04:00
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesLatest, topic)
if err != nil {
return nil, err
}
return readMessages(rows)
}
2026-03-01 13:19:53 -05:00
// MessagesDue returns all messages that are due for publishing
func (c *Cache) MessagesDue() ([]*model.Message, error) {
rows, err := c.db.Query(c.queries.selectMessagesDue, time.Now().Unix())
if err != nil {
return nil, err
}
return readMessages(rows)
}
2026-03-25 15:28:23 -04:00
// DeleteExpiredMessages deletes up to `limit` expired messages in a single query
// and returns the number of deleted rows.
func (c *Cache) DeleteExpiredMessages(limit int) (int64, error) {
c.maybeLock()
defer c.maybeUnlock()
result, err := c.db.Exec(c.queries.deleteExpiredMessages, time.Now().Unix(), limit)
if err != nil {
2026-03-25 15:28:23 -04:00
return 0, err
}
2026-03-25 15:28:23 -04:00
return result.RowsAffected()
}
2026-03-01 13:19:53 -05:00
// Message returns the message with the given ID, or ErrMessageNotFound if not found
func (c *Cache) Message(id string) (*model.Message, error) {
2026-03-10 22:17:40 -04:00
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesByID, id)
if err != nil {
return nil, err
}
defer rows.Close()
if !rows.Next() {
return nil, model.ErrMessageNotFound
}
return readMessage(rows)
}
// UpdateMessageTime updates the time column for a message by ID. This is only used for testing.
2026-03-01 13:19:53 -05:00
func (c *Cache) UpdateMessageTime(messageID string, timestamp int64) error {
2026-02-23 11:17:57 -05:00
c.maybeLock()
defer c.maybeUnlock()
_, err := c.db.Exec(c.queries.updateMessageTime, timestamp, messageID)
return err
}
2026-03-01 13:19:53 -05:00
// MarkPublished marks a message as published
func (c *Cache) MarkPublished(m *model.Message) error {
2026-02-22 16:21:27 -05:00
c.maybeLock()
defer c.maybeUnlock()
_, err := c.db.Exec(c.queries.updateMessagePublished, m.ID)
return err
}
2026-03-01 13:19:53 -05:00
// MessagesCount returns the total number of messages in the cache
func (c *Cache) MessagesCount() (int, error) {
2026-03-10 22:17:40 -04:00
rows, err := c.db.ReadOnly().Query(c.queries.selectMessagesCount)
if err != nil {
2026-02-22 16:21:27 -05:00
return 0, err
}
defer rows.Close()
2026-02-22 16:21:27 -05:00
if !rows.Next() {
return 0, errNoRows
}
var count int
2026-02-22 16:21:27 -05:00
if err := rows.Scan(&count); err != nil {
return 0, err
}
2026-02-22 16:21:27 -05:00
return count, nil
}
2026-03-01 13:19:53 -05:00
// Topics returns a list of all topics with messages in the cache
func (c *Cache) Topics() ([]string, error) {
2026-03-10 22:17:40 -04:00
rows, err := c.db.ReadOnly().Query(c.queries.selectTopics)
if err != nil {
return nil, err
}
defer rows.Close()
2026-03-17 20:53:41 -04:00
return readStrings(rows)
}
// DeleteScheduledBySequenceID deletes unpublished (scheduled) messages with the given topic and sequence ID.
// It returns the message IDs of the deleted messages, which can be used to clean up attachment files.
2026-03-01 13:19:53 -05:00
func (c *Cache) DeleteScheduledBySequenceID(topic, sequenceID string) ([]string, error) {
2026-02-22 16:21:27 -05:00
c.maybeLock()
defer c.maybeUnlock()
2026-03-02 19:45:35 -05:00
return db.QueryTx(c.db, func(tx *sql.Tx) ([]string, error) {
rows, err := tx.Query(c.queries.selectScheduledMessageIDsBySeqID, topic, sequenceID)
if err != nil {
return nil, err
}
2026-03-02 19:45:35 -05:00
defer rows.Close()
2026-03-17 20:53:41 -04:00
ids, err := readStrings(rows)
if err != nil {
2026-03-02 19:45:35 -05:00
return nil, err
}
rows.Close() // Close rows before executing delete in same transaction
if _, err := tx.Exec(c.queries.deleteScheduledBySequenceID, topic, sequenceID); err != nil {
return nil, err
}
return ids, nil
})
}
2026-03-01 13:19:53 -05:00
// ExpireMessages marks messages in the given topics as expired
func (c *Cache) ExpireMessages(topics ...string) error {
2026-02-22 16:21:27 -05:00
c.maybeLock()
defer c.maybeUnlock()
2026-03-02 19:45:35 -05:00
return db.ExecTx(c.db, func(tx *sql.Tx) error {
for _, t := range topics {
if _, err := tx.Exec(c.queries.updateMessagesForTopicExpiry, time.Now().Unix()-1, t); err != nil {
return err
}
}
2026-03-02 19:45:35 -05:00
return nil
})
}
2026-03-25 15:28:23 -04:00
// MarkExpiredAttachmentsDeleted marks up to `limit` expired attachments as deleted in a single
// query and returns the number of updated rows.
func (c *Cache) MarkExpiredAttachmentsDeleted(limit int) (int64, error) {
2026-02-22 16:21:27 -05:00
c.maybeLock()
defer c.maybeUnlock()
2026-03-25 15:28:23 -04:00
result, err := c.db.Exec(c.queries.markExpiredAttachmentsDeleted, time.Now().Unix(), limit)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
2026-03-01 13:19:53 -05:00
// AttachmentBytesUsedBySender returns the total size of active attachments sent by the given sender
func (c *Cache) AttachmentBytesUsedBySender(sender string) (int64, error) {
2026-03-10 22:17:40 -04:00
rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsSizeBySender, sender, time.Now().Unix())
if err != nil {
return 0, err
}
return c.readAttachmentBytesUsed(rows)
}
2026-03-01 13:19:53 -05:00
// AttachmentBytesUsedByUser returns the total size of active attachments for the given user
func (c *Cache) AttachmentBytesUsedByUser(userID string) (int64, error) {
2026-03-10 22:17:40 -04:00
rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsSizeByUserID, userID, time.Now().Unix())
if err != nil {
return 0, err
}
return c.readAttachmentBytesUsed(rows)
}
2026-03-23 12:44:40 -04:00
// AttachmentsWithSizes returns a map of message ID to attachment size for all active
// (non-expired, non-deleted) attachments. This is used to hydrate the attachment store's
// size tracking on startup and during periodic sync.
func (c *Cache) AttachmentsWithSizes() (map[string]int64, error) {
rows, err := c.db.ReadOnly().Query(c.queries.selectAttachmentsWithSizes, time.Now().Unix())
if err != nil {
return nil, err
}
defer rows.Close()
attachments := make(map[string]int64)
for rows.Next() {
var id string
var size int64
if err := rows.Scan(&id, &size); err != nil {
return nil, err
}
attachments[id] = size
}
if err := rows.Err(); err != nil {
return nil, err
}
return attachments, nil
}
2026-03-01 13:19:53 -05:00
func (c *Cache) readAttachmentBytesUsed(rows *sql.Rows) (int64, error) {
defer rows.Close()
var size int64
if !rows.Next() {
return 0, errors.New("no rows found")
}
if err := rows.Scan(&size); err != nil {
return 0, err
} else if err := rows.Err(); err != nil {
return 0, err
}
return size, nil
}
2026-03-01 13:19:53 -05:00
// UpdateStats updates the total message count statistic
func (c *Cache) UpdateStats(messages int64) error {
2026-02-22 16:21:27 -05:00
c.maybeLock()
defer c.maybeUnlock()
_, err := c.db.Exec(c.queries.updateStats, messages)
return err
}
2026-03-01 13:19:53 -05:00
// Stats returns the total message count statistic
func (c *Cache) Stats() (messages int64, err error) {
2026-03-10 22:17:40 -04:00
rows, err := c.db.ReadOnly().Query(c.queries.selectStats)
if err != nil {
return 0, err
}
defer rows.Close()
if !rows.Next() {
return 0, errNoRows
}
if err := rows.Scan(&messages); err != nil {
return 0, err
}
return messages, nil
}
2026-03-01 13:19:53 -05:00
// Close closes the underlying database connection
func (c *Cache) Close() error {
return c.db.Close()
}
2026-03-01 13:19:53 -05:00
func (c *Cache) processMessageBatches() {
if c.queue == nil {
return
}
for messages := range c.queue.Dequeue() {
if err := c.addMessages(messages); err != nil {
log.Tag(tagMessageCache).Err(err).Error("Cannot write message batch")
}
}
}
func readMessages(rows *sql.Rows) ([]*model.Message, error) {
defer rows.Close()
messages := make([]*model.Message, 0)
for rows.Next() {
m, err := readMessage(rows)
if err != nil {
return nil, err
}
messages = append(messages, m)
}
if err := rows.Err(); err != nil {
return nil, err
}
return messages, nil
}
func readMessage(rows *sql.Rows) (*model.Message, error) {
var timestamp, expires, attachmentSize, attachmentExpires int64
var priority int
var id, sequenceID, event, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, contentType, encoding string
err := rows.Scan(
&id,
&sequenceID,
&timestamp,
&event,
&expires,
&topic,
&msg,
&title,
&priority,
&tagsStr,
&click,
&icon,
&actionsStr,
&attachmentName,
&attachmentType,
&attachmentSize,
&attachmentExpires,
&attachmentURL,
&sender,
&user,
&contentType,
&encoding,
)
if err != nil {
return nil, err
}
var tags []string
if tagsStr != "" {
tags = strings.Split(tagsStr, ",")
}
var actions []*model.Action
if actionsStr != "" {
if err := json.Unmarshal([]byte(actionsStr), &actions); err != nil {
return nil, err
}
}
senderIP, err := netip.ParseAddr(sender)
if err != nil {
senderIP = netip.Addr{} // if no IP stored in database, return invalid address
}
var att *model.Attachment
if attachmentName != "" && attachmentURL != "" {
att = &model.Attachment{
Name: attachmentName,
Type: attachmentType,
Size: attachmentSize,
Expires: attachmentExpires,
URL: attachmentURL,
}
}
return &model.Message{
ID: id,
SequenceID: sequenceID,
Time: timestamp,
Expires: expires,
Event: event,
Topic: topic,
Message: msg,
Title: title,
Priority: priority,
Tags: tags,
Click: click,
Icon: icon,
Actions: actions,
Attachment: att,
Sender: senderIP,
User: user,
ContentType: contentType,
Encoding: encoding,
}, nil
}
2026-03-17 20:53:41 -04:00
func readStrings(rows *sql.Rows) ([]string, error) {
strs := make([]string, 0)
for rows.Next() {
var s string
if err := rows.Scan(&s); err != nil {
return nil, err
}
strs = append(strs, s)
}
if err := rows.Err(); err != nil {
return nil, err
}
return strs, nil
}