Skip to content

Commit

Permalink
fixing ml tests
Browse files Browse the repository at this point in the history
  • Loading branch information
masseyke committed Nov 30, 2023
1 parent 9b698e1 commit 99f2cd2
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -60,9 +61,11 @@ public AnnotationPersister(ResultsPersisterService resultsPersisterService) {
public Tuple<String, Annotation> persistAnnotation(@Nullable String annotationId, Annotation annotation) {
Objects.requireNonNull(annotation);
String jobId = annotation.getJobId();
BulkResponse bulkResponse = bulkPersisterBuilder(jobId).persistAnnotation(annotationId, annotation).executeRequest();
assert bulkResponse.getItems().length == 1;
return Tuple.tuple(bulkResponse.getItems()[0].getId(), annotation);
try (Builder builder = bulkPersisterBuilder(jobId)) {
BulkResponse bulkResponse = builder.persistAnnotation(annotationId, annotation).executeRequest();
assert bulkResponse.getItems().length == 1;
return Tuple.tuple(bulkResponse.getItems()[0].getId(), annotation);
}
}

public Builder bulkPersisterBuilder(String jobId) {
Expand All @@ -73,7 +76,7 @@ public Builder bulkPersisterBuilder(String jobId, Supplier<Boolean> shouldRetry)
return new Builder(jobId, shouldRetry);
}

public class Builder {
public class Builder implements Releasable {

private final String jobId;
private BulkRequest bulkRequest = new BulkRequest(AnnotationIndex.WRITE_ALIAS_NAME);
Expand Down Expand Up @@ -109,8 +112,9 @@ public BulkResponse executeRequest() {
if (bulkRequest.numberOfActions() == 0) {
return null;
}
BulkResponse bulkResponse;
logger.trace("[{}] ES API CALL: bulk request with {} actions", () -> jobId, () -> bulkRequest.numberOfActions());
BulkResponse bulkResponse = resultsPersisterService.bulkIndexWithRetry(
bulkResponse = resultsPersisterService.bulkIndexWithRetry(
bulkRequest,
jobId,
shouldRetry,
Expand All @@ -119,5 +123,10 @@ public BulkResponse executeRequest() {
bulkRequest = new BulkRequest(AnnotationIndex.WRITE_ALIAS_NAME);
return bulkResponse;
}

@Override
public void close() {
bulkRequest.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,20 +209,21 @@ void updateStats() {
if (stats.isEmpty()) {
return;
}
BulkRequest bulkRequest = new BulkRequest();
stats.stream().map(TrainedModelStatsService::buildUpdateRequest).filter(Objects::nonNull).forEach(bulkRequest::add);
bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
if (bulkRequest.requests().isEmpty()) {
return;
}
if (shouldStop()) {
return;
}
String jobPattern = stats.stream().map(InferenceStats::getModelId).collect(Collectors.joining(","));
try {
String jobPattern = "";
try (BulkRequest bulkRequest = new BulkRequest()) {
stats.stream().map(TrainedModelStatsService::buildUpdateRequest).filter(Objects::nonNull).forEach(bulkRequest::add);
bulkRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
if (bulkRequest.requests().isEmpty()) {
return;
}
if (shouldStop()) {
return;
}
jobPattern = stats.stream().map(InferenceStats::getModelId).collect(Collectors.joining(","));

resultsPersisterService.bulkIndexWithRetry(bulkRequest, jobPattern, () -> shouldStop() == false, (msg) -> {});
} catch (ElasticsearchException ex) {
logger.warn(() -> "failed to store stats for [" + jobPattern + "]", ex);
logger.warn("failed to store stats for [" + jobPattern + "]", ex);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,13 @@ private void storeTrainedModelAndDefinition(
wrappedListener.onResponse(true);
}, wrappedListener::onFailure);

executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest.request(), bulkResponseActionListener);
executeAsyncWithOrigin(
client,
ML_ORIGIN,
BulkAction.INSTANCE,
bulkRequest.request(),
ActionListener.releaseAfter(bulkResponseActionListener, bulkRequest)
);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.job.results.Bucket;
Expand All @@ -35,7 +36,7 @@
* <p>
* This class is NOT thread safe.
*/
public class JobRenormalizedResultsPersister {
public class JobRenormalizedResultsPersister implements Releasable {

private static final Logger logger = LogManager.getLogger(JobRenormalizedResultsPersister.class);

Expand Down Expand Up @@ -100,9 +101,13 @@ public void executeRequest() {
logger.trace("[{}] ES API CALL: bulk request with {} actions", jobId, bulkRequest.numberOfActions());

try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
BulkResponse addRecordsResponse = client.bulk(bulkRequest).actionGet();
if (addRecordsResponse.hasFailures()) {
logger.error("[{}] Bulk index of results has errors: {}", jobId, addRecordsResponse.buildFailureMessage());
try {
BulkResponse addRecordsResponse = client.bulk(bulkRequest).actionGet();
if (addRecordsResponse.hasFailures()) {
logger.error("[{}] Bulk index of results has errors: {}", jobId, addRecordsResponse.buildFailureMessage());
}
} finally {
bulkRequest.close();
}
}

Expand All @@ -112,4 +117,9 @@ public void executeRequest() {
BulkRequest getBulkRequest() {
return bulkRequest;
}

@Override
public void close() {
bulkRequest.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -764,12 +764,14 @@ AutodetectCommunicator create(JobTask jobTask, Job job, AutodetectParams autodet
// A TP with no queue, so that we fail immediately if there are no threads available
ExecutorService autodetectExecutorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME);
DataCountsReporter dataCountsReporter = new DataCountsReporter(job, autodetectParams.dataCounts(), jobDataCountsPersister);
ScoresUpdater scoresUpdater = new ScoresUpdater(
job,
jobResultsProvider,
new JobRenormalizedResultsPersister(job.getId(), client),
normalizerFactory
);
JobRenormalizedResultsPersister jobRenormalizedResultsPersister = new JobRenormalizedResultsPersister(job.getId(), client);
ScoresUpdater scoresUpdater;
try {
scoresUpdater = new ScoresUpdater(job, jobResultsProvider, jobRenormalizedResultsPersister, normalizerFactory);
} catch (Exception e) {
jobRenormalizedResultsPersister.close();
throw e;
}
ExecutorService renormalizerExecutorService = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME);
Renormalizer renormalizer = new ShortCircuitingRenormalizer(jobId, scoresUpdater, renormalizerExecutorService);

Expand All @@ -793,12 +795,18 @@ AutodetectCommunicator create(JobTask jobTask, Job job, AutodetectParams autodet
ExecutorService autodetectWorkerExecutor;
try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) {
autodetectWorkerExecutor = createAutodetectExecutorService(autodetectExecutorService);
autodetectExecutorService.submit(processor::process);
autodetectExecutorService.submit(() -> {
try {
processor.process();
} finally {
jobRenormalizedResultsPersister.close();
}
});
} catch (EsRejectedExecutionException e) {
// If submitting the operation to read the results from the process fails we need to close
// the process too, so that other submitted operations to threadpool are stopped.
try {
IOUtils.close(process);
IOUtils.close(process, jobRenormalizedResultsPersister);
} catch (IOException ioe) {
logger.error("Can't close autodetect", ioe);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ void persist(String indexOrAlias, BytesReference bytes) throws IOException {
String msg = "failed indexing updated state docs";
LOGGER.error(() -> format("[%s] %s", jobId, msg), ex);
auditor.error(jobId, msg + " error: " + ex.getMessage());
} finally {
bulkRequest.close();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,28 @@ public void addAndExecuteIfNeeded(IndexRequest indexRequest) {

private void execute() {
if (currentBulkRequest.numberOfActions() > 0) {
LOGGER.debug(
"Executing bulk request; current bytes [{}]; bytes limit [{}]; number of actions [{}]",
currentRamBytes,
bytesLimit,
currentBulkRequest.numberOfActions()
);
executor.accept(currentBulkRequest);
try {
LOGGER.debug(
"Executing bulk request; current bytes [{}]; bytes limit [{}]; number of actions [{}]",
currentRamBytes,
bytesLimit,
currentBulkRequest.numberOfActions()
);
executor.accept(currentBulkRequest);
} finally {
currentBulkRequest.close();
}
currentBulkRequest = new BulkRequest();
currentRamBytes = 0;
}
}

@Override
public void close() {
execute();
try {
execute();
} finally {
currentBulkRequest.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ public void indexWithRetry(
try (XContentBuilder content = object.toXContent(XContentFactory.jsonBuilder(), params)) {
bulkRequest.add(new IndexRequest(indexName).id(id).source(content).setRequireAlias(requireAlias));
}
bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, retryMsgHandler, finalListener);
bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, retryMsgHandler, ActionListener.releaseAfter(finalListener, bulkRequest));
}

public BulkResponse bulkIndexWithRetry(
Expand Down Expand Up @@ -239,7 +239,14 @@ private BulkResponse bulkIndexWithRetry(
);
}
final PlainActionFuture<BulkResponse> getResponseFuture = new PlainActionFuture<>();
bulkIndexWithRetry(bulkRequest, jobId, shouldRetry, retryMsgHandler, actionExecutor, getResponseFuture);
bulkIndexWithRetry(
bulkRequest,
jobId,
shouldRetry,
retryMsgHandler,
actionExecutor,
ActionListener.releaseAfter(getResponseFuture, bulkRequest)
);
return getResponseFuture.actionGet();
}

Expand Down Expand Up @@ -385,7 +392,11 @@ private class BulkRetryableAction extends MlRetryableAction<BulkRequest, BulkRes
}
bulkRequestRewriter.rewriteRequest(bulkResponse);
// Let the listener attempt again with the new bulk request
retryableListener.onFailure(new RecoverableException());
ActionListener<BulkResponse> releasingRetryableListener = ActionListener.releaseAfter(
retryableListener,
bulkRequestRewriter.bulkRequest
);
releasingRetryableListener.onFailure(new RecoverableException());
}, retryableListener::onFailure)),
listener
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,14 @@ public void testPersistMultipleAnnotationsWithBulk() {
.execute(eq(BulkAction.INSTANCE), any(), any());

AnnotationPersister persister = new AnnotationPersister(resultsPersisterService);
persister.bulkPersisterBuilder(JOB_ID)
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.executeRequest();
try (AnnotationPersister.Builder builder = persister.bulkPersisterBuilder(JOB_ID)) {
builder.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.executeRequest();
}

verify(client).execute(eq(BulkAction.INSTANCE), bulkRequestCaptor.capture(), any());

Expand All @@ -157,13 +158,14 @@ public void testPersistMultipleAnnotationsWithBulk_LowBulkLimit() {
.execute(eq(BulkAction.INSTANCE), any(), any());

AnnotationPersister persister = new AnnotationPersister(resultsPersisterService, 2);
persister.bulkPersisterBuilder(JOB_ID)
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.executeRequest();
try (AnnotationPersister.Builder builder = persister.bulkPersisterBuilder(JOB_ID)) {
builder.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.persistAnnotation(AnnotationTests.randomAnnotation(JOB_ID))
.executeRequest();
}

verify(client, times(3)).execute(eq(BulkAction.INSTANCE), bulkRequestCaptor.capture(), any());

Expand All @@ -176,7 +178,9 @@ public void testPersistMultipleAnnotationsWithBulk_LowBulkLimit() {

public void testPersistMultipleAnnotationsWithBulk_EmptyRequest() {
AnnotationPersister persister = new AnnotationPersister(resultsPersisterService);
assertThat(persister.bulkPersisterBuilder(JOB_ID).executeRequest(), is(nullValue()));
try (AnnotationPersister.Builder builder = persister.bulkPersisterBuilder(JOB_ID)) {
assertThat(builder.executeRequest(), is(nullValue()));
}
}

public void testPersistMultipleAnnotationsWithBulk_Failure() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,32 @@ public class JobRenormalizedResultsPersisterTests extends ESTestCase {

public void testUpdateBucket() {
BucketNormalizable bn = createBucketNormalizable();
JobRenormalizedResultsPersister persister = createJobRenormalizedResultsPersister();
persister.updateBucket(bn);
try (JobRenormalizedResultsPersister persister = createJobRenormalizedResultsPersister()) {
persister.updateBucket(bn);

assertEquals(3, persister.getBulkRequest().numberOfActions());
assertEquals("foo-index", persister.getBulkRequest().requests().get(0).index());
assertEquals(3, persister.getBulkRequest().numberOfActions());
assertEquals("foo-index", persister.getBulkRequest().requests().get(0).index());
}
}

public void testExecuteRequestResetsBulkRequest() {
BucketNormalizable bn = createBucketNormalizable();
JobRenormalizedResultsPersister persister = createJobRenormalizedResultsPersister();
persister.updateBucket(bn);
persister.executeRequest();
assertEquals(0, persister.getBulkRequest().numberOfActions());
try (JobRenormalizedResultsPersister persister = createJobRenormalizedResultsPersister()) {
persister.updateBucket(bn);
persister.executeRequest();
assertEquals(0, persister.getBulkRequest().numberOfActions());
}
}

public void testBulkRequestExecutesWhenReachMaxDocs() {
BulkResponse bulkResponse = mock(BulkResponse.class);
Client client = new MockClientBuilder("cluster").bulk(bulkResponse).build();
JobRenormalizedResultsPersister persister = new JobRenormalizedResultsPersister("foo", client);

ModelPlot modelPlot = new ModelPlot("foo", new Date(), 123456, 0);
for (int i = 0; i <= JobRenormalizedResultsPersister.BULK_LIMIT; i++) {
persister.updateResult("bar", "index-foo", modelPlot);
try (JobRenormalizedResultsPersister persister = new JobRenormalizedResultsPersister("foo", client)) {
ModelPlot modelPlot = new ModelPlot("foo", new Date(), 123456, 0);
for (int i = 0; i <= JobRenormalizedResultsPersister.BULK_LIMIT; i++) {
persister.updateResult("bar", "index-foo", modelPlot);
}
}

verify(client, times(1)).bulk(any());
verify(client, times(1)).threadPool();
verifyNoMoreInteractions(client);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ public void testExecuteRequest_ClearsBulkRequest() {

JobResultsPersister.Builder builder = persister.bulkPersisterBuilder(JOB_ID);
builder.persistInfluencers(influencers).executeRequest();
assertEquals(0, builder.getBulkRequest().numberOfActions());
try (BulkRequest bulkRequest = builder.getBulkRequest()) {
assertEquals(0, bulkRequest.numberOfActions());
}
}

public void testBulkRequestExecutesWhenReachMaxDocs() {
Expand Down
Loading

0 comments on commit 99f2cd2

Please sign in to comment.