-
Notifications
You must be signed in to change notification settings - Fork 98
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: refactor constraint session logic into smaller pieces
- Loading branch information
Showing
3 changed files
with
105 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
85 changes: 85 additions & 0 deletions
85
core/src/main/java/ai/timefold/solver/core/impl/score/stream/bavet/NodeNetwork.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package ai.timefold.solver.core.impl.score.stream.bavet; | ||
|
||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
|
||
import ai.timefold.solver.core.impl.score.stream.bavet.common.Propagator; | ||
import ai.timefold.solver.core.impl.score.stream.bavet.uni.AbstractForEachUniNode; | ||
|
||
/** | ||
* Represents Bavet's network of nodes, specific to a particular session. | ||
* Nodes only used by disabled constraints have already been removed. | ||
* | ||
* @param declaredClassToNodeMap starting nodes, one for each class used in the constraints; | ||
* root nodes, layer index 0. | ||
* @param layeredNodes nodes grouped first by their layer, then by their index within the layer; | ||
* propagation needs to happen in this order. | ||
*/ | ||
record NodeNetwork(Map<Class<?>, List<AbstractForEachUniNode<Object>>> declaredClassToNodeMap, Propagator[][] layeredNodes) { | ||
|
||
public static final NodeNetwork EMPTY = new NodeNetwork(Map.of(), new Propagator[0][0]); | ||
|
||
public int forEachNodeCount() { | ||
return declaredClassToNodeMap.size(); | ||
} | ||
|
||
public int layerCount() { | ||
return layeredNodes.length; | ||
} | ||
|
||
@SuppressWarnings("unchecked") | ||
public AbstractForEachUniNode<Object>[] getApplicableForEachNodes(Class<?> factClass) { | ||
return declaredClassToNodeMap.entrySet() | ||
.stream() | ||
.filter(entry -> entry.getKey().isAssignableFrom(factClass)) | ||
.map(Map.Entry::getValue) | ||
.flatMap(List::stream) | ||
.toArray(AbstractForEachUniNode[]::new); | ||
} | ||
|
||
public void propagate() { | ||
for (var layerIndex = 0; layerIndex < layerCount(); layerIndex++) { | ||
propagateInLayer(layeredNodes[layerIndex]); | ||
} | ||
} | ||
|
||
private static void propagateInLayer(Propagator[] nodesInLayer) { | ||
var nodeCount = nodesInLayer.length; | ||
if (nodeCount == 1) { | ||
nodesInLayer[0].propagateEverything(); | ||
} else { | ||
for (var node : nodesInLayer) { | ||
node.propagateRetracts(); | ||
} | ||
for (var node : nodesInLayer) { | ||
node.propagateUpdates(); | ||
} | ||
for (var node : nodesInLayer) { | ||
node.propagateInserts(); | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) | ||
return true; | ||
if (!(o instanceof NodeNetwork that)) | ||
return false; | ||
return Objects.equals(declaredClassToNodeMap, that.declaredClassToNodeMap) | ||
&& Objects.deepEquals(layeredNodes, that.layeredNodes); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(declaredClassToNodeMap, Arrays.deepHashCode(layeredNodes)); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return this.getClass().getSimpleName() + " with " + forEachNodeCount() + " forEach nodes."; | ||
} | ||
|
||
} |