Replace CertUtil.withTLS
This commit is contained in:
parent
6a00d9221e
commit
c001055cc9
5 changed files with 141 additions and 152 deletions
|
|
@ -229,7 +229,7 @@ class Backuper(val context: Context) {
|
|||
}
|
||||
certificates.forEach { c ->
|
||||
try {
|
||||
val cert = CertUtil.parseCertificate(c.pem)
|
||||
val cert = CertUtil.parsePemCertificate(c.pem)
|
||||
val fingerprint = CertUtil.calculateFingerprint(cert)
|
||||
repository.addTrustedCertificate(fingerprint, c.pem)
|
||||
} catch (e: Exception) {
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class CertificateSettingsFragment : BasePreferenceFragment(),
|
|||
|
||||
certs.forEach { trustedCert ->
|
||||
try {
|
||||
val cert = CertUtil.parseCertificate(trustedCert.pem)
|
||||
val cert = CertUtil.parsePemCertificate(trustedCert.pem)
|
||||
val issuer = parseCommonName(cert.issuerX500Principal.name)
|
||||
val pref = Preference(preferenceScreen.context)
|
||||
pref.title = parseCommonName(cert.subjectX500Principal.name)
|
||||
|
|
@ -146,7 +146,7 @@ class CertificateSettingsFragment : BasePreferenceFragment(),
|
|||
it.bufferedReader().readText()
|
||||
}
|
||||
if (content != null && content.contains("-----BEGIN CERTIFICATE-----")) {
|
||||
val cert = CertUtil.parseCertificate(content)
|
||||
val cert = CertUtil.parsePemCertificate(content)
|
||||
TrustedCertificateFragment.newInstanceAdd(cert)
|
||||
.show(childFragmentManager, TrustedCertificateFragment.TAG)
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ class TrustedCertificateFragment : DialogFragment() {
|
|||
val trustedCert = repository.getTrustedCertificates().find { it.fingerprint == fingerprint }
|
||||
if (trustedCert != null) {
|
||||
try {
|
||||
val x509Cert = CertUtil.parseCertificate(trustedCert.pem)
|
||||
val x509Cert = CertUtil.parsePemCertificate(trustedCert.pem)
|
||||
cert = x509Cert
|
||||
withContext(Dispatchers.Main) {
|
||||
displayCertificateDetails(x509Cert)
|
||||
|
|
@ -231,9 +231,9 @@ class TrustedCertificateFragment : DialogFragment() {
|
|||
private fun trustCertificate() {
|
||||
val certificate = cert ?: return
|
||||
lifecycleScope.launch(Dispatchers.IO) {
|
||||
val fp = CertUtil.calculateFingerprint(certificate)
|
||||
val pem = CertUtil.encodeToPem(certificate)
|
||||
repository.addTrustedCertificate(fp, pem)
|
||||
val fingerprint = CertUtil.calculateFingerprint(certificate)
|
||||
val pem = CertUtil.encodeCertificateToPem(certificate)
|
||||
repository.addTrustedCertificate(fingerprint, pem)
|
||||
withContext(Dispatchers.Main) {
|
||||
Toast.makeText(context, R.string.trusted_certificate_dialog_added_toast, Toast.LENGTH_SHORT).show()
|
||||
listener?.onCertificateTrusted(certificate)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ import android.content.Context
|
|||
import android.util.Base64
|
||||
import io.heckel.ntfy.db.ClientCertificate
|
||||
import io.heckel.ntfy.db.Repository
|
||||
import io.heckel.ntfy.db.TrustedCertificate
|
||||
import kotlinx.coroutines.runBlocking
|
||||
import okhttp3.OkHttpClient
|
||||
import okhttp3.internal.tls.OkHostnameVerifier
|
||||
import java.io.ByteArrayInputStream
|
||||
import java.net.URL
|
||||
import java.security.KeyStore
|
||||
|
|
@ -15,7 +15,7 @@ import java.security.MessageDigest
|
|||
import java.security.SecureRandom
|
||||
import java.security.cert.CertificateFactory
|
||||
import java.security.cert.X509Certificate
|
||||
import javax.net.ssl.HttpsURLConnection
|
||||
import javax.net.ssl.HostnameVerifier
|
||||
import javax.net.ssl.KeyManager
|
||||
import javax.net.ssl.KeyManagerFactory
|
||||
import javax.net.ssl.SSLContext
|
||||
|
|
@ -24,84 +24,48 @@ import javax.net.ssl.SSLSocket
|
|||
import javax.net.ssl.TrustManager
|
||||
import javax.net.ssl.TrustManagerFactory
|
||||
import javax.net.ssl.X509TrustManager
|
||||
import kotlin.collections.addAll
|
||||
|
||||
/**
|
||||
* Manages SSL/TLS configuration for OkHttpClient instances.
|
||||
*
|
||||
* Supports:
|
||||
* 1. Global trusted CA certificates (for self-signed servers)
|
||||
* 2. Per-URL client certificates for mTLS (PKCS#12 format)
|
||||
*
|
||||
* Uses standard TrustManagerFactory and KeyManagerFactory (not custom implementations).
|
||||
* TLS config:
|
||||
* - Trust system roots
|
||||
* - Also trust user-added certs (leaf and/or CA; chains to user-added CAs)
|
||||
* - Hostname verify ONLY when chain is system-trusted; skip when only user-trusted
|
||||
* - Optional mTLS via per-baseUrl PKCS#12 client cert
|
||||
*/
|
||||
class CertUtil private constructor(context: Context) {
|
||||
private val appContext: Context = context.applicationContext
|
||||
private val repository: Repository by lazy { Repository.getInstance(appContext) }
|
||||
|
||||
fun withTLSConfig(builder: OkHttpClient.Builder, baseUrl: String): OkHttpClient.Builder {
|
||||
fun withTLSConfig(builder: OkHttpClient.Builder, baseUrl: String): OkHttpClient.Builder {
|
||||
try {
|
||||
val trustManagers = mutableListOf<TrustManager>()
|
||||
val keyManagers = mutableListOf<KeyManager>()
|
||||
|
||||
// Get all user-trusted certificates from database
|
||||
val trustedCerts = runBlocking { repository.getTrustedCertificates() }
|
||||
val trustedFingerprints = trustedCerts.map { it.fingerprint }.toSet()
|
||||
if (trustedCerts.isNotEmpty()) {
|
||||
trustManagers.addAll(createCombinedTrustManagers(trustedCerts))
|
||||
}
|
||||
|
||||
// Get client certificate for mTLS
|
||||
val clientCert = runBlocking { repository.getClientCertificate(baseUrl) }
|
||||
if (clientCert != null) {
|
||||
createKeyManagers(clientCert)?.let { keyManagers.addAll(it.toList()) }
|
||||
}
|
||||
|
||||
// Apply SSL configuration if we have custom trust or key managers
|
||||
if (trustManagers.isNotEmpty() || keyManagers.isNotEmpty()) {
|
||||
// Fall back to system trust if no custom trust managers
|
||||
if (trustManagers.isEmpty()) {
|
||||
trustManagers.addAll(getSystemTrustManagers().toList())
|
||||
val userX509 = trustedCerts.mapNotNull {
|
||||
try {
|
||||
parsePemCertificate(it.pem)
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "Failed to parse trusted certificate: ${it.fingerprint}", e)
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
val sslContext = SSLContext.getInstance("TLS")
|
||||
sslContext.init(
|
||||
keyManagers.toTypedArray().ifEmpty { null },
|
||||
trustManagers.toTypedArray(),
|
||||
SecureRandom()
|
||||
)
|
||||
builder.sslSocketFactory(
|
||||
sslContext.socketFactory,
|
||||
trustManagers.filterIsInstance<X509TrustManager>().first()
|
||||
)
|
||||
val clientCert = runBlocking { repository.getClientCertificate(baseUrl) }
|
||||
val keyManagers = clientCert?.let { createKeyManagers(it) }
|
||||
|
||||
// Custom hostname verifier that bypasses only for user-trusted certs
|
||||
if (trustedFingerprints.isNotEmpty()) {
|
||||
builder.hostnameVerifier { hostname, session ->
|
||||
val defaultVerifier = HttpsURLConnection.getDefaultHostnameVerifier()
|
||||
if (defaultVerifier.verify(hostname, session)) {
|
||||
return@hostnameVerifier true
|
||||
}
|
||||
// Always include system trust; add user trust if present.
|
||||
val systemTm = systemTrustManager()
|
||||
val customTm = if (userX509.isNotEmpty()) trustManagerForAddedCerts(userX509) else null
|
||||
val compositeTm = if (customTm != null) compositeTrustManager(systemTm, customTm) else systemTm
|
||||
|
||||
// Check if the server's certificate is user-trusted
|
||||
try {
|
||||
val serverCerts = session.peerCertificates
|
||||
if (serverCerts.isNotEmpty()) {
|
||||
val serverCert = serverCerts[0] as? X509Certificate
|
||||
if (serverCert != null) {
|
||||
val serverFingerprint = calculateFingerprint(serverCert)
|
||||
if (trustedFingerprints.contains(serverFingerprint)) {
|
||||
Log.d(TAG, "Hostname verification bypassed for $hostname - certificate is user-trusted")
|
||||
return@hostnameVerifier true
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "Failed to check server certificate fingerprint", e)
|
||||
}
|
||||
// Only override SSL config if we actually have something to add (user trust or mTLS).
|
||||
if (customTm != null || keyManagers != null) {
|
||||
val sslContext = SSLContext.getInstance("TLS").apply {
|
||||
init(keyManagers, arrayOf<TrustManager>(compositeTm), SecureRandom())
|
||||
}
|
||||
builder.sslSocketFactory(sslContext.socketFactory, compositeTm)
|
||||
|
||||
false
|
||||
}
|
||||
// Hostname rules only matter if we have custom trust. If not, keep default verifier.
|
||||
if (customTm != null) {
|
||||
builder.hostnameVerifier(selectiveHostnameVerifier(systemTm, customTm))
|
||||
}
|
||||
}
|
||||
} catch (e: Exception) {
|
||||
|
|
@ -120,20 +84,17 @@ class CertUtil private constructor(context: Context) {
|
|||
|
||||
val trustManager = object : X509TrustManager {
|
||||
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {}
|
||||
|
||||
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {
|
||||
if (!chain.isNullOrEmpty()) {
|
||||
capturedCert = chain[0]
|
||||
}
|
||||
// Always throw to prevent actual connection
|
||||
if (!chain.isNullOrEmpty()) capturedCert = chain[0]
|
||||
throw SSLException("Certificate captured for inspection")
|
||||
}
|
||||
|
||||
override fun getAcceptedIssuers(): Array<X509Certificate> = arrayOf()
|
||||
}
|
||||
|
||||
val sslContext = SSLContext.getInstance("TLS")
|
||||
sslContext.init(null, arrayOf(trustManager), null)
|
||||
val sslContext = SSLContext.getInstance("TLS").apply {
|
||||
init(null, arrayOf(trustManager), null)
|
||||
}
|
||||
|
||||
try {
|
||||
val url = URL(baseUrl)
|
||||
|
|
@ -145,11 +106,11 @@ class CertUtil private constructor(context: Context) {
|
|||
}
|
||||
|
||||
val socket = sslContext.socketFactory.createSocket(host, port) as SSLSocket
|
||||
socket.soTimeout = 10000
|
||||
socket.soTimeout = 10_000
|
||||
try {
|
||||
socket.startHandshake()
|
||||
} catch (_: Exception) {
|
||||
// Expected - we throw from the trust manager
|
||||
// expected
|
||||
} finally {
|
||||
socket.close()
|
||||
}
|
||||
|
|
@ -161,84 +122,112 @@ class CertUtil private constructor(context: Context) {
|
|||
}
|
||||
|
||||
/**
|
||||
* Create TrustManagers that trust both user-added certs and system CAs.
|
||||
* Uses TrustManagerFactory (standard approach).
|
||||
*/
|
||||
private fun createCombinedTrustManagers(trustedCerts: List<TrustedCertificate>): Array<TrustManager> {
|
||||
// Create a KeyStore with all certificates
|
||||
val keyStore = KeyStore.getInstance(KeyStore.getDefaultType()).apply { load(null) }
|
||||
|
||||
// Add user-trusted certificates
|
||||
trustedCerts.forEachIndexed { index, trustedCert ->
|
||||
try {
|
||||
val cert = parseCertificate(trustedCert.pem)
|
||||
keyStore.setCertificateEntry("user$index", cert)
|
||||
} catch (e: Exception) {
|
||||
Log.w(TAG, "Failed to parse trusted certificate: ${trustedCert.fingerprint}", e)
|
||||
}
|
||||
}
|
||||
|
||||
// Add system CA certificates for combined trust
|
||||
getSystemTrustManager().acceptedIssuers.forEachIndexed { index, cert ->
|
||||
keyStore.setCertificateEntry("system$index", cert)
|
||||
}
|
||||
|
||||
val trustManagerFactory = TrustManagerFactory.getInstance(
|
||||
TrustManagerFactory.getDefaultAlgorithm()
|
||||
)
|
||||
trustManagerFactory.init(keyStore)
|
||||
return trustManagerFactory.trustManagers
|
||||
}
|
||||
|
||||
/**
|
||||
* Create KeyManagers for mTLS client authentication using PKCS#12 data from database.
|
||||
* Uses KeyManagerFactory (standard approach).
|
||||
* mTLS client auth via PKCS#12 from DB.
|
||||
*/
|
||||
private fun createKeyManagers(clientCert: ClientCertificate): Array<KeyManager>? {
|
||||
return try {
|
||||
val p12Data = Base64.decode(clientCert.p12Base64, Base64.DEFAULT)
|
||||
val keyStore = KeyStore.getInstance("PKCS12")
|
||||
ByteArrayInputStream(p12Data).use { keyStore.load(it, clientCert.password.toCharArray()) }
|
||||
|
||||
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
|
||||
keyManagerFactory.init(keyStore, clientCert.password.toCharArray())
|
||||
keyManagerFactory.keyManagers
|
||||
val keyStore = KeyStore.getInstance("PKCS12").apply {
|
||||
ByteArrayInputStream(p12Data).use { load(it, clientCert.password.toCharArray()) }
|
||||
}
|
||||
val kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()).apply {
|
||||
init(keyStore, clientCert.password.toCharArray())
|
||||
}
|
||||
kmf.keyManagers
|
||||
} catch (e: Exception) {
|
||||
Log.e(TAG, "Failed to load PKCS#12 client certificate for ${clientCert.baseUrl}", e)
|
||||
null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the default system TrustManagers
|
||||
*/
|
||||
private fun getSystemTrustManagers(): Array<TrustManager> {
|
||||
val trustManagerFactory = TrustManagerFactory.getInstance(
|
||||
TrustManagerFactory.getDefaultAlgorithm()
|
||||
)
|
||||
trustManagerFactory.init(null as KeyStore?)
|
||||
return trustManagerFactory.trustManagers
|
||||
private fun systemTrustManager(): X509TrustManager {
|
||||
val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply {
|
||||
init(null as KeyStore?)
|
||||
}
|
||||
return tmf.trustManagers.first { it is X509TrustManager } as X509TrustManager
|
||||
}
|
||||
|
||||
private fun trustManagerForAddedCerts(added: List<X509Certificate>): X509TrustManager {
|
||||
val ks = KeyStore.getInstance(KeyStore.getDefaultType()).apply { load(null, null) }
|
||||
added.forEachIndexed { idx, cert ->
|
||||
ks.setCertificateEntry("added-$idx", cert)
|
||||
}
|
||||
val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()).apply {
|
||||
init(ks)
|
||||
}
|
||||
return tmf.trustManagers.first { it is X509TrustManager } as X509TrustManager
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the system X509TrustManager
|
||||
* Trust if either system OR custom accepts the chain.
|
||||
*/
|
||||
private fun getSystemTrustManager(): X509TrustManager {
|
||||
return getSystemTrustManagers().filterIsInstance<X509TrustManager>().first()
|
||||
private fun compositeTrustManager(system: X509TrustManager, custom: X509TrustManager): X509TrustManager =
|
||||
object : X509TrustManager {
|
||||
override fun getAcceptedIssuers(): Array<X509Certificate> =
|
||||
(system.acceptedIssuers + custom.acceptedIssuers)
|
||||
.distinctBy { it.subjectX500Principal.name }
|
||||
.toTypedArray()
|
||||
|
||||
override fun checkClientTrusted(chain: Array<out X509Certificate>?, authType: String?) {
|
||||
try {
|
||||
system.checkClientTrusted(chain, authType)
|
||||
} catch (_: Exception) {
|
||||
custom.checkClientTrusted(chain, authType)
|
||||
}
|
||||
}
|
||||
|
||||
override fun checkServerTrusted(chain: Array<out X509Certificate>?, authType: String?) {
|
||||
try {
|
||||
system.checkServerTrusted(chain, authType)
|
||||
} catch (_: Exception) {
|
||||
custom.checkServerTrusted(chain, authType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Hostname verification:
|
||||
* - if system-trusted => enforce hostname verification
|
||||
* - else if custom-trusted => skip hostname verification
|
||||
* - else => fail
|
||||
*/
|
||||
private fun selectiveHostnameVerifier(system: X509TrustManager, custom: X509TrustManager) =
|
||||
HostnameVerifier { hostname, session ->
|
||||
val chain = try {
|
||||
session.peerCertificates.map { it as X509Certificate }.toTypedArray()
|
||||
} catch (_: Exception) {
|
||||
return@HostnameVerifier false
|
||||
}
|
||||
|
||||
if (isTrustedBy(system, chain)) {
|
||||
OkHostnameVerifier.verify(hostname, session)
|
||||
} else {
|
||||
isTrustedBy(custom, chain)
|
||||
}
|
||||
}
|
||||
|
||||
private fun isTrustedBy(tm: X509TrustManager, chain: Array<X509Certificate>): Boolean {
|
||||
// authType not reliably available here; try common ones.
|
||||
return try {
|
||||
tm.checkServerTrusted(chain, "RSA"); true
|
||||
} catch (_: Exception) {
|
||||
try {
|
||||
tm.checkServerTrusted(chain, "EC"); true
|
||||
} catch (_: Exception) {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
private const val TAG = "NtfySSLManager"
|
||||
|
||||
@Volatile
|
||||
@SuppressLint("StaticFieldLeak") // Only holds applicationContext
|
||||
@SuppressLint("StaticFieldLeak")
|
||||
private var instance: CertUtil? = null
|
||||
|
||||
fun getInstance(context: Context): CertUtil {
|
||||
return instance ?: synchronized(this) {
|
||||
instance ?: CertUtil(context).also { instance = it }
|
||||
}
|
||||
}
|
||||
fun getInstance(context: Context): CertUtil =
|
||||
instance ?: synchronized(this) { instance ?: CertUtil(context).also { instance = it } }
|
||||
|
||||
fun calculateFingerprint(cert: X509Certificate): String {
|
||||
val md = MessageDigest.getInstance("SHA-256")
|
||||
|
|
@ -246,24 +235,24 @@ class CertUtil private constructor(context: Context) {
|
|||
return digest.joinToString(":") { "%02X".format(it) }
|
||||
}
|
||||
|
||||
fun parseCertificate(pem: String): X509Certificate {
|
||||
fun parsePemCertificate(pem: String): X509Certificate {
|
||||
val factory = CertificateFactory.getInstance("X.509")
|
||||
return factory.generateCertificate(pem.byteInputStream()) as X509Certificate
|
||||
}
|
||||
|
||||
fun encodeToPem(cert: X509Certificate): String {
|
||||
fun encodeCertificateToPem(cert: X509Certificate): String {
|
||||
val base64 = Base64.encodeToString(cert.encoded, Base64.NO_WRAP)
|
||||
val sb = StringBuilder()
|
||||
sb.append("-----BEGIN CERTIFICATE-----\n")
|
||||
var i = 0
|
||||
while (i < base64.length) {
|
||||
val end = minOf(i + 64, base64.length)
|
||||
sb.append(base64.substring(i, end))
|
||||
sb.append("\n")
|
||||
i += 64
|
||||
return buildString {
|
||||
append("-----BEGIN CERTIFICATE-----\n")
|
||||
var i = 0
|
||||
while (i < base64.length) {
|
||||
val end = minOf(i + 64, base64.length)
|
||||
append(base64.substring(i, end))
|
||||
append("\n")
|
||||
i += 64
|
||||
}
|
||||
append("-----END CERTIFICATE-----")
|
||||
}
|
||||
sb.append("-----END CERTIFICATE-----")
|
||||
return sb.toString()
|
||||
}
|
||||
|
||||
fun parsePkcs12Certificate(p12Base64: String, password: String): X509Certificate {
|
||||
|
|
@ -274,4 +263,4 @@ class CertUtil private constructor(context: Context) {
|
|||
return keyStore.getCertificate(alias) as X509Certificate
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ object HttpUtil {
|
|||
|
||||
fun defaultBuilder(context: Context, baseUrl: String): OkHttpClient.Builder {
|
||||
return emptyBuilder(context, baseUrl)
|
||||
.callTimeout(60, TimeUnit.SECONDS) // Increased to 60s (from 15s) to reduce client variance
|
||||
.callTimeout(1, TimeUnit.MINUTES) // Increased to 1min (from 15s) to reduce client variance
|
||||
.connectTimeout(15, TimeUnit.SECONDS)
|
||||
.readTimeout(15, TimeUnit.SECONDS)
|
||||
.writeTimeout(15, TimeUnit.SECONDS)
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue