Skip to content

Commit

Permalink
#1654 Report to security hub with correct accId (#1655)
Browse files Browse the repository at this point in the history
* #1654 https://github.com/deepfence/enterprise-roadmap/issues/1929 report to security hub with correct accId

* Changes for report and scanID fix
  • Loading branch information
saurabh2253 authored Oct 13, 2023
1 parent 781097c commit cce58b0
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 46 deletions.
1 change: 1 addition & 0 deletions deepfence_server/model/scans.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ type ScanResultsCommon struct {
ScanID string `json:"scan_id" required:"true"`
UpdatedAt int64 `json:"updated_at" required:"true" format:"int64"`
CreatedAt int64 `json:"created_at" required:"true" format:"int64"`
CloudAccountID string `json:"cloud_account_id" required:"true"`
}

type FiltersReq struct {
Expand Down
86 changes: 53 additions & 33 deletions deepfence_server/pkg/integration/aws-security-hub/awssecurityhub.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/aws/aws-sdk-go/service/sts"
"strings"
"time"

Expand Down Expand Up @@ -50,29 +51,18 @@ func New(ctx context.Context, b []byte) (*AwsSecurityHub, error) {

func (a AwsSecurityHub) SendNotification(ctx context.Context, message string, extras map[string]interface{}) error {

nodeID, ok := extras["node_id"]
scanID, ok := extras["scan_id"]
if !ok {
log.Error().Msgf("AwsSecurityHub: SendNotification: node_id not found in extras")
return nil
}

nodeIDStr, ok := nodeID.(string)
scanIDStr, ok := scanID.(string)
if !ok {
log.Error().Msgf("AwsSecurityHub: SendNotification: node_id not string")
return nil
}

resource, err := getResource(ctx, a.Resource, nodeIDStr, a.Config.AWSRegion, a.Config.AWSAccountId)
if err != nil {
// if err.Err check here
if err.Error() == "not aws" {
log.Info().Msgf("skipping non aws resource")
return nil
}
log.Error().Msg(err.Error())
return nil
}

// Create an AWS session with your credentials and region
sess, err := session.NewSession(&aws.Config{
Region: aws.String(a.Config.AWSRegion),
Expand All @@ -82,6 +72,23 @@ func (a AwsSecurityHub) SendNotification(ctx context.Context, message string, ex
fmt.Println("Failed to create AWS session", err)
return nil
}
stsSvc := sts.New(sess)
id, err := stsSvc.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
fmt.Println("Failed to get caller identity", err)
return nil
}

resource, err := getResource(ctx, a.Resource, scanIDStr, a.Config.AWSRegion, *id.Account)
if err != nil {
// if err.Err check here
if err.Error() == "not aws" {
log.Info().Msgf("skipping non aws resource")
return nil
}
log.Error().Msg(err.Error())
return nil
}

svc := securityhub.New(sess)
var msg []map[string]interface{}
Expand All @@ -91,8 +98,7 @@ func (a AwsSecurityHub) SendNotification(ctx context.Context, message string, ex
fmt.Println("Failed to marshal JSON data", err)
return nil
}

fs := a.mapPayloadToFindings(msg, resource)
fs := a.mapPayloadToFindings(msg, resource, *id.Account)

// Split the JSON data into batches of 100
var batches []*securityhub.BatchImportFindingsInput
Expand Down Expand Up @@ -121,16 +127,16 @@ func (a AwsSecurityHub) SendNotification(ctx context.Context, message string, ex
return nil
}

func getResource(ctx context.Context, scanType, nodeID, region, accountID string) ([]*securityhub.Resource, error) {
func getResource(ctx context.Context, scanType, scanID, region, accountID string) ([]*securityhub.Resource, error) {
if scanType == utils.ScanTypeDetectedNode[utils.NEO4J_VULNERABILITY_SCAN] {
return getResourceForVulnerability(ctx, nodeID, region, accountID)
return getResourceForVulnerability(ctx, scanID, region, accountID)
} else if scanType == utils.ScanTypeDetectedNode[utils.NEO4J_COMPLIANCE_SCAN] {
return getResourceForCompliance(ctx, nodeID, region, accountID)
return getResourceForCompliance(ctx, scanID, region, accountID)
}
return nil, fmt.Errorf("not aws")
}

func getResourceForVulnerability(ctx context.Context, nodeID, region, accountID string) ([]*securityhub.Resource, error) {
func getResourceForVulnerability(ctx context.Context, scanID, region, accountID string) ([]*securityhub.Resource, error) {
driver, err := directory.Neo4jClient(ctx)
if err != nil {
log.Error().Msg(err.Error())
Expand All @@ -152,8 +158,8 @@ func getResourceForVulnerability(ctx context.Context, nodeID, region, accountID
defer tx.Close()

//query for Host/Node
query := `MATCH (m:VulnerabilityScan{node_id: $id})-[:SCHEDULED|SCANNED]->(o:Node) WHERE o.pseudo <> true RETURN o.cloud_provider as cp, o.instance_id as instanceID`
vars := map[string]interface{}{"id": nodeID}
query := `MATCH (m:VulnerabilityScan{node_id: $id})-[:SCHEDULED|SCANNED]->(o:Node) WHERE o.pseudo <> true RETURN o.cloud_provider as cp, o.instance_id as instanceID, o.cloud_account_id as cloudAccountID`
vars := map[string]interface{}{"id": scanID}
r, err := tx.Run(query, vars)

if err != nil {
Expand All @@ -175,7 +181,7 @@ func getResourceForVulnerability(ctx context.Context, nodeID, region, accountID
return []*securityhub.Resource{
{
Type: aws.String("AwsEc2Instance"),
Id: aws.String(fmt.Sprintf("arn:aws:ec2:%s:%s:instance/%s", region, accountID, rec.Values[1].(string))),
Id: aws.String(fmt.Sprintf("arn:aws:ec2:%s:%s:instance/%s", region, rec.Values[2].(string), rec.Values[1].(string))),
},
}, nil
}
Expand Down Expand Up @@ -209,11 +215,10 @@ func getResourceForVulnerability(ctx context.Context, nodeID, region, accountID
}, nil
}
}

return nil, fmt.Errorf("not aws")
}

func getResourceForCompliance(ctx context.Context, nodeID, region, accountID string) ([]*securityhub.Resource, error) {
func getResourceForCompliance(ctx context.Context, scanID, region, accountID string) ([]*securityhub.Resource, error) {
driver, err := directory.Neo4jClient(ctx)
if err != nil {
log.Error().Msg(err.Error())
Expand All @@ -235,8 +240,8 @@ func getResourceForCompliance(ctx context.Context, nodeID, region, accountID str
defer tx.Close()

//query for Host/Node
query := `MATCH (m:ComplianceScan{node_id: $id})-[:SCHEDULED|SCANNED]->(o:Node) WHERE o.pseudo <> true RETURN o.cloud_provider as cp, o.instance_id as instanceID`
vars := map[string]interface{}{"id": nodeID}
query := `MATCH (m:ComplianceScan{node_id: $id})-[:SCHEDULED|SCANNED]->(o:Node) WHERE o.pseudo <> true RETURN o.cloud_provider as cp, o.instance_id as instanceID, o.cloud_account_id as cloudAccountID`
vars := map[string]interface{}{"id": scanID}
r, err := tx.Run(query, vars)

if err != nil {
Expand All @@ -258,19 +263,26 @@ func getResourceForCompliance(ctx context.Context, nodeID, region, accountID str
return []*securityhub.Resource{
{
Type: aws.String("AwsEc2Instance"),
Id: aws.String(fmt.Sprintf("arn:aws:ec2:%s:%s:instance/%s", region, accountID, rec.Values[1].(string))),
Id: aws.String(fmt.Sprintf("arn:aws:ec2:%s:%s:instance/%s", region, rec.Values[2].(string), rec.Values[1].(string))),
},
}, nil
}
}

return nil, fmt.Errorf("not aws")
}

func (a AwsSecurityHub) mapPayloadToFindings(msg []map[string]interface{}, resource []*securityhub.Resource) *securityhub.BatchImportFindingsInput {
func (a AwsSecurityHub) mapPayloadToFindings(msg []map[string]interface{}, resource []*securityhub.Resource, accountID string) *securityhub.BatchImportFindingsInput {
findings := securityhub.BatchImportFindingsInput{}
if a.Resource == utils.ScanTypeDetectedNode[utils.NEO4J_VULNERABILITY_SCAN] {
for _, m := range msg {
accID, found := m["cloud_account_id"]
if !found {
accID = accountID
}
if !utils.InSlice(accID.(string), a.Config.AWSAccountId) {
fmt.Println("Skipping result as not in list of selected account IDs:", accID)
continue
}
finding := securityhub.AwsSecurityFinding{}

var pkgName, pkgVersion string
Expand Down Expand Up @@ -314,8 +326,8 @@ func (a AwsSecurityHub) mapPayloadToFindings(msg []map[string]interface{}, resou
cveDescription = cveDescription[:1024]
}

finding.SetProductArn(fmt.Sprintf("arn:aws:securityhub:%s:%s:product/%s/default", a.Config.AWSRegion, a.Config.AWSAccountId, a.Config.AWSAccountId))
finding.SetAwsAccountId(a.Config.AWSAccountId)
finding.SetProductArn(fmt.Sprintf("arn:aws:securityhub:%s:%s:product/%s/default", a.Config.AWSRegion, accID.(string), accID.(string)))
finding.SetAwsAccountId(accID.(string))
finding.SetCreatedAt(updatedAtStr)
finding.SetUpdatedAt(updatedAtStr)
finding.SetTitle(m["cve_id"].(string))
Expand Down Expand Up @@ -344,6 +356,14 @@ func (a AwsSecurityHub) mapPayloadToFindings(msg []map[string]interface{}, resou
}
} else if a.Resource == utils.ScanTypeDetectedNode[utils.NEO4J_COMPLIANCE_SCAN] {
for _, m := range msg {
accID, found := m["cloud_account_id"]
if !found {
accID = accountID
}
if !utils.InSlice(accID.(string), a.Config.AWSAccountId) {
fmt.Println("Skipping result as not in list of selected account IDs:", accID)
continue
}
finding := securityhub.AwsSecurityFinding{}

updatedAt, ok := m["updated_at"].(int64)
Expand Down Expand Up @@ -372,8 +392,8 @@ func (a AwsSecurityHub) mapPayloadToFindings(msg []map[string]interface{}, resou
compDescription = compDescription[:1024]
}

finding.SetProductArn(fmt.Sprintf("arn:aws:securityhub:%s:%s:product/%s/default", a.Config.AWSRegion, a.Config.AWSAccountId, a.Config.AWSAccountId))
finding.SetAwsAccountId(a.Config.AWSAccountId)
finding.SetProductArn(fmt.Sprintf("arn:aws:securityhub:%s:%s:product/%s/default", a.Config.AWSRegion, accID.(string), accID.(string)))
finding.SetAwsAccountId(accID.(string))
finding.SetCreatedAt(updatedAtStr)
finding.SetUpdatedAt(updatedAtStr)
finding.SetTitle(m["test_category"].(string))
Expand Down
8 changes: 4 additions & 4 deletions deepfence_server/pkg/integration/aws-security-hub/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ type AwsSecurityHub struct {
}

type Config struct {
AWSAccountId string `json:"aws_account_id" validate:"required,number,min=10,max=12" required:"true"`
AWSAccessKey string `json:"aws_access_key" validate:"required,min=16,max=128" required:"true"`
AWSSecretKey string `json:"aws_secret_key" validate:"required,min=16,max=128" required:"true"`
AWSRegion string `json:"aws_region" validate:"required,oneof=us-east-1 us-east-2 us-west-1 us-west-2 af-south-1 ap-east-1 ap-south-1 ap-northeast-1 ap-northeast-2 ap-northeast-3 ap-southeast-1 ap-southeast-2 ap-southeast-3 ca-central-1 eu-central-1 eu-west-1 eu-west-2 eu-west-3 eu-south-1 eu-north-1 me-south-1 me-central-1 sa-east-1 us-gov-east-1 us-gov-west-1" required:"true"`
AWSAccountId []string `json:"aws_account_id"`
AWSAccessKey string `json:"aws_access_key" validate:"required,min=16,max=128" required:"true"`
AWSSecretKey string `json:"aws_secret_key" validate:"required,min=16,max=128" required:"true"`
AWSRegion string `json:"aws_region" validate:"required,oneof=us-east-1 us-east-2 us-west-1 us-west-2 af-south-1 ap-east-1 ap-south-1 ap-northeast-1 ap-northeast-2 ap-northeast-3 ap-southeast-1 ap-southeast-2 ap-southeast-3 ca-central-1 eu-central-1 eu-west-1 eu-west-2 eu-west-3 eu-south-1 eu-north-1 me-south-1 me-central-1 sa-east-1 us-gov-east-1 us-gov-west-1" required:"true"`
}

func (a AwsSecurityHub) ValidateConfig(validate *validator.Validate) error {
Expand Down
20 changes: 18 additions & 2 deletions deepfence_worker/cronjobs/notification.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"encoding/json"
"errors"
reporters_search "github.com/deepfence/ThreatMapper/deepfence_server/reporters/search"
"strconv"
"sync"
"time"
Expand Down Expand Up @@ -160,7 +161,7 @@ func processIntegrationRow(integrationRow postgresql_db.Integration, ctx context
}

func injectNodeDatamap(results []map[string]interface{}, common model.ScanResultsCommon,
integrationType string) []map[string]interface{} {
integrationType string, ctx context.Context) []map[string]interface{} {

for _, r := range results {
//m := utils.ToMap[T](r)
Expand All @@ -176,6 +177,21 @@ func injectNodeDatamap(results []map[string]interface{}, common model.ScanResult
}
if common.HostName != "" {
r["host_name"] = common.HostName
filter := reporters_search.SearchFilter{
Filters: reporters.FieldsFilters{
ContainsFilter: reporters.ContainsFilter{
FieldsValues: map[string][]interface{}{
"host_name": {common.HostName},
},
},
},
}
eFilter := reporters_search.SearchFilter{}
hosts, err := reporters_search.SearchReport[model.Host](
ctx, filter, eFilter, nil, model.FetchWindow{})
if err == nil {
r["cloud_account_id"] = hosts[0].CloudAccountID
}
}
if common.KubernetesClusterName != "" {
r["kubernetes_cluster_name"] = common.KubernetesClusterName
Expand Down Expand Up @@ -285,7 +301,7 @@ func processIntegration[T any](ctx context.Context, task *asynq.Task, integratio
updatedResults = append(updatedResults, utils.ToMap[T](r))
}
}
updatedResults = injectNodeDatamap(updatedResults, common, integrationRow.IntegrationType)
updatedResults = injectNodeDatamap(updatedResults, common, integrationRow.IntegrationType, ctx)
messageByte, err := json.Marshal(updatedResults)
if err != nil {
return err
Expand Down
14 changes: 7 additions & 7 deletions deepfence_worker/tasks/reports/xlsx.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ var (
"L1": "cve_severity",
"M1": "cve_overall_score",
"N1": "cve_type",
"O1": "host",
"P1": "host_name",
"O1": "host_name",
"P1": "cloud_account_id",
"Q1": "masked",
}
secretHeader = map[string]string{
Expand Down Expand Up @@ -58,8 +58,8 @@ var (
"B1": "compliance_check_type",
"C1": "count",
"D1": "doc_id",
"E1": "host",
"F1": "host_name",
"E1": "host_name",
"F1": "cloud_account_id",
"G1": "masked",
"H1": "node_id",
"I1": "node_name",
Expand Down Expand Up @@ -162,7 +162,7 @@ func vulnerabilityXLSX(ctx context.Context, params utils.ReportParams) (string,
v.Cve_overall_score,
v.Cve_type,
nodeScanData.ScanInfo.HostName,
nodeScanData.ScanInfo.HostName,
nodeScanData.ScanInfo.CloudAccountID,
v.Masked,
}
xlsx.SetSheetRow("Sheet1", cellName, &value)
Expand Down Expand Up @@ -288,7 +288,7 @@ func complianceXLSX(ctx context.Context, params utils.ReportParams) (string, err
"",
"",
nodeScanData.ScanInfo.HostName,
nodeScanData.ScanInfo.HostName,
nodeScanData.ScanInfo.CloudAccountID,
c.Masked,
c.ComplianceNodeId,
nodeScanData.ScanInfo.NodeName,
Expand Down Expand Up @@ -335,7 +335,7 @@ func cloudComplianceXLSX(ctx context.Context, params utils.ReportParams) (string
"",
"",
data.ScanInfo.HostName,
data.ScanInfo.HostName,
data.ScanInfo.CloudAccountID,
c.Masked,
c.NodeID,
data.ScanInfo.NodeName,
Expand Down

0 comments on commit cce58b0

Please sign in to comment.