Skip to content

Commit

Permalink
Polished new rabbit implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
dylan-mulligan committed Nov 22, 2023
1 parent debcde6 commit dc69c4a
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 92 deletions.
2 changes: 1 addition & 1 deletion productnameextractor/env.list
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ PNE_OUTPUT_QUEUE_PATCH=PNE_OUT_PATCH
PNE_OUTPUT_QUEUE_FIX=PNE_OUT_FIX

# --- PRODUCT NAME EXTRACTOR VARS ---
INPUT_TYPE=db
INPUT_MODE=rabbit
CVE_LIMIT=6000
CHAR_2_VEC_CONFIG=c2v_model_config_50.json
CHAR_2_VEC_WEIGHTS=c2v_model_weights_50.h5
Expand Down
27 changes: 14 additions & 13 deletions productnameextractor/src/main/java/ProductNameExtractorMain.java
Original file line number Diff line number Diff line change
Expand Up @@ -210,19 +210,20 @@ private static void dbMain(DatabaseHelper databaseHelper) {

// Process vulnerabilities
final long getProdStart = System.currentTimeMillis();
final List<AffectedProduct> affectedProducts = new ArrayList<>();
int numAffectedProducts = 0;

for(CompositeVulnerability vuln : vulnList) {
affectedProducts.addAll(affectedProductIdentifier.identifyAffectedProducts(vuln));
final List<AffectedProduct> products = affectedProductIdentifier.identifyAffectedProducts(vuln);
databaseHelper.insertAffectedProductsToDB(products);
numAffectedProducts += products.size();
}

int numAffectedProducts = affectedProducts.size();

logger.info("Product Name Extractor found {} affected products in {} seconds", numAffectedProducts, Math.floor(((double) (System.currentTimeMillis() - getProdStart) / 1000) * 100) / 100);

// Insert the affected products found into the database
databaseHelper.insertAffectedProductsToDB(affectedProducts);
logger.info("Product Name Extractor found and inserted {} affected products to the database in {} seconds", affectedProducts.size(), Math.floor(((double) (System.currentTimeMillis() - getProdStart) / 1000) * 100) / 100);
// // Insert the affected products found into the database
// databaseHelper.insertAffectedProductsToDB(affectedProducts);
// logger.info("Product Name Extractor found and inserted {} affected products to the database in {} seconds", affectedProducts.size(), Math.floor(((double) (System.currentTimeMillis() - getProdStart) / 1000) * 100) / 100);
}

// TODO: Implement job streaming (queue received jobs to be consumed, support end messages)
Expand All @@ -239,13 +240,13 @@ private static void rabbitMain(DatabaseHelper databaseHelper) {
factory.setUsername(ProductNameExtractorEnvVars.getRabbitUsername());
factory.setPassword(ProductNameExtractorEnvVars.getRabbitPassword());

try {
factory.useSslProtocol();
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
} catch (KeyManagementException e) {
throw new RuntimeException(e);
}
// try {
// factory.useSslProtocol();
// } catch (NoSuchAlgorithmException e) {
// throw new RuntimeException(e);
// } catch (KeyManagementException e) {
// throw new RuntimeException(e);
// }

final Messenger rabbitMQ = new Messenger(
factory,
Expand Down
133 changes: 56 additions & 77 deletions productnameextractor/src/main/java/messenger/Messenger.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,80 +78,53 @@ public Messenger(ConnectionFactory connectionFactory, String inputQueue, String
}

public void run() {
try (Connection connection = factory.newConnection();
Channel channel = connection.createChannel()) {
try {
Connection connection = factory.newConnection();
Channel channel = connection.createChannel();

// TODO: Needed?
// channel.queueDeclare(inputQueue, true, false, false, null);
// channel.queueDeclare(outputQueue, true, false, false, null);
channel.queueDeclare(inputQueue, true, false, false, null);
channel.queueDeclare(patchFinderOutputQueue, true, false, false, null);
channel.queueDeclare(fixFinderOutputQueue, true, false, false, null);

channel.basicConsume(inputQueue, false, new DefaultConsumer(channel) {
@Override
public void handleDelivery(String consumerTag, Envelope envelope, AMQP.BasicProperties properties, byte[] body) throws IOException {
// Get cveId and ensure it is not null
String cveId = parseMessage(new String(body, StandardCharsets.UTF_8));

if(cveId != null){
// Pull specific cve information from database for each CVE ID passed from reconciler
// Pull specific cve information from database for each CVE ID passed from reconciler (ensure not null)
CompositeVulnerability vuln = databaseHelper.getSpecificCompositeVulnerability(cveId);

// Identify affected products from the CVEs
final long getProdStart = System.currentTimeMillis();
List<AffectedProduct> affectedProducts = affectedProductIdentifier.identifyAffectedProducts(vuln);

// Insert the affected products found into the database
databaseHelper.insertAffectedProductsToDB(affectedProducts);
logger.info("Product Name Extractor found and inserted {} affected products to the database in {} seconds", affectedProducts.size(), Math.floor(((double) (System.currentTimeMillis() - getProdStart) / 1000) * 100) / 100);

// // Clear cveIds, extract only the cveIds for which affected products were found to be sent to the Patchfinder
// cveIds.clear();
// for (AffectedProduct affectedProduct : affectedProducts) {
// if (!cveIds.contains(affectedProduct.getCveId())) cveIds.add(affectedProduct.getCveId());
// }

logger.info("Sending jobs to patchfinder and fixfinder...");
String response = genJson(cveId);
channel.basicPublish("", patchFinderOutputQueue, null, response.getBytes(StandardCharsets.UTF_8));
channel.basicPublish("", fixFinderOutputQueue, null, response.getBytes(StandardCharsets.UTF_8));
logger.info("Jobs have been sent!\n\n");
if(vuln == null) {
logger.warn("Could not find CVE '{}' in database", cveId);
} else {
// Identify affected products from the CVEs
final long getProdStart = System.currentTimeMillis();
List<AffectedProduct> affectedProducts = affectedProductIdentifier.identifyAffectedProducts(vuln);

// Insert the affected products found into the database
databaseHelper.insertAffectedProductsToDB(affectedProducts);
logger.info("Product Name Extractor found and inserted {} affected products to the database in {} seconds", affectedProducts.size(), Math.floor(((double) (System.currentTimeMillis() - getProdStart) / 1000) * 100) / 100);

// // Clear cveIds, extract only the cveIds for which affected products were found to be sent to the Patchfinder
// cveIds.clear();
// for (AffectedProduct affectedProduct : affectedProducts) {
// if (!cveIds.contains(affectedProduct.getCveId())) cveIds.add(affectedProduct.getCveId());
// }

logger.info("Sending jobs to patchfinder and fixfinder...");
String response = genJson(cveId);
channel.basicPublish("", patchFinderOutputQueue, null, response.getBytes(StandardCharsets.UTF_8));
channel.basicPublish("", fixFinderOutputQueue, null, response.getBytes(StandardCharsets.UTF_8));
logger.info("Jobs have been sent!\n\n");
}

// Acknowledge job after completion
channel.basicAck(envelope.getDeliveryTag(), false);
}
}
});

// DeliverCallback deliverCallback = (consumerTag, delivery) -> {
// String message = new String(delivery.getBody(), StandardCharsets.UTF_8);
// List<String> cveIds = parseIds(message);
//
// if(!cveIds.isEmpty()){
// logger.info("Received job with CVE(s) {}", cveIds);
//
// // Pull specific cve information from database for each CVE ID passed from reconciler
// List<CompositeVulnerability> vulnList = databaseHelper.getSpecificCompositeVulnerabilities(cveIds);
//
// // Identify affected products from the CVEs
// final long getProdStart = System.currentTimeMillis();
// List<AffectedProduct> affectedProducts = affectedProductIdentifier.identifyAffectedProducts(vulnList);
//
// // Insert the affected products found into the database
// databaseHelper.insertAffectedProductsToDB(affectedProducts);
// logger.info("Product Name Extractor found and inserted {} affected products to the database in {} seconds", affectedProducts.size(), Math.floor(((double) (System.currentTimeMillis() - getProdStart) / 1000) * 100) / 100);
//
// // Clear cveIds, extract only the cveIds for which affected products were found to be sent to the Patchfinder
// cveIds.clear();
// for (AffectedProduct affectedProduct : affectedProducts) {
// if (!cveIds.contains(affectedProduct.getCveId())) cveIds.add(affectedProduct.getCveId());
// }
//
// logger.info("Sending jobs to patchfinder...");
// String response = genJson(cveIds);
// channel.basicPublish("", outputQueue, null, response.getBytes(StandardCharsets.UTF_8));
// logger.info("Jobs have been sent!\n\n");
// }
// };

// channel.basicConsume(inputQueue, true, deliverCallback, consumerTag -> {});

} catch (IOException | TimeoutException e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -191,7 +164,7 @@ private String genJson(String cveId) {
private void sendDummyMessage(String queue, String cveId) {
try (Connection connection = factory.newConnection();
Channel channel = connection.createChannel()) {
channel.queueDeclare(queue, false, false, false, null);
channel.queueDeclare(queue, true, false, false, null);
String message = genJson(cveId);
channel.basicPublish("", queue, null, message.getBytes(StandardCharsets.UTF_8));
logger.info("Successfully sent message:\n\"{}\"", message);
Expand Down Expand Up @@ -266,24 +239,30 @@ public static void main(String[] args) {
factory.setUsername(ProductNameExtractorEnvVars.getRabbitUsername());
factory.setPassword(ProductNameExtractorEnvVars.getRabbitPassword());

try {
factory.useSslProtocol();
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
} catch (KeyManagementException e) {
throw new RuntimeException(e);
}
// try {
// factory.useSslProtocol();
// } catch (NoSuchAlgorithmException e) {
// throw new RuntimeException(e);
// } catch (KeyManagementException e) {
// throw new RuntimeException(e);
// }

// Messenger messenger = new Messenger(
// factory,
// ProductNameExtractorEnvVars.getRabbitInputQueue(),
// ProductNameExtractorEnvVars.getRabbitOutputQueue(),
// affectedProductIdentifier,
// databaseHelper);
List<String> cveIds = new ArrayList<>();
cveIds.addAll(getIdsFromJson("test_output.json"));
writeIdsToFile(cveIds, "test_ids.txt");
// messenger.sendDummyMessage("CRAWLER_OUT", cveIds);
Messenger messenger = new Messenger(
factory,
ProductNameExtractorEnvVars.getRabbitInputQueue(),
ProductNameExtractorEnvVars.getRabbitPatchfinderOutputQueue(),
ProductNameExtractorEnvVars.getRabbitFixfinderOutputQueue(),
null,
new DatabaseHelper(
ProductNameExtractorEnvVars.getDatabaseType(),
ProductNameExtractorEnvVars.getHikariUrl(),
ProductNameExtractorEnvVars.getHikariUser(),
ProductNameExtractorEnvVars.getHikariPassword()
));
// List<String> cveIds = new ArrayList<>();
// cveIds.addAll(getIdsFromJson("test_output.json"));
// writeIdsToFile(cveIds, "test_ids.txt");
messenger.sendDummyMessage("RECONCILER_OUT", "CVE-2013-4190");
// cveIds.add("CVE-2008-2951");
// cveIds.add("CVE-2014-0472");
// cveIds.add("TERMINATE");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ public List<AffectedProduct> identifyAffectedProducts(CompositeVulnerability vul
final int result = processVulnerability(productDetector, cpeLookUp, vuln);

List<AffectedProduct> affectedProducts = new ArrayList<>();
if (vuln.getCveReconcileStatus() == CompositeVulnerability.CveReconcileStatus.DO_NOT_CHANGE)
if (vuln.getCveReconcileStatus() != CompositeVulnerability.CveReconcileStatus.DO_NOT_CHANGE)
affectedProducts.addAll(vuln.getAffectedProducts());

return affectedProducts;
Expand Down

0 comments on commit dc69c4a

Please sign in to comment.