Rewrite RateLimitInterceptor (#7889)

This commit is contained in:
stevenyomi 2022-08-31 01:17:37 +08:00 committed by GitHub
parent 53f5ea7fe9
commit 532f662b05
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 89 deletions

View file

@ -5,6 +5,8 @@ import okhttp3.Interceptor
import okhttp3.OkHttpClient
import okhttp3.Response
import java.io.IOException
import java.util.ArrayDeque
import java.util.concurrent.Semaphore
import java.util.concurrent.TimeUnit
/**
@ -25,54 +27,77 @@ fun OkHttpClient.Builder.rateLimit(
permits: Int,
period: Long = 1,
unit: TimeUnit = TimeUnit.SECONDS,
) = addInterceptor(RateLimitInterceptor(permits, period, unit))
) = addInterceptor(RateLimitInterceptor(null, permits, period, unit))
private class RateLimitInterceptor(
/** We can probably accept domains or wildcards by comparing with [endsWith], etc. */
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
internal class RateLimitInterceptor(
private val host: String?,
private val permits: Int,
period: Long,
unit: TimeUnit,
) : Interceptor {
private val requestQueue = ArrayList<Long>(permits)
private val requestQueue = ArrayDeque<Long>(permits)
private val rateLimitMillis = unit.toMillis(period)
private val fairLock = Semaphore(1, true)
override fun intercept(chain: Interceptor.Chain): Response {
// Ignore canceled calls, otherwise they would jam the queue
if (chain.call().isCanceled()) {
throw IOException()
val call = chain.call()
if (call.isCanceled()) throw IOException("Canceled")
val request = chain.request()
when (host) {
null, request.url.host -> {} // need rate limit
else -> return chain.proceed(request)
}
synchronized(requestQueue) {
val now = SystemClock.elapsedRealtime()
val waitTime = if (requestQueue.size < permits) {
0
} else {
val oldestReq = requestQueue[0]
val newestReq = requestQueue[permits - 1]
try {
fairLock.acquire()
} catch (e: InterruptedException) {
throw IOException(e)
}
if (newestReq - oldestReq > rateLimitMillis) {
0
} else {
oldestReq + rateLimitMillis - now // Remaining time
val requestQueue = this.requestQueue
val timestamp: Long
try {
synchronized(requestQueue) {
while (requestQueue.size >= permits) { // queue is full, remove expired entries
val periodStart = SystemClock.elapsedRealtime() - rateLimitMillis
var hasRemovedExpired = false
while (requestQueue.isEmpty().not() && requestQueue.first <= periodStart) {
requestQueue.removeFirst()
hasRemovedExpired = true
}
if (call.isCanceled()) {
throw IOException("Canceled")
} else if (hasRemovedExpired) {
break
} else try { // wait for the first entry to expire, or notified by cached response
(requestQueue as Object).wait(requestQueue.first - periodStart)
} catch (_: InterruptedException) {
continue
}
}
}
// Final check
if (chain.call().isCanceled()) {
throw IOException()
// add request to queue
timestamp = SystemClock.elapsedRealtime()
requestQueue.addLast(timestamp)
}
} finally {
fairLock.release()
}
if (requestQueue.size == permits) {
requestQueue.removeAt(0)
}
if (waitTime > 0) {
requestQueue.add(now + waitTime)
Thread.sleep(waitTime) // Sleep inside synchronized to pause queued requests
} else {
requestQueue.add(now)
val response = chain.proceed(request)
if (response.networkResponse == null) { // response is cached, remove it from queue
synchronized(requestQueue) {
if (requestQueue.isEmpty() || timestamp < requestQueue.first) return@synchronized
requestQueue.removeFirstOccurrence(timestamp)
(requestQueue as Object).notifyAll()
}
}
return chain.proceed(chain.request())
return response
}
}

View file

@ -1,11 +1,7 @@
package eu.kanade.tachiyomi.network.interceptor
import android.os.SystemClock
import okhttp3.HttpUrl
import okhttp3.Interceptor
import okhttp3.OkHttpClient
import okhttp3.Response
import java.io.IOException
import java.util.concurrent.TimeUnit
/**
@ -28,58 +24,4 @@ fun OkHttpClient.Builder.rateLimitHost(
permits: Int,
period: Long = 1,
unit: TimeUnit = TimeUnit.SECONDS,
) = addInterceptor(SpecificHostRateLimitInterceptor(httpUrl, permits, period, unit))
class SpecificHostRateLimitInterceptor(
httpUrl: HttpUrl,
private val permits: Int,
period: Long,
unit: TimeUnit,
) : Interceptor {
private val requestQueue = ArrayList<Long>(permits)
private val rateLimitMillis = unit.toMillis(period)
private val host = httpUrl.host
override fun intercept(chain: Interceptor.Chain): Response {
// Ignore canceled calls, otherwise they would jam the queue
if (chain.call().isCanceled()) {
throw IOException()
} else if (chain.request().url.host != host) {
return chain.proceed(chain.request())
}
synchronized(requestQueue) {
val now = SystemClock.elapsedRealtime()
val waitTime = if (requestQueue.size < permits) {
0
} else {
val oldestReq = requestQueue[0]
val newestReq = requestQueue[permits - 1]
if (newestReq - oldestReq > rateLimitMillis) {
0
} else {
oldestReq + rateLimitMillis - now // Remaining time
}
}
// Final check
if (chain.call().isCanceled()) {
throw IOException()
}
if (requestQueue.size == permits) {
requestQueue.removeAt(0)
}
if (waitTime > 0) {
requestQueue.add(now + waitTime)
Thread.sleep(waitTime) // Sleep inside synchronized to pause queued requests
} else {
requestQueue.add(now)
}
}
return chain.proceed(chain.request())
}
}
) = addInterceptor(RateLimitInterceptor(httpUrl.host, permits, period, unit))