Skip to content

Commit

Permalink
[ML] Wait to gracefully stop deployments until alternative allocation…
Browse files Browse the repository at this point in the history
… exists (elastic#99107)

* Allowing stopping deployments to handle requests

* Adding some logging

* Update docs/changelog/99107.yaml

* Adding test

* Addressing feedback

* Fixing test
  • Loading branch information
jonathan-buttner authored Oct 12, 2023
1 parent 7d0672d commit 19fb25f
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 15 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/99107.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 99107
summary: Wait to gracefully stop deployments until alternative allocation exists
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,19 @@ public String[] getStartedNodes() {
.toArray(String[]::new);
}

public List<Tuple<String, Integer>> selectRandomStartedNodesWeighedOnAllocationsForNRequests(int numberOfRequests) {
public boolean hasStartedRoutes() {
return nodeRoutingTable.values().stream().anyMatch(routeInfo -> routeInfo.getState() == RoutingState.STARTED);
}

public List<Tuple<String, Integer>> selectRandomStartedNodesWeighedOnAllocationsForNRequests(
int numberOfRequests,
RoutingState requiredState
) {
List<String> nodeIds = new ArrayList<>(nodeRoutingTable.size());
List<Integer> cumulativeAllocations = new ArrayList<>(nodeRoutingTable.size());
int allocationSum = 0;
for (Map.Entry<String, RoutingInfo> routingEntry : nodeRoutingTable.entrySet()) {
if (RoutingState.STARTED.equals(routingEntry.getValue().getState())) {
if (routingEntry.getValue().getState() == requiredState) {
nodeIds.add(routingEntry.getKey());
allocationSum += routingEntry.getValue().getCurrentAllocations();
cumulativeAllocations.add(allocationSum);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import static org.hamcrest.Matchers.arrayContainingInAnyOrder;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasSize;
Expand Down Expand Up @@ -168,15 +169,36 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenNoS
builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STOPPED, ""));
TrainedModelAssignment assignment = builder.build();

assertThat(assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1).isEmpty(), is(true));
assertThat(assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STARTED).isEmpty(), is(true));
}

public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenSingleStartedNode() {
TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5));
builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, ""));
TrainedModelAssignment assignment = builder.build();

var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1);
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STARTED);

assertThat(nodes, hasSize(1));
assertThat(nodes.get(0), equalTo(new Tuple<>("node-1", 1)));
}

public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenAShuttingDownRoute_ItReturnsNoNodes() {
TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5));
builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, ""));
TrainedModelAssignment assignment = builder.build();

var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STOPPING);

assertThat(nodes, empty());
}

public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenAShuttingDownRoute_ItReturnsNode1() {
TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5));
builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STOPPING, ""));
TrainedModelAssignment assignment = builder.build();

var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STOPPING);

assertThat(nodes, hasSize(1));
assertThat(nodes.get(0), equalTo(new Tuple<>("node-1", 1)));
Expand All @@ -188,7 +210,7 @@ public void testSingleRequestWith2Nodes() {
builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, ""));
TrainedModelAssignment assignment = builder.build();

var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1);
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STARTED);
assertThat(nodes, hasSize(1));
assertEquals(nodes.get(0).v2(), Integer.valueOf(1));
}
Expand All @@ -202,7 +224,7 @@ public void testSelectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMul

final int selectionCount = 10000;
final CountAccumulator countsPerNodeAccumulator = new CountAccumulator();
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount);
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount, RoutingState.STARTED);

assertThat(nodes, hasSize(3));
assertThat(nodes.stream().mapToInt(Tuple::v2).sum(), equalTo(selectionCount));
Expand All @@ -223,7 +245,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMul
builder.addRoutingEntry("node-3", new RoutingInfo(0, 0, RoutingState.STARTED, ""));
TrainedModelAssignment assignment = builder.build();
final int selectionCount = 1000;
var nodeCounts = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount);
var nodeCounts = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount, RoutingState.STARTED);
assertThat(nodeCounts, hasSize(3));

var selectedNodes = new HashSet<String>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState;
import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand Down Expand Up @@ -243,7 +244,13 @@ private void inferAgainstAllocatedModel(

// Get a list of nodes to send the requests to and the number of
// documents for each node.
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(request.numberOfDocuments());
var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(request.numberOfDocuments(), RoutingState.STARTED);

// We couldn't find any nodes in the started state so let's look for ones that are stopping in case we're shutting down some nodes
if (nodes.isEmpty()) {
nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(request.numberOfDocuments(), RoutingState.STOPPING);
}

if (nodes.isEmpty()) {
logger.trace(() -> format("[%s] model deployment not allocated to any node", assignment.getDeploymentId()));
listener.onFailure(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,11 @@ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments(
);

Set<String> shuttingDownNodeIds = currentState.metadata().nodeShutdowns().getAllNodeIds();
/*
* To signal that we should gracefully stop the deployments routed to a particular node we set the routing state to stopping.
* The TrainedModelAssignmentNodeService will see that the route is in stopping for a shutting down node and gracefully shut down
* the native process after draining the queues.
*/
TrainedModelAssignmentMetadata.Builder rebalanced = setShuttingDownNodeRoutesToStopping(
currentMetadata,
shuttingDownNodeIds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,10 @@ public void clusterChanged(ClusterChangedEvent event) {
}
}

if (isAssignmentOnShuttingDownNode(routingInfo, trainedModelAssignment.getDeploymentId(), shuttingDownNodes, currentNode)) {
/*
* Check if this is a shutting down node and if we can gracefully shut down the native process after draining its queues
*/
if (shouldGracefullyShutdownDeployment(trainedModelAssignment, shuttingDownNodes, currentNode)) {
gracefullyStopDeployment(trainedModelAssignment.getDeploymentId(), currentNode);
}
} else {
Expand Down Expand Up @@ -440,15 +443,48 @@ private static StartTrainedModelDeploymentAction.TaskParams createStartTrainedMo
);
}

private boolean isAssignmentOnShuttingDownNode(
RoutingInfo routingInfo,
String deploymentId,
private boolean shouldGracefullyShutdownDeployment(
TrainedModelAssignment trainedModelAssignment,
Set<String> shuttingDownNodes,
String currentNode
) {
return deploymentIdToTask.containsKey(deploymentId)
&& routingInfo.getState() == RoutingState.STOPPING
&& shuttingDownNodes.contains(currentNode);
RoutingInfo routingInfo = trainedModelAssignment.getNodeRoutingTable().get(currentNode);

if (routingInfo == null) {
return true;
}

boolean isCurrentNodeShuttingDown = shuttingDownNodes.contains(currentNode);
boolean isRouteStopping = routingInfo.getState() == RoutingState.STOPPING;
boolean hasDeploymentTask = deploymentIdToTask.containsKey(trainedModelAssignment.getDeploymentId());
boolean hasStartedRoutes = trainedModelAssignment.hasStartedRoutes();
boolean assignmentIsRoutedToOneOrFewerNodes = trainedModelAssignment.getNodeRoutingTable().size() <= 1;

// To avoid spamming the logs we'll only print these if we meet the base criteria
if (isCurrentNodeShuttingDown && isRouteStopping && hasDeploymentTask) {
logger.debug(
() -> format(
"[%s] Checking if deployment can be gracefully shutdown on node %s, "
+ "has other started routes: %s, "
+ "single or no routed nodes: %s",
trainedModelAssignment.getDeploymentId(),
currentNode,
hasStartedRoutes,
assignmentIsRoutedToOneOrFewerNodes
)
);
}

// the current node is shutting down
return isCurrentNodeShuttingDown
// the route is marked as ready to shut down during a rebalance
&& isRouteStopping
// the deployment wasn't already being stopped by a stop deployment API call
&& hasDeploymentTask
// the assignment has another allocation that can serve any additional requests or the shutting down node is the only node that
// serves this model (maybe the other available nodes are already full or no other ML nodes exist) in which case we can't wait
// for another node to become available so allow a graceful shutdown
&& (hasStartedRoutes || assignmentIsRoutedToOneOrFewerNodes);
}

private void gracefullyStopDeployment(String deploymentId, String currentNode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,54 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStop
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
}

public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherAllocationIsNotReady_DoesNotCallStop() {
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
String node2 = "test-node-2";
final DiscoveryNodes nodes = DiscoveryNodes.builder()
.localNodeId(NODE_ID)
.add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID))
.add(DiscoveryNodeUtils.create(node2, node2))
.build();
String modelOne = "model-1";
String deploymentOne = "deployment-1";

var taskParams = newParams(deploymentOne, modelOne);

ClusterChangedEvent event = new ClusterChangedEvent(
"testClusterChanged",
ClusterState.builder(new ClusterName("testClusterChanged"))
.nodes(nodes)
.metadata(
Metadata.builder()
.putCustom(
TrainedModelAssignmentMetadata.NAME,
TrainedModelAssignmentMetadata.Builder.empty()
.addNewAssignment(
deploymentOne,
TrainedModelAssignment.Builder.empty(taskParams)
.addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STOPPING, ""))
.addRoutingEntry(node2, new RoutingInfo(1, 1, RoutingState.STARTING, ""))
)
.build()
)
.putCustom(NodesShutdownMetadata.TYPE, shutdownMetadata(NODE_ID))
.build()
)
.build(),
ClusterState.EMPTY_STATE
);

trainedModelAssignmentNodeService.prepareModelToLoad(taskParams);
trainedModelAssignmentNodeService.clusterChanged(event);

verify(deploymentManager, never()).stopAfterCompletingPendingWork(any());
verify(trainedModelAssignmentService, never()).updateModelAssignmentState(
any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class),
any()
);
verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService);
}

public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlreadyRemoved_DoesNotCallStop() {
final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService();
final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();
Expand Down

0 comments on commit 19fb25f

Please sign in to comment.