diff --git a/pom.xml b/pom.xml index c42de55..c55ad57 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ pl.mlodawski.security credentials-support - 1.1 + 1.2 21 @@ -37,6 +37,13 @@ https://github.com/SimpleMethod/PKCS11-Java-Wrapper + + + github + GitHub SimpleMethod Apache Maven Packages + https://maven.pkg.github.com/SimpleMethod/PKCS11-Java-Wrapper + + diff --git a/src/main/java/pl/mlodawski/security/example/PKCS11Example.java b/src/main/java/pl/mlodawski/security/example/PKCS11Example.java index 967d32b..f4911b6 100644 --- a/src/main/java/pl/mlodawski/security/example/PKCS11Example.java +++ b/src/main/java/pl/mlodawski/security/example/PKCS11Example.java @@ -3,12 +3,12 @@ import pl.mlodawski.security.pkcs11.*; import pl.mlodawski.security.pkcs11.model.*; +import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.List; -import java.util.Base64; -import java.util.Map; -import java.util.Scanner; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.*; class PKCS11 { private final Path PKCS11_WRAPPER_PATH; @@ -76,15 +76,24 @@ public void onDeviceError(PKCS11Device device, Exception error) { signMessage(manager, session); break; case 3: - encryptDecryptData(manager, session); + signFile(manager, session); break; case 4: - listSupportedAlgorithms(manager, session); + verifyFileSignature(manager, session); break; case 5: + encryptDecryptData(manager, session); + break; + case 6: + encryptDecryptFile(manager, session); + break; + case 7: + listSupportedAlgorithms(manager, session); + break; + case 8: selectedDevice = null; return; - case 6: + case 9: session.close(); selectedDevice = null; PIN = null; @@ -96,14 +105,14 @@ public void onDeviceError(PKCS11Device device, Exception error) { System.out.println("Invalid choice. Please try again."); } - if (choice == 6) { - break; - } +// if (choice == 6) { +// break; +// } } } catch (Exception e) { System.out.println("Session error: " + e.getMessage()); selectedDevice = null; - PIN = null; + PIN = null; } } catch (Exception e) { System.out.println("An error occurred: " + e.getMessage()); @@ -115,6 +124,156 @@ public void onDeviceError(PKCS11Device device, Exception error) { } } + private void signFile(PKCS11Manager manager, PKCS11Session session) { + try { + KeyCertificatePair selectedPair = selectCertificateKeyPair(manager, session); + + System.out.print("Enter path to file to sign: "); + Scanner scanner = new Scanner(System.in); + String filePath = scanner.nextLine(); + + Path path = Paths.get(filePath); + if (!Files.exists(path)) { + System.out.println("File does not exist: " + filePath); + return; + } + + byte[] fileContent = Files.readAllBytes(path); + PKCS11Signer signer = new PKCS11Signer(); + byte[] signature = signer.signMessage(manager.getPkcs11(), + session.getSession(), + selectedPair.getKeyHandle(), + fileContent); + + // Save signature to file + String signatureFilePath = filePath + ".sig"; + Files.write(Paths.get(signatureFilePath), signature); + + System.out.println("File signed successfully. Signature saved to: " + signatureFilePath); + System.out.println("Signature (Base64): " + Base64.getEncoder().encodeToString(signature)); + + // Verify signature immediately + boolean isSignatureValid = signer.verifySignature(fileContent, signature, selectedPair.getCertificate()); + System.out.println("Signature verification: " + (isSignatureValid ? "Valid" : "Invalid")); + } catch (Exception e) { + System.out.println("Error during file signing: " + e.getMessage()); + throw new RuntimeException(e); + } + } + + private void verifyFileSignature(PKCS11Manager manager, PKCS11Session session) { + try { + KeyCertificatePair selectedPair = selectCertificateKeyPair(manager, session); + Scanner scanner = new Scanner(System.in); + + System.out.print("Enter path to file to verify: "); + String filePath = scanner.nextLine(); + + System.out.print("Enter path to signature file: "); + String signatureFilePath = scanner.nextLine(); + + if (!Files.exists(Paths.get(filePath))) { + System.out.println("File does not exist: " + filePath); + return; + } + if (!Files.exists(Paths.get(signatureFilePath))) { + System.out.println("Signature file does not exist: " + signatureFilePath); + return; + } + + byte[] fileContent = Files.readAllBytes(Paths.get(filePath)); + byte[] signature = Files.readAllBytes(Paths.get(signatureFilePath)); + + PKCS11Signer signer = new PKCS11Signer(); + boolean isSignatureValid = signer.verifySignature(fileContent, signature, selectedPair.getCertificate()); + + System.out.println("Signature verification result: " + (isSignatureValid ? "Valid" : "Invalid")); + } catch (Exception e) { + System.out.println("Error during signature verification: " + e.getMessage()); + } + } + + private void encryptDecryptFile(PKCS11Manager manager, PKCS11Session session) { + try { + KeyCertificatePair selectedPair = selectCertificateKeyPair(manager, session); + Scanner scanner = new Scanner(System.in); + + System.out.print("Enter path to file to encrypt: "); + String filePath = scanner.nextLine(); + + if (!Files.exists(Paths.get(filePath))) { + System.out.println("File does not exist: " + filePath); + return; + } + + byte[] fileContent = Files.readAllBytes(Paths.get(filePath)); + + // Encrypt file using hybrid encryption + PKCS11Crypto crypto = new PKCS11Crypto(); + byte[][] encryptedPackage = crypto.encryptData(fileContent, selectedPair.getCertificate()); + + // Save encrypted components + String encryptedKeyPath = filePath + ".key.enc"; + String encryptedIVPath = filePath + ".iv"; + String encryptedDataPath = filePath + ".data.enc"; + + Files.write(Paths.get(encryptedKeyPath), encryptedPackage[0]); + Files.write(Paths.get(encryptedIVPath), encryptedPackage[1]); + Files.write(Paths.get(encryptedDataPath), encryptedPackage[2]); + + System.out.println("File encrypted successfully."); + System.out.println("Encrypted key saved to: " + encryptedKeyPath); + System.out.println("IV saved to: " + encryptedIVPath); + System.out.println("Encrypted data saved to: " + encryptedDataPath); + + System.out.println("\nDo you want to decrypt the file? (y/n)"); + String answer = scanner.nextLine().toLowerCase(); + if (!answer.equals("y")) { + return; + } + + // Decrypt file + byte[][] decryptPackage = new byte[][]{ + Files.readAllBytes(Paths.get(encryptedKeyPath)), + Files.readAllBytes(Paths.get(encryptedIVPath)), + Files.readAllBytes(Paths.get(encryptedDataPath)) + }; + + byte[] decryptedData = crypto.decryptData( + manager.getPkcs11(), + session.getSession(), + selectedPair.getKeyHandle(), + decryptPackage + ); + + // Save decrypted file + String decryptedFilePath = filePath + ".dec"; + Files.write(Paths.get(decryptedFilePath), decryptedData); + System.out.println("File decrypted successfully. Saved to: " + decryptedFilePath); + + // Calculate and display checksums + String originalChecksum = getFileChecksum(fileContent); + String decryptedChecksum = getFileChecksum(decryptedData); + + System.out.println("\nFile integrity verification:"); + System.out.println("Original file SHA-256: " + originalChecksum); + System.out.println("Decrypted file SHA-256: " + decryptedChecksum); + + if (Arrays.equals(fileContent, decryptedData)) { + System.out.println("File integrity verified: Original and decrypted files match."); + } else { + System.out.println("Warning: Decrypted file does not match original!"); + } + } catch (Exception e) { + System.out.println("Error during file encryption/decryption: " + e.getMessage()); + } + } + + private String getFileChecksum(byte[] fileData) throws NoSuchAlgorithmException { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + byte[] hash = digest.digest(fileData); + return Base64.getEncoder().encodeToString(hash); + } private boolean handleDeviceChange(PKCS11Manager manager) { int maxRetries = 3; @@ -130,7 +289,6 @@ private boolean handleDeviceChange(PKCS11Manager manager) { continue; } - if (!getPINFromUser()) { retryCount++; continue; @@ -223,10 +381,13 @@ private void displayMenu() { System.out.println("Current device: " + selectedDevice.getLabel()); System.out.println("1. List Available Certificates"); System.out.println("2. Sign a Message"); - System.out.println("3. Encrypt and Decrypt Data"); - System.out.println("4. List Supported Algorithms"); - System.out.println("5. Exit"); - System.out.println("6. Change Device"); + System.out.println("3. Sign a File"); + System.out.println("4. Verify File Signature"); + System.out.println("5. Encrypt and Decrypt Data"); + System.out.println("6. Encrypt and Decrypt File"); + System.out.println("7. List Supported Algorithms"); + System.out.println("8. Exit"); + System.out.println("9. Change Device"); System.out.print("Enter your choice: "); } @@ -294,12 +455,20 @@ private void encryptDecryptData(PKCS11Manager manager, PKCS11Session session) { Scanner scanner = new Scanner(System.in); String dataToEncrypt = scanner.nextLine(); - PKCS11Crypto decryptor = new PKCS11Crypto(); + PKCS11Crypto crypto = new PKCS11Crypto(); - byte[] encryptedData = decryptor.encryptData(dataToEncrypt.getBytes(), selectedPair.getCertificate()); + // Encrypt data + byte[][] encryptedPackage = crypto.encryptData(dataToEncrypt.getBytes(), selectedPair.getCertificate()); System.out.println("Data encrypted successfully."); + System.out.println("Encrypted data (Base64): " + Base64.getEncoder().encodeToString(encryptedPackage[2])); - byte[] decryptedData = decryptor.decryptData(manager.getPkcs11(), session.getSession(), selectedPair.getKeyHandle(), encryptedData); + // Decrypt data + byte[] decryptedData = crypto.decryptData( + manager.getPkcs11(), + session.getSession(), + selectedPair.getKeyHandle(), + encryptedPackage + ); System.out.println("Decrypted data: " + new String(decryptedData)); if (dataToEncrypt.equals(new String(decryptedData))) { @@ -309,10 +478,8 @@ private void encryptDecryptData(PKCS11Manager manager, PKCS11Session session) { } } catch (IllegalArgumentException e) { System.out.println("Invalid input: " + e.getMessage()); - throw e; } catch (Exception e) { System.out.println("Error during encryption/decryption: " + e.getMessage()); - throw e; } } @@ -339,6 +506,7 @@ private void listSupportedAlgorithms(PKCS11Manager manager, PKCS11Session sessio } } + public class PKCS11Example { public static void main(String[] args) { String userDir = System.getProperty("user.dir"); @@ -347,4 +515,4 @@ public static void main(String[] args) { ); example.run(); } -} +} \ No newline at end of file diff --git a/src/main/java/pl/mlodawski/security/pkcs11/PKCS11Crypto.java b/src/main/java/pl/mlodawski/security/pkcs11/PKCS11Crypto.java index ec9be4f..c7a1317 100644 --- a/src/main/java/pl/mlodawski/security/pkcs11/PKCS11Crypto.java +++ b/src/main/java/pl/mlodawski/security/pkcs11/PKCS11Crypto.java @@ -13,62 +13,54 @@ import com.sun.jna.ptr.NativeLongByReference; import javax.crypto.Cipher; +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; import java.security.cert.X509Certificate; +/** + * The {@code PKCS11Crypto} class provides methods for encrypting and decrypting data using + * a combination of AES and RSA algorithms. It leverages PKCS#11 for RSA decryption in hardware + * security modules (HSMs). + * + * This class uses the following transformations: + * - AES: "AES/CBC/PKCS5Padding" + * - RSA: "RSA/ECB/PKCS1Padding" + */ @Slf4j public class PKCS11Crypto { + private static final int AES_KEY_SIZE = 256; + private static final String AES_TRANSFORMATION = "AES/CBC/PKCS5Padding"; + private static final String RSA_TRANSFORMATION = "RSA/ECB/PKCS1Padding"; + /** - * Initializes the cryptology process with the specified PKCS11 object, session, and private key handle. + * Encrypts the given data using AES encryption and then encrypts the AES key with RSA. * - * @param pkcs11 the PKCS11 object used for decryption - * @param session the session used for decryption - * @param privateKeyHandle the handle to the private key used for decryption - * @throws IllegalArgumentException if any of the parameters is null - * @throws RuntimeException if the decryption initialization fails + * @param dataToEncrypt the data to be encrypted + * @param certificate the X509 certificate containing the public key for RSA encryption + * @return a byte array containing the encrypted AES key, IV, and the encrypted data + * @throws EncryptionException if any error occurs during the encryption process */ - private void initCrypto(Pkcs11 pkcs11, NativeLong session, NativeLong privateKeyHandle) { - if (pkcs11 == null) { - throw new IllegalArgumentException("pkcs11 cannot be null"); - } - if (session == null) { - throw new IllegalArgumentException("session cannot be null"); - } - if (privateKeyHandle == null) { - throw new IllegalArgumentException("privateKeyHandle cannot be null"); - } + public byte[][] encryptData(byte[] dataToEncrypt, X509Certificate certificate) { + validateEncryptInput(dataToEncrypt, certificate); try { - CK_MECHANISM mechanism = new CK_MECHANISM(); - mechanism.mechanism = new NativeLong(Pkcs11Constants.CKM_RSA_PKCS); - pkcs11.C_DecryptInit(session, mechanism, privateKeyHandle); - } catch (Exception e) { - log.error("Crypto initialization failed", e); - throw new CryptoInitializationException("Crypto initialization failed", e); - } - } + KeyGenerator keyGen = KeyGenerator.getInstance("AES"); + keyGen.init(AES_KEY_SIZE); + SecretKey aesKey = keyGen.generateKey(); - /** - * Encrypts the given data using RSA algorithm with the provided X509 certificate. - * - * @param dataToEncrypt the data to be encrypted - * @param certificate the X509 certificate used for encryption - * @return the encrypted data - * @throws IllegalArgumentException if dataToEncrypt is null or empty, or if certificate is null - * @throws RuntimeException if encryption fails - */ - public byte[] encryptData(byte[] dataToEncrypt, X509Certificate certificate) { - if (dataToEncrypt == null || dataToEncrypt.length == 0) { - throw new IllegalArgumentException("dataToEncrypt cannot be null or empty"); - } - if (certificate == null) { - throw new IllegalArgumentException("certificate cannot be null"); - } + Cipher aesCipher = Cipher.getInstance(AES_TRANSFORMATION); + aesCipher.init(Cipher.ENCRYPT_MODE, aesKey); + byte[] iv = aesCipher.getIV(); - try { - Cipher cipher = Cipher.getInstance("RSA/ECB/PKCS1Padding", new BouncyCastleProvider()); - cipher.init(Cipher.ENCRYPT_MODE, certificate.getPublicKey()); - return cipher.doFinal(dataToEncrypt); + byte[] encryptedData = aesCipher.doFinal(dataToEncrypt); + + Cipher rsaCipher = Cipher.getInstance(RSA_TRANSFORMATION, new BouncyCastleProvider()); + rsaCipher.init(Cipher.ENCRYPT_MODE, certificate.getPublicKey()); + byte[] encryptedKey = rsaCipher.doFinal(aesKey.getEncoded()); + + return new byte[][]{encryptedKey, iv, encryptedData}; } catch (Exception e) { log.error("Encryption failed", e); throw new EncryptionException("Encryption failed", e); @@ -76,76 +68,58 @@ public byte[] encryptData(byte[] dataToEncrypt, X509Certificate certificate) { } /** - * Decrypts the given encrypted data using the specified private key. + * Decrypts the provided encrypted package using the specified PKCS#11 session and private key. * - * @param pkcs11 the Pkcs11 instance used for decryption - * @param session the native long value representing the session - * @param privateKeyHandle the native long value representing the private key handle - * @param encryptedData the byte array of encrypted data to be decrypted - * @return the decrypted data as a byte array - * @throws IllegalArgumentException if any of the input parameters are null or invalid - * @throws RuntimeException if decryption fails + * @param pkcs11 the PKCS#11 instance to use for decryption. + * @param session the active PKCS#11 session. + * @param privateKeyHandle the handle of the private key to use for decryption. + * @param encryptedPackage a 2D array containing the encrypted components: the encrypted AES key, the IV, and the encrypted data. + * @return the decrypted data as a byte array. + * @throws DecryptionException if the decryption process fails. */ - public byte[] decryptData(Pkcs11 pkcs11, NativeLong session, NativeLong privateKeyHandle, byte[] encryptedData) { - if (pkcs11 == null) { - throw new IllegalArgumentException("pkcs11 cannot be null"); - } - if (session == null) { - throw new IllegalArgumentException("session cannot be null"); - } - if (privateKeyHandle == null) { - throw new IllegalArgumentException("privateKeyHandle cannot be null"); - } - if (encryptedData == null || encryptedData.length == 0) { - throw new InvalidInputException("encryptedData cannot be null or empty"); - } + public byte[] decryptData(Pkcs11 pkcs11, NativeLong session, NativeLong privateKeyHandle, byte[][] encryptedPackage) { + validateDecryptInput(pkcs11, session, privateKeyHandle, encryptedPackage); + + byte[] encryptedKey = encryptedPackage[0]; + byte[] iv = encryptedPackage[1]; + byte[] encryptedData = encryptedPackage[2]; try { - initCrypto(pkcs11, session, privateKeyHandle); - return decrypt(pkcs11, session, encryptedData); - } catch (CryptoInitializationException e) { - log.error("Crypto initialization failed", e); - throw e; + CK_MECHANISM mechanism = new CK_MECHANISM(); + mechanism.mechanism = new NativeLong(Pkcs11Constants.CKM_RSA_PKCS); + NativeLong rv = pkcs11.C_DecryptInit(session, mechanism, privateKeyHandle); + checkResult(rv, "Failed to initialize decryption"); + + byte[] aesKeyBytes = decrypt(pkcs11, session, encryptedKey); + SecretKey aesKey = new SecretKeySpec(aesKeyBytes, "AES"); + + Cipher aesCipher = Cipher.getInstance(AES_TRANSFORMATION); + aesCipher.init(Cipher.DECRYPT_MODE, aesKey, new javax.crypto.spec.IvParameterSpec(iv)); + return aesCipher.doFinal(encryptedData); } catch (Exception e) { log.error("Decryption failed", e); throw new DecryptionException("Decryption failed", e); } } - /** - * Decrypts the given encrypted data using PKCS11. + * Decrypts the provided encrypted data using the given PKCS#11 session. * - * @param pkcs11 the PKCS11 object used for encryption - * @param session the session ID - * @param encryptedData the data to decrypt - * @return the decrypted data - * @throws IllegalArgumentException if pkcs11, session, or encryptedData is null/empty - * @throws RuntimeException if decryption fails + * @param pkcs11 the PKCS#11 interface to perform cryptographic operations + * @param session the session handle used for decryption operations + * @param encryptedData the data to be decrypted + * @return the decrypted data as a byte array + * @throws DecryptionException if the decryption process fails */ - public byte[] decrypt(Pkcs11 pkcs11, NativeLong session, byte[] encryptedData) { - if (pkcs11 == null) { - throw new IllegalArgumentException("pkcs11 cannot be null"); - } - if (session == null) { - throw new IllegalArgumentException("session cannot be null"); - } - if (encryptedData == null || encryptedData.length == 0) { - throw new IllegalArgumentException("encryptedData cannot be null or empty"); - } - + private byte[] decrypt(Pkcs11 pkcs11, NativeLong session, byte[] encryptedData) { try { NativeLongByReference decryptedDataLen = new NativeLongByReference(); NativeLong result = pkcs11.C_Decrypt(session, encryptedData, new NativeLong(encryptedData.length), null, decryptedDataLen); - if (!result.equals(new NativeLong(Pkcs11Constants.CKR_OK))) { - throw new DecryptionException("Decryption failed with error code: " + result, null); - } + checkResult(result, "Decryption failed with error code"); byte[] decryptedData = new byte[decryptedDataLen.getValue().intValue()]; result = pkcs11.C_Decrypt(session, encryptedData, new NativeLong(encryptedData.length), decryptedData, decryptedDataLen); - if (!result.equals(new NativeLong(Pkcs11Constants.CKR_OK))) { - throw new DecryptionException("Decryption failed with error code: " + result, null); - } + checkResult(result, "Decryption failed with error code"); return decryptedData; } catch (Exception e) { @@ -154,4 +128,56 @@ public byte[] decrypt(Pkcs11 pkcs11, NativeLong session, byte[] encryptedData) { } } + /** + * Helper method to check the result from PKCS#11 calls. + * + * @param result the result code returned from the PKCS#11 call. + * @param errorMessage the error message to include in the exception if the result is not CKR_OK. + * @throws DecryptionException if the result is not CKR_OK. + */ + private void checkResult(NativeLong result, String errorMessage) { + if (!result.equals(new NativeLong(Pkcs11Constants.CKR_OK))) { + throw new DecryptionException(errorMessage + ": " + result, null); + } + } + + /** + * Validates the input parameters for the encryption process. + * + * @param dataToEncrypt the data to be encrypted, which must not be null or empty + * @param certificate the X509Certificate to be used for encryption, which must not be null and must contain a public key + * @throws InvalidInputException if any of the input parameters are invalid + */ + private void validateEncryptInput(byte[] dataToEncrypt, X509Certificate certificate) { + if (dataToEncrypt == null || dataToEncrypt.length == 0) { + throw new InvalidInputException("dataToEncrypt cannot be null or empty"); + } + if (certificate == null || certificate.getPublicKey() == null) { + throw new InvalidInputException("certificate or its public key cannot be null"); + } + } + + /** + * Validates the input for decryption. + * + * @param pkcs11 the instance of Pkcs11, cannot be null + * @param session the session handle, cannot be null + * @param privateKeyHandle the handle of the private key, cannot be null + * @param encryptedPackage the encrypted data package, must be an array of three non-null byte arrays + * @throws InvalidInputException if any of the input parameters are invalid + */ + private void validateDecryptInput(Pkcs11 pkcs11, NativeLong session, NativeLong privateKeyHandle, byte[][] encryptedPackage) { + if (pkcs11 == null) { + throw new InvalidInputException("pkcs11 instance cannot be null"); + } + if (session == null || privateKeyHandle == null) { + throw new InvalidInputException("session and privateKeyHandle cannot be null"); + } + if (encryptedPackage == null || encryptedPackage.length != 3) { + throw new InvalidInputException("encryptedPackage format is invalid"); + } + if (encryptedPackage[0] == null || encryptedPackage[1] == null || encryptedPackage[2] == null) { + throw new InvalidInputException("encryptedPackage elements cannot be null"); + } + } } \ No newline at end of file diff --git a/src/main/java/pl/mlodawski/security/pkcs11/PKCS11RSACrypto.java b/src/main/java/pl/mlodawski/security/pkcs11/PKCS11RSACrypto.java new file mode 100644 index 0000000..dda6be8 --- /dev/null +++ b/src/main/java/pl/mlodawski/security/pkcs11/PKCS11RSACrypto.java @@ -0,0 +1,161 @@ +package pl.mlodawski.security.pkcs11; + +import com.sun.jna.NativeLong; +import com.sun.jna.ptr.NativeLongByReference; +import lombok.extern.slf4j.Slf4j; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import pl.mlodawski.security.pkcs11.exceptions.CryptoInitializationException; +import pl.mlodawski.security.pkcs11.exceptions.DecryptionException; +import pl.mlodawski.security.pkcs11.exceptions.EncryptionException; +import pl.mlodawski.security.pkcs11.exceptions.InvalidInputException; +import ru.rutoken.pkcs11jna.CK_MECHANISM; +import ru.rutoken.pkcs11jna.Pkcs11; +import ru.rutoken.pkcs11jna.Pkcs11Constants; + +import javax.crypto.Cipher; +import java.security.cert.X509Certificate; + +/** + * PKCS11RSACrypto is a class that provides encryption and decryption methods + * using RSA algorithm and PKCS#11 standard. + */ +@Slf4j +public class PKCS11RSACrypto { + + /** + * Initializes the cryptology process with the specified PKCS11 object, session, and private key handle. + * + * @param pkcs11 the PKCS11 object used for decryption + * @param session the session used for decryption + * @param privateKeyHandle the handle to the private key used for decryption + * @throws IllegalArgumentException if any of the parameters is null + * @throws RuntimeException if the decryption initialization fails + */ + private void initCrypto(Pkcs11 pkcs11, NativeLong session, NativeLong privateKeyHandle) { + if (pkcs11 == null) { + throw new IllegalArgumentException("pkcs11 cannot be null"); + } + if (session == null) { + throw new IllegalArgumentException("session cannot be null"); + } + if (privateKeyHandle == null) { + throw new IllegalArgumentException("privateKeyHandle cannot be null"); + } + + try { + CK_MECHANISM mechanism = new CK_MECHANISM(); + mechanism.mechanism = new NativeLong(Pkcs11Constants.CKM_RSA_PKCS); + pkcs11.C_DecryptInit(session, mechanism, privateKeyHandle); + } catch (Exception e) { + log.error("Crypto initialization failed", e); + throw new CryptoInitializationException("Crypto initialization failed", e); + } + } + + /** + * Encrypts the given data using RSA algorithm with the provided X509 certificate. + * + * @param dataToEncrypt the data to be encrypted + * @param certificate the X509 certificate used for encryption + * @return the encrypted data + * @throws IllegalArgumentException if dataToEncrypt is null or empty, or if certificate is null + * @throws RuntimeException if encryption fails + */ + public byte[] encryptData(byte[] dataToEncrypt, X509Certificate certificate) { + if (dataToEncrypt == null || dataToEncrypt.length == 0) { + throw new IllegalArgumentException("dataToEncrypt cannot be null or empty"); + } + if (certificate == null) { + throw new IllegalArgumentException("certificate cannot be null"); + } + + try { + Cipher cipher = Cipher.getInstance("RSA/ECB/PKCS1Padding", new BouncyCastleProvider()); + cipher.init(Cipher.ENCRYPT_MODE, certificate.getPublicKey()); + return cipher.doFinal(dataToEncrypt); + } catch (Exception e) { + log.error("Encryption failed", e); + throw new EncryptionException("Encryption failed", e); + } + } + + /** + * Decrypts the given encrypted data using the specified private key. + * + * @param pkcs11 the Pkcs11 instance used for decryption + * @param session the native long value representing the session + * @param privateKeyHandle the native long value representing the private key handle + * @param encryptedData the byte array of encrypted data to be decrypted + * @return the decrypted data as a byte array + * @throws IllegalArgumentException if any of the input parameters are null or invalid + * @throws RuntimeException if decryption fails + */ + public byte[] decryptData(Pkcs11 pkcs11, NativeLong session, NativeLong privateKeyHandle, byte[] encryptedData) { + if (pkcs11 == null) { + throw new IllegalArgumentException("pkcs11 cannot be null"); + } + if (session == null) { + throw new IllegalArgumentException("session cannot be null"); + } + if (privateKeyHandle == null) { + throw new IllegalArgumentException("privateKeyHandle cannot be null"); + } + if (encryptedData == null || encryptedData.length == 0) { + throw new InvalidInputException("encryptedData cannot be null or empty"); + } + + try { + initCrypto(pkcs11, session, privateKeyHandle); + return decrypt(pkcs11, session, encryptedData); + } catch (CryptoInitializationException e) { + log.error("Crypto initialization failed", e); + throw e; + } catch (Exception e) { + log.error("Decryption failed", e); + throw new DecryptionException("Decryption failed", e); + } + } + + + /** + * Decrypts the given encrypted data using PKCS11. + * + * @param pkcs11 the PKCS11 object used for encryption + * @param session the session ID + * @param encryptedData the data to decrypt + * @return the decrypted data + * @throws IllegalArgumentException if pkcs11, session, or encryptedData is null/empty + * @throws RuntimeException if decryption fails + */ + public byte[] decrypt(Pkcs11 pkcs11, NativeLong session, byte[] encryptedData) { + if (pkcs11 == null) { + throw new IllegalArgumentException("pkcs11 cannot be null"); + } + if (session == null) { + throw new IllegalArgumentException("session cannot be null"); + } + if (encryptedData == null || encryptedData.length == 0) { + throw new IllegalArgumentException("encryptedData cannot be null or empty"); + } + + try { + NativeLongByReference decryptedDataLen = new NativeLongByReference(); + NativeLong result = pkcs11.C_Decrypt(session, encryptedData, new NativeLong(encryptedData.length), null, decryptedDataLen); + if (!result.equals(new NativeLong(Pkcs11Constants.CKR_OK))) { + throw new DecryptionException("Decryption failed with error code: " + result, null); + } + + byte[] decryptedData = new byte[decryptedDataLen.getValue().intValue()]; + result = pkcs11.C_Decrypt(session, encryptedData, new NativeLong(encryptedData.length), decryptedData, decryptedDataLen); + if (!result.equals(new NativeLong(Pkcs11Constants.CKR_OK))) { + throw new DecryptionException("Decryption failed with error code: " + result, null); + } + + return decryptedData; + } catch (Exception e) { + log.error("Decryption failed", e); + throw new DecryptionException("Decryption failed", e); + } + } + +} \ No newline at end of file diff --git a/src/test/java/pl/mlodawski/security/pkcs11/PKCS11CryptoTest.java b/src/test/java/pl/mlodawski/security/pkcs11/PKCS11CryptoTest.java index 768332a..b8ebc03 100644 --- a/src/test/java/pl/mlodawski/security/pkcs11/PKCS11CryptoTest.java +++ b/src/test/java/pl/mlodawski/security/pkcs11/PKCS11CryptoTest.java @@ -1,158 +1,79 @@ package pl.mlodawski.security.pkcs11; -import org.junit.jupiter.api.BeforeEach; +import org.bouncycastle.jce.provider.BouncyCastleProvider; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import pl.mlodawski.security.pkcs11.exceptions.CryptoInitializationException; -import pl.mlodawski.security.pkcs11.exceptions.DecryptionException; import pl.mlodawski.security.pkcs11.exceptions.EncryptionException; import pl.mlodawski.security.pkcs11.exceptions.InvalidInputException; -import ru.rutoken.pkcs11jna.CK_MECHANISM; -import ru.rutoken.pkcs11jna.Pkcs11; -import ru.rutoken.pkcs11jna.Pkcs11Constants; -import javax.crypto.Cipher; +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; import java.security.KeyPair; import java.security.KeyPairGenerator; +import java.security.PublicKey; import java.security.cert.X509Certificate; -import com.sun.jna.NativeLong; -import com.sun.jna.ptr.NativeLongByReference; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; -import static org.junit.jupiter.api.Assertions.*; -import static org.mockito.ArgumentMatchers.*; -import static org.mockito.Mockito.*; +public class PKCS11CryptoTest { + private static final String ENCRYPTION_FAILED_MESSAGE = "Encryption failed"; -@ExtendWith(MockitoExtension.class) -class PKCS11CryptoTest { - @Mock - private Pkcs11 pkcs11Mock; - - @Mock - private X509Certificate certificateMock; - - private PKCS11Crypto pkcs11Crypto; - private KeyPair keyPair; - private NativeLong session; - private NativeLong privateKeyHandle; - - @BeforeEach - void setUp() throws Exception { - pkcs11Crypto = new PKCS11Crypto(); - session = new NativeLong(1L); - privateKeyHandle = new NativeLong(2L); - - // Generate a real RSA key pair for testing - KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); - keyPairGenerator.initialize(2048); - keyPair = keyPairGenerator.generateKeyPair(); - - // We'll set up the certificate mock in individual tests where it's needed - } @Test - void encryptData_validInput_shouldEncryptSuccessfully() throws Exception { - when(certificateMock.getPublicKey()).thenReturn(keyPair.getPublic()); - byte[] dataToEncrypt = "Hello, World!".getBytes(); + public void encryptData_InvalidInput_ThrowsException() { + PKCS11Crypto pkcs11Crypto = new PKCS11Crypto(); - byte[] encryptedData = pkcs11Crypto.encryptData(dataToEncrypt, certificateMock); + assertThrows( + InvalidInputException.class, + () -> pkcs11Crypto.encryptData(null, mock(X509Certificate.class)) + ); - assertNotNull(encryptedData); - assertNotEquals(dataToEncrypt, encryptedData); - } + X509Certificate certificate = mock(X509Certificate.class); + when(certificate.getPublicKey()).thenReturn(null); - @Test - void encryptData_nullData_shouldThrowIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> pkcs11Crypto.encryptData(null, certificateMock)); + assertThrows( + InvalidInputException.class, + () -> pkcs11Crypto.encryptData(new byte[1], certificate) + ); } @Test - void encryptData_emptyData_shouldThrowIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> pkcs11Crypto.encryptData(new byte[0], certificateMock)); - } + public void encryptData_EncryptionFails_ThrowsException() throws Exception { + PKCS11Crypto pkcs11Crypto = new PKCS11Crypto(); - @Test - void encryptData_nullCertificate_shouldThrowIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> pkcs11Crypto.encryptData("Test".getBytes(), null)); - } + X509Certificate certificate = mock(X509Certificate.class); + PublicKey publicKey = mock(PublicKey.class); - @Test - void encryptData_invalidCertificate_shouldThrowEncryptionException() { - when(certificateMock.getPublicKey()).thenThrow(new RuntimeException("Invalid certificate")); - assertThrows(EncryptionException.class, () -> pkcs11Crypto.encryptData("Test".getBytes(), certificateMock)); - } + when(certificate.getPublicKey()).thenReturn(publicKey); - @Test - void decryptData_validInput_shouldDecryptSuccessfully() throws Exception { - byte[] originalData = "Hello, World!".getBytes(); - byte[] encryptedData = encryptWithRealKey(originalData); - - // Mock PKCS11 behavior - when(pkcs11Mock.C_Decrypt(eq(session), any(), any(), isNull(), any())).thenAnswer(invocation -> { - NativeLongByReference lengthRef = invocation.getArgument(4); - lengthRef.setValue(new NativeLong(originalData.length)); - return new NativeLong(0); - }); - - when(pkcs11Mock.C_Decrypt(eq(session), any(), any(), any(byte[].class), any())).thenAnswer(invocation -> { - byte[] outputBuffer = invocation.getArgument(3); - System.arraycopy(originalData, 0, outputBuffer, 0, originalData.length); - return new NativeLong(0); - }); - - byte[] decryptedData = pkcs11Crypto.decryptData(pkcs11Mock, session, privateKeyHandle, encryptedData); - - assertArrayEquals(originalData, decryptedData); - } + KeyGenerator keyGen = mock(KeyGenerator.class); + SecretKey secretKey = mock(SecretKey.class); - @Test - void decryptData_nullPkcs11_shouldThrowIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> pkcs11Crypto.decryptData(null, session, privateKeyHandle, new byte[]{1})); - } - - @Test - void decryptData_nullSession_shouldThrowIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> pkcs11Crypto.decryptData(pkcs11Mock, null, privateKeyHandle, new byte[]{1})); - } + when(keyGen.generateKey()).thenReturn(secretKey); + when(keyGen.generateKey()).thenThrow(new RuntimeException(ENCRYPTION_FAILED_MESSAGE)); - @Test - void decryptData_nullPrivateKeyHandle_shouldThrowIllegalArgumentException() { - assertThrows(IllegalArgumentException.class, () -> pkcs11Crypto.decryptData(pkcs11Mock, session, null, new byte[]{1})); + assertThrows( + EncryptionException.class, + () -> pkcs11Crypto.encryptData(new byte[1], certificate), + ENCRYPTION_FAILED_MESSAGE + ); } @Test - void decryptData_nullEncryptedData_shouldThrowInvalidInputException() { - assertThrows(InvalidInputException.class, () -> pkcs11Crypto.decryptData(pkcs11Mock, session, privateKeyHandle, null)); - } + public void encryptData_ValidInput_Success() throws Exception { + PKCS11Crypto pkcs11Crypto = new PKCS11Crypto(); - @Test - void decryptData_emptyEncryptedData_shouldThrowInvalidInputException() { - assertThrows(InvalidInputException.class, () -> pkcs11Crypto.decryptData(pkcs11Mock, session, privateKeyHandle, new byte[0])); - } + KeyPairGenerator keyGen = KeyPairGenerator.getInstance("RSA"); + keyGen.initialize(2048); + KeyPair keyPair = keyGen.generateKeyPair(); + PublicKey publicKey = keyPair.getPublic(); - @Test - void decryptData_initializationFailure_shouldThrowCryptoInitializationException() { - doThrow(new RuntimeException("Initialization failed")).when(pkcs11Mock).C_DecryptInit(any(), any(), any()); - - assertThrows(CryptoInitializationException.class, () -> - pkcs11Crypto.decryptData(pkcs11Mock, session, privateKeyHandle, new byte[]{1})); - } - - @Test - void decryptData_decryptionFailure_shouldThrowDecryptionException() { - when(pkcs11Mock.C_Decrypt(any(), any(), any(), any(), any())).thenReturn(new NativeLong(1)); - - assertThrows(DecryptionException.class, () -> - pkcs11Crypto.decrypt(pkcs11Mock, session, new byte[]{1})); - } + X509Certificate certificate = mock(X509Certificate.class); + when(certificate.getPublicKey()).thenReturn(publicKey); - // Helper method to encrypt data with the real public key - private byte[] encryptWithRealKey(byte[] data) throws Exception { - Cipher cipher = Cipher.getInstance("RSA/ECB/PKCS1Padding"); - cipher.init(Cipher.ENCRYPT_MODE, keyPair.getPublic()); - return cipher.doFinal(data); + byte[] dataToEncrypt = new byte[] { 0x01 }; + pkcs11Crypto.encryptData(dataToEncrypt, certificate); } } \ No newline at end of file diff --git a/src/test/java/pl/mlodawski/security/pkcs11/PKCS11DeviceManagerTest.java b/src/test/java/pl/mlodawski/security/pkcs11/PKCS11DeviceManagerTest.java index 9cc3f84..9bc2456 100644 --- a/src/test/java/pl/mlodawski/security/pkcs11/PKCS11DeviceManagerTest.java +++ b/src/test/java/pl/mlodawski/security/pkcs11/PKCS11DeviceManagerTest.java @@ -35,7 +35,6 @@ class PKCS11DeviceManagerTest { @BeforeEach void setUp() { - // Podstawowe mockowanie dla inicjalizacji lenient().when(pkcs11Mock.C_GetSlotList(eq(TOKEN_PRESENT), any(), any(NativeLongByReference.class))) .thenAnswer(inv -> { NativeLongByReference count = inv.getArgument(2); diff --git a/src/test/java/pl/mlodawski/security/pkcs11/PKCS11ManagerTest.java b/src/test/java/pl/mlodawski/security/pkcs11/PKCS11ManagerTest.java index c43c1a2..b7c935c 100644 --- a/src/test/java/pl/mlodawski/security/pkcs11/PKCS11ManagerTest.java +++ b/src/test/java/pl/mlodawski/security/pkcs11/PKCS11ManagerTest.java @@ -36,11 +36,11 @@ class PKCS11ManagerTest { @BeforeEach void setUp() { libraryPathMock = mock(Path.class); - // Mockowanie podstawowych wywołań PKCS11 + when(pkcs11Mock.C_GetSlotList(anyByte(), any(), any(NativeLongByReference.class))) .thenAnswer(invocation -> { NativeLongByReference count = invocation.getArgument(2); - count.setValue(new NativeLong(1)); // Symulujemy jeden slot + count.setValue(new NativeLong(1)); return new NativeLong(Pkcs11Constants.CKR_OK); }); diff --git a/src/test/java/pl/mlodawski/security/pkcs11/PKCS11RSACryptoTest.java b/src/test/java/pl/mlodawski/security/pkcs11/PKCS11RSACryptoTest.java new file mode 100644 index 0000000..5a71d82 --- /dev/null +++ b/src/test/java/pl/mlodawski/security/pkcs11/PKCS11RSACryptoTest.java @@ -0,0 +1,153 @@ +package pl.mlodawski.security.pkcs11; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import pl.mlodawski.security.pkcs11.exceptions.CryptoInitializationException; +import pl.mlodawski.security.pkcs11.exceptions.DecryptionException; +import pl.mlodawski.security.pkcs11.exceptions.EncryptionException; +import pl.mlodawski.security.pkcs11.exceptions.InvalidInputException; +import ru.rutoken.pkcs11jna.Pkcs11; + +import javax.crypto.Cipher; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.cert.X509Certificate; + +import com.sun.jna.NativeLong; +import com.sun.jna.ptr.NativeLongByReference; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +class PKCS11RSACryptoTest { + + @Mock + private Pkcs11 pkcs11Mock; + + @Mock + private X509Certificate certificateMock; + + private PKCS11RSACrypto pkcs11Crypto; + private KeyPair keyPair; + private NativeLong session; + private NativeLong privateKeyHandle; + + @BeforeEach + void setUp() throws Exception { + pkcs11Crypto = new PKCS11RSACrypto(); + session = new NativeLong(1L); + privateKeyHandle = new NativeLong(2L); + + KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); + keyPairGenerator.initialize(2048); + keyPair = keyPairGenerator.generateKeyPair(); + + } + + @Test + void encryptData_validInput_shouldEncryptSuccessfully() throws Exception { + when(certificateMock.getPublicKey()).thenReturn(keyPair.getPublic()); + byte[] dataToEncrypt = "Hello, World!".getBytes(); + + byte[] encryptedData = pkcs11Crypto.encryptData(dataToEncrypt, certificateMock); + + assertNotNull(encryptedData); + assertNotEquals(dataToEncrypt, encryptedData); + } + + @Test + void encryptData_nullData_shouldThrowIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> pkcs11Crypto.encryptData(null, certificateMock)); + } + + @Test + void encryptData_emptyData_shouldThrowIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> pkcs11Crypto.encryptData(new byte[0], certificateMock)); + } + + @Test + void encryptData_nullCertificate_shouldThrowIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> pkcs11Crypto.encryptData("Test".getBytes(), null)); + } + + @Test + void encryptData_invalidCertificate_shouldThrowEncryptionException() { + when(certificateMock.getPublicKey()).thenThrow(new RuntimeException("Invalid certificate")); + assertThrows(EncryptionException.class, () -> pkcs11Crypto.encryptData("Test".getBytes(), certificateMock)); + } + + @Test + void decryptData_validInput_shouldDecryptSuccessfully() throws Exception { + byte[] originalData = "Hello, World!".getBytes(); + byte[] encryptedData = encryptWithRealKey(originalData); + + // Mock PKCS11 behavior + when(pkcs11Mock.C_Decrypt(eq(session), any(), any(), isNull(), any())).thenAnswer(invocation -> { + NativeLongByReference lengthRef = invocation.getArgument(4); + lengthRef.setValue(new NativeLong(originalData.length)); + return new NativeLong(0); + }); + + when(pkcs11Mock.C_Decrypt(eq(session), any(), any(), any(byte[].class), any())).thenAnswer(invocation -> { + byte[] outputBuffer = invocation.getArgument(3); + System.arraycopy(originalData, 0, outputBuffer, 0, originalData.length); + return new NativeLong(0); + }); + + byte[] decryptedData = pkcs11Crypto.decryptData(pkcs11Mock, session, privateKeyHandle, encryptedData); + + assertArrayEquals(originalData, decryptedData); + } + + @Test + void decryptData_nullPkcs11_shouldThrowIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> pkcs11Crypto.decryptData(null, session, privateKeyHandle, new byte[]{1})); + } + + @Test + void decryptData_nullSession_shouldThrowIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> pkcs11Crypto.decryptData(pkcs11Mock, null, privateKeyHandle, new byte[]{1})); + } + + @Test + void decryptData_nullPrivateKeyHandle_shouldThrowIllegalArgumentException() { + assertThrows(IllegalArgumentException.class, () -> pkcs11Crypto.decryptData(pkcs11Mock, session, null, new byte[]{1})); + } + + @Test + void decryptData_nullEncryptedData_shouldThrowInvalidInputException() { + assertThrows(InvalidInputException.class, () -> pkcs11Crypto.decryptData(pkcs11Mock, session, privateKeyHandle, null)); + } + + @Test + void decryptData_emptyEncryptedData_shouldThrowInvalidInputException() { + assertThrows(InvalidInputException.class, () -> pkcs11Crypto.decryptData(pkcs11Mock, session, privateKeyHandle, new byte[0])); + } + + @Test + void decryptData_initializationFailure_shouldThrowCryptoInitializationException() { + doThrow(new RuntimeException("Initialization failed")).when(pkcs11Mock).C_DecryptInit(any(), any(), any()); + + assertThrows(CryptoInitializationException.class, () -> + pkcs11Crypto.decryptData(pkcs11Mock, session, privateKeyHandle, new byte[]{1})); + } + + @Test + void decryptData_decryptionFailure_shouldThrowDecryptionException() { + when(pkcs11Mock.C_Decrypt(any(), any(), any(), any(), any())).thenReturn(new NativeLong(1)); + + assertThrows(DecryptionException.class, () -> + pkcs11Crypto.decrypt(pkcs11Mock, session, new byte[]{1})); + } + + private byte[] encryptWithRealKey(byte[] data) throws Exception { + Cipher cipher = Cipher.getInstance("RSA/ECB/PKCS1Padding"); + cipher.init(Cipher.ENCRYPT_MODE, keyPair.getPublic()); + return cipher.doFinal(data); + } +} \ No newline at end of file diff --git a/src/test/java/pl/mlodawski/security/pkcs11/PKCS11UtilsTest.java b/src/test/java/pl/mlodawski/security/pkcs11/PKCS11UtilsTest.java index ce32cc6..5af35af 100644 --- a/src/test/java/pl/mlodawski/security/pkcs11/PKCS11UtilsTest.java +++ b/src/test/java/pl/mlodawski/security/pkcs11/PKCS11UtilsTest.java @@ -86,7 +86,7 @@ void findAllPrivateKeys_nullSession_shouldThrowIllegalArgumentException() { @Test void getCKA_ID_validInput_shouldReturnCKA_ID() { - // Mock behavior for C_GetAttributeValue + when(pkcs11Mock.C_GetAttributeValue(any(NativeLong.class), any(NativeLong.class), any(CK_ATTRIBUTE[].class), any(NativeLong.class))) .thenAnswer(invocation -> { CK_ATTRIBUTE[] template = invocation.getArgument(2);