diff --git a/app/src/main/java/io/heckel/ntfy/msg/DownloadAttachmentWorker.kt b/app/src/main/java/io/heckel/ntfy/msg/DownloadAttachmentWorker.kt index 0adbbde9..19f3338a 100644 --- a/app/src/main/java/io/heckel/ntfy/msg/DownloadAttachmentWorker.kt +++ b/app/src/main/java/io/heckel/ntfy/msg/DownloadAttachmentWorker.kt @@ -7,7 +7,7 @@ import android.os.Looper import android.webkit.MimeTypeMap import android.widget.Toast import androidx.core.content.FileProvider -import androidx.work.Worker +import androidx.work.CoroutineWorker import androidx.work.WorkerParameters import io.heckel.ntfy.BuildConfig import io.heckel.ntfy.R @@ -27,8 +27,9 @@ import io.heckel.ntfy.util.extractBaseUrl import okhttp3.Response import java.io.File import java.util.concurrent.TimeUnit +import kotlin.coroutines.cancellation.CancellationException -class DownloadAttachmentWorker(private val context: Context, params: WorkerParameters) : Worker(context, params) { +class DownloadAttachmentWorker(private val context: Context, params: WorkerParameters) : CoroutineWorker(context, params) { private val notifier = NotificationService(context) private lateinit var repository: Repository private lateinit var subscription: Subscription @@ -36,7 +37,7 @@ class DownloadAttachmentWorker(private val context: Context, params: WorkerParam private lateinit var attachment: Attachment private var uri: Uri? = null - override fun doWork(): Result { + override suspend fun doWork(): Result { if (context.applicationContext !is Application) return Result.failure() val notificationId = inputData.getString(INPUT_DATA_ID) ?: return Result.failure() val userAction = inputData.getBoolean(INPUT_DATA_USER_ACTION, false) @@ -47,18 +48,17 @@ class DownloadAttachmentWorker(private val context: Context, params: WorkerParam attachment = notification.attachment ?: return Result.failure() try { downloadAttachment(userAction) + } catch (e: CancellationException) { + Log.d(TAG, "Attachment download was canceled") + maybeDeleteFile() + throw e // We must re-throw this to stop the worker } catch (e: Exception) { failed(e) } return Result.success() } - override fun onStopped() { - Log.d(TAG, "Attachment download was canceled") - maybeDeleteFile() - } - - private fun downloadAttachment(userAction: Boolean) { + private suspend fun downloadAttachment(userAction: Boolean) { Log.d(TAG, "Downloading attachment from ${attachment.url}") try { diff --git a/app/src/main/java/io/heckel/ntfy/msg/DownloadIconWorker.kt b/app/src/main/java/io/heckel/ntfy/msg/DownloadIconWorker.kt index 4b0b2d0a..b5487d19 100644 --- a/app/src/main/java/io/heckel/ntfy/msg/DownloadIconWorker.kt +++ b/app/src/main/java/io/heckel/ntfy/msg/DownloadIconWorker.kt @@ -3,7 +3,7 @@ package io.heckel.ntfy.msg import android.content.Context import android.net.Uri import androidx.core.content.FileProvider -import androidx.work.Worker +import androidx.work.CoroutineWorker import androidx.work.WorkerParameters import io.heckel.ntfy.BuildConfig import io.heckel.ntfy.app.Application @@ -18,8 +18,9 @@ import io.heckel.ntfy.util.sha256 import okhttp3.Response import java.io.File import java.util.Date +import kotlin.coroutines.cancellation.CancellationException -class DownloadIconWorker(private val context: Context, params: WorkerParameters) : Worker(context, params) { +class DownloadIconWorker(private val context: Context, params: WorkerParameters) : CoroutineWorker(context, params) { private val notifier = NotificationService(context) private lateinit var repository: Repository private lateinit var subscription: Subscription @@ -27,7 +28,7 @@ class DownloadIconWorker(private val context: Context, params: WorkerParameters) private lateinit var icon: Icon private var uri: Uri? = null - override fun doWork(): Result { + override suspend fun doWork(): Result { if (context.applicationContext !is Application) return Result.failure() val notificationId = inputData.getString(INPUT_DATA_ID) ?: return Result.failure() val app = context.applicationContext as Application @@ -43,21 +44,20 @@ class DownloadIconWorker(private val context: Context, params: WorkerParameters) } else { Log.d(TAG, "Loading icon from cache: $iconFile") val iconUri = createIconUri(iconFile) - this.uri = iconUri // Required for cleanup in onStopped() + this.uri = iconUri save(icon.copy(contentUri = iconUri.toString())) } + } catch (e: CancellationException) { + Log.d(TAG, "Icon download was canceled") + maybeDeleteFile() + throw e // We must re-throw this to stop the worker } catch (e: Exception) { failed(e) } return Result.success() } - override fun onStopped() { - Log.d(TAG, "Icon download was canceled") - maybeDeleteFile() - } - - private fun downloadIcon(iconFile: File) { + private suspend fun downloadIcon(iconFile: File) { Log.d(TAG, "Downloading icon from ${icon.url}") try { val request = HttpUtil.requestBuilder(icon.url).build() diff --git a/app/src/main/java/io/heckel/ntfy/msg/UserActionWorker.kt b/app/src/main/java/io/heckel/ntfy/msg/UserActionWorker.kt index 54f014e7..b8b84b2f 100644 --- a/app/src/main/java/io/heckel/ntfy/msg/UserActionWorker.kt +++ b/app/src/main/java/io/heckel/ntfy/msg/UserActionWorker.kt @@ -1,7 +1,7 @@ package io.heckel.ntfy.msg import android.content.Context -import androidx.work.Worker +import androidx.work.CoroutineWorker import androidx.work.WorkerParameters import io.heckel.ntfy.R import io.heckel.ntfy.app.Application @@ -20,7 +20,7 @@ import io.heckel.ntfy.util.extractBaseUrl import okhttp3.RequestBody.Companion.toRequestBody import java.util.Locale -class UserActionWorker(private val context: Context, params: WorkerParameters) : Worker(context, params) { +class UserActionWorker(private val context: Context, params: WorkerParameters) : CoroutineWorker(context, params) { private val notifier = NotificationService(context) private val broadcaster = BroadcastService(context) private lateinit var repository: Repository @@ -28,7 +28,7 @@ class UserActionWorker(private val context: Context, params: WorkerParameters) : private lateinit var notification: Notification private lateinit var action: Action - override fun doWork(): Result { + override suspend fun doWork(): Result { if (context.applicationContext !is Application) return Result.failure() val notificationId = inputData.getString(INPUT_DATA_NOTIFICATION_ID) ?: return Result.failure() val actionId = inputData.getString(INPUT_DATA_ACTION_ID) ?: return Result.failure() @@ -63,7 +63,7 @@ class UserActionWorker(private val context: Context, params: WorkerParameters) : } } - private fun performHttpAction(action: Action) { + private suspend fun performHttpAction(action: Action) { save(action.copy(progress = ACTION_PROGRESS_ONGOING, error = null)) val url = action.url ?: return diff --git a/app/src/main/java/io/heckel/ntfy/service/Connection.kt b/app/src/main/java/io/heckel/ntfy/service/Connection.kt index 9994dbe4..cb298b3a 100644 --- a/app/src/main/java/io/heckel/ntfy/service/Connection.kt +++ b/app/src/main/java/io/heckel/ntfy/service/Connection.kt @@ -15,5 +15,7 @@ data class ConnectionId( val topicsToSubscriptionIds: Map, val connectionProtocol: String, val credentialsHash: Int, // Hash of "username:password" or 0 if no user - val headersHash: Int // Hash of sorted headers or 0 if none + val headersHash: Int, // Hash of sorted headers or 0 if none + val trustedCertsHash: Int, // Hash of trusted certificates or 0 if none + val clientCertHash: Int // Hash of client certificate or 0 if none ) diff --git a/app/src/main/java/io/heckel/ntfy/service/SubscriberService.kt b/app/src/main/java/io/heckel/ntfy/service/SubscriberService.kt index 69d02025..630ec31a 100644 --- a/app/src/main/java/io/heckel/ntfy/service/SubscriberService.kt +++ b/app/src/main/java/io/heckel/ntfy/service/SubscriberService.kt @@ -20,6 +20,7 @@ import io.heckel.ntfy.msg.ApiService import io.heckel.ntfy.msg.NotificationDispatcher import io.heckel.ntfy.ui.Colors import io.heckel.ntfy.ui.MainActivity +import io.heckel.ntfy.util.HttpUtil import io.heckel.ntfy.util.Log import io.heckel.ntfy.util.topicUrl import kotlinx.coroutines.CoroutineScope @@ -205,6 +206,9 @@ class SubscriberService : Service() { .filter { s -> s.instant } val activeConnectionIds = connections.keys().toList().toSet() val connectionProtocol = repository.getConnectionProtocol() + val trustedCertsHash = repository.getTrustedCertificates() + .joinToString(",") { it.fingerprint } + .hashCode() val desiredConnectionIds = instantSubscriptions // Set .groupBy { s -> s.baseUrl } .map { (baseUrl, subs) -> @@ -217,12 +221,15 @@ class SubscriberService : Service() { .sortedBy { "${it.name}:${it.value}" } .joinToString(",") { "${it.name}:${it.value}" } .hashCode() + val clientCertHash = repository.getClientCertificate(baseUrl)?.hashCode() ?: 0 ConnectionId( baseUrl = baseUrl, topicsToSubscriptionIds = subs.associate { s -> s.topic to s.id }, connectionProtocol = connectionProtocol, credentialsHash = credentialsHash, - headersHash = headersHash + headersHash = headersHash, + trustedCertsHash = trustedCertsHash, + clientCertHash = clientCertHash ) } .toSet() @@ -256,7 +263,8 @@ class SubscriberService : Service() { val customHeaders = repository.getCustomHeaders(connectionId.baseUrl) val connection = if (connectionId.connectionProtocol == Repository.CONNECTION_PROTOCOL_WS) { val alarmManager = getSystemService(ALARM_SERVICE) as AlarmManager - WsConnection(this, connectionId, repository, user, customHeaders, since, ::onStateChanged, ::onNotificationReceived, alarmManager) + val httpClient = HttpUtil.wsClient(this, connectionId.baseUrl) + WsConnection(connectionId, repository, httpClient, user, customHeaders, since, ::onStateChanged, ::onNotificationReceived, alarmManager) } else { JsonConnection(connectionId, scope, repository, api, user, since, ::onStateChanged, ::onNotificationReceived, serviceActive) } diff --git a/app/src/main/java/io/heckel/ntfy/service/WsConnection.kt b/app/src/main/java/io/heckel/ntfy/service/WsConnection.kt index 5174f98b..b228b5df 100644 --- a/app/src/main/java/io/heckel/ntfy/service/WsConnection.kt +++ b/app/src/main/java/io/heckel/ntfy/service/WsConnection.kt @@ -34,9 +34,9 @@ import kotlin.random.Random * https://github.com/gotify/android/blob/master/app/src/main/java/com/github/gotify/service/WebSocketConnection.java */ class WsConnection( - private val context: Context, private val connectionId: ConnectionId, private val repository: Repository, + private val httpClient: OkHttpClient, private val user: User?, private val customHeaders: List, private val sinceId: String?, @@ -45,9 +45,6 @@ class WsConnection( private val alarmManager: AlarmManager ) : Connection { private val parser = NotificationParser() - private val client: OkHttpClient by lazy { - HttpUtil.wsClient(context, connectionId.baseUrl) - } private var errorCount = 0 private var webSocket: WebSocket? = null private var state: State? = null @@ -83,7 +80,7 @@ class WsConnection( val urlWithSince = topicUrlWs(baseUrl, topicsStr, sinceVal) val request = HttpUtil.requestBuilder(urlWithSince, user, customHeaders).build() Log.d(TAG, "$shortUrl (gid=$globalId): Opening $urlWithSince with listener ID $nextListenerId ...") - webSocket = client.newWebSocket(request, Listener(nextListenerId)) + webSocket = httpClient.newWebSocket(request, Listener(nextListenerId)) } @Synchronized diff --git a/app/src/main/java/io/heckel/ntfy/util/CertUtil.kt b/app/src/main/java/io/heckel/ntfy/util/CertUtil.kt index c26324a3..6a79750c 100644 --- a/app/src/main/java/io/heckel/ntfy/util/CertUtil.kt +++ b/app/src/main/java/io/heckel/ntfy/util/CertUtil.kt @@ -36,9 +36,9 @@ class CertUtil private constructor(context: Context) { private val appContext: Context = context.applicationContext private val repository: Repository by lazy { Repository.getInstance(appContext) } - fun withTLSConfig(builder: OkHttpClient.Builder, baseUrl: String): OkHttpClient.Builder { + suspend fun withTLSConfig(builder: OkHttpClient.Builder, baseUrl: String): OkHttpClient.Builder { try { - val trustedCerts = runBlocking { repository.getTrustedCertificates() } + val trustedCerts = repository.getTrustedCertificates() val userX509 = trustedCerts.mapNotNull { try { parsePemCertificate(it.pem) @@ -48,7 +48,7 @@ class CertUtil private constructor(context: Context) { } } - val clientCert = runBlocking { repository.getClientCertificate(baseUrl) } + val clientCert = repository.getClientCertificate(baseUrl) val keyManagers = clientCert?.let { createKeyManagers(it) } // Always include system trust; add user trust if present. diff --git a/app/src/main/java/io/heckel/ntfy/util/HttpUtil.kt b/app/src/main/java/io/heckel/ntfy/util/HttpUtil.kt index f26d2bd0..c4ac9c7f 100644 --- a/app/src/main/java/io/heckel/ntfy/util/HttpUtil.kt +++ b/app/src/main/java/io/heckel/ntfy/util/HttpUtil.kt @@ -21,16 +21,16 @@ object HttpUtil { /** * Client for regular API calls (auth, poll, etc.). */ - fun defaultClient(context: Context, baseUrl: String): OkHttpClient { - return defaultBuilder(context, baseUrl).build() + suspend fun defaultClient(context: Context, baseUrl: String): OkHttpClient { + return defaultClientBuilder(context, baseUrl).build() } /** * Client with a longer call timeout (5 minutes). * Allows for large file uploads or downloads. */ - fun longCallClient(context: Context, baseUrl: String): OkHttpClient { - return defaultBuilder(context, baseUrl) + suspend fun longCallClient(context: Context, baseUrl: String): OkHttpClient { + return defaultClientBuilder(context, baseUrl) .callTimeout(5, TimeUnit.MINUTES) .build() } @@ -38,8 +38,8 @@ object HttpUtil { /** * Client for long-polling/streaming subscriptions. */ - fun subscriberClient(context: Context, baseUrl: String): OkHttpClient { - return emptyBuilder(context, baseUrl) + suspend fun subscriberClient(context: Context, baseUrl: String): OkHttpClient { + return emptyClientBuilder(context, baseUrl) .readTimeout(77, TimeUnit.SECONDS) // Long enough to allow for server-side keepalive messages .build() } @@ -48,28 +48,14 @@ object HttpUtil { * Client for WebSocket connections. * No read timeout, 1 minute ping interval, 10s connect timeout. */ - fun wsClient(context: Context, baseUrl: String): OkHttpClient { - return emptyBuilder(context, baseUrl) + suspend fun wsClient(context: Context, baseUrl: String): OkHttpClient { + return emptyClientBuilder(context, baseUrl) .readTimeout(0, TimeUnit.MILLISECONDS) .pingInterval(1, TimeUnit.MINUTES) // Technically not necessary, the server also pings us .connectTimeout(10, TimeUnit.SECONDS) .build() } - fun defaultBuilder(context: Context, baseUrl: String): OkHttpClient.Builder { - return emptyBuilder(context, baseUrl) - .callTimeout(1, TimeUnit.MINUTES) // Increased to 1min (from 15s) to reduce client variance - .connectTimeout(15, TimeUnit.SECONDS) - .readTimeout(15, TimeUnit.SECONDS) - .writeTimeout(15, TimeUnit.SECONDS) - } - - fun emptyBuilder(context: Context, baseUrl: String): OkHttpClient.Builder { - return CertUtil - .getInstance(context) - .withTLSConfig(OkHttpClient.Builder(), baseUrl) - } - fun requestBuilder(url: String, user: User? = null, customHeaders: List = emptyList()): Request.Builder { val builder = Request.Builder() .url(url) @@ -82,5 +68,19 @@ object HttpUtil { } return builder } + + private suspend fun emptyClientBuilder(context: Context, baseUrl: String): OkHttpClient.Builder { + return CertUtil + .getInstance(context) + .withTLSConfig(OkHttpClient.Builder(), baseUrl) + } + + private suspend fun defaultClientBuilder(context: Context, baseUrl: String): OkHttpClient.Builder { + return emptyClientBuilder(context, baseUrl) + .callTimeout(1, TimeUnit.MINUTES) // Increased to 1min (from 15s) to reduce client variance + .connectTimeout(15, TimeUnit.SECONDS) + .readTimeout(15, TimeUnit.SECONDS) + .writeTimeout(15, TimeUnit.SECONDS) + } }