Skip to content

Commit

Permalink
Change trigger response (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
v9n authored Jan 7, 2025
1 parent 3b8201c commit a6d3ede
Show file tree
Hide file tree
Showing 11 changed files with 782 additions and 365 deletions.
16 changes: 15 additions & 1 deletion aggregator/rpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func (r *RpcServer) ListExecutions(ctx context.Context, payload *avsproto.ListEx
return r.engine.ListExecutions(user, payload)
}

func (r *RpcServer) GetExecution(ctx context.Context, payload *avsproto.GetExecutionReq) (*avsproto.Execution, error) {
func (r *RpcServer) GetExecution(ctx context.Context, payload *avsproto.ExecutionReq) (*avsproto.Execution, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
Expand All @@ -191,6 +191,20 @@ func (r *RpcServer) GetExecution(ctx context.Context, payload *avsproto.GetExecu
return r.engine.GetExecution(user, payload)
}

func (r *RpcServer) GetExecutionStatus(ctx context.Context, payload *avsproto.ExecutionReq) (*avsproto.ExecutionStatusResp, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "%s: %s", auth.AuthenticationError, err.Error())
}

r.config.Logger.Info("process get execution",
"user", user.Address.String(),
"task_id", payload.TaskId,
"execution_id", payload.ExecutionId,
)
return r.engine.GetExecutionStatus(user, payload)
}

func (r *RpcServer) GetTask(ctx context.Context, payload *avsproto.IdReq) (*avsproto.Task, error) {
user, err := r.verifyAuth(ctx)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion aggregator/task_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (agg *Aggregator) startTaskEngine(ctx context.Context) {
macros.SetRpc(agg.config.SmartWallet.EthRpcUrl)

agg.worker.RegisterProcessor(
taskengine.ExecuteTask,
taskengine.JobTypeExecuteTask,
taskExecutor,
)

Expand Down
49 changes: 27 additions & 22 deletions core/apqueue/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,37 @@ func NewWorker(q *Queue, db storage.Storage) *Worker {
return w
}

// wake up and pop first item in the queue to process
func (w *Worker) ProcessSignal(jid uint64) {
w.logger.Info("process job from queue", "signal", jid)
job, err := w.q.Dequeue()
if err != nil {
w.logger.Error("failed to dequeue", "error", err)
}

processor, ok := w.processorRegistry[job.Type]
if ok {
err = processor.Perform(job)
} else {
w.logger.Info("unsupported job", "job", string(job.Data))
}
w.logger.Info("decoded job", "job_id", jid, "jobName", job.Name, "jobdata", string(job.Data))

if err == nil {
w.q.markJobDone(job, jobComplete)
w.logger.Info("succesfully perform job", "job_id", jid, "task_id", job.Name)
} else {
// TODO: move to a retry queue depend on what kind of error
w.q.markJobDone(job, jobFailed)
w.logger.Error("failed to perform job", "error", err, "job_id", jid, "task_id", job.Name)
}
}

func (w *Worker) loop() {
for {
select {
case jid := <-w.q.eventCh:
w.logger.Info("process job from queue", "job_id", jid)
job, err := w.q.Dequeue()
if err != nil {
w.logger.Error("failed to dequeue", "error", err)
}

processor, ok := w.processorRegistry[job.Type]
if ok {
err = processor.Perform(job)
} else {
w.logger.Info("unsupported job", "job", string(job.Data))
}
w.logger.Info("decoded job", "job_id", jid, "jobName", job.Name, "jobdata", string(job.Data))

if err == nil {
w.q.markJobDone(job, jobComplete)
w.logger.Info("succesfully perform job", "job_id", jid, "task_id", job.Name)
} else {
// TODO: move to a retry queue depend on what kind of error
w.q.markJobDone(job, jobFailed)
w.logger.Error("failed to perform job", "error", err, "job_id", jid, "task_id", job.Name)
}
w.ProcessSignal(jid)
case <-w.q.closeCh: // loop was stopped
return
}
Expand Down
1 change: 1 addition & 0 deletions core/taskengine/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ w:<eoa>:<smart-wallet-address> -> {factory, salt}
t:<task-status>:<task-id> -> task payload, the source of truth of task information
u:<eoa>:<smart-wallet-address>:<task-id> -> task status
history:<task-id>:<execution-id> -> an execution history
trigger:<task-id>:<execution-id> -> execution status
The task storage was designed for fast retrieve time at the cost of extra storage.
Expand Down
91 changes: 75 additions & 16 deletions core/taskengine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import (
)

const (
ExecuteTask = "execute_task"
JobTypeExecuteTask = "execute_task"
DefaultItemPerPage = 50
)

Expand Down Expand Up @@ -390,13 +390,18 @@ func (n *Engine) AggregateChecksResult(address string, payload *avsproto.NotifyT
n.logger.Info("processed aggregator check hit", "operator", address, "task_id", payload.TaskId)
n.lock.Unlock()

data, err := json.Marshal(payload.TriggerMetadata)
queueTaskData := QueueExecutionData{
TriggerMetadata: payload.TriggerMetadata,
ExecutionID: ulid.Make().String(),
}

data, err := json.Marshal(queueTaskData)
if err != nil {
n.logger.Error("error serialize trigger to json", err)
return err
}

n.queue.Enqueue(ExecuteTask, payload.TaskId, data)
n.queue.Enqueue(JobTypeExecuteTask, payload.TaskId, data)
n.logger.Info("enqueue task into the queue system", "task_id", payload.TaskId)

// if the task can still run, add it back
Expand Down Expand Up @@ -558,36 +563,40 @@ func (n *Engine) TriggerTask(user *model.User, payload *avsproto.UserTriggerTask
return nil, grpcstatus.Errorf(codes.NotFound, TaskNotFoundError)
}

data, err := json.Marshal(payload.TriggerMetadata)
if err != nil {
n.logger.Error("error serialize trigger to json", err)
return nil, status.Errorf(codes.InvalidArgument, codes.InvalidArgument.String())
queueTaskData := QueueExecutionData{
TriggerMetadata: payload.TriggerMetadata,
ExecutionID: ulid.Make().String(),
}

if payload.IsBlocking {
// Run the task inline, by pass the queue system
executor := NewExecutor(n.db, n.logger)
execution, err := executor.RunTask(task, payload.TriggerMetadata)
execution, err := executor.RunTask(task, &queueTaskData)
if err == nil {
return &avsproto.UserTriggerTaskResp{
Result: true,
ExecutionId: execution.Id,
}, nil
}

return &avsproto.UserTriggerTaskResp{
Result: false,
}, err
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_TaskTriggerError), fmt.Sprintf("error trigger task: %s", err.Error()))
}

jid, err := n.queue.Enqueue(ExecuteTask, payload.TaskId, data)
data, err := json.Marshal(queueTaskData)
if err != nil {
n.logger.Error("error serialize trigger to json", err)
return nil, status.Errorf(codes.InvalidArgument, codes.InvalidArgument.String())
}

jid, err := n.queue.Enqueue(JobTypeExecuteTask, payload.TaskId, data)
if err != nil {
n.logger.Error("error enqueue job %s %s %w", payload.TaskId, string(data), err)
return nil, grpcstatus.Errorf(codes.Code(avsproto.Error_StorageUnavailable), StorageQueueUnavailableError)
}

n.logger.Info("enqueue task into the queue system", "task_id", payload.TaskId, "jid", jid)
n.setExecutionStatusQueue(task, queueTaskData.ExecutionID)
n.logger.Info("enqueue task into the queue system", "task_id", payload.TaskId, "jid", jid, "execution_id", queueTaskData.ExecutionID)
return &avsproto.UserTriggerTaskResp{
Result: true,
ExecutionId: queueTaskData.ExecutionID,
}, nil
}

Expand Down Expand Up @@ -696,10 +705,30 @@ func (n *Engine) ListExecutions(user *model.User, payload *avsproto.ListExecutio
return executioResp, nil
}

func (n *Engine) setExecutionStatusQueue(task *model.Task, executionID string) error {
status := strconv.Itoa(int(avsproto.ExecutionStatus_Queued))
return n.db.Set(TaskTriggerKey(task, executionID), []byte(status))
}

func (n *Engine) getExecutonStatusFromQueue(task *model.Task, executionID string) (*avsproto.ExecutionStatus, error) {
status, err := n.db.GetKey(TaskTriggerKey(task, executionID))
if err != nil {
return nil, err
}

value, err := strconv.Atoi(string(status))
if err != nil {
return nil, err
}
statusValue := avsproto.ExecutionStatus(value)
return &statusValue, nil
}

// Get xecution for a given task id and execution id
func (n *Engine) GetExecution(user *model.User, payload *avsproto.GetExecutionReq) (*avsproto.Execution, error) {
func (n *Engine) GetExecution(user *model.User, payload *avsproto.ExecutionReq) (*avsproto.Execution, error) {
// Validate all tasks own by the caller, if there are any tasks won't be owned by caller, we return permission error
task, err := n.GetTaskByID(payload.TaskId)

if err != nil {
return nil, grpcstatus.Errorf(codes.NotFound, TaskNotFoundError)
}
Expand Down Expand Up @@ -727,9 +756,39 @@ func (n *Engine) GetExecution(user *model.User, payload *avsproto.GetExecutionRe
exec.TriggerMetadata.Type = avsproto.TriggerMetadata_Event
}
}

return &exec, nil
}

func (n *Engine) GetExecutionStatus(user *model.User, payload *avsproto.ExecutionReq) (*avsproto.ExecutionStatusResp, error) {
task, err := n.GetTaskByID(payload.TaskId)

if err != nil {
return nil, grpcstatus.Errorf(codes.NotFound, TaskNotFoundError)
}

if !task.OwnedBy(user.Address) {
return nil, grpcstatus.Errorf(codes.NotFound, TaskNotFoundError)
}

// First look into execution first
if _, err = n.db.GetKey(TaskExecutionKey(task, payload.ExecutionId)); err != nil {
// When execution not found, it could be in pending status, we will check that storage
if status, err := n.getExecutonStatusFromQueue(task, payload.ExecutionId); err == nil {
return &avsproto.ExecutionStatusResp{
Status: *status,
}, nil
}
return nil, fmt.Errorf("invalid ")
}

// if the key existed, the execution has finished, no need to decode the whole storage, we just return the status in this call
return &avsproto.ExecutionStatusResp{
Status: avsproto.ExecutionStatus_Finished,
}, nil

}

func (n *Engine) DeleteTaskByUser(user *model.User, taskID string) (bool, error) {
task, err := n.GetTask(user, taskID)

Expand Down
Loading

0 comments on commit a6d3ede

Please sign in to comment.