diff --git a/CHANGELOG.md b/CHANGELOG.md index 05ed3faa7e3b..e4b1fb43a194 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,7 @@ All notable changes to this project will be documented in this file. When sendin ### Language -- +- [Add getAllDependents and getAllDependencies method to the DependencyGraph class](https://github.com/ballerina-platform/ballerina-lang/pull/41561) - - - diff --git a/compiler/ballerina-lang/src/main/java/io/ballerina/projects/DependencyGraph.java b/compiler/ballerina-lang/src/main/java/io/ballerina/projects/DependencyGraph.java index 784ec17c2697..9bfe85bf9466 100644 --- a/compiler/ballerina-lang/src/main/java/io/ballerina/projects/DependencyGraph.java +++ b/compiler/ballerina-lang/src/main/java/io/ballerina/projects/DependencyGraph.java @@ -129,6 +129,22 @@ public T getRoot() { return rootNode; } + // Returns all direct and indirect dependents of the node T + public Collection getAllDependents(T node) { + Set allDependents = new HashSet<>(); + Set visited = new HashSet<>(); + getAllDependentsRecursive(node, allDependents, visited); + return allDependents; + } + + // Returns all direct and indirect dependencies of node T + public Collection getAllDependencies(T node) { + Set allDependencies = new HashSet<>(); + Set visited = new HashSet<>(); + getAllDependenciesRecursive(node, allDependencies, visited); + return allDependencies; + } + public boolean contains(T node) { return dependencies.containsKey(node); } @@ -192,6 +208,30 @@ private void sortTopologically(T vertex, List visited, List ancestors, Lis ancestors.remove(vertex); } + private void getAllDependentsRecursive(T node, Set allDependents, Set visited) { + visited.add(node); + Collection directDependents = getDirectDependents(node); + allDependents.addAll(directDependents); + + for (T dependent : directDependents) { + if (!visited.contains(dependent)) { + getAllDependentsRecursive(dependent, allDependents, visited); + } + } + } + + private void getAllDependenciesRecursive(T node, Set allDependencies, Set visited) { + visited.add(node); + Collection directDependencies = getDirectDependencies(node); + allDependencies.addAll(directDependencies); + + for (T dependency : directDependencies) { + if (!visited.contains(dependency)) { + getAllDependenciesRecursive(dependency, allDependencies, visited); + } + } + } + /** * Builds a {@code DependencyGraph}. * diff --git a/project-api/project-api-test/src/test/java/io/ballerina/projects/test/DependencyGraphTests.java b/project-api/project-api-test/src/test/java/io/ballerina/projects/test/DependencyGraphTests.java index f6ccbb0f2795..3d1696388f01 100644 --- a/project-api/project-api-test/src/test/java/io/ballerina/projects/test/DependencyGraphTests.java +++ b/project-api/project-api-test/src/test/java/io/ballerina/projects/test/DependencyGraphTests.java @@ -58,8 +58,10 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.LinkedList; @@ -628,6 +630,56 @@ public void testVersionResolutionHARD() { ResolutionResponse.ResolutionStatus.RESOLVED); } + @Test + public void testGetAllDependents() { + // Create a sample DependencyGraph with String nodes + DependencyGraph dependencyGraph = DependencyGraph.from(new LinkedHashMap<>() {{ + put("A", new LinkedHashSet<>() {{ + add("B"); + add("C"); + }}); + put("B", new LinkedHashSet<>() {{ + add("D"); + }}); + put("C", new LinkedHashSet<>()); + put("D", new LinkedHashSet<>()); + put("E", new LinkedHashSet<>() {{ + add("F"); + }}); + put("F", new LinkedHashSet<>()); + }}); + + Collection allDependents = dependencyGraph.getAllDependents("D"); + Set expectedDependents = new HashSet<>(Arrays.asList("A", "B")); + + Assert.assertEquals(expectedDependents, allDependents); + } + + @Test + public void testGetAllDependencies() { + // Create a sample DependencyGraph with String nodes + DependencyGraph dependencyGraph = DependencyGraph.from(new LinkedHashMap<>() {{ + put("A", new LinkedHashSet<>() {{ + add("B"); + add("C"); + }}); + put("B", new LinkedHashSet<>() {{ + add("D"); + }}); + put("C", new LinkedHashSet<>()); + put("D", new LinkedHashSet<>()); + put("E", new LinkedHashSet<>() {{ + add("F"); + }}); + put("F", new LinkedHashSet<>()); + }}); + + Collection allDependencies = dependencyGraph.getAllDependencies("A"); + Set expectedDependencies = new HashSet<>(Arrays.asList("B", "C", "D")); + + Assert.assertEquals(expectedDependencies, allDependencies); + } + @Test public void testTopologicalSortOfModuleDescriptor() { PackageName packageName = PackageName.from("package");