Skip to content

Commit

Permalink
refactor: add generic Set implementation (#8149)
Browse files Browse the repository at this point in the history
Signed-off-by: knqyf263 <knqyf263@gmail.com>
  • Loading branch information
knqyf263 authored Dec 24, 2024
1 parent e6d0ba5 commit b5859d3
Show file tree
Hide file tree
Showing 34 changed files with 968 additions and 270 deletions.
5 changes: 5 additions & 0 deletions misc/lint/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ func errorsJoin(m dsl.Matcher) {
m.Match(`errors.Join($*args)`).
Report("use github.com/hashicorp/go-multierror.Append instead of errors.Join.")
}

func mapSet(m dsl.Matcher) {
m.Match(`map[$x]struct{}`).
Report("use github.com/aquasecurity/trivy/pkg/set.Set instead of map.")
}
8 changes: 4 additions & 4 deletions pkg/compliance/spec/compliance.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import (
"path/filepath"
"strings"

"github.com/samber/lo"
"golang.org/x/xerrors"
"gopkg.in/yaml.v3"

sp "github.com/aquasecurity/trivy-checks/pkg/spec"
iacTypes "github.com/aquasecurity/trivy/pkg/iac/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/types"
)

Expand All @@ -31,17 +31,17 @@ const (

// Scanners reads spec control and determines the scanners by check ID prefix
func (cs *ComplianceSpec) Scanners() (types.Scanners, error) {
scannerTypes := make(map[types.Scanner]struct{})
scannerTypes := set.New[types.Scanner]()
for _, control := range cs.Spec.Controls {
for _, check := range control.Checks {
scannerType := scannerByCheckID(check.ID)
if scannerType == types.UnknownScanner {
return nil, xerrors.Errorf("unsupported check ID: %s", check.ID)
}
scannerTypes[scannerType] = struct{}{}
scannerTypes.Append(scannerType)
}
}
return lo.Keys(scannerTypes), nil
return scannerTypes.Items(), nil
}

// CheckIDs return list of compliance check IDs
Expand Down
3 changes: 2 additions & 1 deletion pkg/dependency/parser/java/pom/artifact.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/version/doc"
)

Expand All @@ -30,7 +31,7 @@ type artifact struct {
Version version
Licenses []string

Exclusions map[string]struct{}
Exclusions set.Set[string]

Module bool
Relationship ftypes.Relationship
Expand Down
30 changes: 17 additions & 13 deletions pkg/dependency/parser/java/pom/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency/parser/utils"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
xio "github.com/aquasecurity/trivy/pkg/x/io"
)

Expand Down Expand Up @@ -118,11 +119,11 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
rootArt := root.artifact()
rootArt.Relationship = ftypes.RelationshipRoot

return p.parseRoot(rootArt, make(map[string]struct{}))
return p.parseRoot(rootArt, set.New[string]())
}

// nolint: gocyclo
func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ftypes.Package, []ftypes.Dependency, error) {
func (p *Parser) parseRoot(root artifact, uniqModules set.Set[string]) ([]ftypes.Package, []ftypes.Dependency, error) {
// Prepare a queue for dependencies
queue := newArtifactQueue()

Expand All @@ -145,10 +146,10 @@ func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ft
// Modules should be handled separately so that they can have independent dependencies.
// It means multi-module allows for duplicate dependencies.
if art.Module {
if _, ok := uniqModules[art.String()]; ok {
if uniqModules.Contains(art.String()) {
continue
}
uniqModules[art.String()] = struct{}{}
uniqModules.Append(art.String())

modulePkgs, moduleDeps, err := p.parseRoot(art, uniqModules)
if err != nil {
Expand Down Expand Up @@ -251,7 +252,7 @@ func (p *Parser) parseRoot(root artifact, uniqModules map[string]struct{}) ([]ft
// `mvn` shows modules separately from the root package and does not show module nesting.
// So we can add all modules as dependencies of root package.
if art.Relationship == ftypes.RelationshipRoot {
dependsOn = append(dependsOn, lo.Keys(uniqModules)...)
dependsOn = append(dependsOn, uniqModules.Items()...)
}

sort.Strings(dependsOn)
Expand Down Expand Up @@ -340,14 +341,17 @@ type analysisResult struct {
}

type analysisOptions struct {
exclusions map[string]struct{}
exclusions set.Set[string]
depManagement []pomDependency // from the root POM
}

func (p *Parser) analyze(pom *pom, opts analysisOptions) (analysisResult, error) {
if pom.nil() {
return analysisResult{}, nil
}
if opts.exclusions == nil {
opts.exclusions = set.New[string]()
}
// Update remoteRepositories
pomReleaseRemoteRepos, pomSnapshotRemoteRepos := pom.repositories(p.servers)
p.releaseRemoteRepos = lo.Uniq(append(pomReleaseRemoteRepos, p.releaseRemoteRepos...))
Expand Down Expand Up @@ -408,16 +412,16 @@ func (p *Parser) resolveParent(pom *pom) error {
}

func (p *Parser) mergeDependencyManagements(depManagements ...[]pomDependency) []pomDependency {
uniq := make(map[string]struct{})
uniq := set.New[string]()
var depManagement []pomDependency
// The preceding argument takes precedence.
for _, dm := range depManagements {
for _, dep := range dm {
if _, ok := uniq[dep.Name()]; ok {
if uniq.Contains(dep.Name()) {
continue
}
depManagement = append(depManagement, dep)
uniq[dep.Name()] = struct{}{}
uniq.Append(dep.Name())
}
}
return depManagement
Expand Down Expand Up @@ -492,19 +496,19 @@ func (p *Parser) mergeDependencies(child, parent []pomDependency) []pomDependenc
})
}

func (p *Parser) filterDependencies(artifacts []artifact, exclusions map[string]struct{}) []artifact {
func (p *Parser) filterDependencies(artifacts []artifact, exclusions set.Set[string]) []artifact {
return lo.Filter(artifacts, func(art artifact, _ int) bool {
return !excludeDep(exclusions, art)
})
}

func excludeDep(exclusions map[string]struct{}, art artifact) bool {
if _, ok := exclusions[art.Name()]; ok {
func excludeDep(exclusions set.Set[string], art artifact) bool {
if exclusions.Contains(art.Name()) {
return true
}
// Maven can use "*" in GroupID and ArtifactID fields to exclude dependencies
// https://maven.apache.org/pom.html#exclusions
for exlusion := range exclusions {
for exlusion := range exclusions.Iter() {
// exclusion format - "<groupID>:<artifactID>"
e := strings.Split(exlusion, ":")
if (e[0] == art.GroupID || e[0] == "*") && (e[1] == art.ArtifactID || e[1] == "*") {
Expand Down
8 changes: 4 additions & 4 deletions pkg/dependency/parser/java/pom/pom.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/xml"
"fmt"
"io"
"maps"
"net/url"
"reflect"
"strings"
Expand All @@ -15,6 +14,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency/parser/utils"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
"github.com/aquasecurity/trivy/pkg/x/slices"
)

Expand Down Expand Up @@ -287,12 +287,12 @@ func (d pomDependency) ToArtifact(opts analysisOptions) artifact {
// To avoid shadow adding exclusions to top pom's,
// we need to initialize a new map for each new artifact
// See `exclusions in child` test for more information
exclusions := make(map[string]struct{})
exclusions := set.New[string]()
if opts.exclusions != nil {
exclusions = maps.Clone(opts.exclusions)
exclusions = opts.exclusions.Clone()
}
for _, e := range d.Exclusions.Exclusion {
exclusions[fmt.Sprintf("%s:%s", e.GroupID, e.ArtifactID)] = struct{}{}
exclusions.Append(fmt.Sprintf("%s:%s", e.GroupID, e.ArtifactID))
}

var locations ftypes.Locations
Expand Down
15 changes: 8 additions & 7 deletions pkg/dependency/parser/nodejs/npm/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency/parser/utils"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
xio "github.com/aquasecurity/trivy/pkg/x/io"
)

Expand Down Expand Up @@ -91,7 +92,7 @@ func (p *Parser) parseV2(packages map[string]Package) ([]ftypes.Package, []ftype
// https://docs.npmjs.com/cli/v9/configuring-npm/package-lock-json#packages
p.resolveLinks(packages)

directDeps := make(map[string]struct{})
directDeps := set.New[string]()
for name, version := range lo.Assign(packages[""].Dependencies, packages[""].OptionalDependencies, packages[""].DevDependencies, packages[""].PeerDependencies) {
pkgPath := joinPaths(nodeModulesDir, name)
if _, ok := packages[pkgPath]; !ok {
Expand All @@ -101,7 +102,7 @@ func (p *Parser) parseV2(packages map[string]Package) ([]ftypes.Package, []ftype
}
// Store the package paths of direct dependencies
// e.g. node_modules/body-parser
directDeps[pkgPath] = struct{}{}
directDeps.Append(pkgPath)
}

for pkgPath, pkg := range packages {
Expand Down Expand Up @@ -366,13 +367,13 @@ func (p *Parser) pkgNameFromPath(pkgPath string) string {

func uniqueDeps(deps []ftypes.Dependency) []ftypes.Dependency {
var uniqDeps ftypes.Dependencies
unique := make(map[string]struct{})
unique := set.New[string]()

for _, dep := range deps {
sort.Strings(dep.DependsOn)
depKey := fmt.Sprintf("%s:%s", dep.ID, strings.Join(dep.DependsOn, ","))
if _, ok := unique[depKey]; !ok {
unique[depKey] = struct{}{}
if !unique.Contains(depKey) {
unique.Append(depKey)
uniqDeps = append(uniqDeps, dep)
}
}
Expand All @@ -381,11 +382,11 @@ func uniqueDeps(deps []ftypes.Dependency) []ftypes.Dependency {
return uniqDeps
}

func isIndirectPkg(pkgPath string, directDeps map[string]struct{}) bool {
func isIndirectPkg(pkgPath string, directDeps set.Set[string]) bool {
// A project can contain 2 different versions of the same dependency.
// e.g. `node_modules/string-width/node_modules/strip-ansi` and `node_modules/string-ansi`
// direct dependencies always have root path (`node_modules/<pkg_name>`)
if _, ok := directDeps[pkgPath]; ok {
if directDeps.Contains(pkgPath) {
return false
}
return true
Expand Down
9 changes: 5 additions & 4 deletions pkg/dependency/parser/nodejs/pnpm/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/aquasecurity/trivy/pkg/dependency"
ftypes "github.com/aquasecurity/trivy/pkg/fanal/types"
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/set"
xio "github.com/aquasecurity/trivy/pkg/x/io"
)

Expand Down Expand Up @@ -215,7 +216,7 @@ func (p *Parser) parseV9(lockFile LockFile) ([]ftypes.Package, []ftypes.Dependen
}
}

visited := make(map[string]struct{})
visited := set.New[string]()
// Overwrite the `Dev` field for dev deps and their child dependencies.
for _, pkg := range resolvedPkgs {
if !pkg.Dev {
Expand All @@ -227,8 +228,8 @@ func (p *Parser) parseV9(lockFile LockFile) ([]ftypes.Package, []ftypes.Dependen
}

// markRootPkgs sets `Dev` to false for non dev dependency.
func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps map[string]ftypes.Dependency, visited map[string]struct{}) {
if _, ok := visited[id]; ok {
func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps map[string]ftypes.Dependency, visited set.Set[string]) {
if visited.Contains(id) {
return
}
pkg, ok := pkgs[id]
Expand All @@ -238,7 +239,7 @@ func (p *Parser) markRootPkgs(id string, pkgs map[string]ftypes.Package, deps ma

pkg.Dev = false
pkgs[id] = pkg
visited[id] = struct{}{}
visited.Append(id)

// Update child deps
for _, depID := range deps[id].DependsOn {
Expand Down
2 changes: 1 addition & 1 deletion pkg/dependency/parser/nuget/lock/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (p *Parser) Parse(r xio.ReadSeekerAt) ([]ftypes.Package, []ftypes.Dependenc
}

if savedDependsOn, ok := depsMap[depId]; ok {
dependsOn = utils.UniqueStrings(append(dependsOn, savedDependsOn...))
dependsOn = lo.Uniq(append(dependsOn, savedDependsOn...))
}

if len(dependsOn) > 0 {
Expand Down
17 changes: 10 additions & 7 deletions pkg/dependency/parser/python/pyproject/pyproject.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"golang.org/x/xerrors"

"github.com/aquasecurity/trivy/pkg/dependency/parser/python"
"github.com/aquasecurity/trivy/pkg/set"
)

type PyProject struct {
Expand All @@ -19,25 +20,27 @@ type Tool struct {
}

type Poetry struct {
Dependencies dependencies `toml:"dependencies"`
Dependencies Dependencies `toml:"dependencies"`
Groups map[string]Group `toml:"group"`
}

type Group struct {
Dependencies dependencies `toml:"dependencies"`
Dependencies Dependencies `toml:"dependencies"`
}

type dependencies map[string]struct{}
type Dependencies struct {
set.Set[string]
}

func (d *dependencies) UnmarshalTOML(data any) error {
func (d *Dependencies) UnmarshalTOML(data any) error {
m, ok := data.(map[string]any)
if !ok {
return xerrors.Errorf("dependencies must be map, but got: %T", data)
}

*d = lo.MapEntries(m, func(pkgName string, _ any) (string, struct{}) {
return python.NormalizePkgName(pkgName), struct{}{}
})
d.Set = set.New[string](lo.MapToSlice(m, func(pkgName string, _ any) string {
return python.NormalizePkgName(pkgName)
})...)
return nil
}

Expand Down
16 changes: 7 additions & 9 deletions pkg/dependency/parser/python/pyproject/pyproject_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/aquasecurity/trivy/pkg/dependency/parser/python/pyproject"
"github.com/aquasecurity/trivy/pkg/set"
)

func TestParser_Parse(t *testing.T) {
Expand All @@ -24,21 +25,18 @@ func TestParser_Parse(t *testing.T) {
want: pyproject.PyProject{
Tool: pyproject.Tool{
Poetry: pyproject.Poetry{
Dependencies: map[string]struct{}{
"flask": {},
"python": {},
"requests": {},
"virtualenv": {},
Dependencies: pyproject.Dependencies{
Set: set.New[string]("flask", "python", "requests", "virtualenv"),
},
Groups: map[string]pyproject.Group{
"dev": {
Dependencies: map[string]struct{}{
"pytest": {},
Dependencies: pyproject.Dependencies{
Set: set.New[string]("pytest"),
},
},
"lint": {
Dependencies: map[string]struct{}{
"ruff": {},
Dependencies: pyproject.Dependencies{
Set: set.New[string]("ruff"),
},
},
},
Expand Down
Loading

0 comments on commit b5859d3

Please sign in to comment.