Make HttpUtil stuff suspend because of the CertUtil stuff

This commit is contained in:
Philipp Heckel 2026-01-04 23:07:08 -05:00
parent c001055cc9
commit 7a132642fb
8 changed files with 63 additions and 56 deletions

View file

@ -7,7 +7,7 @@ import android.os.Looper
import android.webkit.MimeTypeMap
import android.widget.Toast
import androidx.core.content.FileProvider
import androidx.work.Worker
import androidx.work.CoroutineWorker
import androidx.work.WorkerParameters
import io.heckel.ntfy.BuildConfig
import io.heckel.ntfy.R
@ -27,8 +27,9 @@ import io.heckel.ntfy.util.extractBaseUrl
import okhttp3.Response
import java.io.File
import java.util.concurrent.TimeUnit
import kotlin.coroutines.cancellation.CancellationException
class DownloadAttachmentWorker(private val context: Context, params: WorkerParameters) : Worker(context, params) {
class DownloadAttachmentWorker(private val context: Context, params: WorkerParameters) : CoroutineWorker(context, params) {
private val notifier = NotificationService(context)
private lateinit var repository: Repository
private lateinit var subscription: Subscription
@ -36,7 +37,7 @@ class DownloadAttachmentWorker(private val context: Context, params: WorkerParam
private lateinit var attachment: Attachment
private var uri: Uri? = null
override fun doWork(): Result {
override suspend fun doWork(): Result {
if (context.applicationContext !is Application) return Result.failure()
val notificationId = inputData.getString(INPUT_DATA_ID) ?: return Result.failure()
val userAction = inputData.getBoolean(INPUT_DATA_USER_ACTION, false)
@ -47,18 +48,17 @@ class DownloadAttachmentWorker(private val context: Context, params: WorkerParam
attachment = notification.attachment ?: return Result.failure()
try {
downloadAttachment(userAction)
} catch (e: CancellationException) {
Log.d(TAG, "Attachment download was canceled")
maybeDeleteFile()
throw e // We must re-throw this to stop the worker
} catch (e: Exception) {
failed(e)
}
return Result.success()
}
override fun onStopped() {
Log.d(TAG, "Attachment download was canceled")
maybeDeleteFile()
}
private fun downloadAttachment(userAction: Boolean) {
private suspend fun downloadAttachment(userAction: Boolean) {
Log.d(TAG, "Downloading attachment from ${attachment.url}")
try {

View file

@ -3,7 +3,7 @@ package io.heckel.ntfy.msg
import android.content.Context
import android.net.Uri
import androidx.core.content.FileProvider
import androidx.work.Worker
import androidx.work.CoroutineWorker
import androidx.work.WorkerParameters
import io.heckel.ntfy.BuildConfig
import io.heckel.ntfy.app.Application
@ -18,8 +18,9 @@ import io.heckel.ntfy.util.sha256
import okhttp3.Response
import java.io.File
import java.util.Date
import kotlin.coroutines.cancellation.CancellationException
class DownloadIconWorker(private val context: Context, params: WorkerParameters) : Worker(context, params) {
class DownloadIconWorker(private val context: Context, params: WorkerParameters) : CoroutineWorker(context, params) {
private val notifier = NotificationService(context)
private lateinit var repository: Repository
private lateinit var subscription: Subscription
@ -27,7 +28,7 @@ class DownloadIconWorker(private val context: Context, params: WorkerParameters)
private lateinit var icon: Icon
private var uri: Uri? = null
override fun doWork(): Result {
override suspend fun doWork(): Result {
if (context.applicationContext !is Application) return Result.failure()
val notificationId = inputData.getString(INPUT_DATA_ID) ?: return Result.failure()
val app = context.applicationContext as Application
@ -43,21 +44,20 @@ class DownloadIconWorker(private val context: Context, params: WorkerParameters)
} else {
Log.d(TAG, "Loading icon from cache: $iconFile")
val iconUri = createIconUri(iconFile)
this.uri = iconUri // Required for cleanup in onStopped()
this.uri = iconUri
save(icon.copy(contentUri = iconUri.toString()))
}
} catch (e: CancellationException) {
Log.d(TAG, "Icon download was canceled")
maybeDeleteFile()
throw e // We must re-throw this to stop the worker
} catch (e: Exception) {
failed(e)
}
return Result.success()
}
override fun onStopped() {
Log.d(TAG, "Icon download was canceled")
maybeDeleteFile()
}
private fun downloadIcon(iconFile: File) {
private suspend fun downloadIcon(iconFile: File) {
Log.d(TAG, "Downloading icon from ${icon.url}")
try {
val request = HttpUtil.requestBuilder(icon.url).build()

View file

@ -1,7 +1,7 @@
package io.heckel.ntfy.msg
import android.content.Context
import androidx.work.Worker
import androidx.work.CoroutineWorker
import androidx.work.WorkerParameters
import io.heckel.ntfy.R
import io.heckel.ntfy.app.Application
@ -20,7 +20,7 @@ import io.heckel.ntfy.util.extractBaseUrl
import okhttp3.RequestBody.Companion.toRequestBody
import java.util.Locale
class UserActionWorker(private val context: Context, params: WorkerParameters) : Worker(context, params) {
class UserActionWorker(private val context: Context, params: WorkerParameters) : CoroutineWorker(context, params) {
private val notifier = NotificationService(context)
private val broadcaster = BroadcastService(context)
private lateinit var repository: Repository
@ -28,7 +28,7 @@ class UserActionWorker(private val context: Context, params: WorkerParameters) :
private lateinit var notification: Notification
private lateinit var action: Action
override fun doWork(): Result {
override suspend fun doWork(): Result {
if (context.applicationContext !is Application) return Result.failure()
val notificationId = inputData.getString(INPUT_DATA_NOTIFICATION_ID) ?: return Result.failure()
val actionId = inputData.getString(INPUT_DATA_ACTION_ID) ?: return Result.failure()
@ -63,7 +63,7 @@ class UserActionWorker(private val context: Context, params: WorkerParameters) :
}
}
private fun performHttpAction(action: Action) {
private suspend fun performHttpAction(action: Action) {
save(action.copy(progress = ACTION_PROGRESS_ONGOING, error = null))
val url = action.url ?: return

View file

@ -15,5 +15,7 @@ data class ConnectionId(
val topicsToSubscriptionIds: Map<String, Long>,
val connectionProtocol: String,
val credentialsHash: Int, // Hash of "username:password" or 0 if no user
val headersHash: Int // Hash of sorted headers or 0 if none
val headersHash: Int, // Hash of sorted headers or 0 if none
val trustedCertsHash: Int, // Hash of trusted certificates or 0 if none
val clientCertHash: Int // Hash of client certificate or 0 if none
)

View file

@ -20,6 +20,7 @@ import io.heckel.ntfy.msg.ApiService
import io.heckel.ntfy.msg.NotificationDispatcher
import io.heckel.ntfy.ui.Colors
import io.heckel.ntfy.ui.MainActivity
import io.heckel.ntfy.util.HttpUtil
import io.heckel.ntfy.util.Log
import io.heckel.ntfy.util.topicUrl
import kotlinx.coroutines.CoroutineScope
@ -205,6 +206,9 @@ class SubscriberService : Service() {
.filter { s -> s.instant }
val activeConnectionIds = connections.keys().toList().toSet()
val connectionProtocol = repository.getConnectionProtocol()
val trustedCertsHash = repository.getTrustedCertificates()
.joinToString(",") { it.fingerprint }
.hashCode()
val desiredConnectionIds = instantSubscriptions // Set<ConnectionId>
.groupBy { s -> s.baseUrl }
.map { (baseUrl, subs) ->
@ -217,12 +221,15 @@ class SubscriberService : Service() {
.sortedBy { "${it.name}:${it.value}" }
.joinToString(",") { "${it.name}:${it.value}" }
.hashCode()
val clientCertHash = repository.getClientCertificate(baseUrl)?.hashCode() ?: 0
ConnectionId(
baseUrl = baseUrl,
topicsToSubscriptionIds = subs.associate { s -> s.topic to s.id },
connectionProtocol = connectionProtocol,
credentialsHash = credentialsHash,
headersHash = headersHash
headersHash = headersHash,
trustedCertsHash = trustedCertsHash,
clientCertHash = clientCertHash
)
}
.toSet()
@ -256,7 +263,8 @@ class SubscriberService : Service() {
val customHeaders = repository.getCustomHeaders(connectionId.baseUrl)
val connection = if (connectionId.connectionProtocol == Repository.CONNECTION_PROTOCOL_WS) {
val alarmManager = getSystemService(ALARM_SERVICE) as AlarmManager
WsConnection(this, connectionId, repository, user, customHeaders, since, ::onStateChanged, ::onNotificationReceived, alarmManager)
val httpClient = HttpUtil.wsClient(this, connectionId.baseUrl)
WsConnection(connectionId, repository, httpClient, user, customHeaders, since, ::onStateChanged, ::onNotificationReceived, alarmManager)
} else {
JsonConnection(connectionId, scope, repository, api, user, since, ::onStateChanged, ::onNotificationReceived, serviceActive)
}

View file

@ -34,9 +34,9 @@ import kotlin.random.Random
* https://github.com/gotify/android/blob/master/app/src/main/java/com/github/gotify/service/WebSocketConnection.java
*/
class WsConnection(
private val context: Context,
private val connectionId: ConnectionId,
private val repository: Repository,
private val httpClient: OkHttpClient,
private val user: User?,
private val customHeaders: List<CustomHeader>,
private val sinceId: String?,
@ -45,9 +45,6 @@ class WsConnection(
private val alarmManager: AlarmManager
) : Connection {
private val parser = NotificationParser()
private val client: OkHttpClient by lazy {
HttpUtil.wsClient(context, connectionId.baseUrl)
}
private var errorCount = 0
private var webSocket: WebSocket? = null
private var state: State? = null
@ -83,7 +80,7 @@ class WsConnection(
val urlWithSince = topicUrlWs(baseUrl, topicsStr, sinceVal)
val request = HttpUtil.requestBuilder(urlWithSince, user, customHeaders).build()
Log.d(TAG, "$shortUrl (gid=$globalId): Opening $urlWithSince with listener ID $nextListenerId ...")
webSocket = client.newWebSocket(request, Listener(nextListenerId))
webSocket = httpClient.newWebSocket(request, Listener(nextListenerId))
}
@Synchronized

View file

@ -36,9 +36,9 @@ class CertUtil private constructor(context: Context) {
private val appContext: Context = context.applicationContext
private val repository: Repository by lazy { Repository.getInstance(appContext) }
fun withTLSConfig(builder: OkHttpClient.Builder, baseUrl: String): OkHttpClient.Builder {
suspend fun withTLSConfig(builder: OkHttpClient.Builder, baseUrl: String): OkHttpClient.Builder {
try {
val trustedCerts = runBlocking { repository.getTrustedCertificates() }
val trustedCerts = repository.getTrustedCertificates()
val userX509 = trustedCerts.mapNotNull {
try {
parsePemCertificate(it.pem)
@ -48,7 +48,7 @@ class CertUtil private constructor(context: Context) {
}
}
val clientCert = runBlocking { repository.getClientCertificate(baseUrl) }
val clientCert = repository.getClientCertificate(baseUrl)
val keyManagers = clientCert?.let { createKeyManagers(it) }
// Always include system trust; add user trust if present.

View file

@ -21,16 +21,16 @@ object HttpUtil {
/**
* Client for regular API calls (auth, poll, etc.).
*/
fun defaultClient(context: Context, baseUrl: String): OkHttpClient {
return defaultBuilder(context, baseUrl).build()
suspend fun defaultClient(context: Context, baseUrl: String): OkHttpClient {
return defaultClientBuilder(context, baseUrl).build()
}
/**
* Client with a longer call timeout (5 minutes).
* Allows for large file uploads or downloads.
*/
fun longCallClient(context: Context, baseUrl: String): OkHttpClient {
return defaultBuilder(context, baseUrl)
suspend fun longCallClient(context: Context, baseUrl: String): OkHttpClient {
return defaultClientBuilder(context, baseUrl)
.callTimeout(5, TimeUnit.MINUTES)
.build()
}
@ -38,8 +38,8 @@ object HttpUtil {
/**
* Client for long-polling/streaming subscriptions.
*/
fun subscriberClient(context: Context, baseUrl: String): OkHttpClient {
return emptyBuilder(context, baseUrl)
suspend fun subscriberClient(context: Context, baseUrl: String): OkHttpClient {
return emptyClientBuilder(context, baseUrl)
.readTimeout(77, TimeUnit.SECONDS) // Long enough to allow for server-side keepalive messages
.build()
}
@ -48,28 +48,14 @@ object HttpUtil {
* Client for WebSocket connections.
* No read timeout, 1 minute ping interval, 10s connect timeout.
*/
fun wsClient(context: Context, baseUrl: String): OkHttpClient {
return emptyBuilder(context, baseUrl)
suspend fun wsClient(context: Context, baseUrl: String): OkHttpClient {
return emptyClientBuilder(context, baseUrl)
.readTimeout(0, TimeUnit.MILLISECONDS)
.pingInterval(1, TimeUnit.MINUTES) // Technically not necessary, the server also pings us
.connectTimeout(10, TimeUnit.SECONDS)
.build()
}
fun defaultBuilder(context: Context, baseUrl: String): OkHttpClient.Builder {
return emptyBuilder(context, baseUrl)
.callTimeout(1, TimeUnit.MINUTES) // Increased to 1min (from 15s) to reduce client variance
.connectTimeout(15, TimeUnit.SECONDS)
.readTimeout(15, TimeUnit.SECONDS)
.writeTimeout(15, TimeUnit.SECONDS)
}
fun emptyBuilder(context: Context, baseUrl: String): OkHttpClient.Builder {
return CertUtil
.getInstance(context)
.withTLSConfig(OkHttpClient.Builder(), baseUrl)
}
fun requestBuilder(url: String, user: User? = null, customHeaders: List<CustomHeader> = emptyList()): Request.Builder {
val builder = Request.Builder()
.url(url)
@ -82,5 +68,19 @@ object HttpUtil {
}
return builder
}
private suspend fun emptyClientBuilder(context: Context, baseUrl: String): OkHttpClient.Builder {
return CertUtil
.getInstance(context)
.withTLSConfig(OkHttpClient.Builder(), baseUrl)
}
private suspend fun defaultClientBuilder(context: Context, baseUrl: String): OkHttpClient.Builder {
return emptyClientBuilder(context, baseUrl)
.callTimeout(1, TimeUnit.MINUTES) // Increased to 1min (from 15s) to reduce client variance
.connectTimeout(15, TimeUnit.SECONDS)
.readTimeout(15, TimeUnit.SECONDS)
.writeTimeout(15, TimeUnit.SECONDS)
}
}