Rewrite RateLimitInterceptor (#7889)
This commit is contained in:
parent
53f5ea7fe9
commit
532f662b05
2 changed files with 56 additions and 89 deletions
|
@ -5,6 +5,8 @@ import okhttp3.Interceptor
|
|||
import okhttp3.OkHttpClient
|
||||
import okhttp3.Response
|
||||
import java.io.IOException
|
||||
import java.util.ArrayDeque
|
||||
import java.util.concurrent.Semaphore
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
/**
|
||||
|
@ -25,54 +27,77 @@ fun OkHttpClient.Builder.rateLimit(
|
|||
permits: Int,
|
||||
period: Long = 1,
|
||||
unit: TimeUnit = TimeUnit.SECONDS,
|
||||
) = addInterceptor(RateLimitInterceptor(permits, period, unit))
|
||||
) = addInterceptor(RateLimitInterceptor(null, permits, period, unit))
|
||||
|
||||
private class RateLimitInterceptor(
|
||||
/** We can probably accept domains or wildcards by comparing with [endsWith], etc. */
|
||||
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
|
||||
internal class RateLimitInterceptor(
|
||||
private val host: String?,
|
||||
private val permits: Int,
|
||||
period: Long,
|
||||
unit: TimeUnit,
|
||||
) : Interceptor {
|
||||
|
||||
private val requestQueue = ArrayList<Long>(permits)
|
||||
private val requestQueue = ArrayDeque<Long>(permits)
|
||||
private val rateLimitMillis = unit.toMillis(period)
|
||||
private val fairLock = Semaphore(1, true)
|
||||
|
||||
override fun intercept(chain: Interceptor.Chain): Response {
|
||||
// Ignore canceled calls, otherwise they would jam the queue
|
||||
if (chain.call().isCanceled()) {
|
||||
throw IOException()
|
||||
val call = chain.call()
|
||||
if (call.isCanceled()) throw IOException("Canceled")
|
||||
|
||||
val request = chain.request()
|
||||
when (host) {
|
||||
null, request.url.host -> {} // need rate limit
|
||||
else -> return chain.proceed(request)
|
||||
}
|
||||
|
||||
synchronized(requestQueue) {
|
||||
val now = SystemClock.elapsedRealtime()
|
||||
val waitTime = if (requestQueue.size < permits) {
|
||||
0
|
||||
} else {
|
||||
val oldestReq = requestQueue[0]
|
||||
val newestReq = requestQueue[permits - 1]
|
||||
try {
|
||||
fairLock.acquire()
|
||||
} catch (e: InterruptedException) {
|
||||
throw IOException(e)
|
||||
}
|
||||
|
||||
if (newestReq - oldestReq > rateLimitMillis) {
|
||||
0
|
||||
} else {
|
||||
oldestReq + rateLimitMillis - now // Remaining time
|
||||
val requestQueue = this.requestQueue
|
||||
val timestamp: Long
|
||||
|
||||
try {
|
||||
synchronized(requestQueue) {
|
||||
while (requestQueue.size >= permits) { // queue is full, remove expired entries
|
||||
val periodStart = SystemClock.elapsedRealtime() - rateLimitMillis
|
||||
var hasRemovedExpired = false
|
||||
while (requestQueue.isEmpty().not() && requestQueue.first <= periodStart) {
|
||||
requestQueue.removeFirst()
|
||||
hasRemovedExpired = true
|
||||
}
|
||||
if (call.isCanceled()) {
|
||||
throw IOException("Canceled")
|
||||
} else if (hasRemovedExpired) {
|
||||
break
|
||||
} else try { // wait for the first entry to expire, or notified by cached response
|
||||
(requestQueue as Object).wait(requestQueue.first - periodStart)
|
||||
} catch (_: InterruptedException) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Final check
|
||||
if (chain.call().isCanceled()) {
|
||||
throw IOException()
|
||||
// add request to queue
|
||||
timestamp = SystemClock.elapsedRealtime()
|
||||
requestQueue.addLast(timestamp)
|
||||
}
|
||||
} finally {
|
||||
fairLock.release()
|
||||
}
|
||||
|
||||
if (requestQueue.size == permits) {
|
||||
requestQueue.removeAt(0)
|
||||
}
|
||||
if (waitTime > 0) {
|
||||
requestQueue.add(now + waitTime)
|
||||
Thread.sleep(waitTime) // Sleep inside synchronized to pause queued requests
|
||||
} else {
|
||||
requestQueue.add(now)
|
||||
val response = chain.proceed(request)
|
||||
if (response.networkResponse == null) { // response is cached, remove it from queue
|
||||
synchronized(requestQueue) {
|
||||
if (requestQueue.isEmpty() || timestamp < requestQueue.first) return@synchronized
|
||||
requestQueue.removeFirstOccurrence(timestamp)
|
||||
(requestQueue as Object).notifyAll()
|
||||
}
|
||||
}
|
||||
|
||||
return chain.proceed(chain.request())
|
||||
return response
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,11 +1,7 @@
|
|||
package eu.kanade.tachiyomi.network.interceptor
|
||||
|
||||
import android.os.SystemClock
|
||||
import okhttp3.HttpUrl
|
||||
import okhttp3.Interceptor
|
||||
import okhttp3.OkHttpClient
|
||||
import okhttp3.Response
|
||||
import java.io.IOException
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
/**
|
||||
|
@ -28,58 +24,4 @@ fun OkHttpClient.Builder.rateLimitHost(
|
|||
permits: Int,
|
||||
period: Long = 1,
|
||||
unit: TimeUnit = TimeUnit.SECONDS,
|
||||
) = addInterceptor(SpecificHostRateLimitInterceptor(httpUrl, permits, period, unit))
|
||||
|
||||
class SpecificHostRateLimitInterceptor(
|
||||
httpUrl: HttpUrl,
|
||||
private val permits: Int,
|
||||
period: Long,
|
||||
unit: TimeUnit,
|
||||
) : Interceptor {
|
||||
|
||||
private val requestQueue = ArrayList<Long>(permits)
|
||||
private val rateLimitMillis = unit.toMillis(period)
|
||||
private val host = httpUrl.host
|
||||
|
||||
override fun intercept(chain: Interceptor.Chain): Response {
|
||||
// Ignore canceled calls, otherwise they would jam the queue
|
||||
if (chain.call().isCanceled()) {
|
||||
throw IOException()
|
||||
} else if (chain.request().url.host != host) {
|
||||
return chain.proceed(chain.request())
|
||||
}
|
||||
|
||||
synchronized(requestQueue) {
|
||||
val now = SystemClock.elapsedRealtime()
|
||||
val waitTime = if (requestQueue.size < permits) {
|
||||
0
|
||||
} else {
|
||||
val oldestReq = requestQueue[0]
|
||||
val newestReq = requestQueue[permits - 1]
|
||||
|
||||
if (newestReq - oldestReq > rateLimitMillis) {
|
||||
0
|
||||
} else {
|
||||
oldestReq + rateLimitMillis - now // Remaining time
|
||||
}
|
||||
}
|
||||
|
||||
// Final check
|
||||
if (chain.call().isCanceled()) {
|
||||
throw IOException()
|
||||
}
|
||||
|
||||
if (requestQueue.size == permits) {
|
||||
requestQueue.removeAt(0)
|
||||
}
|
||||
if (waitTime > 0) {
|
||||
requestQueue.add(now + waitTime)
|
||||
Thread.sleep(waitTime) // Sleep inside synchronized to pause queued requests
|
||||
} else {
|
||||
requestQueue.add(now)
|
||||
}
|
||||
}
|
||||
|
||||
return chain.proceed(chain.request())
|
||||
}
|
||||
}
|
||||
) = addInterceptor(RateLimitInterceptor(httpUrl.host, permits, period, unit))
|
||||
|
|
Reference in a new issue