Skip to content

Commit

Permalink
Add capability to use multiple keymanagers with same issuer
Browse files Browse the repository at this point in the history
  • Loading branch information
chamilaadhi committed Oct 9, 2024
1 parent 48ec34a commit 88b7607
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ public boolean mediate(MessageContext messageContext) {
String issuer = authContext.getIssuer();
List<String> scopes = authContext.getRequestTokenScopes();
if (StringUtils.isNotBlank(issuer)) {
KeyManagerDto keyManagerDto = KeyManagerHolder.getKeyManagerByIssuer(tenantDomain, issuer);
List<KeyManagerDto> keyManagerDtoList = KeyManagerHolder.getKeyManagerByIssuer(tenantDomain, issuer);
KeyManagerDto keyManagerDto = keyManagerDtoList.get(0); // TODO : Does not support multiple km with same
// issuer

if (keyManagerDto != null && StringUtils.isNotBlank(scope) && scopes.contains(scope)) {
String token = authContext.getAccessToken();
String consumerKey = authContext.getConsumerKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ public void init() {
Mockito.when(authContext.getIssuer()).thenReturn("https://localhost:9443/oauth2/token");
Mockito.when(authContext.getRequestTokenScopes()).thenReturn(scopes);

List<KeyManagerDto> keymanagerList = new ArrayList<KeyManagerDto>();
keymanagerList.add(keyManagerDto);
Mockito.when(KeyManagerHolder.getKeyManagerByIssuer(Mockito.anyString(), Mockito.anyString())).
thenReturn(keyManagerDto);
thenReturn(keymanagerList);
Mockito.when(keyManagerDto.getKeyManager()).thenReturn(keyManager);
Mockito.doNothing().when(oneTimeExecutorService).execute(() ->
keyManagerDto.getKeyManager().revokeOneTimeToken(Mockito.anyString(), Mockito.anyString()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
import org.wso2.carbon.apimgt.impl.APIConstants;
import org.wso2.carbon.apimgt.impl.jwt.JWTValidator;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

public class OrganizationKeyManagerDto {

private Map<String, KeyManagerDto> keyManagerMap = new LinkedHashMap<>();
private Map<String, String> issuerNameMap = new HashMap<>();
private Map<String, List<String>> issuerNameMap = new HashMap<>();

public Map<String, KeyManagerDto> getKeyManagerMap() {

Expand All @@ -51,35 +53,43 @@ public void putKeyManagerDto(KeyManagerDto keyManagerDto) {
} else {
keyManagerMap.put(keyManagerDto.getName(), keyManagerDto);
}
issuerNameMap.put(keyManagerDto.getIssuer(), keyManagerDto.getName());
issuerNameMap.computeIfAbsent(keyManagerDto.getIssuer(), k -> new ArrayList<>()).add(keyManagerDto.getName());

}

public void removeKeyManagerDtoByName(String name) {

KeyManagerDto keyManagerDto = keyManagerMap.get(name);
if (keyManagerDto != null) {
issuerNameMap.remove(keyManagerDto.getIssuer());
issuerNameMap.get(keyManagerDto.getIssuer()).remove(name);
if (issuerNameMap.get(keyManagerDto.getIssuer()).isEmpty()) {
issuerNameMap.remove(keyManagerDto.getIssuer());
}
}
keyManagerMap.remove(name);
}

public JWTValidator getJWTValidatorByIssuer(String issuer) {

String keyManagerName = issuerNameMap.get(issuer);
if (StringUtils.isNotEmpty(keyManagerName)) {
KeyManagerDto keyManagerDto = keyManagerMap.get(keyManagerName);
List<String> keyManagerNames = issuerNameMap.get(issuer);
if (keyManagerNames != null && !keyManagerNames.isEmpty()) {
KeyManagerDto keyManagerDto = keyManagerMap.get(keyManagerNames.get(0));
if (keyManagerDto != null) {
return keyManagerDto.getJwtValidator();
}
}
return null;
}

public KeyManagerDto getKeyManagerDtoByIssuer(String issuer) {
public List<KeyManagerDto> getKeyManagerDtoByIssuer(String issuer) {

String keyManagerName = issuerNameMap.get(issuer);
if (StringUtils.isNotEmpty(keyManagerName)) {
return keyManagerMap.get(keyManagerName);
List<KeyManagerDto> dtoList = new ArrayList<KeyManagerDto>();
List<String> keyManagerNames = issuerNameMap.get(issuer);
if (keyManagerNames != null && !keyManagerNames.isEmpty()) {
for (String keyManagerName : keyManagerNames) {
dtoList.add(keyManagerMap.get(keyManagerName));
}
return dtoList;
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.wso2.carbon.apimgt.impl.utils.APIUtil;

import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -259,10 +260,12 @@ public static KeyManager getKeyManagerInstance(String tenantDomain, String keyMa
return keyManager;
}

public static KeyManagerDto getKeyManagerByIssuer(String tenantDomain, String issuer) {
public static List<KeyManagerDto> getKeyManagerByIssuer(String tenantDomain, String issuer) {

if (globalJWTValidatorMap.containsKey(issuer)) {
return globalJWTValidatorMap.get(issuer);
List list = new ArrayList<KeyManagerDto>();
list.add(globalJWTValidatorMap.get(issuer));
return list;
}
OrganizationKeyManagerDto organizationKeyManagerDto = getTenantKeyManagerDto(tenantDomain);
if (organizationKeyManagerDto != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package org.wso2.carbon.apimgt.impl.jwt;

import java.util.List;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand All @@ -39,7 +41,18 @@ public JWTValidationInfo validateJWTToken(SignedJWTInfo signedJWTInfo) throws AP
JWTValidationInfo jwtValidationInfo = new JWTValidationInfo();
String issuer = signedJWTInfo.getJwtClaimsSet().getIssuer();
if (StringUtils.isNotEmpty(issuer)) {
KeyManagerDto keyManagerDto = KeyManagerHolder.getKeyManagerByIssuer(tenantDomain, issuer);
List<KeyManagerDto> keyManagerDtoList = KeyManagerHolder.getKeyManagerByIssuer(tenantDomain, issuer);
KeyManagerDto keyManagerDto = null;
if (keyManagerDtoList.size() == 1) { // only one keymanager. no need to check if it can handle token
keyManagerDto = keyManagerDtoList.get(0);
} else {
for (KeyManagerDto kmrDto : keyManagerDtoList) {
if (kmrDto.getKeyManager().canHandleToken(signedJWTInfo.getToken())) {
keyManagerDto = kmrDto;
break;
}
}
}
if (keyManagerDto != null && keyManagerDto.getJwtValidator() != null) {
JWTValidationInfo validationInfo = keyManagerDto.getJwtValidator().validateToken(signedJWTInfo);
validationInfo.setKeyManager(keyManagerDto.getName());
Expand All @@ -56,7 +69,18 @@ public String getKeyManagerNameIfJwtValidatorExist(SignedJWTInfo signedJWTInfo)
String tenantDomain = CarbonContext.getThreadLocalCarbonContext().getTenantDomain();

String issuer = signedJWTInfo.getJwtClaimsSet().getIssuer();
KeyManagerDto keyManagerDto = KeyManagerHolder.getKeyManagerByIssuer(tenantDomain, issuer);
List<KeyManagerDto> keyManagerDtoList = KeyManagerHolder.getKeyManagerByIssuer(tenantDomain, issuer);
KeyManagerDto keyManagerDto = null;
if (keyManagerDtoList.size() == 1) { // only one keymanager. no need to check if it can handle token
keyManagerDto = keyManagerDtoList.get(0);
} else {
for (KeyManagerDto kmrDto : keyManagerDtoList) {
if (kmrDto.getKeyManager().canHandleToken(signedJWTInfo.getToken())) {
keyManagerDto = kmrDto;
break;
}
}
}
if (keyManagerDto != null && keyManagerDto.getJwtValidator() != null) {
return keyManagerDto.getName();
}else{
Expand Down

0 comments on commit 88b7607

Please sign in to comment.