Skip to content

Commit

Permalink
scheduler service leaves current transaction before executing task
Browse files Browse the repository at this point in the history
  • Loading branch information
sterlp committed Jan 8, 2025
1 parent 3afd730 commit 57f0399
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;

import org.springframework.lang.NonNull;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.event.TransactionPhase;
import org.springframework.transaction.event.TransactionalEventListener;
import org.springframework.transaction.support.TransactionTemplate;
import org.sterl.spring.persistent_tasks.api.AddTriggerRequest;
import org.sterl.spring.persistent_tasks.api.TriggerKey;
import org.sterl.spring.persistent_tasks.scheduler.component.EditSchedulerStatusComponent;
import org.sterl.spring.persistent_tasks.scheduler.component.TaskExecutorComponent;
import org.sterl.spring.persistent_tasks.scheduler.entity.SchedulerEntity;
import org.sterl.spring.persistent_tasks.shared.model.TriggerStatus;
import org.sterl.spring.persistent_tasks.trigger.TriggerService;
import org.sterl.spring.persistent_tasks.trigger.event.TriggerAddedEvent;
import org.sterl.spring.persistent_tasks.trigger.model.TriggerEntity;

import jakarta.annotation.PostConstruct;
Expand Down Expand Up @@ -100,6 +101,9 @@ public List<Future<TriggerKey>> triggerNextTasks() {
/**
* Like {@link #triggerNextTasks()} but allows to set the time e.g. to the future to trigger
* tasks which wouldn't be triggered now.
* <p>
* This method should not be called in a transaction!
* </p>
*/
@NonNull
public List<Future<TriggerKey>> triggerNextTasks(OffsetDateTime timeDue) {
Expand All @@ -123,29 +127,30 @@ public List<Future<TriggerKey>> triggerNextTasks(OffsetDateTime timeDue) {
* and the runAt time is not in the future.
* @return the reference to the {@link Future} with the key, if no threads are available it is resolved
*/
public <T extends Serializable> Future<TriggerKey> runOrQueue(
AddTriggerRequest<T> triggerRequest) {
final var runningTrigger = trx.execute(t -> {
var trigger = triggerService.queue(triggerRequest);
// exit now if this trigger is for the future ...
if (trigger.shouldRunInFuture()) return trigger;

@Transactional(timeout = 10)
public <T extends Serializable> TriggerKey runOrQueue(
AddTriggerRequest<T> triggerRequest) {
var trigger = triggerService.queue(triggerRequest);

if (!trigger.shouldRunInFuture()) {
if (taskExecutor.getFreeThreads() > 0) {
trigger = triggerService.markTriggersAsRunning(trigger, name);
pingRegistry().addRunning(1);
} else {
log.debug("Currently not enough free thread available {} of {} in use. PersistentTask {} queued.",
taskExecutor.getFreeThreads(), taskExecutor.getMaxThreads(), trigger.getKey());
}
return trigger;
});
Future<TriggerKey> result;
if (runningTrigger.isRunning()) {
result = taskExecutor.submit(runningTrigger);
} else {
result = CompletableFuture.completedFuture(runningTrigger.getKey());
}
return result;
// we will listen for the commit event to execute this trigger ...
return trigger.getKey();
}

@TransactionalEventListener(phase = TransactionPhase.AFTER_COMMIT)
void checkIfTrigerIsRunning(TriggerAddedEvent addedTrigger) {
if (addedTrigger.isRunningOn(name) && !taskExecutor.isRunning(addedTrigger.trigger())) {
log.debug("New triger added for imidiate execution {}", addedTrigger.key());
taskExecutor.submit(addedTrigger.trigger());
}
}

public SchedulerEntity getStatus() {
Expand All @@ -166,16 +171,4 @@ public List<TriggerEntity> rescheduleAbandonedTasks(OffsetDateTime timeout) {
running, runningKeys, schedulers);
return triggerService.rescheduleAbandonedTasks(timeout);
}

/**
* Adds or updates an existing trigger based on its {@link TriggerKey}
*
* @param <T> the state type
* @param trigger the {@link AddTriggerRequest} to save
* @return the saved {@link TriggerEntity}
* @throws IllegalStateException if the trigger already exists and is {@link TriggerStatus#RUNNING}
*/
public <T extends Serializable> TriggerEntity queue(AddTriggerRequest<T> trigger) {
return triggerService.queue(trigger);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,8 @@ public void setMaxThreads(int value) {
public int getMaxThreads() {
return this.maxThreads.get();
}

public boolean isRunning(TriggerEntity trigger) {
return runningTasks.contains(trigger);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ public <T extends Serializable> TriggerEntity addTrigger(AddTriggerRequest<T> ti
} else {
result = triggerRepository.save(result);
log.debug("Added trigger={}", result);
publisher.publishEvent(new TriggerAddedEvent(result, tigger.state()));
}
publisher.publishEvent(new TriggerAddedEvent(result, tigger.state()));
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@

public record TriggerAddedEvent(TriggerEntity trigger, Serializable state) implements TriggerLifeCycleEvent {

public boolean isRunningOn(String name) {
return trigger.isRunning() && name != null && name.equals(trigger.getRunningOn());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.net.UnknownHostException;
import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.TimeoutException;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
Expand All @@ -24,6 +25,7 @@
import org.sterl.spring.persistent_tasks.scheduler.SchedulerService;
import org.sterl.spring.persistent_tasks.scheduler.component.EditSchedulerStatusComponent;
import org.sterl.spring.persistent_tasks.scheduler.component.TaskExecutorComponent;
import org.sterl.spring.persistent_tasks.shared.model.TriggerStatus;
import org.sterl.spring.persistent_tasks.task.TaskService;
import org.sterl.spring.persistent_tasks.trigger.TriggerService;
import org.sterl.spring.persistent_tasks.trigger.model.TriggerEntity;
Expand Down Expand Up @@ -160,6 +162,16 @@ PersistentTask<Long> slowTask(AsyncAsserts asserts) {
protected Optional<TriggerEntity> runNextTrigger() {
return triggerService.run(triggerService.lockNextTrigger("test"));
}

protected void awaitRunningTasks() throws TimeoutException, InterruptedException {
final long start = System.currentTimeMillis();
if (triggerService.countTriggers(TriggerStatus.RUNNING) > 0) {
if (System.currentTimeMillis() - start > 2000) {
throw new TimeoutException("Still running after 2s");
}
Thread.sleep(50);
}
}

@BeforeEach
public void beforeEach() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.time.Duration;
import java.time.OffsetDateTime;
import java.util.concurrent.Future;
import java.util.concurrent.TimeoutException;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -94,16 +95,18 @@ void testRunOrQueue() throws Exception {

// WHEN
var ref = subject.runOrQueue(triggerRequest);

// THEN
assertThat(subject.getScheduler().getRunnungTasks()).isOne();
assertThat(persistentTaskService.getLastTriggerData(
ref.get()).get().getStatus()).isEqualTo(TriggerStatus.SUCCESS);
// AND
awaitRunningTasks();
assertThat(persistentTaskService.getLastTriggerData(ref).get().getStatus())
.isEqualTo(TriggerStatus.SUCCESS);
asserts.assertValue(Task3.NAME + "::Hallo");
}

@Test
void testQueuedInFuture() {
void testQueuedInFuture() throws TimeoutException, InterruptedException {
// GIVEN
final AddTriggerRequest<String> triggerRequest = Task3.ID
.newTrigger("Hallo")
Expand All @@ -113,6 +116,7 @@ void testQueuedInFuture() {

// WHEN
persistentTaskService.executeTriggersAndWait();
awaitRunningTasks();

// THEN
asserts.assertMissing(Task3.NAME + "::Hallo");
Expand All @@ -124,7 +128,7 @@ void runSimpleTaskMultipleTimesTest() throws Exception {
// GIVEN
TaskId<String> taskId = taskService.replace("foo", c -> asserts.info(c));
for (int i = 1; i < 21; ++i) {
subject.queue(taskId.newTrigger(i + " state").build());
triggerService.queue(taskId.newTrigger(i + " state").build());
}

// WHEN
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,24 @@ void testFailTrxCount() throws Exception {
// third to write the history
hibernateAsserts.assertTrxCount(3);
}

@Test
void testRunOrQueue() throws Exception {
// GIVEN
var k1 = subject.runOrQueue(TaskTriggerBuilder.newTrigger("savePersonInTrx").state("Paul").build());
var k2 = subject.runOrQueue(TaskTriggerBuilder.newTrigger("savePersonInTrx").state("Paul").build());

// WHEN
assertThat(persistentTaskService.getLastTriggerData(k1).get().getStatus())
.isEqualTo(TriggerStatus.RUNNING);
assertThat(persistentTaskService.getLastTriggerData(k2).get().getStatus())
.isEqualTo(TriggerStatus.RUNNING);


// THEN
awaitRunningTasks();
assertThat(personRepository.count()).isEqualTo(2);
}

@Test
void testRollbackAndRetry() throws Exception {
Expand All @@ -131,9 +149,10 @@ void testRollbackAndRetry() throws Exception {

// WHEN
var key = subject.runOrQueue(triggerRequest);

// THEN
key.get();
assertThat(persistentTaskService.getLastTriggerData(key.get()).get().getStatus())
awaitRunningTasks();
assertThat(persistentTaskService.getLastTriggerData(key).get().getStatus())
.isEqualTo(TriggerStatus.WAITING);

// WHEN
Expand All @@ -142,7 +161,7 @@ void testRollbackAndRetry() throws Exception {

// THEN
assertThat(executed).hasSize(1);
assertExecutionCount(key.get(), 2);
assertExecutionCount(key, 2);
assertThat(personRepository.count()).isOne();
}

Expand Down

0 comments on commit 57f0399

Please sign in to comment.