diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/BavetConstraintSessionFactory.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/BavetConstraintSessionFactory.java index 1670d44378..d9127373b5 100644 --- a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/BavetConstraintSessionFactory.java +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/BavetConstraintSessionFactory.java @@ -11,12 +11,15 @@ import java.util.SortedMap; import java.util.TreeMap; +import ai.timefold.solver.constraint.streams.bavet.common.AbstractConcatNode; import ai.timefold.solver.constraint.streams.bavet.common.AbstractIfExistsNode; import ai.timefold.solver.constraint.streams.bavet.common.AbstractJoinNode; import ai.timefold.solver.constraint.streams.bavet.common.AbstractNode; import ai.timefold.solver.constraint.streams.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.constraint.streams.bavet.common.BavetConcatConstraintStream; import ai.timefold.solver.constraint.streams.bavet.common.BavetIfExistsConstraintStream; import ai.timefold.solver.constraint.streams.bavet.common.BavetJoinConstraintStream; +import ai.timefold.solver.constraint.streams.bavet.common.BavetStreamBinaryOperation; import ai.timefold.solver.constraint.streams.bavet.common.NodeBuildHelper; import ai.timefold.solver.constraint.streams.bavet.common.PropagationQueue; import ai.timefold.solver.constraint.streams.bavet.common.Propagator; @@ -134,19 +137,14 @@ private long determineLayerIndex(AbstractNode node, NodeBuildHelper buil if (node instanceof AbstractForEachUniNode) { // ForEach nodes, and only they, are in layer 0. return 0; } else if (node instanceof AbstractJoinNode joinNode) { - var nodeCreator = (BavetJoinConstraintStream) buildHelper.getNodeCreatingStream(joinNode); - var leftParent = nodeCreator.getLeftParent(); - var rightParent = nodeCreator.getRightParent(); - var leftParentNode = buildHelper.findParentNode(leftParent); - var rightParentNode = buildHelper.findParentNode(rightParent); - return Math.max(leftParentNode.getLayerIndex(), rightParentNode.getLayerIndex()) + 1; + return determineLayerIndexOfBinaryOperation( + (BavetJoinConstraintStream) buildHelper.getNodeCreatingStream(joinNode), buildHelper); + } else if (node instanceof AbstractConcatNode concatNode) { + return determineLayerIndexOfBinaryOperation( + (BavetConcatConstraintStream) buildHelper.getNodeCreatingStream(concatNode), buildHelper); } else if (node instanceof AbstractIfExistsNode ifExistsNode) { - var nodeCreator = (BavetIfExistsConstraintStream) buildHelper.getNodeCreatingStream(ifExistsNode); - var leftParent = nodeCreator.getLeftParent(); - var rightParent = nodeCreator.getRightParent(); - var leftParentNode = buildHelper.findParentNode(leftParent); - var rightParentNode = buildHelper.findParentNode(rightParent); - return Math.max(leftParentNode.getLayerIndex(), rightParentNode.getLayerIndex()) + 1; + return determineLayerIndexOfBinaryOperation( + (BavetIfExistsConstraintStream) buildHelper.getNodeCreatingStream(ifExistsNode), buildHelper); } else { var nodeCreator = (BavetAbstractConstraintStream) buildHelper.getNodeCreatingStream(node); var parentNode = buildHelper.findParentNode(nodeCreator.getParent()); @@ -154,4 +152,13 @@ private long determineLayerIndex(AbstractNode node, NodeBuildHelper buil } } + private long determineLayerIndexOfBinaryOperation(BavetStreamBinaryOperation nodeCreator, + NodeBuildHelper buildHelper) { + var leftParent = nodeCreator.getLeftParent(); + var rightParent = nodeCreator.getRightParent(); + var leftParentNode = buildHelper.findParentNode(leftParent); + var rightParentNode = buildHelper.findParentNode(rightParent); + return Math.max(leftParentNode.getLayerIndex(), rightParentNode.getLayerIndex()) + 1; + } + } diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/bi/BavetAbstractBiConstraintStream.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/bi/BavetAbstractBiConstraintStream.java index a725fbef27..83e55e260d 100644 --- a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/bi/BavetAbstractBiConstraintStream.java +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/bi/BavetAbstractBiConstraintStream.java @@ -359,6 +359,19 @@ public BiConstraintStream distinct() { } } + @Override + public BiConstraintStream concat(BiConstraintStream otherStream) { + var other = (BavetAbstractBiConstraintStream) otherStream; + var leftBridge = new BavetForeBridgeBiConstraintStream<>(constraintFactory, this); + var rightBridge = new BavetForeBridgeBiConstraintStream<>(constraintFactory, other); + var concatStream = new BavetConcatBiConstraintStream<>(constraintFactory, leftBridge, rightBridge); + return constraintFactory.share(concatStream, concatStream_ -> { + // Connect the bridges upstream + getChildStreamList().add(leftBridge); + other.getChildStreamList().add(rightBridge); + }); + } + @Override public UniConstraintStream map(BiFunction mapping) { var stream = shareAndAddChild(new BavetUniMapBiConstraintStream<>(constraintFactory, this, mapping)); diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/bi/BavetBiConcatNode.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/bi/BavetBiConcatNode.java new file mode 100644 index 0000000000..395d2e3546 --- /dev/null +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/bi/BavetBiConcatNode.java @@ -0,0 +1,25 @@ +package ai.timefold.solver.constraint.streams.bavet.bi; + +import ai.timefold.solver.constraint.streams.bavet.common.AbstractConcatNode; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.BiTuple; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.TupleLifecycle; + +public final class BavetBiConcatNode extends AbstractConcatNode> { + + BavetBiConcatNode(TupleLifecycle> nextNodesTupleLifecycle, int inputStoreIndexLeftOutTupleList, + int inputStoreIndexRightOutTupleList, + int outputStoreSize) { + super(nextNodesTupleLifecycle, inputStoreIndexLeftOutTupleList, inputStoreIndexRightOutTupleList, outputStoreSize); + } + + @Override + protected BiTuple getOutTuple(BiTuple inTuple) { + return new BiTuple<>(inTuple.factA, inTuple.factB, outputStoreSize); + } + + @Override + protected void updateOutTuple(BiTuple inTuple, BiTuple outTuple) { + outTuple.factA = inTuple.factA; + outTuple.factB = inTuple.factB; + } +} diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/bi/BavetConcatBiConstraintStream.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/bi/BavetConcatBiConstraintStream.java new file mode 100644 index 0000000000..fdd08c8259 --- /dev/null +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/bi/BavetConcatBiConstraintStream.java @@ -0,0 +1,104 @@ +package ai.timefold.solver.constraint.streams.bavet.bi; + +import java.util.Objects; +import java.util.Set; + +import ai.timefold.solver.constraint.streams.bavet.BavetConstraintFactory; +import ai.timefold.solver.constraint.streams.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.constraint.streams.bavet.common.BavetConcatConstraintStream; +import ai.timefold.solver.constraint.streams.bavet.common.NodeBuildHelper; +import ai.timefold.solver.constraint.streams.bavet.common.bridge.BavetForeBridgeBiConstraintStream; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.BiTuple; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.api.score.Score; + +public final class BavetConcatBiConstraintStream extends BavetAbstractBiConstraintStream + implements BavetConcatConstraintStream { + + private final BavetForeBridgeBiConstraintStream leftParent; + private final BavetForeBridgeBiConstraintStream rightParent; + + public BavetConcatBiConstraintStream(BavetConstraintFactory constraintFactory, + BavetForeBridgeBiConstraintStream leftParent, + BavetForeBridgeBiConstraintStream rightParent) { + super(constraintFactory, leftParent.getRetrievalSemantics()); + this.leftParent = leftParent; + this.rightParent = rightParent; + } + + @Override + public boolean guaranteesDistinct() { + return false; + } + + // ************************************************************************ + // Node creation + // ************************************************************************ + + @Override + public void collectActiveConstraintStreams(Set> constraintStreamSet) { + leftParent.collectActiveConstraintStreams(constraintStreamSet); + rightParent.collectActiveConstraintStreams(constraintStreamSet); + constraintStreamSet.add(this); + } + + @Override + public > void buildNode(NodeBuildHelper buildHelper) { + TupleLifecycle> downstream = buildHelper.getAggregatedTupleLifecycle(childStreamList); + int leftCloneStoreIndex = buildHelper.reserveTupleStoreIndex(leftParent.getTupleSource()); + int rightCloneStoreIndex = buildHelper.reserveTupleStoreIndex(rightParent.getTupleSource()); + int outputStoreSize = buildHelper.extractTupleStoreSize(this); + var node = new BavetBiConcatNode<>(downstream, + leftCloneStoreIndex, + rightCloneStoreIndex, + outputStoreSize); + buildHelper.addNode(node, this, leftParent, rightParent); + } + + // ************************************************************************ + // Equality for node sharing + // ************************************************************************ + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BavetConcatBiConstraintStream other = (BavetConcatBiConstraintStream) o; + /* + * Bridge streams do not implement equality because their equals() would have to point back to this stream, + * resulting in StackOverflowError. + * Therefore we need to check bridge parents to see where this concat node comes from. + */ + return Objects.equals(leftParent.getParent(), other.leftParent.getParent()) + && Objects.equals(rightParent.getParent(), other.rightParent.getParent()); + } + + @Override + public int hashCode() { + return Objects.hash(leftParent.getParent(), rightParent.getParent()); + } + + @Override + public String toString() { + return "Concat() with " + childStreamList.size() + " children"; + } + + // ************************************************************************ + // Getters/setters + // ************************************************************************ + + @Override + public BavetAbstractConstraintStream getLeftParent() { + return leftParent; + } + + @Override + public BavetAbstractConstraintStream getRightParent() { + return rightParent; + } + +} diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/AbstractConcatNode.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/AbstractConcatNode.java new file mode 100644 index 0000000000..9d4f251a01 --- /dev/null +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/AbstractConcatNode.java @@ -0,0 +1,134 @@ +package ai.timefold.solver.constraint.streams.bavet.common; + +import static ai.timefold.solver.constraint.streams.bavet.common.tuple.TupleState.ABORTING; +import static ai.timefold.solver.constraint.streams.bavet.common.tuple.TupleState.CREATING; +import static ai.timefold.solver.constraint.streams.bavet.common.tuple.TupleState.DYING; +import static ai.timefold.solver.constraint.streams.bavet.common.tuple.TupleState.UPDATING; + +import ai.timefold.solver.constraint.streams.bavet.common.tuple.AbstractTuple; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.LeftTupleLifecycle; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.RightTupleLifecycle; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.TupleState; + +/** + * Implements the concat operation. Concat cannot be implemented as a pass-through operation because of two caveats: + * + *
    + *
  • It is possible to have the same {@link TupleSource} for both parent streams, + * in which case the exact same tuple can be inserted twice. Such a tuple + * should be counted twice downstream, and thus need to be cloned. + *
  • + * + *
  • Because concat has two parent nodes, it must be a {@link TupleSource} (since + * all nodes have exactly one {@link TupleSource}, and the source tuple can come from + * either parent). {@link TupleSource} must produce new tuples and not reuse them, since + * if tuples are reused, the stores inside them get corrupted. + *
  • + *
+ * + * The {@link AbstractConcatNode} works by creating a copy of the source tuple and putting it into + * the tuple's store. If the same tuple is inserted twice (i.e. when the left and right parent + * have the same {@link TupleSource}), it creates another clone. + */ +public abstract class AbstractConcatNode + extends AbstractNode + implements LeftTupleLifecycle, RightTupleLifecycle { + private final int leftSourceTupleCloneStoreIndex; + private final int rightSourceTupleCloneStoreIndex; + protected final int outputStoreSize; + private final StaticPropagationQueue propagationQueue; + + protected AbstractConcatNode(TupleLifecycle nextNodesTupleLifecycle, + int leftSourceTupleCloneStoreIndex, + int rightSourceTupleCloneStoreIndex, + int outputStoreSize) { + this.propagationQueue = new StaticPropagationQueue<>(nextNodesTupleLifecycle); + this.leftSourceTupleCloneStoreIndex = leftSourceTupleCloneStoreIndex; + this.rightSourceTupleCloneStoreIndex = rightSourceTupleCloneStoreIndex; + this.outputStoreSize = outputStoreSize; + } + + /** + * Creates a copy of the inTuple with the same fact (and new store/state). + */ + protected abstract Tuple_ getOutTuple(Tuple_ inTuple); + + /** + * Updates outTuple to contain the same facts as inTuple. + */ + protected abstract void updateOutTuple(Tuple_ inTuple, Tuple_ outTuple); + + private void insert(Tuple_ tuple, int storeIndex) { + Tuple_ outTuple = getOutTuple(tuple); + tuple.setStore(storeIndex, outTuple); + propagationQueue.insert(outTuple); + } + + private void update(Tuple_ tuple, int storeIndex) { + Tuple_ outTuple = tuple.getStore(storeIndex); + if (outTuple == null) { + // No fail fast if null because we don't track which tuples made it through the filter predicate(s) + insert(tuple, storeIndex); + return; + } + + updateOutTuple(tuple, outTuple); + // Even if the facts of tuple do not change, an update MUST be done so + // downstream nodes get notified of updates in planning variables. + TupleState previousState = outTuple.state; + if (previousState == CREATING || previousState == UPDATING) { + return; + } + propagationQueue.update(outTuple); + } + + private void retract(Tuple_ tuple, int storeIndex) { + Tuple_ outTuple = tuple.getStore(storeIndex); + if (outTuple == null) { + // No fail fast if null because we don't track which tuples made it through the filter predicate(s) + return; + } + TupleState state = outTuple.state; + if (!state.isActive()) { + throw new IllegalStateException("Impossible state: The tuple (" + outTuple.state + ") in node (" + this + + ") is in an unexpected state (" + outTuple.state + ")."); + } + propagationQueue.retract(outTuple, state == CREATING ? ABORTING : DYING); + } + + @Override + public final void insertLeft(Tuple_ tuple) { + insert(tuple, leftSourceTupleCloneStoreIndex); + } + + @Override + public final void updateLeft(Tuple_ tuple) { + update(tuple, leftSourceTupleCloneStoreIndex); + } + + @Override + public final void retractLeft(Tuple_ tuple) { + retract(tuple, leftSourceTupleCloneStoreIndex); + } + + @Override + public final void insertRight(Tuple_ tuple) { + insert(tuple, rightSourceTupleCloneStoreIndex); + } + + @Override + public final void updateRight(Tuple_ tuple) { + update(tuple, rightSourceTupleCloneStoreIndex); + } + + @Override + public final void retractRight(Tuple_ tuple) { + retract(tuple, rightSourceTupleCloneStoreIndex); + } + + @Override + public Propagator getPropagator() { + return propagationQueue; + } +} diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/BavetConcatConstraintStream.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/BavetConcatConstraintStream.java new file mode 100644 index 0000000000..d46217b49f --- /dev/null +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/BavetConcatConstraintStream.java @@ -0,0 +1,6 @@ +package ai.timefold.solver.constraint.streams.bavet.common; + +public interface BavetConcatConstraintStream + extends BavetStreamBinaryOperation, TupleSource { + +} diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/BavetIfExistsConstraintStream.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/BavetIfExistsConstraintStream.java index ef466636de..08561fe1e6 100644 --- a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/BavetIfExistsConstraintStream.java +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/BavetIfExistsConstraintStream.java @@ -1,15 +1,5 @@ package ai.timefold.solver.constraint.streams.bavet.common; -import ai.timefold.solver.constraint.streams.bavet.common.bridge.BavetForeBridgeUniConstraintStream; - -public interface BavetIfExistsConstraintStream { - - BavetAbstractConstraintStream getLeftParent(); - - /** - * - * @return An instance of {@link BavetForeBridgeUniConstraintStream}. - */ - BavetAbstractConstraintStream getRightParent(); +public interface BavetIfExistsConstraintStream extends BavetStreamBinaryOperation { } diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/BavetJoinConstraintStream.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/BavetJoinConstraintStream.java index adb21f51e1..05b90482e5 100644 --- a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/BavetJoinConstraintStream.java +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/BavetJoinConstraintStream.java @@ -3,7 +3,7 @@ import ai.timefold.solver.constraint.streams.bavet.common.bridge.BavetForeBridgeUniConstraintStream; public interface BavetJoinConstraintStream - extends TupleSource { + extends BavetStreamBinaryOperation, TupleSource { /** * diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/BavetStreamBinaryOperation.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/BavetStreamBinaryOperation.java new file mode 100644 index 0000000000..aba1e38425 --- /dev/null +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/BavetStreamBinaryOperation.java @@ -0,0 +1,16 @@ +package ai.timefold.solver.constraint.streams.bavet.common; + +import ai.timefold.solver.constraint.streams.bavet.common.bridge.BavetForeBridgeUniConstraintStream; + +public interface BavetStreamBinaryOperation { + /** + * @return An instance of {@link BavetForeBridgeUniConstraintStream}. + */ + BavetAbstractConstraintStream getLeftParent(); + + /** + * @return An instance of {@link BavetForeBridgeUniConstraintStream}. + */ + BavetAbstractConstraintStream getRightParent(); + +} diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/bridge/BavetForeBridgeQuadConstraintStream.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/bridge/BavetForeBridgeQuadConstraintStream.java new file mode 100644 index 0000000000..77b8f30bb3 --- /dev/null +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/common/bridge/BavetForeBridgeQuadConstraintStream.java @@ -0,0 +1,34 @@ +package ai.timefold.solver.constraint.streams.bavet.common.bridge; + +import ai.timefold.solver.constraint.streams.bavet.BavetConstraintFactory; +import ai.timefold.solver.constraint.streams.bavet.common.NodeBuildHelper; +import ai.timefold.solver.constraint.streams.bavet.quad.BavetAbstractQuadConstraintStream; +import ai.timefold.solver.core.api.score.Score; + +public final class BavetForeBridgeQuadConstraintStream + extends BavetAbstractQuadConstraintStream { + + public BavetForeBridgeQuadConstraintStream(BavetConstraintFactory constraintFactory, + BavetAbstractQuadConstraintStream parent) { + super(constraintFactory, parent); + } + + // ************************************************************************ + // Node creation + // ************************************************************************ + + @Override + public > void buildNode(NodeBuildHelper buildHelper) { + // Do nothing. The child stream builds everything. + } + + @Override + public String toString() { + return "Generic bridge"; + } + + // ************************************************************************ + // Getters/setters + // ************************************************************************ + +} diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/quad/BavetAbstractQuadConstraintStream.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/quad/BavetAbstractQuadConstraintStream.java index 71cc3ba05f..7684ba04aa 100644 --- a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/quad/BavetAbstractQuadConstraintStream.java +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/quad/BavetAbstractQuadConstraintStream.java @@ -14,6 +14,7 @@ import ai.timefold.solver.constraint.streams.bavet.common.bridge.BavetAftBridgeQuadConstraintStream; import ai.timefold.solver.constraint.streams.bavet.common.bridge.BavetAftBridgeTriConstraintStream; import ai.timefold.solver.constraint.streams.bavet.common.bridge.BavetAftBridgeUniConstraintStream; +import ai.timefold.solver.constraint.streams.bavet.common.bridge.BavetForeBridgeQuadConstraintStream; import ai.timefold.solver.constraint.streams.bavet.common.bridge.BavetForeBridgeUniConstraintStream; import ai.timefold.solver.constraint.streams.bavet.common.tuple.BiTuple; import ai.timefold.solver.constraint.streams.bavet.common.tuple.QuadTuple; @@ -347,6 +348,19 @@ public QuadConstraintStream distinct() { } } + @Override + public QuadConstraintStream concat(QuadConstraintStream otherStream) { + var other = (BavetAbstractQuadConstraintStream) otherStream; + var leftBridge = new BavetForeBridgeQuadConstraintStream<>(constraintFactory, this); + var rightBridge = new BavetForeBridgeQuadConstraintStream<>(constraintFactory, other); + var concatStream = new BavetConcatQuadConstraintStream<>(constraintFactory, leftBridge, rightBridge); + return constraintFactory.share(concatStream, concatStream_ -> { + // Connect the bridges upstream + getChildStreamList().add(leftBridge); + other.getChildStreamList().add(rightBridge); + }); + } + @Override public UniConstraintStream map(QuadFunction mapping) { var stream = shareAndAddChild(new BavetUniMapQuadConstraintStream<>(constraintFactory, this, mapping)); diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/quad/BavetConcatQuadConstraintStream.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/quad/BavetConcatQuadConstraintStream.java new file mode 100644 index 0000000000..c7da61cf51 --- /dev/null +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/quad/BavetConcatQuadConstraintStream.java @@ -0,0 +1,105 @@ +package ai.timefold.solver.constraint.streams.bavet.quad; + +import java.util.Objects; +import java.util.Set; + +import ai.timefold.solver.constraint.streams.bavet.BavetConstraintFactory; +import ai.timefold.solver.constraint.streams.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.constraint.streams.bavet.common.BavetConcatConstraintStream; +import ai.timefold.solver.constraint.streams.bavet.common.NodeBuildHelper; +import ai.timefold.solver.constraint.streams.bavet.common.bridge.BavetForeBridgeQuadConstraintStream; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.QuadTuple; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.api.score.Score; + +public final class BavetConcatQuadConstraintStream + extends BavetAbstractQuadConstraintStream + implements BavetConcatConstraintStream { + + private final BavetForeBridgeQuadConstraintStream leftParent; + private final BavetForeBridgeQuadConstraintStream rightParent; + + public BavetConcatQuadConstraintStream(BavetConstraintFactory constraintFactory, + BavetForeBridgeQuadConstraintStream leftParent, + BavetForeBridgeQuadConstraintStream rightParent) { + super(constraintFactory, leftParent.getRetrievalSemantics()); + this.leftParent = leftParent; + this.rightParent = rightParent; + } + + @Override + public boolean guaranteesDistinct() { + return false; + } + + // ************************************************************************ + // Node creation + // ************************************************************************ + + @Override + public void collectActiveConstraintStreams(Set> constraintStreamSet) { + leftParent.collectActiveConstraintStreams(constraintStreamSet); + rightParent.collectActiveConstraintStreams(constraintStreamSet); + constraintStreamSet.add(this); + } + + @Override + public > void buildNode(NodeBuildHelper buildHelper) { + TupleLifecycle> downstream = buildHelper.getAggregatedTupleLifecycle(childStreamList); + int leftCloneStoreIndex = buildHelper.reserveTupleStoreIndex(leftParent.getTupleSource()); + int rightCloneStoreIndex = buildHelper.reserveTupleStoreIndex(rightParent.getTupleSource()); + int outputStoreSize = buildHelper.extractTupleStoreSize(this); + var node = new BavetQuadConcatNode<>(downstream, + leftCloneStoreIndex, + rightCloneStoreIndex, + outputStoreSize); + buildHelper.addNode(node, this, leftParent, rightParent); + } + + // ************************************************************************ + // Equality for node sharing + // ************************************************************************ + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BavetConcatQuadConstraintStream other = (BavetConcatQuadConstraintStream) o; + /* + * Bridge streams do not implement equality because their equals() would have to point back to this stream, + * resulting in StackOverflowError. + * Therefore we need to check bridge parents to see where this concat node comes from. + */ + return Objects.equals(leftParent.getParent(), other.leftParent.getParent()) + && Objects.equals(rightParent.getParent(), other.rightParent.getParent()); + } + + @Override + public int hashCode() { + return Objects.hash(leftParent.getParent(), rightParent.getParent()); + } + + @Override + public String toString() { + return "Concat() with " + childStreamList.size() + " children"; + } + + // ************************************************************************ + // Getters/setters + // ************************************************************************ + + @Override + public BavetAbstractConstraintStream getLeftParent() { + return leftParent; + } + + @Override + public BavetAbstractConstraintStream getRightParent() { + return rightParent; + } + +} diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/quad/BavetQuadConcatNode.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/quad/BavetQuadConcatNode.java new file mode 100644 index 0000000000..fa9603a7d1 --- /dev/null +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/quad/BavetQuadConcatNode.java @@ -0,0 +1,26 @@ +package ai.timefold.solver.constraint.streams.bavet.quad; + +import ai.timefold.solver.constraint.streams.bavet.common.AbstractConcatNode; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.QuadTuple; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.TupleLifecycle; + +public final class BavetQuadConcatNode extends AbstractConcatNode> { + + BavetQuadConcatNode(TupleLifecycle> nextNodesTupleLifecycle, + int inputStoreIndexLeftOutTupleList, int inputStoreIndexRightOutTupleList, int outputStoreSize) { + super(nextNodesTupleLifecycle, inputStoreIndexLeftOutTupleList, inputStoreIndexRightOutTupleList, outputStoreSize); + } + + @Override + protected QuadTuple getOutTuple(QuadTuple inTuple) { + return new QuadTuple<>(inTuple.factA, inTuple.factB, inTuple.factC, inTuple.factD, outputStoreSize); + } + + @Override + protected void updateOutTuple(QuadTuple inTuple, QuadTuple outTuple) { + outTuple.factA = inTuple.factA; + outTuple.factB = inTuple.factB; + outTuple.factC = inTuple.factC; + outTuple.factD = inTuple.factD; + } +} diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/tri/BavetAbstractTriConstraintStream.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/tri/BavetAbstractTriConstraintStream.java index b97e3c73c2..5812b8dd69 100644 --- a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/tri/BavetAbstractTriConstraintStream.java +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/tri/BavetAbstractTriConstraintStream.java @@ -370,6 +370,19 @@ public TriConstraintStream distinct() { } } + @Override + public TriConstraintStream concat(TriConstraintStream otherStream) { + var other = (BavetAbstractTriConstraintStream) otherStream; + var leftBridge = new BavetForeBridgeTriConstraintStream<>(constraintFactory, this); + var rightBridge = new BavetForeBridgeTriConstraintStream<>(constraintFactory, other); + var concatStream = new BavetConcatTriConstraintStream<>(constraintFactory, leftBridge, rightBridge); + return constraintFactory.share(concatStream, concatStream_ -> { + // Connect the bridges upstream + getChildStreamList().add(leftBridge); + other.getChildStreamList().add(rightBridge); + }); + } + @Override public UniConstraintStream map(TriFunction mapping) { var stream = shareAndAddChild(new BavetUniMapTriConstraintStream<>(constraintFactory, this, mapping)); diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/tri/BavetConcatTriConstraintStream.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/tri/BavetConcatTriConstraintStream.java new file mode 100644 index 0000000000..1191f1977d --- /dev/null +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/tri/BavetConcatTriConstraintStream.java @@ -0,0 +1,105 @@ +package ai.timefold.solver.constraint.streams.bavet.tri; + +import java.util.Objects; +import java.util.Set; + +import ai.timefold.solver.constraint.streams.bavet.BavetConstraintFactory; +import ai.timefold.solver.constraint.streams.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.constraint.streams.bavet.common.BavetConcatConstraintStream; +import ai.timefold.solver.constraint.streams.bavet.common.NodeBuildHelper; +import ai.timefold.solver.constraint.streams.bavet.common.bridge.BavetForeBridgeTriConstraintStream; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.TriTuple; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.core.api.score.Score; + +public final class BavetConcatTriConstraintStream + extends BavetAbstractTriConstraintStream + implements BavetConcatConstraintStream { + + private final BavetForeBridgeTriConstraintStream leftParent; + private final BavetForeBridgeTriConstraintStream rightParent; + + public BavetConcatTriConstraintStream(BavetConstraintFactory constraintFactory, + BavetForeBridgeTriConstraintStream leftParent, + BavetForeBridgeTriConstraintStream rightParent) { + super(constraintFactory, leftParent.getRetrievalSemantics()); + this.leftParent = leftParent; + this.rightParent = rightParent; + } + + @Override + public boolean guaranteesDistinct() { + return false; + } + + // ************************************************************************ + // Node creation + // ************************************************************************ + + @Override + public void collectActiveConstraintStreams(Set> constraintStreamSet) { + leftParent.collectActiveConstraintStreams(constraintStreamSet); + rightParent.collectActiveConstraintStreams(constraintStreamSet); + constraintStreamSet.add(this); + } + + @Override + public > void buildNode(NodeBuildHelper buildHelper) { + TupleLifecycle> downstream = buildHelper.getAggregatedTupleLifecycle(childStreamList); + int leftCloneStoreIndex = buildHelper.reserveTupleStoreIndex(leftParent.getTupleSource()); + int rightCloneStoreIndex = buildHelper.reserveTupleStoreIndex(rightParent.getTupleSource()); + int outputStoreSize = buildHelper.extractTupleStoreSize(this); + var node = new BavetTriConcatNode<>(downstream, + leftCloneStoreIndex, + rightCloneStoreIndex, + outputStoreSize); + buildHelper.addNode(node, this, leftParent, rightParent); + } + + // ************************************************************************ + // Equality for node sharing + // ************************************************************************ + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BavetConcatTriConstraintStream other = (BavetConcatTriConstraintStream) o; + /* + * Bridge streams do not implement equality because their equals() would have to point back to this stream, + * resulting in StackOverflowError. + * Therefore we need to check bridge parents to see where this concat node comes from. + */ + return Objects.equals(leftParent.getParent(), other.leftParent.getParent()) + && Objects.equals(rightParent.getParent(), other.rightParent.getParent()); + } + + @Override + public int hashCode() { + return Objects.hash(leftParent.getParent(), rightParent.getParent()); + } + + @Override + public String toString() { + return "Concat() with " + childStreamList.size() + " children"; + } + + // ************************************************************************ + // Getters/setters + // ************************************************************************ + + @Override + public BavetAbstractConstraintStream getLeftParent() { + return leftParent; + } + + @Override + public BavetAbstractConstraintStream getRightParent() { + return rightParent; + } + +} diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/tri/BavetTriConcatNode.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/tri/BavetTriConcatNode.java new file mode 100644 index 0000000000..acf052d660 --- /dev/null +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/tri/BavetTriConcatNode.java @@ -0,0 +1,26 @@ +package ai.timefold.solver.constraint.streams.bavet.tri; + +import ai.timefold.solver.constraint.streams.bavet.common.AbstractConcatNode; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.TriTuple; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.TupleLifecycle; + +public final class BavetTriConcatNode extends AbstractConcatNode> { + + BavetTriConcatNode(TupleLifecycle> nextNodesTupleLifecycle, int inputStoreIndexLeftOutTupleList, + int inputStoreIndexRightOutTupleList, + int outputStoreSize) { + super(nextNodesTupleLifecycle, inputStoreIndexLeftOutTupleList, inputStoreIndexRightOutTupleList, outputStoreSize); + } + + @Override + protected TriTuple getOutTuple(TriTuple inTuple) { + return new TriTuple<>(inTuple.factA, inTuple.factB, inTuple.factC, outputStoreSize); + } + + @Override + protected void updateOutTuple(TriTuple inTuple, TriTuple outTuple) { + outTuple.factA = inTuple.factA; + outTuple.factB = inTuple.factB; + outTuple.factC = inTuple.factC; + } +} diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/uni/BavetAbstractUniConstraintStream.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/uni/BavetAbstractUniConstraintStream.java index 3591139d50..5b9bd01569 100644 --- a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/uni/BavetAbstractUniConstraintStream.java +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/uni/BavetAbstractUniConstraintStream.java @@ -361,6 +361,19 @@ public UniConstraintStream distinct() { } } + @Override + public UniConstraintStream concat(UniConstraintStream otherStream) { + var other = (BavetAbstractUniConstraintStream) otherStream; + var leftBridge = new BavetForeBridgeUniConstraintStream<>(constraintFactory, this); + var rightBridge = new BavetForeBridgeUniConstraintStream<>(constraintFactory, other); + var concatStream = new BavetConcatUniConstraintStream<>(constraintFactory, leftBridge, rightBridge); + return constraintFactory.share(concatStream, concatStream_ -> { + // Connect the bridges upstream + getChildStreamList().add(leftBridge); + other.getChildStreamList().add(rightBridge); + }); + } + @Override public UniConstraintStream map(Function mapping) { var stream = shareAndAddChild(new BavetUniMapUniConstraintStream<>(constraintFactory, this, mapping)); diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/uni/BavetConcatUniConstraintStream.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/uni/BavetConcatUniConstraintStream.java new file mode 100644 index 0000000000..c8576c1fb8 --- /dev/null +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/uni/BavetConcatUniConstraintStream.java @@ -0,0 +1,104 @@ +package ai.timefold.solver.constraint.streams.bavet.uni; + +import java.util.Objects; +import java.util.Set; + +import ai.timefold.solver.constraint.streams.bavet.BavetConstraintFactory; +import ai.timefold.solver.constraint.streams.bavet.common.BavetAbstractConstraintStream; +import ai.timefold.solver.constraint.streams.bavet.common.BavetConcatConstraintStream; +import ai.timefold.solver.constraint.streams.bavet.common.NodeBuildHelper; +import ai.timefold.solver.constraint.streams.bavet.common.bridge.BavetForeBridgeUniConstraintStream; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.UniTuple; +import ai.timefold.solver.core.api.score.Score; + +public final class BavetConcatUniConstraintStream extends BavetAbstractUniConstraintStream + implements BavetConcatConstraintStream { + + private final BavetForeBridgeUniConstraintStream leftParent; + private final BavetForeBridgeUniConstraintStream rightParent; + + public BavetConcatUniConstraintStream(BavetConstraintFactory constraintFactory, + BavetForeBridgeUniConstraintStream leftParent, + BavetForeBridgeUniConstraintStream rightParent) { + super(constraintFactory, leftParent.getRetrievalSemantics()); + this.leftParent = leftParent; + this.rightParent = rightParent; + } + + @Override + public boolean guaranteesDistinct() { + return false; + } + + // ************************************************************************ + // Node creation + // ************************************************************************ + + @Override + public void collectActiveConstraintStreams(Set> constraintStreamSet) { + leftParent.collectActiveConstraintStreams(constraintStreamSet); + rightParent.collectActiveConstraintStreams(constraintStreamSet); + constraintStreamSet.add(this); + } + + @Override + public > void buildNode(NodeBuildHelper buildHelper) { + TupleLifecycle> downstream = buildHelper.getAggregatedTupleLifecycle(childStreamList); + int leftCloneStoreIndex = buildHelper.reserveTupleStoreIndex(leftParent.getTupleSource()); + int rightCloneStoreIndex = buildHelper.reserveTupleStoreIndex(rightParent.getTupleSource()); + int outputStoreSize = buildHelper.extractTupleStoreSize(this); + var node = new BavetUniConcatNode<>(downstream, + leftCloneStoreIndex, + rightCloneStoreIndex, + outputStoreSize); + buildHelper.addNode(node, this, leftParent, rightParent); + } + + // ************************************************************************ + // Equality for node sharing + // ************************************************************************ + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BavetConcatUniConstraintStream other = (BavetConcatUniConstraintStream) o; + /* + * Bridge streams do not implement equality because their equals() would have to point back to this stream, + * resulting in StackOverflowError. + * Therefore we need to check bridge parents to see where this concat node comes from. + */ + return Objects.equals(leftParent.getParent(), other.leftParent.getParent()) + && Objects.equals(rightParent.getParent(), other.rightParent.getParent()); + } + + @Override + public int hashCode() { + return Objects.hash(leftParent.getParent(), rightParent.getParent()); + } + + @Override + public String toString() { + return "Concat() with " + childStreamList.size() + " children"; + } + + // ************************************************************************ + // Getters/setters + // ************************************************************************ + + @Override + public BavetAbstractConstraintStream getLeftParent() { + return leftParent; + } + + @Override + public BavetAbstractConstraintStream getRightParent() { + return rightParent; + } + +} diff --git a/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/uni/BavetUniConcatNode.java b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/uni/BavetUniConcatNode.java new file mode 100644 index 0000000000..14faae40c0 --- /dev/null +++ b/core/constraint-streams/src/main/java/ai/timefold/solver/constraint/streams/bavet/uni/BavetUniConcatNode.java @@ -0,0 +1,25 @@ +package ai.timefold.solver.constraint.streams.bavet.uni; + +import ai.timefold.solver.constraint.streams.bavet.common.AbstractConcatNode; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.TupleLifecycle; +import ai.timefold.solver.constraint.streams.bavet.common.tuple.UniTuple; + +public final class BavetUniConcatNode extends AbstractConcatNode> { + + BavetUniConcatNode(TupleLifecycle> nextNodesTupleLifecycle, int inputStoreIndexLeftOutTupleList, + int inputStoreIndexRightOutTupleList, + int outputStoreSize) { + super(nextNodesTupleLifecycle, inputStoreIndexLeftOutTupleList, inputStoreIndexRightOutTupleList, + outputStoreSize); + } + + @Override + protected UniTuple getOutTuple(UniTuple inTuple) { + return new UniTuple<>(inTuple.factA, outputStoreSize); + } + + @Override + protected void updateOutTuple(UniTuple inTuple, UniTuple outTuple) { + outTuple.factA = inTuple.factA; + } +} diff --git a/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/ConstraintStreamFunctionalTest.java b/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/ConstraintStreamFunctionalTest.java index ee22d25902..2557bc5943 100644 --- a/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/ConstraintStreamFunctionalTest.java +++ b/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/ConstraintStreamFunctionalTest.java @@ -99,7 +99,7 @@ default void joinAfterGroupBy() { void groupBy_4Mapping0Collector(); // ************************************************************************ - // Map/expand/flatten/distinct + // Map/expand/flatten/distinct/concat // ************************************************************************ void distinct(); @@ -118,6 +118,16 @@ default void joinAfterGroupBy() { void mapToQuad(); + void concatWithoutValueDuplicates(); + + void concatAndDistinctWithoutValueDuplicates(); + + void concatWithValueDuplicates(); + + void concatAndDistinctWithValueDuplicates(); + + void concatAfterGroupBy(); + default void expandToBi() { // Only Uni can be expanded to Bi, so don't force it. } diff --git a/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/bi/AbstractBiConstraintStreamTest.java b/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/bi/AbstractBiConstraintStreamTest.java index 363b6d9ba4..b735c98e7f 100644 --- a/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/bi/AbstractBiConstraintStreamTest.java +++ b/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/bi/AbstractBiConstraintStreamTest.java @@ -1344,7 +1344,7 @@ public void groupBy_4Mapping0Collector() { } // ************************************************************************ - // Map/flatten/distinct + // Map/flatten/distinct/concat // ************************************************************************ @Override @@ -1793,6 +1793,294 @@ public void flattenLastAndDistinctWithoutDuplicates() { assertMatch(entity2, group1)); } + @Override + @TestTemplate + public void concatWithoutValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3))) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1, entity2), + assertMatch(entity2, entity3)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity1, entity2), + assertMatch(entity1, entity3)); + + // Incremental for which the first change matches a join that doesn't survive the second change + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(value3); + scoreDirector.afterVariableChanged(entity1, "value"); + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value1); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity3, entity2), + assertMatch(entity2, entity1)); + } + + @Override + @TestTemplate + public void concatWithValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1 || entity.getValue() == value2) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2 || entity.getValue() == value3)) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3))) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1, entity2), + assertMatch(entity1, entity3), + assertMatch(entity2, entity3), + assertMatch(entity2, entity2), + assertMatch(entity2, entity3)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity1, entity2), + assertMatch(entity1, entity2), + assertMatch(entity2, entity2), + assertMatch(entity3, entity3), + assertMatch(entity1, entity3), + assertMatch(entity2, entity3)); + + // Incremental for which the first change matches a join that doesn't survive the second change + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(value3); + scoreDirector.afterVariableChanged(entity1, "value"); + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value1); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity3, entity2), + assertMatch(entity3, entity1), + assertMatch(entity2, entity1), + assertMatch(entity2, entity1), + assertMatch(entity2, entity2)); + } + + @Override + @TestTemplate + public void concatAndDistinctWithoutValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3))) + .distinct() + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1, entity2), + assertMatch(entity2, entity3)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity1, entity2), + assertMatch(entity1, entity3)); + + // Incremental for which the first change matches a join that doesn't survive the second change + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(value3); + scoreDirector.afterVariableChanged(entity1, "value"); + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value1); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity3, entity2), + assertMatch(entity2, entity1)); + } + + @Override + @TestTemplate + public void concatAndDistinctWithValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1 || entity.getValue() == value2) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2 || entity.getValue() == value3)) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3))) + .distinct() + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1, entity2), + assertMatch(entity1, entity3), + assertMatch(entity2, entity2), + assertMatch(entity2, entity3)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity1, entity2), + assertMatch(entity1, entity2), + assertMatch(entity2, entity2), + assertMatch(entity3, entity3), + assertMatch(entity1, entity3), + assertMatch(entity2, entity3)); + + // Incremental for which the first change matches a join that doesn't survive the second change + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(value3); + scoreDirector.afterVariableChanged(entity1, "value"); + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value1); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity3, entity2), + assertMatch(entity3, entity1), + assertMatch(entity2, entity1), + assertMatch(entity2, entity2)); + } + + @Override + @TestTemplate + public void concatAfterGroupBy() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .join(TestdataLavishEntity.class) + .filter((e1, e2) -> e1.getValue() == value1 && e2.getValue() == value2) + .groupBy((e1, e2) -> e1.getValue(), + (e1, e2) -> e2.getValue(), + ConstraintCollectors.countBi()) + .concat(factory.forEach(TestdataLavishEntity.class) + .join(TestdataLavishEntity.class) + .filter((e1, e2) -> e1.getValue() == value2 && e2.getValue() == value3) + .groupBy((e1, e2) -> e1.getValue(), + (e1, e2) -> e2.getValue(), + ConstraintCollectors.countBi())) + .penalize(SimpleScore.ONE, (v1, v2, count) -> count) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatchWithScore(-1, value1, value2, 1), + assertMatchWithScore(-1, value2, value3, 1)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatchWithScore(-2, value1, value2, 2)); + + // Incremental for which the first change matches a join that doesn't survive the second change + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(value3); + scoreDirector.afterVariableChanged(entity1, "value"); + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value1); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatchWithScore(-1, value1, value2, 1), + assertMatchWithScore(-1, value2, value3, 1)); + } + // ************************************************************************ // Penalize/reward // ************************************************************************ diff --git a/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/quad/AbstractQuadConstraintStreamTest.java b/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/quad/AbstractQuadConstraintStreamTest.java index 0164550f22..606547b690 100644 --- a/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/quad/AbstractQuadConstraintStreamTest.java +++ b/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/quad/AbstractQuadConstraintStreamTest.java @@ -1423,6 +1423,285 @@ public void flattenLastAndDistinctWithoutDuplicates() { assertScore(scoreDirector); } + @Override + @TestTemplate + public void concatWithoutValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1)) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2))) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1, entity2, entity3, entity1), + assertMatch(entity2, entity3, entity1, entity2)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + + scoreDirector.beforeVariableChanged(entity2, "value"); + entity2.setValue(value3); + scoreDirector.afterVariableChanged(entity2, "value"); + assertScore(scoreDirector, + assertMatch(entity1, entity3, entity2, entity1), + assertMatch(entity3, entity2, entity1, entity3)); + } + + @Override + @TestTemplate + public void concatWithValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1)) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1))) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1, entity2, entity3, entity1), + assertMatch(entity1, entity2, entity3, entity1)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + + scoreDirector.beforeVariableChanged(entity2, "value"); + entity2.setValue(value3); + scoreDirector.afterVariableChanged(entity2, "value"); + assertScore(scoreDirector, + assertMatch(entity1, entity3, entity2, entity1), + assertMatch(entity1, entity3, entity2, entity1)); + } + + @Override + @TestTemplate + public void concatAndDistinctWithoutValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1)) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2))) + .distinct() + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1, entity2, entity3, entity1), + assertMatch(entity2, entity3, entity1, entity2)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + + scoreDirector.beforeVariableChanged(entity2, "value"); + entity2.setValue(value3); + scoreDirector.afterVariableChanged(entity2, "value"); + assertScore(scoreDirector, + assertMatch(entity1, entity3, entity2, entity1), + assertMatch(entity3, entity2, entity1, entity3)); + } + + @Override + @TestTemplate + public void concatAndDistinctWithValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1)) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1))) + .distinct() + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1, entity2, entity3, entity1)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + + scoreDirector.beforeVariableChanged(entity2, "value"); + entity2.setValue(value3); + scoreDirector.afterVariableChanged(entity2, "value"); + assertScore(scoreDirector, + assertMatch(entity1, entity3, entity2, entity1)); + } + + @Override + @TestTemplate + public void concatAfterGroupBy() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .join(TestdataLavishEntity.class) + .join(TestdataLavishEntity.class) + .join(TestdataLavishEntity.class) + .filter((e1, e2, e3, e4) -> e1.getValue() == value1 && e2.getValue() == value2 + && e3.getValue() == value3 && e4.getValue() == value1) + .groupBy((e1, e2, e3, e4) -> e1.getValue(), + (e1, e2, e3, e4) -> e2.getValue(), + (e1, e2, e3, e4) -> e3.getValue().getCode() + e4.getValue().getCode(), + ConstraintCollectors.countQuad()) + .concat(factory.forEach(TestdataLavishEntity.class) + .join(TestdataLavishEntity.class) + .join(TestdataLavishEntity.class) + .join(TestdataLavishEntity.class) + .filter((e1, e2, e3, e4) -> e1.getValue() == value3 && e2.getValue() == value2 + && e3.getValue() == value1 && e4.getValue() == value3) + .groupBy((e1, e2, e3, e4) -> e1.getValue(), + (e1, e2, e3, e4) -> e2.getValue(), + (e1, e2, e3, e4) -> e3.getValue().getCode() + e4.getValue().getCode(), + ConstraintCollectors.countQuad())) + .penalize(SimpleScore.ONE, (v1, v2, v3, count) -> count) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatchWithScore(-1, value1, value2, value3.getCode() + value1.getCode(), 1), + assertMatchWithScore(-1, value3, value2, value1.getCode() + value3.getCode(), 1)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector); + + // Incremental for which the first change matches a join that doesn't survive the second change + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(value3); + scoreDirector.afterVariableChanged(entity1, "value"); + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value1); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatchWithScore(-1, value1, value2, value3.getCode() + value1.getCode(), 1), + assertMatchWithScore(-1, value3, value2, value1.getCode() + value3.getCode(), 1)); + } + // ************************************************************************ // Penalize/reward // ************************************************************************ diff --git a/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/tri/AbstractTriConstraintStreamTest.java b/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/tri/AbstractTriConstraintStreamTest.java index a9fa5280d3..354ea0b20e 100644 --- a/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/tri/AbstractTriConstraintStreamTest.java +++ b/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/tri/AbstractTriConstraintStreamTest.java @@ -1745,6 +1745,266 @@ public void flattenLastAndDistinctWithoutDuplicates() { assertScore(scoreDirector); } + @Override + @TestTemplate + public void concatWithoutValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3)) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1))) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1, entity2, entity3), + assertMatch(entity2, entity3, entity1)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + + scoreDirector.beforeVariableChanged(entity2, "value"); + entity2.setValue(value3); + scoreDirector.afterVariableChanged(entity2, "value"); + assertScore(scoreDirector, + assertMatch(entity1, entity3, entity2), + assertMatch(entity1, entity3, entity2)); + } + + @Override + @TestTemplate + public void concatWithValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3)) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3))) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1, entity2, entity3), + assertMatch(entity1, entity2, entity3)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + + scoreDirector.beforeVariableChanged(entity2, "value"); + entity2.setValue(value3); + scoreDirector.afterVariableChanged(entity2, "value"); + assertScore(scoreDirector, + assertMatch(entity1, entity3, entity2), + assertMatch(entity1, entity3, entity2)); + } + + @Override + @TestTemplate + public void concatAndDistinctWithoutValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3)) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1))) + .distinct() + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1, entity2, entity3), + assertMatch(entity2, entity3, entity1)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + + scoreDirector.beforeVariableChanged(entity2, "value"); + entity2.setValue(value3); + scoreDirector.afterVariableChanged(entity2, "value"); + assertScore(scoreDirector, + assertMatch(entity1, entity3, entity2), + assertMatch(entity1, entity3, entity2)); + } + + @Override + @TestTemplate + public void concatAndDistinctWithValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3)) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .join(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value3))) + .distinct() + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1, entity2, entity3)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + + scoreDirector.beforeVariableChanged(entity2, "value"); + entity2.setValue(value3); + scoreDirector.afterVariableChanged(entity2, "value"); + assertScore(scoreDirector, + assertMatch(entity1, entity3, entity2)); + } + + @Override + @TestTemplate + public void concatAfterGroupBy() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .join(TestdataLavishEntity.class) + .join(TestdataLavishEntity.class) + .filter((e1, e2, e3) -> e1.getValue() == value1 && e2.getValue() == value2 && e3.getValue() == value3) + .groupBy((e1, e2, e3) -> e1.getValue(), + (e1, e2, e3) -> e2.getValue(), + (e1, e2, e3) -> e3.getValue(), + ConstraintCollectors.countTri()) + .concat(factory.forEach(TestdataLavishEntity.class) + .join(TestdataLavishEntity.class) + .join(TestdataLavishEntity.class) + .filter((e1, e2, e3) -> e1.getValue() == value3 && e2.getValue() == value2 + && e3.getValue() == value1) + .groupBy((e1, e2, e3) -> e1.getValue(), + (e1, e2, e3) -> e2.getValue(), + (e1, e2, e3) -> e3.getValue(), + ConstraintCollectors.countTri())) + .penalize(SimpleScore.ONE, (v1, v2, v3, count) -> count) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatchWithScore(-1, value1, value2, value3, 1), + assertMatchWithScore(-1, value3, value2, value1, 1)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector); + + // Incremental for which the first change matches a join that doesn't survive the second change + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(value3); + scoreDirector.afterVariableChanged(entity1, "value"); + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value1); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatchWithScore(-1, value1, value2, value3, 1), + assertMatchWithScore(-1, value3, value2, value1, 1)); + } + // ************************************************************************ // Penalize/reward // ************************************************************************ diff --git a/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/uni/AbstractUniConstraintStreamTest.java b/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/uni/AbstractUniConstraintStreamTest.java index 45aa51b546..0913c21e0f 100644 --- a/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/uni/AbstractUniConstraintStreamTest.java +++ b/core/constraint-streams/src/test/java/ai/timefold/solver/constraint/streams/common/uni/AbstractUniConstraintStreamTest.java @@ -1687,7 +1687,7 @@ public void groupBy_4Mapping0Collector() { } // ************************************************************************ - // Map/flatten/distinct + // Map/flatten/distinct/concat // ************************************************************************ @Override @@ -2149,6 +2149,367 @@ public void flattenLastAndDistinctWithoutDuplicates() { assertMatch(group2)); } + @Override + @TestTemplate + public void concatWithoutValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2), + assertMatch(entity3)); + + // Incremental for which the first change matches a join that doesn't survive the second change + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(value3); + scoreDirector.afterVariableChanged(entity1, "value"); + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value1); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity2), + assertMatch(entity3)); + } + + @Override + @TestTemplate + public void concatWithValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1 || entity.getValue() == value3) + .concat(factory.forEach(TestdataLavishEntity.class) + .map(Function.identity()) // This map make the tuples not reference equal + .filter(entity -> entity.getValue() == value2 || entity.getValue() == value3)) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2), + assertMatch(entity3), + assertMatch(entity3)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2), + assertMatch(entity3)); + + // Incremental for which the first change matches a join that doesn't survive the second change + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(value3); + scoreDirector.afterVariableChanged(entity1, "value"); + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value1); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity1), + assertMatch(entity2), + assertMatch(entity3)); + } + + @TestTemplate + public void concatWithReferenceDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1 || entity.getValue() == value3) + .concat(factory.forEach(TestdataLavishEntity.class) + // The tuples are reference equal since filter is not a TupleSource + .filter(entity -> entity.getValue() == value2 || entity.getValue() == value3)) + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2), + assertMatch(entity3), + assertMatch(entity3)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2), + assertMatch(entity3)); + + // Incremental for which the first change matches a join that doesn't survive the second change + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(value3); + scoreDirector.afterVariableChanged(entity1, "value"); + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value1); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity1), + assertMatch(entity2), + assertMatch(entity3)); + } + + @TestTemplate + public void concatWithReferenceDuplicatesGroupBy() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1 || entity.getValue() == value3) + .concat(factory.forEach(TestdataLavishEntity.class) + // The tuples are reference equal since filter is not a TupleSource + .filter(entity -> entity.getValue() == value2 || entity.getValue() == value3)) + .groupBy(ConstraintCollectors.count()) + .penalize(SimpleScore.ONE, count -> count) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatchWithScore(-4, 4)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatchWithScore(-3, 3)); + + // Incremental for which the first change matches a join that doesn't survive the second change + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(value3); + scoreDirector.afterVariableChanged(entity1, "value"); + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value1); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatchWithScore(-4, 4)); + } + + @Override + @TestTemplate + public void concatAndDistinctWithoutValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2)) + .distinct() + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2), + assertMatch(entity3)); + + // Incremental for which the first change matches a join that doesn't survive the second change + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(value3); + scoreDirector.afterVariableChanged(entity1, "value"); + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value1); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity2), + assertMatch(entity3)); + } + + @Override + @TestTemplate + public void concatAndDistinctWithValueDuplicates() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1 || entity.getValue() == value3) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2 || entity.getValue() == value3)) + .distinct() + .penalize(SimpleScore.ONE) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2), + assertMatch(entity3)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2), + assertMatch(entity3)); + + // Incremental for which the first change matches a join that doesn't survive the second change + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(value3); + scoreDirector.afterVariableChanged(entity1, "value"); + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value1); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatch(entity1), + assertMatch(entity2), + assertMatch(entity3)); + } + + @Override + @TestTemplate + public void concatAfterGroupBy() { + TestdataLavishSolution solution = TestdataLavishSolution.generateSolution(2, 5, 1, 1); + TestdataLavishValue value1 = solution.getFirstValue(); + TestdataLavishValue value2 = new TestdataLavishValue("MyValue 2", solution.getFirstValueGroup()); + TestdataLavishValue value3 = new TestdataLavishValue("MyValue 3", solution.getFirstValueGroup()); + TestdataLavishEntity entity1 = solution.getFirstEntity(); + TestdataLavishEntity entity2 = new TestdataLavishEntity("MyEntity 2", solution.getFirstEntityGroup(), + value2); + solution.getEntityList().add(entity2); + TestdataLavishEntity entity3 = new TestdataLavishEntity("MyEntity 3", solution.getFirstEntityGroup(), + value3); + solution.getEntityList().add(entity3); + + InnerScoreDirector scoreDirector = + buildScoreDirector(factory -> factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value1) + .groupBy(TestdataLavishEntity::getValue, ConstraintCollectors.count()) + .concat(factory.forEach(TestdataLavishEntity.class) + .filter(entity -> entity.getValue() == value2) + .groupBy(TestdataLavishEntity::getValue, ConstraintCollectors.count())) + .penalize(SimpleScore.ONE, (value, count) -> count) + .asConstraint(TEST_CONSTRAINT_NAME)); + + // From scratch + scoreDirector.setWorkingSolution(solution); + assertScore(scoreDirector, + assertMatchWithScore(-1, value1, 1), + assertMatchWithScore(-1, value2, 1)); + + // Incremental + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value2); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatchWithScore(-1, value1, 1), + assertMatchWithScore(-2, value2, 2)); + + // Incremental for which the first change matches a join that doesn't survive the second change + scoreDirector.beforeVariableChanged(entity1, "value"); + entity1.setValue(value3); + scoreDirector.afterVariableChanged(entity1, "value"); + scoreDirector.beforeVariableChanged(entity3, "value"); + entity3.setValue(value1); + scoreDirector.afterVariableChanged(entity3, "value"); + assertScore(scoreDirector, + assertMatchWithScore(-1, value1, 1), + assertMatchWithScore(-1, value2, 1)); + } + // ************************************************************************ // Penalize/reward // ************************************************************************ diff --git a/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/bi/BiConstraintStream.java b/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/bi/BiConstraintStream.java index f88d1aa245..e9640a4cd1 100644 --- a/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/bi/BiConstraintStream.java +++ b/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/bi/BiConstraintStream.java @@ -1055,6 +1055,21 @@ QuadConstraintStream distinct(); + /** + * Returns a new {@link BiConstraintStream} containing all the tuples of both this {@link BiConstraintStream} and the + * provided {@link BiConstraintStream}. Tuples in both this {@link BiConstraintStream} and the provided + * {@link BiConstraintStream} will appear at least twice. + * + *

+ * For instance, if this stream consists of {@code [(A, 1), (B, 2), (C, 3)]} and the other stream consists of + * {@code [(C, 3), (D, 4), (E, 5)]}, {@code this.concat(other)} will consist of + * {@code [(A, 1), (B, 2), (C, 3), (C, 3), (D, 4), (E, 5)]}. This operation can be thought of as an or between streams. + * + * @param otherStream + * @return + */ + BiConstraintStream concat(BiConstraintStream otherStream); + // ************************************************************************ // Other operations // ************************************************************************ diff --git a/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/quad/QuadConstraintStream.java b/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/quad/QuadConstraintStream.java index ca8d19e2da..65d16b701f 100644 --- a/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/quad/QuadConstraintStream.java +++ b/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/quad/QuadConstraintStream.java @@ -837,6 +837,22 @@ QuadConstraintStream distinct(); + /** + * Returns a new {@link QuadConstraintStream} containing all the tuples of both this {@link QuadConstraintStream} and the + * provided {@link QuadConstraintStream}. Tuples in both this {@link QuadConstraintStream} and the provided + * {@link QuadConstraintStream} will appear at least twice. + * + *

+ * For instance, if this stream consists of {@code [(A, 1, -1, a), (B, 2, -2, b), (C, 3, -3, c)]} and the other stream + * consists of {@code [(C, 3, -3, c), (D, 4, -4, d), (E, 5, -5, e)]}, {@code this.concat(other)} will consist of + * {@code [(A, 1, -1, a), (B, 2, -2, b), (C, 3, -3, c), (C, 3, -3, c), (D, 4, -4, d), (E, 5, -5,e)]}. This operation can be + * thought of as an or between streams. + * + * @param otherStream + * @return + */ + QuadConstraintStream concat(QuadConstraintStream otherStream); + // ************************************************************************ // Penalize/reward // ************************************************************************ diff --git a/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/tri/TriConstraintStream.java b/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/tri/TriConstraintStream.java index 0c11950615..b1b152dee8 100644 --- a/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/tri/TriConstraintStream.java +++ b/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/tri/TriConstraintStream.java @@ -1049,6 +1049,22 @@ QuadConstraintStream distinct(); + /** + * Returns a new {@link TriConstraintStream} containing all the tuples of both this {@link TriConstraintStream} and the + * provided {@link TriConstraintStream}. Tuples in both this {@link TriConstraintStream} and the provided + * {@link TriConstraintStream} will appear at least twice. + * + *

+ * For instance, if this stream consists of {@code [(A, 1, -1), (B, 2, -2), (C, 3, -3)]} and the other stream consists of + * {@code [(C, 3, -3), (D, 4, -4), (E, 5, -5)]}, {@code this.concat(other)} will consist of + * {@code [(A, 1, -1), (B, 2, -2), (C, 3, -3), (C, 3, -3), (D, 4, -4), (E, 5, -5)]}. This operation can be thought of as an + * or between streams. + * + * @param otherStream + * @return + */ + TriConstraintStream concat(TriConstraintStream otherStream); + // ************************************************************************ // Other operations // ************************************************************************ diff --git a/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/uni/UniConstraintStream.java b/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/uni/UniConstraintStream.java index b22f8067c9..e3627d7663 100644 --- a/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/uni/UniConstraintStream.java +++ b/core/core-impl/src/main/java/ai/timefold/solver/core/api/score/stream/uni/UniConstraintStream.java @@ -1482,6 +1482,21 @@ QuadConstraintStream distinct(); + /** + * Returns a new {@link UniConstraintStream} containing all the tuples of both this {@link UniConstraintStream} and the + * provided {@link UniConstraintStream}. Tuples in both this {@link UniConstraintStream} and the provided + * {@link UniConstraintStream} will appear at least twice. + * + *

+ * For instance, if this stream consists of {@code [A, B, C]} and the other stream consists of {@code [C, D, E]}, + * {@code this.concat(other)} will consist of {@code [A, B, C, C, D, E]}. This operation can be thought of as an or between + * streams. + * + * @param otherStream + * @return + */ + UniConstraintStream concat(UniConstraintStream otherStream); + // ************************************************************************ // Other operations // ************************************************************************ diff --git a/docs/src/modules/ROOT/images/constraints-and-score/constraintStreamConcat.png b/docs/src/modules/ROOT/images/constraints-and-score/constraintStreamConcat.png new file mode 100644 index 0000000000..c4ef1d2d72 Binary files /dev/null and b/docs/src/modules/ROOT/images/constraints-and-score/constraintStreamConcat.png differ diff --git a/docs/src/modules/ROOT/images/constraints-and-score/constraintStreamConcat.svg b/docs/src/modules/ROOT/images/constraints-and-score/constraintStreamConcat.svg new file mode 100644 index 0000000000..3af70fba93 --- /dev/null +++ b/docs/src/modules/ROOT/images/constraints-and-score/constraintStreamConcat.svg @@ -0,0 +1,1114 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + inkscape:perspective sodipodi:type="inkscape:persp3d" inkscape:vp_x="0 : 526.18109 : 1" inkscape:vp_y="0 : 1000 : 0" inkscape:vp_z="744.09448 : 526.18109 : 1" inkscape:persp3d-origin="372.04724 : 350.78739 : 1" id="perspective10" /> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + en + + + + + + + Constraint Streams: concat + Similar to an SQL UNION ALL. If tuples are repeated, they appear twice. + forEach(Employee.class) + forEach(Employee.class) + + + + + .filter(employee -> employee.hasDog()) + .filter(employee -> employee.hasCat()) + + + + + + + Ann + + + + Beth + + + + Carl + + + + Ann + + + + Carl + + + + Beth + + + + Carl + + + + Ann + + + + Carl + + + + Beth + + + + Carl + + + + + + + + .concat() + ... + + + + + + + + diff --git a/docs/src/modules/ROOT/pages/constraints-and-score/constraints-and-score.adoc b/docs/src/modules/ROOT/pages/constraints-and-score/constraints-and-score.adoc index a0afb8841d..5b2cbfb6f9 100644 --- a/docs/src/modules/ROOT/pages/constraints-and-score/constraints-and-score.adoc +++ b/docs/src/modules/ROOT/pages/constraints-and-score/constraints-and-score.adoc @@ -1511,6 +1511,57 @@ the tuple `(SomePerson, USER)` is sent downstream twice. See <> for how to deal with duplicate tuples. ==== +[[constraintStreamsConcat]] +==== Concat + +The `concat` building block allows you to create a constraint stream containing tuples of two constraint streams of the same <>. +If <> acts like a cartesian product of two lists, `concat` acts like a concatenation of two lists. +Unlike union of sets, concatenation of lists repeats duplicated elements. +If the two constraint stream parents share tuples, which happens eg. when the streams being concatenated come from the same source of data, the tuples will be repeated downstream. +If this is undesired, use the <>. + +image::constraints-and-score/constraintStreamConcat.png[align="center"] + +For example, to ensure each employee has a minimum number of assigned shifts: + +[source,java,options="nowrap"] +---- + private Constraint ensureEachEmployeeHasAtLeastTwoShifts(ConstraintFactory constraintFactory) { + return constraintFactory.forEach(Employee.class) + .join(Shift.class, equal(Function.identity(), Shift::getEmployee)) + .concat( + constraintFactory.forEach(Employee.class) + .ifNotExists(Shift.class, equal(Function.identity(), Shift::getEmployee)) + .expand(employee -> (Shift) null) + ) + .groupBy((employee, shift) -> employee, + conditionally((employee, shift) -> shift != null, + countBi()) + ) + .filter((employee, shiftCount) -> shiftCount < employee.minimumAssignedShifts) + .penalize(HardSoftScore.ONE_SOFT, (employee, shiftCount) -> employee.minimumAssignedShifts - shiftCount) + .asConstraint("Minimum number of assigned shifts"); + } +---- + +This correctly counts the number of shifts each Employee has, *even when the Employee has no shifts*. +If it was implemented without `concat` like this: + +[source,java,options="nowrap"] +---- + private Constraint incorrectEnsureEachEmployeeHasAtLeastTwoShifts(ConstraintFactory constraintFactory) { + return constraintFactory.forEach(Employee.class) + .join(Shift.class, equal(Function.identity(), Shift::getEmployee)) + .groupBy((employee, shift) -> employee, + countBi()) + ) + .filter((employee, shiftCount) -> shiftCount < employee.minimumAssignedShifts) + .penalize(HardSoftScore.ONE_SOFT, (employee, shiftCount) -> employee.minimumAssignedShifts - shiftCount) + .asConstraint("Minimum number of assigned shifts (incorrect)"); + } +---- + +An employee with no assigned shifts _will not be penalized because no tuples were passed to the `groupBy` building block_. [[constraintStreamsTesting]] === Testing a constraint stream @@ -2525,4 +2576,4 @@ The sum of all the `Indictment.getScoreTotal()` differs from the overall score, <> support constraint matches automatically, but <> requires <>. -==== \ No newline at end of file +==== diff --git a/examples/src/main/java/ai/timefold/solver/examples/nurserostering/score/NurseRosteringConstraintProvider.java b/examples/src/main/java/ai/timefold/solver/examples/nurserostering/score/NurseRosteringConstraintProvider.java index 7cc1305939..d2d5898654 100644 --- a/examples/src/main/java/ai/timefold/solver/examples/nurserostering/score/NurseRosteringConstraintProvider.java +++ b/examples/src/main/java/ai/timefold/solver/examples/nurserostering/score/NurseRosteringConstraintProvider.java @@ -14,6 +14,9 @@ import ai.timefold.solver.core.api.score.stream.ConstraintFactory; import ai.timefold.solver.core.api.score.stream.ConstraintProvider; import ai.timefold.solver.core.api.score.stream.Joiners; +import ai.timefold.solver.core.api.score.stream.bi.BiConstraintStream; +import ai.timefold.solver.core.api.score.stream.tri.TriConstraintStream; +import ai.timefold.solver.core.api.score.stream.tri.TriJoiner; import ai.timefold.solver.examples.common.experimental.ExperimentalConstraintCollectors; import ai.timefold.solver.examples.common.experimental.api.ConsecutiveInfo; import ai.timefold.solver.examples.common.util.Pair; @@ -43,7 +46,6 @@ public Constraint[] defineConstraints(ConstraintFactory constraintFactory) { return new Constraint[] { oneShiftPerDay(constraintFactory), minimumAndMaximumNumberOfAssignments(constraintFactory), - minimumNumberOfAssignmentsNoAssignments(constraintFactory), consecutiveWorkingDays(constraintFactory), consecutiveFreeDays(constraintFactory), maximumConsecutiveFreeDaysNoAssignments(constraintFactory), @@ -78,16 +80,27 @@ Constraint oneShiftPerDay(ConstraintFactory constraintFactory) { // ############################################################################ // Soft constraints // ############################################################################ + @SafeVarargs + private static TriConstraintStream outerJoin(BiConstraintStream source, + Class joinedClass, TriJoiner... joiners) { + return source.join(joinedClass, joiners).concat( + source.ifNotExists(joinedClass, joiners).expand((ingoredA, ingoredB) -> null)); + } + Constraint minimumAndMaximumNumberOfAssignments(ConstraintFactory constraintFactory) { - return constraintFactory.forEach(MinMaxContractLine.class) - .filter(minMaxContractLine -> minMaxContractLine.getContractLineType() == ContractLineType.TOTAL_ASSIGNMENTS && - minMaxContractLine.isEnabled()) - .join(constraintFactory.forEach(ShiftAssignment.class) - .filter(shift -> shift.getEmployee() != null), - Joiners.equal(ContractLine::getContract, ShiftAssignment::getContract)) - .groupBy((line, shift) -> shift.getEmployee(), - (line, shift) -> line, - ConstraintCollectors.countBi()) + return outerJoin( + constraintFactory + .forEach(MinMaxContractLine.class) + .filter(minMaxContractLine -> minMaxContractLine + .getContractLineType() == ContractLineType.TOTAL_ASSIGNMENTS && minMaxContractLine.isEnabled()) + .join(Employee.class, Joiners.equal(ContractLine::getContract, Employee::getContract)), + ShiftAssignment.class, + Joiners.equal((contractLine, employee) -> employee, ShiftAssignment::getEmployee)) + .groupBy((line, employee, shift) -> employee, + (line, employee, shift) -> line, + ConstraintCollectors.conditionally( + (line, employee, shift) -> shift != null, + ConstraintCollectors.countTri())) .map((employee, contract, shiftCount) -> employee, (employee, contract, shiftCount) -> contract, (employee, contract, shiftCount) -> contract.getViolationAmount(shiftCount)) @@ -97,21 +110,6 @@ Constraint minimumAndMaximumNumberOfAssignments(ConstraintFactory constraintFact .asConstraint("Minimum and maximum number of assignments"); } - Constraint minimumNumberOfAssignmentsNoAssignments(ConstraintFactory constraintFactory) { - return constraintFactory.forEach(MinMaxContractLine.class) - .filter(minMaxContractLine -> minMaxContractLine.getContractLineType() == ContractLineType.TOTAL_ASSIGNMENTS && - minMaxContractLine.isEnabled()) - .join(Employee.class, - Joiners.equal(MinMaxContractLine::getContract, Employee::getContract)) - .ifNotExists(ShiftAssignment.class, - Joiners.equal((contractLine, employee) -> employee, ShiftAssignment::getEmployee)) - .expand((contract, employee) -> contract.getViolationAmount(0)) - .filter((contract, employee, violationAmount) -> violationAmount != 0) - .penalize(HardSoftScore.ONE_SOFT, (contract, employee, violationAmount) -> violationAmount) - .indictWith((contract, employee, violationAmount) -> Arrays.asList(employee, contract)) - .asConstraint("Minimum and maximum number of assignments (no assignments)"); - } - // Min/Max consecutive working days Constraint consecutiveWorkingDays(ConstraintFactory constraintFactory) { return constraintFactory.forEach(MinMaxContractLine.class) diff --git a/examples/src/test/java/ai/timefold/solver/examples/nurserostering/score/NurseRosteringConstraintProviderTest.java b/examples/src/test/java/ai/timefold/solver/examples/nurserostering/score/NurseRosteringConstraintProviderTest.java index 6ee982dc2c..6bb29c6b3d 100644 --- a/examples/src/test/java/ai/timefold/solver/examples/nurserostering/score/NurseRosteringConstraintProviderTest.java +++ b/examples/src/test/java/ai/timefold/solver/examples/nurserostering/score/NurseRosteringConstraintProviderTest.java @@ -347,6 +347,10 @@ void minimumAndMaximumNumberOfAssignments( minimumAndMaximumNumberOfAssignmentsConstraint.given(contract.getFirstConstractLine(), employee, shift1) .penalizesBy(2); + minimumAndMaximumNumberOfAssignmentsConstraint + .given(contract.getFirstConstractLine(), employee) + .penalizesBy(4); + minimumAndMaximumNumberOfAssignmentsConstraint .given(contract.getFirstConstractLine(), employee, shift1, shift2, shift3) .penalizesBy(0); @@ -355,26 +359,6 @@ void minimumAndMaximumNumberOfAssignments( .penalizesBy(0); } - @ConstraintProviderTest - void minimumNumberOfAssignmentsNoAssignments( - ConstraintVerifier constraintVerifier) { - Contract contract = new MinMaxContractBuilder(ContractLineType.TOTAL_ASSIGNMENTS) - .withMinimum(2) - .withMinimumWeight(5) - .build(); - - Employee employeeNoShifts = getEmployee(contract); - Employee employeeWithShifts = getEmployee(contract); - - ShiftAssignment shift = getShiftAssignment(0, employeeWithShifts); - - constraintVerifier.verifyThat(NurseRosteringConstraintProvider::minimumNumberOfAssignmentsNoAssignments) - .given(contract.getFirstConstractLine(), - employeeNoShifts, employeeWithShifts, - shift) - .penalizesBy(10); - } - @ConstraintProviderTest void consecutiveWorkingDays(ConstraintVerifier constraintVerifier) { Contract contract = new MinMaxContractBuilder(ContractLineType.CONSECUTIVE_WORKING_DAYS)