Replace CertUtil.withTLS

This commit is contained in:
Philipp Heckel 2026-01-04 22:26:13 -05:00
parent 6a00d9221e
commit c001055cc9
5 changed files with 141 additions and 152 deletions

View file

@ -229,7 +229,7 @@ class Backuper(val context: Context) {
}
certificates.forEach { c ->
try {
val cert = CertUtil.parseCertificate(c.pem)
val cert = CertUtil.parsePemCertificate(c.pem)
val fingerprint = CertUtil.calculateFingerprint(cert)
repository.addTrustedCertificate(fingerprint, c.pem)
} catch (e: Exception) {

View file

@ -68,7 +68,7 @@ class CertificateSettingsFragment : BasePreferenceFragment(),
certs.forEach { trustedCert ->
try {
val cert = CertUtil.parseCertificate(trustedCert.pem)
val cert = CertUtil.parsePemCertificate(trustedCert.pem)
val issuer = parseCommonName(cert.issuerX500Principal.name)
val pref = Preference(preferenceScreen.context)
pref.title = parseCommonName(cert.subjectX500Principal.name)
@ -146,7 +146,7 @@ class CertificateSettingsFragment : BasePreferenceFragment(),
it.bufferedReader().readText()
}
if (content != null && content.contains("-----BEGIN CERTIFICATE-----")) {
val cert = CertUtil.parseCertificate(content)
val cert = CertUtil.parsePemCertificate(content)
TrustedCertificateFragment.newInstanceAdd(cert)
.show(childFragmentManager, TrustedCertificateFragment.TAG)
} else {

View file

@ -184,7 +184,7 @@ class TrustedCertificateFragment : DialogFragment() {
val trustedCert = repository.getTrustedCertificates().find { it.fingerprint == fingerprint }
if (trustedCert != null) {
try {
val x509Cert = CertUtil.parseCertificate(trustedCert.pem)
val x509Cert = CertUtil.parsePemCertificate(trustedCert.pem)
cert = x509Cert
withContext(Dispatchers.Main) {
displayCertificateDetails(x509Cert)
@ -231,9 +231,9 @@ class TrustedCertificateFragment : DialogFragment() {
private fun trustCertificate() {
val certificate = cert ?: return
lifecycleScope.launch(Dispatchers.IO) {
val fp = CertUtil.calculateFingerprint(certificate)
val pem = CertUtil.encodeToPem(certificate)
repository.addTrustedCertificate(fp, pem)
val fingerprint = CertUtil.calculateFingerprint(certificate)
val pem = CertUtil.encodeCertificateToPem(certificate)
repository.addTrustedCertificate(fingerprint, pem)
withContext(Dispatchers.Main) {
Toast.makeText(context, R.string.trusted_certificate_dialog_added_toast, Toast.LENGTH_SHORT).show()
listener?.onCertificateTrusted(certificate)

View file

@ -5,9 +5,9 @@ import android.content.Context
import android.util.Base64
import io.heckel.ntfy.db.ClientCertificate
import io.heckel.ntfy.db.Repository
import io.heckel.ntfy.db.TrustedCertificate
import kotlinx.coroutines.runBlocking
import okhttp3.OkHttpClient
import okhttp3.internal.tls.OkHostnameVerifier
import java.io.ByteArrayInputStream
import java.net.URL
import java.security.KeyStore
@ -15,7 +15,7 @@ import java.security.MessageDigest
import java.security.SecureRandom
import java.security.cert.CertificateFactory
import java.security.cert.X509Certificate
import javax.net.ssl.HttpsURLConnection
import javax.net.ssl.HostnameVerifier
import javax.net.ssl.KeyManager
import javax.net.ssl.KeyManagerFactory
import javax.net.ssl.SSLContext
@ -24,84 +24,48 @@ import javax.net.ssl.SSLSocket
import javax.net.ssl.TrustManager
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager
import kotlin.collections.addAll
/**
* Manages SSL/TLS configuration for OkHttpClient instances.
*
* Supports:
* 1. Global trusted CA certificates (for self-signed servers)
* 2. Per-URL client certificates for mTLS (PKCS#12 format)
*
* Uses standard TrustManagerFactory and KeyManagerFactory (not custom implementations).
* TLS config:
* - Trust system roots
* - Also trust user-added certs (leaf and/or CA; chains to user-added CAs)
* - Hostname verify ONLY when chain is system-trusted; skip when only user-trusted
* - Optional mTLS via per-baseUrl PKCS#12 client cert
*/
class CertUtil private constructor(context: Context) {
private val appContext: Context = context.applicationContext
private val repository: Repository by lazy { Repository.getInstance(appContext) }
fun withTLSConfig(builder: OkHttpClient.Builder, baseUrl: String): OkHttpClient.Builder {
fun withTLSConfig(builder: OkHttpClient.Builder, baseUrl: String): OkHttpClient.Builder {
try {
val trustManagers = mutableListOf<TrustManager>()
val keyManagers = mutableListOf<KeyManager>()
// Get all user-trusted certificates from database
val trustedCerts = runBlocking { repository.getTrustedCertificates() }
val trustedFingerprints = trustedCerts.map { it.fingerprint }.toSet()
if (trustedCerts.isNotEmpty()) {
trustManagers.addAll(createCombinedTrustManagers(trustedCerts))
}
// Get client certificate for mTLS
val clientCert = runBlocking { repository.getClientCertificate(baseUrl) }
if (clientCert != null) {
createKeyManagers(clientCert)?.let { keyManagers.addAll(it.toList()) }
}
// Apply SSL configuration if we have custom trust or key managers
if (trustManagers.isNotEmpty() || keyManagers.isNotEmpty()) {
// Fall back to system trust if no custom trust managers
if (trustManagers.isEmpty()) {
trustManagers.addAll(getSystemTrustManagers().toList())
val userX509 = trustedCerts.mapNotNull {
try {
parsePemCertificate(it.pem)
} catch (e: Exception) {
Log.w(TAG, "Failed to parse trusted certificate: ${it.fingerprint}", e)
null
}
}
val sslContext = SSLContext.getInstance("TLS")
sslContext.init(
keyManagers.toTypedArray().ifEmpty { null },
trustManagers.toTypedArray(),
SecureRandom()
)
builder.sslSocketFactory(
sslContext.socketFactory,
trustManagers.filterIsInstance<X509TrustManager>().first()
)
val clientCert = runBlocking { repository.getClientCertificate(baseUrl) }
val keyManagers = clientCert?.let { createKeyManagers(it) }
// Custom hostname verifier that bypasses only for user-trusted certs
if (trustedFingerprints.isNotEmpty()) {
builder.hostnameVerifier { hostname, session ->
val defaultVerifier = HttpsURLConnection.getDefaultHostnameVerifier()
if (defaultVerifier.verify(hostname, session)) {
return@hostnameVerifier true
}
// Always include system trust; add user trust if present.
val systemTm = systemTrustManager()
val customTm = if (userX509.isNotEmpty()) trustManagerForAddedCerts(userX509) else null
val compositeTm = if (customTm != null) compositeTrustManager(systemTm, customTm) else systemTm
// Check if the server's certificate is user-trusted
try {
val serverCerts = session.peerCertificates
if (serverCerts.isNotEmpty()) {
val serverCert = serverCerts[0] as? X509Certificate
if (serverCert != null) {
val serverFingerprint = calculateFingerprint(serverCert)
if (trustedFingerprints.contains(serverFingerprint)) {
Log.d(TAG, "Hostname verification bypassed for $hostname - certificate is user-trusted")
return@hostnameVerifier true
}
}
}
} catch (e: Exception) {
Log.w(TAG, "Failed to check server certificate fingerprint", e)
}
// Only override SSL config if we actually have something to add (user trust or mTLS).
if (customTm != null || keyManagers != null) {
val sslContext = SSLContext.getInstance("TLS").apply {
init(keyManagers, arrayOf<TrustManager>(compositeTm), SecureRandom())
}
builder.sslSocketFactory(sslContext.socketFactory, compositeTm)
false
}
// Hostname rules only matter if we have custom trust. If not, keep default verifier.
if (customTm != null) {
builder.hostnameVerifier(selectiveHostnameVerifier(systemTm, customTm))
}
}
} catch (e: Exception) {
@ -120,20 +84,17 @@ class CertUtil private constructor(context: Context) {
val trustManager = object : X509TrustManager {
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {}
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {
if (!chain.isNullOrEmpty()) {
capturedCert = chain[0]
}
// Always throw to prevent actual connection
if (!chain.isNullOrEmpty()) capturedCert = chain[0]
throw SSLException("Certificate captured for inspection")
}
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
}
val sslContext = SSLContext.getInstance("TLS")
sslContext.init(null, arrayOf(trustManager), null)
val sslContext = SSLContext.getInstance("TLS").apply {
init(null, arrayOf(trustManager), null)
}
try {
val url = URL(baseUrl)
@ -145,11 +106,11 @@ class CertUtil private constructor(context: Context) {
}
val socket = sslContext.socketFactory.createSocket(host, port) as SSLSocket
socket.soTimeout = 10000
socket.soTimeout = 10_000
try {
socket.startHandshake()
} catch (_: Exception) {
// Expected - we throw from the trust manager
// expected
} finally {
socket.close()
}
@ -161,84 +122,112 @@ class CertUtil private constructor(context: Context) {
}
/**
* Create TrustManagers that trust both user-added certs and system CAs.
* Uses TrustManagerFactory (standard approach).
*/
private fun createCombinedTrustManagers(trustedCerts: List<TrustedCertificate>): Array<TrustManager> {
// Create a KeyStore with all certificates
val keyStore = KeyStore.getInstance(KeyStore.getDefaultType()).apply { load(null) }
// Add user-trusted certificates
trustedCerts.forEachIndexed { index, trustedCert ->
try {
val cert = parseCertificate(trustedCert.pem)
keyStore.setCertificateEntry("user$index", cert)
} catch (e: Exception) {
Log.w(TAG, "Failed to parse trusted certificate: ${trustedCert.fingerprint}", e)
}
}
// Add system CA certificates for combined trust
getSystemTrustManager().acceptedIssuers.forEachIndexed { index, cert ->
keyStore.setCertificateEntry("system$index", cert)
}
val trustManagerFactory = TrustManagerFactory.getInstance(
TrustManagerFactory.getDefaultAlgorithm()
)
trustManagerFactory.init(keyStore)
return trustManagerFactory.trustManagers
}
/**
* Create KeyManagers for mTLS client authentication using PKCS#12 data from database.
* Uses KeyManagerFactory (standard approach).
* mTLS client auth via PKCS#12 from DB.
*/
private fun createKeyManagers(clientCert: ClientCertificate): Array<KeyManager>? {
return try {
val p12Data = Base64.decode(clientCert.p12Base64, Base64.DEFAULT)
val keyStore = KeyStore.getInstance("PKCS12")
ByteArrayInputStream(p12Data).use { keyStore.load(it, clientCert.password.toCharArray()) }
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
keyManagerFactory.init(keyStore, clientCert.password.toCharArray())
keyManagerFactory.keyManagers
val keyStore = KeyStore.getInstance("PKCS12").apply {
ByteArrayInputStream(p12Data).use { load(it, clientCert.password.toCharArray()) }
}
val kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()).apply {
init(keyStore, clientCert.password.toCharArray())
}
kmf.keyManagers
} catch (e: Exception) {
Log.e(TAG, "Failed to load PKCS#12 client certificate for ${clientCert.baseUrl}", e)
null
}
}
/**
* Get the default system TrustManagers
*/
private fun getSystemTrustManagers(): Array<TrustManager> {
val trustManagerFactory = TrustManagerFactory.getInstance(
TrustManagerFactory.getDefaultAlgorithm()
)
trustManagerFactory.init(null as KeyStore?)
return trustManagerFactory.trustManagers
private fun systemTrustManager(): X509TrustManager {
val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply {
init(null as KeyStore?)
}
return tmf.trustManagers.first { it is X509TrustManager } as X509TrustManager
}
private fun trustManagerForAddedCerts(added: List<X509Certificate>): X509TrustManager {
val ks = KeyStore.getInstance(KeyStore.getDefaultType()).apply { load(null, null) }
added.forEachIndexed { idx, cert ->
ks.setCertificateEntry("added-$idx", cert)
}
val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply {
init(ks)
}
return tmf.trustManagers.first { it is X509TrustManager } as X509TrustManager
}
/**
* Get the system X509TrustManager
* Trust if either system OR custom accepts the chain.
*/
private fun getSystemTrustManager(): X509TrustManager {
return getSystemTrustManagers().filterIsInstance<X509TrustManager>().first()
private fun compositeTrustManager(system: X509TrustManager, custom: X509TrustManager): X509TrustManager =
object : X509TrustManager {
override fun getAcceptedIssuers(): Array<X509Certificate> =
(system.acceptedIssuers + custom.acceptedIssuers)
.distinctBy { it.subjectX500Principal.name }
.toTypedArray()
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {
try {
system.checkClientTrusted(chain, authType)
} catch (_: Exception) {
custom.checkClientTrusted(chain, authType)
}
}
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {
try {
system.checkServerTrusted(chain, authType)
} catch (_: Exception) {
custom.checkServerTrusted(chain, authType)
}
}
}
/**
* Hostname verification:
* - if system-trusted => enforce hostname verification
* - else if custom-trusted => skip hostname verification
* - else => fail
*/
private fun selectiveHostnameVerifier(system: X509TrustManager, custom: X509TrustManager) =
HostnameVerifier { hostname, session ->
val chain = try {
session.peerCertificates.map { it as X509Certificate }.toTypedArray()
} catch (_: Exception) {
return@HostnameVerifier false
}
if (isTrustedBy(system, chain)) {
OkHostnameVerifier.verify(hostname, session)
} else {
isTrustedBy(custom, chain)
}
}
private fun isTrustedBy(tm: X509TrustManager, chain: Array<X509Certificate>): Boolean {
// authType not reliably available here; try common ones.
return try {
tm.checkServerTrusted(chain, "RSA"); true
} catch (_: Exception) {
try {
tm.checkServerTrusted(chain, "EC"); true
} catch (_: Exception) {
false
}
}
}
companion object {
private const val TAG = "NtfySSLManager"
@Volatile
@SuppressLint("StaticFieldLeak") // Only holds applicationContext
@SuppressLint("StaticFieldLeak")
private var instance: CertUtil? = null
fun getInstance(context: Context): CertUtil {
return instance ?: synchronized(this) {
instance ?: CertUtil(context).also { instance = it }
}
}
fun getInstance(context: Context): CertUtil =
instance ?: synchronized(this) { instance ?: CertUtil(context).also { instance = it } }
fun calculateFingerprint(cert: X509Certificate): String {
val md = MessageDigest.getInstance("SHA-256")
@ -246,24 +235,24 @@ class CertUtil private constructor(context: Context) {
return digest.joinToString(":") { "%02X".format(it) }
}
fun parseCertificate(pem: String): X509Certificate {
fun parsePemCertificate(pem: String): X509Certificate {
val factory = CertificateFactory.getInstance("X.509")
return factory.generateCertificate(pem.byteInputStream()) as X509Certificate
}
fun encodeToPem(cert: X509Certificate): String {
fun encodeCertificateToPem(cert: X509Certificate): String {
val base64 = Base64.encodeToString(cert.encoded, Base64.NO_WRAP)
val sb = StringBuilder()
sb.append("-----BEGIN CERTIFICATE-----\n")
var i = 0
while (i < base64.length) {
val end = minOf(i + 64, base64.length)
sb.append(base64.substring(i, end))
sb.append("\n")
i += 64
return buildString {
append("-----BEGIN CERTIFICATE-----\n")
var i = 0
while (i < base64.length) {
val end = minOf(i + 64, base64.length)
append(base64.substring(i, end))
append("\n")
i += 64
}
append("-----END CERTIFICATE-----")
}
sb.append("-----END CERTIFICATE-----")
return sb.toString()
}
fun parsePkcs12Certificate(p12Base64: String, password: String): X509Certificate {
@ -274,4 +263,4 @@ class CertUtil private constructor(context: Context) {
return keyStore.getCertificate(alias) as X509Certificate
}
}
}
}

View file

@ -58,7 +58,7 @@ object HttpUtil {
fun defaultBuilder(context: Context, baseUrl: String): OkHttpClient.Builder {
return emptyBuilder(context, baseUrl)
.callTimeout(60, TimeUnit.SECONDS) // Increased to 60s (from 15s) to reduce client variance
.callTimeout(1, TimeUnit.MINUTES) // Increased to 1min (from 15s) to reduce client variance
.connectTimeout(15, TimeUnit.SECONDS)
.readTimeout(15, TimeUnit.SECONDS)
.writeTimeout(15, TimeUnit.SECONDS)