Skip to content

Commit

Permalink
fix(authentication): Assign New Token to New Requests (#207)
Browse files Browse the repository at this point in the history
PR: #207
  • Loading branch information
osama-salman99 authored May 30, 2023
1 parent 4549fa1 commit 80e9ac4
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,34 +37,32 @@ internal object AuthenticationHookFactory : HookFactory<AuthenticationConfigurat

private class AuthenticationHookBuilder(private val client: Client) : HookBuilder<AuthenticationConfiguration> {
private val log = OpenWorldLoggerFactory.getLogger(javaClass)
private val isLock = atomic(false)
private val lock = atomic(false)
private val authenticationStrategy = client.getAuthenticationStrategy()

override fun build(configs: AuthenticationConfiguration) {
val httpClient = client.httpClient

httpClient.plugin(HttpSend).intercept { request ->
if (authenticationStrategy.isNotIdentityRequest(request) && authenticationStrategy.isTokenAboutToExpire()) {
log.info(TOKEN_EXPIRED)
if (!isLock.getAndSet(true)) {
try {
authenticationStrategy.renewToken()
} finally {
isLock.compareAndSet(expect = true, update = false)
if (!authenticationStrategy.isIdentityRequest(request)) {
if (authenticationStrategy.isTokenAboutToExpire()) {
log.info(TOKEN_EXPIRED)
if (!lock.getAndSet(true)) {
try {
authenticationStrategy.renewToken()
} finally {
lock.compareAndSet(expect = true, update = false)
}
}
}
waitForTokenRenewal()
assignNewToken(request)
while (lock.value) delay(AUTHORIZATION_REQUEST_LOCK_DELAY)
assignLatestToken(request)
}
execute(request)
}
}

private fun assignNewToken(request: HttpRequestBuilder) {
private fun assignLatestToken(request: HttpRequestBuilder) {
request.headers[HeaderKey.AUTHORIZATION] = authenticationStrategy.getAuthorizationHeader()
}

private suspend fun waitForTokenRenewal() {
while (isLock.value) delay(AUTHORIZATION_REQUEST_LOCK_DELAY)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ internal interface AuthenticationStrategy {

fun renewToken()

fun isNotIdentityRequest(request: HttpRequestBuilder): Boolean
fun isIdentityRequest(request: HttpRequestBuilder): Boolean

fun getAuthorizationHeader(): String

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,11 @@ internal class OpenWorldAuthenticationStrategy(
) : AuthenticationStrategy {
private val log = OpenWorldLoggerFactory.getLogger(javaClass)
private var bearerTokenStorage = BearerTokensInfo.emptyBearerTokenInfo
private val loadTokensBlock: () -> BearerTokens = {
getTokens()
}

override fun loadAuth(auth: Auth) {
auth.bearer {
loadTokens(loadTokensBlock)

sendWithoutRequest { request ->
isNotIdentityRequest(request)
isIdentityRequest(request)
}
}
}
Expand Down Expand Up @@ -105,7 +100,7 @@ internal class OpenWorldAuthenticationStrategy(
)
}

override fun isNotIdentityRequest(request: HttpRequestBuilder): Boolean = request.url.buildString() != configs.authUrl
override fun isIdentityRequest(request: HttpRequestBuilder): Boolean = request.url.buildString() == configs.authUrl

override fun getAuthorizationHeader() = "${Authentication.BEARER} ${getTokens().accessToken}"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ internal class RapidAuthenticationStrategy(private val configs: AuthenticationCo
signature = calculateSignature(credentials.key, credentials.secret, Instant.now().epochSecond)
}

override fun isNotIdentityRequest(request: HttpRequestBuilder) = true
override fun isIdentityRequest(request: HttpRequestBuilder) = false

override fun getAuthorizationHeader() = createAuthorizationHeader(signature)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.expediagroup.openworld.sdk.core.configuration.Credentials
import com.expediagroup.openworld.sdk.core.configuration.OpenWorldClientConfiguration
import com.expediagroup.openworld.sdk.core.configuration.provider.OpenWorldConfigurationProvider
import com.expediagroup.openworld.sdk.core.constant.Authentication.BEARER
import com.expediagroup.openworld.sdk.core.constant.Constant.SUCCESSFUL_STATUS_CODES_RANGE
import com.expediagroup.openworld.sdk.core.constant.ExceptionMessage
import com.expediagroup.openworld.sdk.core.constant.HeaderKey
import com.expediagroup.openworld.sdk.core.model.exception.service.OpenWorldAuthException
Expand All @@ -40,6 +41,7 @@ import io.ktor.client.HttpClientConfig
import io.ktor.client.request.get
import io.ktor.client.request.request
import io.ktor.client.request.url
import io.ktor.client.statement.HttpResponse
import io.ktor.client.statement.request
import io.ktor.http.HttpMethod
import io.ktor.http.HttpStatusCode
Expand All @@ -57,6 +59,10 @@ import org.junit.jupiter.api.assertThrows
import org.junit.jupiter.params.ParameterizedTest
import org.junit.jupiter.params.provider.ArgumentsSource
import org.junit.jupiter.params.provider.ValueSource
import java.util.concurrent.Callable
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
import java.util.concurrent.Future

internal class OpenWorldAuthenticationStrategyTest : AuthenticationPluginTest() {

Expand Down Expand Up @@ -130,6 +136,32 @@ internal class OpenWorldAuthenticationStrategyTest : AuthenticationPluginTest()
}
}

@Test
fun `given requests constructed during token renewal then get assigned the new token`() {
runBlocking {
val client = ClientFactory.createOpenWorldClient()
val httpClient = client.httpClient
val authentication = client.getAuthenticationStrategy()
mockkObject(authentication)

val numberOfThreads = 8
val threadPool: ExecutorService = Executors.newFixedThreadPool(numberOfThreads)
val futures: MutableList<Future<HttpResponse>> = mutableListOf()
repeat(numberOfThreads + 5) {
futures.add(threadPool.submit(Callable { runBlocking { httpClient.get(ANY_URL) } }))
}

val failedRequests: List<HttpResponse> = futures.map { it.get() }.filter {
it.status.value !in SUCCESSFUL_STATUS_CODES_RANGE
}
assertThat(failedRequests).isEmpty()

verify(exactly = 1) {
authentication.renewToken()
}
}
}

@ParameterizedTest
@ValueSource(ints = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
fun `given request when token almost or is expired then should renew token`(expiresIn: Int) {
Expand Down

0 comments on commit 80e9ac4

Please sign in to comment.