diff --git a/docs/changelog/99107.yaml b/docs/changelog/99107.yaml new file mode 100644 index 0000000000000..a808fb57fcf80 --- /dev/null +++ b/docs/changelog/99107.yaml @@ -0,0 +1,5 @@ +pr: 99107 +summary: Wait to gracefully stop deployments until alternative allocation exists +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java index 96ac120356283..e92e6e9b99119 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java @@ -183,12 +183,19 @@ public String[] getStartedNodes() { .toArray(String[]::new); } - public List> selectRandomStartedNodesWeighedOnAllocationsForNRequests(int numberOfRequests) { + public boolean hasStartedRoutes() { + return nodeRoutingTable.values().stream().anyMatch(routeInfo -> routeInfo.getState() == RoutingState.STARTED); + } + + public List> selectRandomStartedNodesWeighedOnAllocationsForNRequests( + int numberOfRequests, + RoutingState requiredState + ) { List nodeIds = new ArrayList<>(nodeRoutingTable.size()); List cumulativeAllocations = new ArrayList<>(nodeRoutingTable.size()); int allocationSum = 0; for (Map.Entry routingEntry : nodeRoutingTable.entrySet()) { - if (RoutingState.STARTED.equals(routingEntry.getValue().getState())) { + if (routingEntry.getValue().getState() == requiredState) { nodeIds.add(routingEntry.getKey()); allocationSum += routingEntry.getValue().getCurrentAllocations(); cumulativeAllocations.add(allocationSum); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java index ca777be21b3be..4e6b88d2ff054 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java @@ -27,6 +27,7 @@ import static org.hamcrest.Matchers.arrayContainingInAnyOrder; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; @@ -168,7 +169,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenNoS builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STOPPED, "")); TrainedModelAssignment assignment = builder.build(); - assertThat(assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1).isEmpty(), is(true)); + assertThat(assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STARTED).isEmpty(), is(true)); } public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenSingleStartedNode() { @@ -176,7 +177,28 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenSin builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, "")); TrainedModelAssignment assignment = builder.build(); - var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1); + var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STARTED); + + assertThat(nodes, hasSize(1)); + assertThat(nodes.get(0), equalTo(new Tuple<>("node-1", 1))); + } + + public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenAShuttingDownRoute_ItReturnsNoNodes() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, "")); + TrainedModelAssignment assignment = builder.build(); + + var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STOPPING); + + assertThat(nodes, empty()); + } + + public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenAShuttingDownRoute_ItReturnsNode1() { + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STOPPING, "")); + TrainedModelAssignment assignment = builder.build(); + + var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STOPPING); assertThat(nodes, hasSize(1)); assertThat(nodes.get(0), equalTo(new Tuple<>("node-1", 1))); @@ -188,7 +210,7 @@ public void testSingleRequestWith2Nodes() { builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")); TrainedModelAssignment assignment = builder.build(); - var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1); + var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(1, RoutingState.STARTED); assertThat(nodes, hasSize(1)); assertEquals(nodes.get(0).v2(), Integer.valueOf(1)); } @@ -202,7 +224,7 @@ public void testSelectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMul final int selectionCount = 10000; final CountAccumulator countsPerNodeAccumulator = new CountAccumulator(); - var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount); + var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount, RoutingState.STARTED); assertThat(nodes, hasSize(3)); assertThat(nodes.stream().mapToInt(Tuple::v2).sum(), equalTo(selectionCount)); @@ -223,7 +245,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMul builder.addRoutingEntry("node-3", new RoutingInfo(0, 0, RoutingState.STARTED, "")); TrainedModelAssignment assignment = builder.build(); final int selectionCount = 1000; - var nodeCounts = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount); + var nodeCounts = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(selectionCount, RoutingState.STARTED); assertThat(nodeCounts, hasSize(3)); var selectedNodes = new HashSet(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index b1d799cd33622..d414a013a0e8c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.TrainedModelType; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState; +import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -243,7 +244,13 @@ private void inferAgainstAllocatedModel( // Get a list of nodes to send the requests to and the number of // documents for each node. - var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(request.numberOfDocuments()); + var nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(request.numberOfDocuments(), RoutingState.STARTED); + + // We couldn't find any nodes in the started state so let's look for ones that are stopping in case we're shutting down some nodes + if (nodes.isEmpty()) { + nodes = assignment.selectRandomStartedNodesWeighedOnAllocationsForNRequests(request.numberOfDocuments(), RoutingState.STOPPING); + } + if (nodes.isEmpty()) { logger.trace(() -> format("[%s] model deployment not allocated to any node", assignment.getDeploymentId())); listener.onFailure( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java index 7ecb580d5feac..2caf338d2a3c7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java @@ -653,6 +653,11 @@ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments( ); Set shuttingDownNodeIds = currentState.metadata().nodeShutdowns().getAllNodeIds(); + /* + * To signal that we should gracefully stop the deployments routed to a particular node we set the routing state to stopping. + * The TrainedModelAssignmentNodeService will see that the route is in stopping for a shutting down node and gracefully shut down + * the native process after draining the queues. + */ TrainedModelAssignmentMetadata.Builder rebalanced = setShuttingDownNodeRoutesToStopping( currentMetadata, shuttingDownNodeIds, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java index 44ead5ca35dd3..26adca10b8da8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java @@ -373,7 +373,10 @@ public void clusterChanged(ClusterChangedEvent event) { } } - if (isAssignmentOnShuttingDownNode(routingInfo, trainedModelAssignment.getDeploymentId(), shuttingDownNodes, currentNode)) { + /* + * Check if this is a shutting down node and if we can gracefully shut down the native process after draining its queues + */ + if (shouldGracefullyShutdownDeployment(trainedModelAssignment, shuttingDownNodes, currentNode)) { gracefullyStopDeployment(trainedModelAssignment.getDeploymentId(), currentNode); } } else { @@ -440,15 +443,48 @@ private static StartTrainedModelDeploymentAction.TaskParams createStartTrainedMo ); } - private boolean isAssignmentOnShuttingDownNode( - RoutingInfo routingInfo, - String deploymentId, + private boolean shouldGracefullyShutdownDeployment( + TrainedModelAssignment trainedModelAssignment, Set shuttingDownNodes, String currentNode ) { - return deploymentIdToTask.containsKey(deploymentId) - && routingInfo.getState() == RoutingState.STOPPING - && shuttingDownNodes.contains(currentNode); + RoutingInfo routingInfo = trainedModelAssignment.getNodeRoutingTable().get(currentNode); + + if (routingInfo == null) { + return true; + } + + boolean isCurrentNodeShuttingDown = shuttingDownNodes.contains(currentNode); + boolean isRouteStopping = routingInfo.getState() == RoutingState.STOPPING; + boolean hasDeploymentTask = deploymentIdToTask.containsKey(trainedModelAssignment.getDeploymentId()); + boolean hasStartedRoutes = trainedModelAssignment.hasStartedRoutes(); + boolean assignmentIsRoutedToOneOrFewerNodes = trainedModelAssignment.getNodeRoutingTable().size() <= 1; + + // To avoid spamming the logs we'll only print these if we meet the base criteria + if (isCurrentNodeShuttingDown && isRouteStopping && hasDeploymentTask) { + logger.debug( + () -> format( + "[%s] Checking if deployment can be gracefully shutdown on node %s, " + + "has other started routes: %s, " + + "single or no routed nodes: %s", + trainedModelAssignment.getDeploymentId(), + currentNode, + hasStartedRoutes, + assignmentIsRoutedToOneOrFewerNodes + ) + ); + } + + // the current node is shutting down + return isCurrentNodeShuttingDown + // the route is marked as ready to shut down during a rebalance + && isRouteStopping + // the deployment wasn't already being stopped by a stop deployment API call + && hasDeploymentTask + // the assignment has another allocation that can serve any additional requests or the shutting down node is the only node that + // serves this model (maybe the other available nodes are already full or no other ML nodes exist) in which case we can't wait + // for another node to become available so allow a graceful shutdown + && (hasStartedRoutes || assignmentIsRoutedToOneOrFewerNodes); } private void gracefullyStopDeployment(String deploymentId, String currentNode) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java index 0bd2e716758e4..795f184a49a4d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java @@ -425,6 +425,54 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStop verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); } + public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherAllocationIsNotReady_DoesNotCallStop() { + final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); + String node2 = "test-node-2"; + final DiscoveryNodes nodes = DiscoveryNodes.builder() + .localNodeId(NODE_ID) + .add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)) + .add(DiscoveryNodeUtils.create(node2, node2)) + .build(); + String modelOne = "model-1"; + String deploymentOne = "deployment-1"; + + var taskParams = newParams(deploymentOne, modelOne); + + ClusterChangedEvent event = new ClusterChangedEvent( + "testClusterChanged", + ClusterState.builder(new ClusterName("testClusterChanged")) + .nodes(nodes) + .metadata( + Metadata.builder() + .putCustom( + TrainedModelAssignmentMetadata.NAME, + TrainedModelAssignmentMetadata.Builder.empty() + .addNewAssignment( + deploymentOne, + TrainedModelAssignment.Builder.empty(taskParams) + .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STOPPING, "")) + .addRoutingEntry(node2, new RoutingInfo(1, 1, RoutingState.STARTING, "")) + ) + .build() + ) + .putCustom(NodesShutdownMetadata.TYPE, shutdownMetadata(NODE_ID)) + .build() + ) + .build(), + ClusterState.EMPTY_STATE + ); + + trainedModelAssignmentNodeService.prepareModelToLoad(taskParams); + trainedModelAssignmentNodeService.clusterChanged(event); + + verify(deploymentManager, never()).stopAfterCompletingPendingWork(any()); + verify(trainedModelAssignmentService, never()).updateModelAssignmentState( + any(UpdateTrainedModelAssignmentRoutingInfoAction.Request.class), + any() + ); + verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); + } + public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlreadyRemoved_DoesNotCallStop() { final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build();