Skip to content

Commit

Permalink
always return execution id when triggering task
Browse files Browse the repository at this point in the history
  • Loading branch information
v9n committed Jan 2, 2025
1 parent 3b8201c commit 395e147
Show file tree
Hide file tree
Showing 12 changed files with 642 additions and 303 deletions.
2 changes: 1 addition & 1 deletion aggregator/rpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (r *RpcServer) ListExecutions(ctx context.Context, payload *avsproto.ListEx
return r.engine.ListExecutions(user, payload)
}

func (r *RpcServer) GetExecution(ctx context.Context, payload *avsproto.GetExecutionReq) (*avsproto.Execution, error) {
func (r *RpcServer) GetExecution(ctx context.Context, payload *avsproto.GetExecutionReq) (*avsproto.GetExecutionResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
Expand Down
2 changes: 1 addition & 1 deletion aggregator/task_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (agg *Aggregator) startTaskEngine(ctx context.Context) {
macros.SetRpc(agg.config.SmartWallet.EthRpcUrl)

agg.worker.RegisterProcessor(
taskengine.ExecuteTask,
taskengine.JobTypeExecuteTask,
taskExecutor,
)

Expand Down
49 changes: 27 additions & 22 deletions core/apqueue/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,37 @@ func NewWorker(q *Queue, db storage.Storage) *Worker {
return w
}

// wake up and pop first item in the queue to process
func (w *Worker) ProcessSignal(jid uint64) {
w.logger.Info("process job from queue", "signal", jid)
job, err := w.q.Dequeue()
if err != nil {
w.logger.Error("failed to dequeue", "error", err)
}

processor, ok := w.processorRegistry[job.Type]
if ok {
err = processor.Perform(job)
} else {
w.logger.Info("unsupported job", "job", string(job.Data))
}
w.logger.Info("decoded job", "job_id", jid, "jobName", job.Name, "jobdata", string(job.Data))

if err == nil {
w.q.markJobDone(job, jobComplete)
w.logger.Info("succesfully perform job", "job_id", jid, "task_id", job.Name)
} else {
// TODO: move to a retry queue depend on what kind of error
w.q.markJobDone(job, jobFailed)
w.logger.Error("failed to perform job", "error", err, "job_id", jid, "task_id", job.Name)
}
}

func (w *Worker) loop() {
for {
select {
case jid := <-w.q.eventCh:
w.logger.Info("process job from queue", "job_id", jid)
job, err := w.q.Dequeue()
if err != nil {
w.logger.Error("failed to dequeue", "error", err)
}

processor, ok := w.processorRegistry[job.Type]
if ok {
err = processor.Perform(job)
} else {
w.logger.Info("unsupported job", "job", string(job.Data))
}
w.logger.Info("decoded job", "job_id", jid, "jobName", job.Name, "jobdata", string(job.Data))

if err == nil {
w.q.markJobDone(job, jobComplete)
w.logger.Info("succesfully perform job", "job_id", jid, "task_id", job.Name)
} else {
// TODO: move to a retry queue depend on what kind of error
w.q.markJobDone(job, jobFailed)
w.logger.Error("failed to perform job", "error", err, "job_id", jid, "task_id", job.Name)
}
w.ProcessSignal(jid)
case <-w.q.closeCh: // loop was stopped
return
}
Expand Down
1 change: 1 addition & 0 deletions core/taskengine/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ w:<eoa>:<smart-wallet-address> -> {factory, salt}
t:<task-status>:<task-id> -> task payload, the source of truth of task information
u:<eoa>:<smart-wallet-address>:<task-id> -> task status
history:<task-id>:<execution-id> -> an execution history
trigger:<task-id>:<execution-id> -> execution status
The task storage was designed for fast retrieve time at the cost of extra storage.
Expand Down
74 changes: 57 additions & 17 deletions core/taskengine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
)

const (
ExecuteTask = "execute_task"
JobTypeExecuteTask = "execute_task"
DefaultItemPerPage = 50
)

Expand Down Expand Up @@ -390,13 +390,18 @@ func (n *Engine) AggregateChecksResult(address string, payload *avsproto.NotifyT
n.logger.Info("processed aggregator check hit", "operator", address, "task_id", payload.TaskId)
n.lock.Unlock()

data, err := json.Marshal(payload.TriggerMetadata)
queueTaskData := QueueExecutionData{
TriggerMetadata: payload.TriggerMetadata,
ExecutionID: ulid.Make().String(),
}

data, err := json.Marshal(queueTaskData)
if err != nil {
n.logger.Error("error serialize trigger to json", err)
return err
}

n.queue.Enqueue(ExecuteTask, payload.TaskId, data)
n.queue.Enqueue(JobTypeExecuteTask, payload.TaskId, data)
n.logger.Info("enqueue task into the queue system", "task_id", payload.TaskId)

// if the task can still run, add it back
Expand Down Expand Up @@ -558,36 +563,40 @@ func (n *Engine) TriggerTask(user *model.User, payload *avsproto.UserTriggerTask
return nil, grpcstatus.Errorf(codes.NotFound, TaskNotFoundError)
}

data, err := json.Marshal(payload.TriggerMetadata)
if err != nil {
n.logger.Error("error serialize trigger to json", err)
return nil, status.Errorf(codes.InvalidArgument, codes.InvalidArgument.String())
queueTaskData := QueueExecutionData{
TriggerMetadata: payload.TriggerMetadata,
ExecutionID: ulid.Make().String(),
}

if payload.IsBlocking {
// Run the task inline, by pass the queue system
executor := NewExecutor(n.db, n.logger)
execution, err := executor.RunTask(task, payload.TriggerMetadata)
execution, err := executor.RunTask(task, &queueTaskData)
if err == nil {
return &avsproto.UserTriggerTaskResp{
Result: true,
ExecutionId: execution.Id,
}, nil
}

return &avsproto.UserTriggerTaskResp{
Result: false,
}, err
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_TaskTriggerError), fmt.Sprintf("error trigger task: %s", err.Error()))
}

jid, err := n.queue.Enqueue(ExecuteTask, payload.TaskId, data)
data, err := json.Marshal(queueTaskData)
if err != nil {
n.logger.Error("error serialize trigger to json", err)
return nil, status.Errorf(codes.InvalidArgument, codes.InvalidArgument.String())
}

jid, err := n.queue.Enqueue(JobTypeExecuteTask, payload.TaskId, data)
if err != nil {
n.logger.Error("error enqueue job %s %s %w", payload.TaskId, string(data), err)
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_StorageUnavailable), StorageQueueUnavailableError)
}

n.logger.Info("enqueue task into the queue system", "task_id", payload.TaskId, "jid", jid)
n.setExecutionStatusQueue(task, queueTaskData.ExecutionID)
n.logger.Info("enqueue task into the queue system", "task_id", payload.TaskId, "jid", jid, "execution_id", queueTaskData.ExecutionID)
return &avsproto.UserTriggerTaskResp{
Result: true,
ExecutionId: queueTaskData.ExecutionID,
}, nil
}

Expand Down Expand Up @@ -696,10 +705,30 @@ func (n *Engine) ListExecutions(user *model.User, payload *avsproto.ListExecutio
return executioResp, nil
}

func (n *Engine) setExecutionStatusQueue(task *model.Task, executionID string) error {
status := strconv.Itoa(int(avsproto.GetExecutionResp_Queue))
return n.db.Set(TaskTriggerKey(task, executionID), []byte(status))
}

func (n *Engine) getExecutonStatusFromQueue(task *model.Task, executionID string) (*avsproto.GetExecutionResp_ExecutionStatus, error) {
status, err := n.db.GetKey(TaskTriggerKey(task, executionID))
if err != nil {
return nil, err
}

value, err := strconv.Atoi(string(status))
if err != nil {
return nil, err
}
statusValue := avsproto.GetExecutionResp_ExecutionStatus(value)
return &statusValue, nil
}

// Get xecution for a given task id and execution id
func (n *Engine) GetExecution(user *model.User, payload *avsproto.GetExecutionReq) (*avsproto.Execution, error) {
func (n *Engine) GetExecution(user *model.User, payload *avsproto.GetExecutionReq) (*avsproto.GetExecutionResp, error) {
// Validate all tasks own by the caller, if there are any tasks won't be owned by caller, we return permission error
task, err := n.GetTaskByID(payload.TaskId)

if err != nil {
return nil, grpcstatus.Errorf(codes.NotFound, TaskNotFoundError)
}
Expand All @@ -710,6 +739,12 @@ func (n *Engine) GetExecution(user *model.User, payload *avsproto.GetExecutionRe

executionValue, err := n.db.GetKey(TaskExecutionKey(task, payload.ExecutionId))
if err != nil {
// When execution not found, it could be in pending status, we will check that storage
if status, err := n.getExecutonStatusFromQueue(task, payload.ExecutionId); err == nil {
return &avsproto.GetExecutionResp{
Status: *status,
}, nil
}
return nil, grpcstatus.Errorf(codes.NotFound, ExecutionNotFoundError)
}
exec := avsproto.Execution{}
Expand All @@ -727,7 +762,12 @@ func (n *Engine) GetExecution(user *model.User, payload *avsproto.GetExecutionRe
exec.TriggerMetadata.Type = avsproto.TriggerMetadata_Event
}
}
return &exec, nil

result := &avsproto.GetExecutionResp{
Status: avsproto.GetExecutionResp_Completed,
Data: &exec,
}
return result, nil
}

func (n *Engine) DeleteTaskByUser(user *model.User, taskID string) (bool, error) {
Expand Down
134 changes: 130 additions & 4 deletions core/taskengine/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package taskengine

import (
"fmt"
"strings"
"testing"

"github.com/AvaProtocol/ap-avs/core/apqueue"
"github.com/AvaProtocol/ap-avs/core/testutil"
avsproto "github.com/AvaProtocol/ap-avs/protobuf"
"github.com/AvaProtocol/ap-avs/storage"
Expand Down Expand Up @@ -212,12 +214,12 @@ func TestGetExecution(t *testing.T) {
ExecutionId: resultTrigger.ExecutionId,
})

if execution.Id != resultTrigger.ExecutionId {
t.Errorf("invalid execution id. expect %s got %s", resultTrigger.ExecutionId, execution.Id)
if execution.Data.Id != resultTrigger.ExecutionId {
t.Errorf("invalid execution id. expect %s got %s", resultTrigger.ExecutionId, execution.Data.Id)
}

if execution.TriggerMetadata.BlockNumber != 101 {
t.Errorf("invalid triggered block. expect 101 got %d", execution.TriggerMetadata.BlockNumber)
if execution.Data.TriggerMetadata.BlockNumber != 101 {
t.Errorf("invalid triggered block. expect 101 got %d", execution.Data.TriggerMetadata.BlockNumber)
}

// Another user cannot get this executin id
Expand Down Expand Up @@ -288,3 +290,127 @@ func TestListWallets(t *testing.T) {
t.Errorf("expect only default wallet but got %d", len(wallets))
}
}

func TestTriggerSync(t *testing.T) {
db := testutil.TestMustDB()
defer storage.Destroy(db.(*storage.BadgerStorage))

config := testutil.GetAggregatorConfig()
n := New(db, config, nil, testutil.GetLogger())

// Now create a test task
tr1 := testutil.RestTask()
tr1.Memo = "t1"
// salt 0
tr1.SmartWalletAddress = "0x7c3a76086588230c7B3f4839A4c1F5BBafcd57C6"
result, _ := n.CreateTask(testutil.TestUser1(), tr1)

resultTrigger, err := n.TriggerTask(testutil.TestUser1(), &avsproto.UserTriggerTaskReq{
TaskId: result.Id,
TriggerMetadata: &avsproto.TriggerMetadata{
BlockNumber: 101,
},
IsBlocking: true,
})

if err != nil {
t.Errorf("expected trigger succesfully but got error: %s", err)
}

// Now get back that execution id
execution, err := n.GetExecution(testutil.TestUser1(), &avsproto.GetExecutionReq{
TaskId: result.Id,
ExecutionId: resultTrigger.ExecutionId,
})

if execution.Status != avsproto.GetExecutionResp_Completed {
t.Errorf("invalid execution status, expected conpleted but got %s", avsproto.GetExecutionResp_ExecutionStatus_name[int32(execution.Status)])
}

if execution.Data.Id != resultTrigger.ExecutionId {
t.Errorf("invalid execution id. expect %s got %s", resultTrigger.ExecutionId, execution.Data.Id)
}

if execution.Data.TriggerMetadata.BlockNumber != 101 {
t.Errorf("invalid triggered block. expect 101 got %d", execution.Data.TriggerMetadata.BlockNumber)
}
}

func TestTriggerAsync(t *testing.T) {
db := testutil.TestMustDB()
defer storage.Destroy(db.(*storage.BadgerStorage))

config := testutil.GetAggregatorConfig()
n := New(db, config, nil, testutil.GetLogger())
n.queue = apqueue.New(db, testutil.GetLogger(), &apqueue.QueueOption{
Prefix: "default",
})
worker := apqueue.NewWorker(n.queue, n.db)
taskExecutor := NewExecutor(n.db, testutil.GetLogger())
worker.RegisterProcessor(
JobTypeExecuteTask,
taskExecutor,
)
n.queue.MustStart()

// Now create a test task
tr1 := testutil.RestTask()
tr1.Memo = "t1"
// salt 0 wallet
tr1.SmartWalletAddress = "0x7c3a76086588230c7B3f4839A4c1F5BBafcd57C6"
result, _ := n.CreateTask(testutil.TestUser1(), tr1)

resultTrigger, err := n.TriggerTask(testutil.TestUser1(), &avsproto.UserTriggerTaskReq{
TaskId: result.Id,
TriggerMetadata: &avsproto.TriggerMetadata{
BlockNumber: 101,
},
IsBlocking: false,
})

if err != nil {
t.Errorf("expected trigger succesfully but got error: %s", err)
}

// Now get back that execution id, because the task is run async we won't have any data yet,
// just the status for now
execution, err := n.GetExecution(testutil.TestUser1(), &avsproto.GetExecutionReq{
TaskId: result.Id,
ExecutionId: resultTrigger.ExecutionId,
})

if execution.Data != nil {
t.Errorf("malform execution result. expect no data but got %s", execution.Data)
}

if execution.Status != avsproto.GetExecutionResp_Queue {
t.Errorf("invalid execution status, expected queue but got %s", avsproto.GetExecutionResp_ExecutionStatus_name[int32(execution.Status)])
}

// Now let the queue start and process job
// In our end to end system the worker will process the job eventually
worker.ProcessSignal(1)

execution, err = n.GetExecution(testutil.TestUser1(), &avsproto.GetExecutionReq{
TaskId: result.Id,
ExecutionId: resultTrigger.ExecutionId,
})
if execution.Status != avsproto.GetExecutionResp_Completed {
t.Errorf("invalid execution status, expected completed but got %s", avsproto.GetExecutionResp_ExecutionStatus_name[int32(execution.Status)])
}

if execution.Data.Id != resultTrigger.ExecutionId {
t.Errorf("wring execution id, expected %s got %s", resultTrigger.ExecutionId, execution.Data.Id)
}

if !execution.Data.Success {
t.Errorf("wrong success result, expected true got false")
}

if execution.Data.Steps[0].NodeId != "ping1" {
t.Errorf("wrong node id in execution log")
}
if !strings.Contains(execution.Data.Steps[0].OutputData, "httpbin.org") {
t.Error("Invalid output data")
}
}
Loading

0 comments on commit 395e147

Please sign in to comment.