Skip to content

Instantly share code, notes, and snippets.

@jboxx
Created August 19, 2025 18:48
Show Gist options
  • Select an option

  • Save jboxx/59bdad1bcf6b0937b9f661cf03742908 to your computer and use it in GitHub Desktop.

Select an option

Save jboxx/59bdad1bcf6b0937b9f661cf03742908 to your computer and use it in GitHub Desktop.

Revisions

  1. jboxx created this gist Aug 19, 2025.
    702 changes: 702 additions & 0 deletions ConnectionHelper.kt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,702 @@
    import android.annotation.SuppressLint
    import android.content.Context
    import android.net.ConnectivityManager
    import android.net.LinkProperties
    import android.net.Network
    import android.net.NetworkCapabilities
    import android.net.NetworkRequest
    import android.net.wifi.WifiManager
    import android.os.Build
    import android.util.Log
    import kotlinx.coroutines.CoroutineName
    import kotlinx.coroutines.CoroutineScope
    import kotlinx.coroutines.Dispatchers
    import kotlinx.coroutines.Job
    import kotlinx.coroutines.SupervisorJob
    import kotlinx.coroutines.async
    import kotlinx.coroutines.awaitAll
    import kotlinx.coroutines.cancel
    import kotlinx.coroutines.channels.Channel
    import kotlinx.coroutines.delay
    import kotlinx.coroutines.flow.Flow
    import kotlinx.coroutines.flow.MutableStateFlow
    import kotlinx.coroutines.flow.StateFlow
    import kotlinx.coroutines.flow.asStateFlow
    import kotlinx.coroutines.flow.distinctUntilChanged
    import kotlinx.coroutines.flow.launchIn
    import kotlinx.coroutines.flow.map
    import kotlinx.coroutines.flow.onEach
    import kotlinx.coroutines.isActive
    import kotlinx.coroutines.launch
    import kotlinx.coroutines.sync.Semaphore
    import kotlinx.coroutines.withContext
    import okhttp3.OkHttpClient
    import okhttp3.Request
    import java.io.IOException
    import java.net.InetSocketAddress
    import java.net.NetworkInterface
    import java.net.Socket
    import java.net.URI
    import java.util.concurrent.TimeUnit
    import java.util.concurrent.atomic.AtomicBoolean
    import kotlin.math.roundToLong

    /* ---------------------------
    Public models
    --------------------------- */

    sealed class NetworkState {
    object Unknown : NetworkState()
    object Disconnected : NetworkState()
    data class Connected(val connectionInfo: ConnectionInfo) : NetworkState()
    data class Error(val throwable: Throwable) : NetworkState()
    }

    data class ConnectionInfo(
    val type: ConnectionType,
    val isVpnActive: Boolean,
    val hasInternet: Boolean = false,
    val vpnInfo: VpnInfo? = null,
    val details: NetworkDetails,
    val internetTestResults: InternetTestResults = InternetTestResults()
    )

    data class VpnInfo(
    val interfaceName: String?,
    val probableType: VpnType = VpnType.UNKNOWN,
    val gateway: String? = null,
    val isBlockingInternet: Boolean = false
    )

    data class NetworkDetails(
    val networkId: String,
    val ipAddresses: List<String> = emptyList(),
    val dnsServers: List<String> = emptyList(),
    val gateway: String? = null,
    val ssid: String? = null,
    val isMetered: Boolean = false,
    val downstreamKbps: Int? = null,
    val upstreamKbps: Int? = null,
    val signalStrength: Int? = null
    )

    data class InternetTestResults(
    val successfulEndpoints: List<String> = emptyList(),
    val failedEndpoints: List<String> = emptyList(),
    val averageLatencyMs: Long = 0,
    val packetLossPercent: Float = 0f,
    val lastTestAtMs: Long = System.currentTimeMillis()
    )

    enum class ConnectionType { WIFI, CELLULAR, ETHERNET, VPN, BLUETOOTH, USB, UNKNOWN }
    enum class VpnType { WIREGUARD, OPENVPN, IPSEC, PPTP, L2TP, SSTP, UNKNOWN }

    /* ---------------------------
    Configuration
    --------------------------- */

    data class ConnectionHelperConfig(
    val enablePeriodicInternetTest: Boolean = true,
    val periodicIntervalMs: Long = 30_000L,
    val testEndpoints: List<String> = listOf(
    "https://dns.google/resolve",
    "https://cloudflare-dns.com/dns-query",
    "1.1.1.1:53",
    "8.8.8.8:53"
    ),
    val connectionTimeoutMs: Long = 5_000L,
    val maxRetries: Int = 2,
    val maxConcurrentTests: Int = 2,
    val useTcpFallback: Boolean = true,
    val enableDetailedLogging: Boolean = false,
    val debounceDelayMs: Long = 500L // Add debouncing for rapid network changes
    )

    /* ---------------------------
    Implementation
    --------------------------- */

    @Suppress("MemberVisibilityCanBePrivate")
    class ConnectionHelper(
    private val context: Context,
    private val config: ConnectionHelperConfig = ConnectionHelperConfig(),
    externalScope: CoroutineScope? = null
    ) {

    companion object {
    private const val TAG = "ConnectionHelper"
    }

    // Proper scope management
    private val internalScope = CoroutineScope(
    SupervisorJob() +
    Dispatchers.Default +
    CoroutineName("ConnectionHelper")
    )
    private val scope = externalScope ?: internalScope
    private val ownsScope = externalScope == null

    private val cm = context.applicationContext.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager

    private val _state = MutableStateFlow<NetworkState>(NetworkState.Unknown)
    val state: StateFlow<NetworkState> = _state.asStateFlow()

    // Convenience: current ConnectionInfo (nullable)
    val currentConnectionInfo: Flow<ConnectionInfo?> = state.map {
    (it as? NetworkState.Connected)?.connectionInfo
    }.distinctUntilChanged()

    // Better HTTP client configuration
    private val okHttpClient: OkHttpClient = OkHttpClient.Builder()
    .connectTimeout(config.connectionTimeoutMs, TimeUnit.MILLISECONDS)
    .readTimeout(config.connectionTimeoutMs, TimeUnit.MILLISECONDS)
    .callTimeout(config.connectionTimeoutMs * 2, TimeUnit.MILLISECONDS)
    .retryOnConnectionFailure(true)
    .followRedirects(false) // Don't follow redirects for connectivity tests
    .build()

    // Concurrency control
    private val testSemaphore = Semaphore(config.maxConcurrentTests)

    // State management
    private val isStarted = AtomicBoolean(false)
    private val isDisposed = AtomicBoolean(false)

    // Jobs
    private var periodicJob: Job? = null
    private var debounceJob: Job? = null

    // Debouncing channel for network changes
    private val networkChangeChannel = Channel<Unit>(Channel.UNLIMITED)

    private val networkCallback = object : ConnectivityManager.NetworkCallback() {
    override fun onAvailable(network: Network) {
    super.onAvailable(network)
    log("onAvailable($network)")
    scheduleRefresh()
    }

    override fun onLost(network: Network) {
    super.onLost(network)
    log("onLost($network)")
    _state.value = NetworkState.Disconnected
    scheduleRefresh() // Still schedule refresh to check for other networks
    }

    override fun onCapabilitiesChanged(network: Network, networkCapabilities: NetworkCapabilities) {
    super.onCapabilitiesChanged(network, networkCapabilities)
    log("onCapabilitiesChanged($network)")
    scheduleRefresh()
    }

    override fun onLinkPropertiesChanged(network: Network, linkProperties: LinkProperties) {
    super.onLinkPropertiesChanged(network, linkProperties)
    log("onLinkPropertiesChanged($network)")
    scheduleRefresh()
    }

    override fun onUnavailable() {
    super.onUnavailable()
    log("onUnavailable")
    _state.value = NetworkState.Disconnected
    }
    }

    init {
    // Start debounce processing
    scope.launch {
    for (unit in networkChangeChannel) {
    delay(config.debounceDelayMs)
    // Consume all pending changes
    while (!networkChangeChannel.isEmpty) {
    networkChangeChannel.tryReceive()
    }
    if (!isDisposed.get()) {
    refreshNetworkInfo()
    }
    }
    }
    }

    fun start() {
    if (!isStarted.compareAndSet(false, true)) {
    log("Already started")
    return
    }

    if (isDisposed.get()) {
    log("Cannot start disposed ConnectionHelper")
    return
    }

    try {
    registerNetworkCallback()

    // Perform immediate refresh
    refreshNetworkInfo()

    // Start periodic tests if enabled
    if (config.enablePeriodicInternetTest) {
    startPeriodicTesting()
    }

    log("ConnectionHelper started successfully")
    } catch (t: Throwable) {
    log("Failed to start: ${t.message}")
    _state.value = NetworkState.Error(t)
    }
    }

    fun stop() {
    if (!isStarted.compareAndSet(true, false)) {
    log("Already stopped or never started")
    return
    }

    try {
    unregisterNetworkCallback()
    periodicJob?.cancel()
    periodicJob = null
    debounceJob?.cancel()
    debounceJob = null

    if (ownsScope) {
    internalScope.cancel()
    }

    log("ConnectionHelper stopped")
    } catch (t: Throwable) {
    log("Error during stop: ${t.message}")
    }
    }

    fun dispose() {
    if (isDisposed.compareAndSet(false, true)) {
    stop()
    networkChangeChannel.close()
    try {
    okHttpClient.dispatcher.executorService.shutdown()
    } catch (t: Throwable) {
    log("Error disposing HTTP client: ${t.message}")
    }
    log("ConnectionHelper disposed")
    }
    }

    // Manual trigger (safe to call)
    fun refreshNetworkInfo() {
    if (isDisposed.get()) return

    scope.launch {
    try {
    updateNetworkState()
    } catch (t: Throwable) {
    log("Error in refreshNetworkInfo: ${t.message}")
    _state.value = NetworkState.Error(t)
    }
    }
    }

    private fun registerNetworkCallback() {
    try {
    val builder = NetworkRequest.Builder()
    .addCapability(NetworkCapabilities.NET_CAPABILITY_INTERNET)

    // Add transports
    builder.addTransportType(NetworkCapabilities.TRANSPORT_WIFI)
    .addTransportType(NetworkCapabilities.TRANSPORT_CELLULAR)
    .addTransportType(NetworkCapabilities.TRANSPORT_ETHERNET)
    .addTransportType(NetworkCapabilities.TRANSPORT_VPN)

    if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
    builder.addTransportType(NetworkCapabilities.TRANSPORT_BLUETOOTH)
    }
    if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.S) {
    builder.addTransportType(NetworkCapabilities.TRANSPORT_USB)
    }

    cm.registerNetworkCallback(builder.build(), networkCallback)
    log("NetworkCallback registered")
    } catch (t: Throwable) {
    log("Failed to register callback: ${t.message}")
    throw t
    }
    }

    private fun unregisterNetworkCallback() {
    try {
    cm.unregisterNetworkCallback(networkCallback)
    log("NetworkCallback unregistered")
    } catch (t: Throwable) {
    log("Error unregistering callback: ${t.message}")
    }
    }

    private fun scheduleRefresh() {
    if (isDisposed.get()) return
    networkChangeChannel.trySend(Unit)
    }

    private fun startPeriodicTesting() {
    periodicJob = scope.launch {
    while (isActive && !isDisposed.get()) {
    try {
    val currentState = _state.value
    if (currentState is NetworkState.Connected) {
    performInternetTestsAndPublish()
    }
    } catch (t: Throwable) {
    log("Periodic test failed: ${t.message}")
    }

    delay(config.periodicIntervalMs)
    }
    }
    }

    private suspend fun updateNetworkState() = withContext(Dispatchers.IO) {
    try {
    val activeNetwork = cm.activeNetwork
    val caps = activeNetwork?.let { cm.getNetworkCapabilities(it) }
    val linkProps = activeNetwork?.let { cm.getLinkProperties(it) }

    if (activeNetwork == null || caps == null) {
    log("No active network or capabilities")
    _state.value = NetworkState.Disconnected
    return@withContext
    }

    val connectionType = determineConnectionType(caps)
    val isVpnActive = isVpnActive(caps)
    val vpnInfo = if (isVpnActive) createVpnInfo(linkProps) else null
    val details = buildNetworkDetails(caps, linkProps, activeNetwork)

    // Quick internet check for initial state
    val hasInternetInitial = if (config.enablePeriodicInternetTest) {
    performQuickInternetCheck()
    } else {
    caps.hasCapability(NetworkCapabilities.NET_CAPABILITY_VALIDATED)
    }

    val connInfo = ConnectionInfo(
    type = connectionType,
    isVpnActive = isVpnActive,
    hasInternet = hasInternetInitial,
    vpnInfo = vpnInfo,
    details = details
    )

    _state.value = NetworkState.Connected(connInfo)

    // Perform full test asynchronously if enabled
    if (config.enablePeriodicInternetTest) {
    scope.launch {
    try {
    performInternetTestsAndPublish()
    } catch (t: Throwable) {
    log("Async internet test failed: ${t.message}")
    }
    }
    }

    } catch (t: Throwable) {
    log("updateNetworkState error: ${t.message}")
    _state.value = NetworkState.Error(t)
    }
    }

    private fun determineConnectionType(caps: NetworkCapabilities): ConnectionType {
    return when {
    caps.hasTransport(NetworkCapabilities.TRANSPORT_VPN) -> ConnectionType.VPN
    caps.hasTransport(NetworkCapabilities.TRANSPORT_WIFI) -> ConnectionType.WIFI
    caps.hasTransport(NetworkCapabilities.TRANSPORT_CELLULAR) -> ConnectionType.CELLULAR
    caps.hasTransport(NetworkCapabilities.TRANSPORT_ETHERNET) -> ConnectionType.ETHERNET
    Build.VERSION.SDK_INT >= Build.VERSION_CODES.O &&
    caps.hasTransport(NetworkCapabilities.TRANSPORT_BLUETOOTH) -> ConnectionType.BLUETOOTH
    Build.VERSION.SDK_INT >= Build.VERSION_CODES.S &&
    caps.hasTransport(NetworkCapabilities.TRANSPORT_USB) -> ConnectionType.USB
    else -> ConnectionType.UNKNOWN
    }
    }

    private fun isVpnActive(caps: NetworkCapabilities): Boolean {
    return caps.hasTransport(NetworkCapabilities.TRANSPORT_VPN) || checkVpnByInterface()
    }

    private fun createVpnInfo(linkProps: LinkProperties?): VpnInfo {
    val ifName = linkProps?.interfaceName ?: guessVpnInterfaceName()
    val gateway = linkProps?.routes?.find { it.isDefaultRoute }?.gateway?.hostAddress

    return VpnInfo(
    interfaceName = ifName,
    probableType = determineVpnType(ifName),
    gateway = gateway,
    isBlockingInternet = false // Could be enhanced with deeper inspection
    )
    }

    private fun buildNetworkDetails(
    caps: NetworkCapabilities,
    linkProps: LinkProperties?,
    network: Network
    ): NetworkDetails {
    val ips = linkProps?.linkAddresses?.mapNotNull { it.address?.hostAddress } ?: emptyList()
    val dns = linkProps?.dnsServers?.mapNotNull { it.hostAddress } ?: emptyList()
    val gateway = linkProps?.routes?.find { it.isDefaultRoute }?.gateway?.hostAddress

    val ssid: String? = if (caps.hasTransport(NetworkCapabilities.TRANSPORT_WIFI)) {
    getSsidSafely()
    } else null

    return NetworkDetails(
    networkId = network.toString(),
    ipAddresses = ips,
    dnsServers = dns,
    gateway = gateway,
    ssid = ssid,
    isMetered = !caps.hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_METERED),
    downstreamKbps = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
    caps.linkDownstreamBandwidthKbps.takeIf { it > 0 }
    } else null,
    upstreamKbps = if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
    caps.linkUpstreamBandwidthKbps.takeIf { it > 0 }
    } else null,
    signalStrength = getSignalStrength(caps)
    )
    }

    @SuppressLint("HardwareIds")
    private fun getSsidSafely(): String? {
    return try {
    val wm = context.applicationContext.getSystemService(Context.WIFI_SERVICE) as? WifiManager
    wm?.connectionInfo?.ssid?.removeSurrounding("\"")?.takeIf { it != "<unknown ssid>" }
    } catch (se: SecurityException) {
    log("No permission to read SSID: ${se.message}")
    null
    } catch (t: Throwable) {
    log("Error reading SSID: ${t.message}")
    null
    }
    }

    private fun getSignalStrength(caps: NetworkCapabilities): Int? {
    return if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
    caps.signalStrength.takeIf { it != Int.MIN_VALUE }
    } else null
    }

    private fun checkVpnByInterface(): Boolean {
    return try {
    NetworkInterface.getNetworkInterfaces()?.asSequence()?.any { nif ->
    nif.isUp && isVpnInterface(nif.name)
    } ?: false
    } catch (t: Throwable) {
    log("checkVpnByInterface failed: ${t.message}")
    false
    }
    }

    private fun isVpnInterface(name: String): Boolean {
    return name.startsWith("tun") || name.startsWith("ppp") ||
    name.startsWith("wg") || name.startsWith("ipsec") ||
    name.startsWith("utun") || name.startsWith("tap")
    }

    private fun guessVpnInterfaceName(): String? {
    return try {
    NetworkInterface.getNetworkInterfaces()?.asSequence()
    ?.firstOrNull { it.isUp && isVpnInterface(it.name) }
    ?.name
    } catch (t: Throwable) {
    null
    }
    }

    private fun determineVpnType(interfaceName: String?): VpnType {
    val name = interfaceName?.lowercase() ?: return VpnType.UNKNOWN
    return when {
    name.startsWith("wg") -> VpnType.WIREGUARD
    name.startsWith("tun") -> VpnType.OPENVPN
    name.startsWith("ppp") -> VpnType.PPTP
    name.startsWith("ipsec") -> VpnType.IPSEC
    name.startsWith("l2tp") -> VpnType.L2TP
    name.contains("sstp") -> VpnType.SSTP
    else -> VpnType.UNKNOWN
    }
    }

    private suspend fun performQuickInternetCheck(): Boolean = withContext(Dispatchers.IO) {
    for (endpoint in config.testEndpoints) {
    if (testEndpointWithRetry(endpoint, maxRetries = 1)) {
    return@withContext true
    }
    }
    false
    }

    private suspend fun performInternetTestsAndPublish() = withContext(Dispatchers.IO) {
    if (isDisposed.get()) return@withContext

    val endpoints = config.testEndpoints.distinct()
    val results = mutableListOf<Pair<String, Boolean>>()
    val latencies = mutableListOf<Long>()

    // Test with limited concurrency
    val tests = endpoints.map { endpoint ->
    async(Dispatchers.IO) {
    testSemaphore.acquire()
    try {
    val startTime = System.currentTimeMillis()
    val success = testEndpointWithRetry(endpoint, config.maxRetries)
    val latency = System.currentTimeMillis() - startTime

    synchronized(results) { results.add(endpoint to success) }
    if (success) {
    synchronized(latencies) { latencies.add(latency) }
    }

    endpoint to (success to latency)
    } finally {
    testSemaphore.release()
    }
    }
    }

    val resolved = tests.awaitAll()

    val successes = resolved.filter { it.second.first }.map { it.first }
    val failures = resolved.filterNot { it.second.first }.map { it.first }
    val avgLatency = if (latencies.isNotEmpty()) latencies.average().roundToLong() else 0L
    val packetLoss = if (resolved.isNotEmpty()) {
    ((resolved.size - successes.size).toFloat() / resolved.size) * 100f
    } else 0f

    val testResults = InternetTestResults(
    successfulEndpoints = successes,
    failedEndpoints = failures,
    averageLatencyMs = avgLatency,
    packetLossPercent = packetLoss,
    lastTestAtMs = System.currentTimeMillis()
    )

    // Update state with results
    val currentState = _state.value
    if (currentState is NetworkState.Connected) {
    val updatedInfo = currentState.connectionInfo.copy(
    hasInternet = successes.isNotEmpty(),
    internetTestResults = testResults
    )
    _state.value = NetworkState.Connected(updatedInfo)
    }
    }

    private suspend fun testEndpointWithRetry(endpoint: String, maxRetries: Int): Boolean {
    repeat(maxRetries + 1) { attempt ->
    if (testSingleEndpoint(endpoint)) {
    return true
    }
    if (attempt < maxRetries) {
    delay(500L * (attempt + 1)) // Exponential backoff
    }
    }
    return false
    }

    private suspend fun testSingleEndpoint(endpoint: String): Boolean = withContext(Dispatchers.IO) {
    val trimmedEndpoint = endpoint.trim()

    when {
    trimmedEndpoint.startsWith("http", ignoreCase = true) -> {
    testHttpEndpoint(trimmedEndpoint)
    }
    trimmedEndpoint.contains(":") -> {
    val parts = trimmedEndpoint.split(":", limit = 2)
    val host = parts[0]
    val port = parts.getOrNull(1)?.toIntOrNull() ?: 53
    testTcpConnection(host, port)
    }
    else -> {
    // Try multiple ports for plain host
    testTcpConnection(trimmedEndpoint, 443) ||
    testTcpConnection(trimmedEndpoint, 80) ||
    testTcpConnection(trimmedEndpoint, 53)
    }
    }
    }

    private suspend fun testHttpEndpoint(url: String): Boolean = withContext(Dispatchers.IO) {
    try {
    val request = Request.Builder().url(url).head().build()
    okHttpClient.newCall(request).execute().use { response ->
    val success = response.isSuccessful || response.code in 200..499
    log("HTTP $url -> ${response.code} success=$success")
    return@withContext success
    }
    } catch (e: IOException) {
    log("HTTP test $url failed: ${e.message}")
    if (config.useTcpFallback) {
    return@withContext testTcpFallbackFromUrl(url)
    }
    return@withContext false
    }
    }

    private suspend fun testTcpFallbackFromUrl(url: String): Boolean {
    return try {
    val uri = URI(url)
    val host = uri.host ?: return false
    val port = if (uri.port > 0) uri.port else if (uri.scheme == "https") 443 else 80
    testTcpConnection(host, port)
    } catch (t: Throwable) {
    log("TCP fallback from URL failed: ${t.message}")
    false
    }
    }

    private suspend fun testTcpConnection(host: String, port: Int): Boolean = withContext(Dispatchers.IO) {
    var socket: Socket? = null
    try {
    socket = Socket()
    socket.connect(InetSocketAddress(host, port), config.connectionTimeoutMs.toInt())
    log("TCP connect $host:$port OK")
    return@withContext true
    } catch (t: Throwable) {
    log("TCP connect $host:$port failed: ${t.message}")
    return@withContext false
    } finally {
    try {
    socket?.close()
    } catch (_: Throwable) {
    // Ignore close errors
    }
    }
    }

    // Public API for manual endpoint testing
    suspend fun testCustomEndpoint(endpoint: String): Boolean {
    return if (isDisposed.get()) {
    false
    } else {
    testEndpointWithRetry(endpoint, config.maxRetries)
    }
    }

    // Get current network info safely
    fun getCurrentNetworkInfo(): ConnectionInfo? {
    return (state.value as? NetworkState.Connected)?.connectionInfo
    }

    private fun log(message: String) {
    if (config.enableDetailedLogging) {
    Log.d(TAG, message)
    }
    }
    }

    // Extension function for easier observation
    fun ConnectionHelper.observeNetworkState(
    scope: CoroutineScope,
    onStateChange: (NetworkState) -> Unit
    ) {
    state.onEach(onStateChange).launchIn(scope)
    }