Skip to content

Commit

Permalink
Implement new triggering and execution system (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
v9n authored Dec 7, 2024
1 parent 21923a0 commit 049bc66
Show file tree
Hide file tree
Showing 43 changed files with 8,819 additions and 2,992 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ jobs:
run: |
# TODO Implement test for all packages
go test -v ./core/taskengine
go test -v ./core/taskengine/trigger
go test -v ./core/taskengine/macros
go test -v ./pkg/timekeeper
publish-dev-build:
Expand Down
4 changes: 4 additions & 0 deletions aggregator/repl.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ func handleConnection(agg *Aggregator, conn net.Conn) {
case "exit":
fmt.Fprintln(conn, "Exiting...")
return
case "trigger":
fmt.Fprintln(conn, "about to trigger on server")
//agg.engine.TriggerWith

default:
fmt.Fprintln(conn, "Unknown command:", command)
}
Expand Down
25 changes: 19 additions & 6 deletions aggregator/rpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import (
// RpcServer is our grpc sever struct hold the entry point of request handler
type RpcServer struct {
avsproto.UnimplementedAggregatorServer
avsproto.UnimplementedNodeServer

config *config.Config
cache *bigcache.BigCache
db storage.Storage
Expand Down Expand Up @@ -191,23 +193,30 @@ func (r *RpcServer) GetTask(ctx context.Context, payload *avsproto.IdReq) (*avsp
}

// Operator action
func (r *RpcServer) SyncTasks(payload *avsproto.SyncTasksReq, srv avsproto.Aggregator_SyncTasksServer) error {
func (r *RpcServer) SyncMessages(payload *avsproto.SyncMessagesReq, srv avsproto.Node_SyncMessagesServer) error {
err := r.engine.StreamCheckToOperator(payload, srv)

return err
}

// Operator action
func (r *RpcServer) UpdateChecks(ctx context.Context, payload *avsproto.UpdateChecksReq) (*avsproto.UpdateChecksResp, error) {
if err := r.engine.AggregateChecksResult(payload.Address, payload.Id); err != nil {
func (r *RpcServer) NotifyTriggers(ctx context.Context, payload *avsproto.NotifyTriggersReq) (*avsproto.NotifyTriggersResp, error) {
if err := r.engine.AggregateChecksResult(payload.Address, payload); err != nil {
return nil, err
}

return &avsproto.UpdateChecksResp{
return &avsproto.NotifyTriggersResp{
UpdatedAt: timestamppb.Now(),
}, nil
}

// Operator action
func (r *RpcServer) Ack(ctx context.Context, payload *avsproto.AckMessageReq) (*wrapperspb.BoolValue, error) {
// TODO: Implement ACK before merge

return wrapperspb.Bool(true), nil
}

// startRpcServer initializes and establish a tcp socket on given address from
// config file
func (agg *Aggregator) startRpcServer(ctx context.Context) error {
Expand All @@ -231,7 +240,7 @@ func (agg *Aggregator) startRpcServer(ctx context.Context) error {
panic(err)
}

avsproto.RegisterAggregatorServer(s, &RpcServer{
rpcServer := &RpcServer{
cache: agg.cache,
db: agg.db,
engine: agg.engine,
Expand All @@ -241,7 +250,11 @@ func (agg *Aggregator) startRpcServer(ctx context.Context) error {

config: agg.config,
operatorPool: agg.operatorPool,
})
}

// TODO: split node and aggregator
avsproto.RegisterAggregatorServer(s, rpcServer)
avsproto.RegisterNodeServer(s, rpcServer)

// Register reflection service on gRPC server.
// This allow clien to discover url endpoint
Expand Down
8 changes: 6 additions & 2 deletions aggregator/task_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,13 @@ func (agg *Aggregator) startTaskEngine(ctx context.Context) {
Prefix: "default",
})
agg.worker = apqueue.NewWorker(agg.queue, agg.db)
taskExecutor := taskengine.NewExecutor(agg.db, agg.logger)
taskengine.SetMacro(agg.config.Macros)
taskengine.SetCache(agg.cache)

agg.worker.RegisterProcessor(
"contract_run",
taskengine.NewProcessor(agg.db, agg.config.SmartWallet, agg.logger),
taskengine.ExecuteTask,
taskExecutor,
)

agg.engine = taskengine.New(
Expand Down
6 changes: 3 additions & 3 deletions core/apqueue/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ func (w *Worker) loop() {
} else {
w.logger.Info("unsupported job", "job", string(job.Data))
}
w.logger.Info("decoded job", "jobid", jid, "jobdata", string(job.Data))
w.logger.Info("decoded job", "jobid", jid, "jobName", job.Name, "jobdata", string(job.Data))

if err == nil {
w.q.markJobDone(job, jobComplete)
w.logger.Info("succesfully perform job", "jobid", jid)
w.logger.Info("succesfully perform job", "jobid", jid, "task_id", job.Name)
} else {
// TODO: move to a retry queue depend on what kind of error
w.q.markJobDone(job, jobFailed)
w.logger.Info("failed to perform job", "jobid", jid)
w.logger.Errorf("failed to perform job %w", err, "jobid", jid, "task_id", job.Name)
}
case <-w.q.closeCh: // loop was stopped
return
Expand Down
5 changes: 5 additions & 0 deletions core/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ type Config struct {

SocketPath string
Environment sdklogging.LogLevel

Macros map[string]string
}

type SmartWalletConfig struct {
Expand Down Expand Up @@ -85,6 +87,8 @@ type ConfigRaw struct {
} `yaml:"smart_wallet"`

SocketPath string `yaml:"socket_path"`

Macros map[string]string `yaml:"macros"`
}

// These are read from CredibleSquaringDeploymentFileFlag
Expand Down Expand Up @@ -188,6 +192,7 @@ func NewConfig(configFilePath string) (*Config, error) {
},

SocketPath: configRaw.SocketPath,
Macros: configRaw.Macros,
}

if config.SocketPath == "" {
Expand Down
1 change: 0 additions & 1 deletion core/taskengine/cron_trigger.go

This file was deleted.

149 changes: 91 additions & 58 deletions core/taskengine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,87 @@ import (
"github.com/AvaProtocol/ap-avs/model"
"github.com/AvaProtocol/ap-avs/storage"
sdklogging "github.com/Layr-Labs/eigensdk-go/logging"
"github.com/allegro/bigcache/v3"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/math"
"github.com/ethereum/go-ethereum/ethclient"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
grpcstatus "google.golang.org/grpc/status"
"google.golang.org/protobuf/encoding/protojson"

avsproto "github.com/AvaProtocol/ap-avs/protobuf"
)

const (
ExecuteTask = "execute_task"
)

var (
rpcConn *ethclient.Client
// websocket client used for subscription
wsEthClient *ethclient.Client
wsRpcURL string
logger sdklogging.Logger

// a global variable that we expose to our tasks. User can use `{{name}}` to access them
// These macro are define in our aggregator yaml config file under `macros`
macroEnvs map[string]string
cache *bigcache.BigCache
)

// Set a global logger for task engine
func SetLogger(mylogger sdklogging.Logger) {
logger = mylogger
}

// Set the global macro system. macros are static, immutable and available to all tasks at runtime
func SetMacro(v map[string]string) {
macroEnvs = v
}

func SetCache(c *bigcache.BigCache) {
cache = c
}

// Initialize a shared rpc client instance
func SetRpc(rpcURL string) {
if conn, err := ethclient.Dial(rpcURL); err == nil {
rpcConn = conn
} else {
panic(err)
}
}

// Initialize a shared websocket rpc client instance
func SetWsRpc(rpcURL string) {
wsRpcURL = rpcURL
if err := retryWsRpc(); err != nil {
panic(err)
}
}

func retryWsRpc() error {
for {
conn, err := ethclient.Dial(wsRpcURL)
if err == nil {
wsEthClient = conn
return nil
}
logger.Errorf("cannot establish websocket client for RPC, retry in 15 seconds", "err", err)
time.Sleep(15 * time.Second)
}

return nil
}

type operatorState struct {
// list of task id that we had synced to this operator
TaskID map[string]bool
MonotonicClock int64
}

// The core datastructure of the task engine
type Engine struct {
db storage.Storage
queue *apqueue.Queue
Expand All @@ -66,31 +119,7 @@ type Engine struct {
logger sdklogging.Logger
}

func SetRpc(rpcURL string) {
if conn, err := ethclient.Dial(rpcURL); err == nil {
rpcConn = conn
} else {
panic(err)
}
}

func SetWsRpc(rpcURL string) {
wsRpcURL = rpcURL
if err := retryWsRpc(); err != nil {
panic(err)
}
}

func retryWsRpc() error {
conn, err := ethclient.Dial(wsRpcURL)
if err == nil {
wsEthClient = conn
return nil
}

return err
}

// create a new task engine using given storage, config and queueu
func New(db storage.Storage, config *config.Config, queue *apqueue.Queue, logger sdklogging.Logger) *Engine {
e := Engine{
db: db,
Expand Down Expand Up @@ -129,9 +158,12 @@ func (n *Engine) MustStart() {
panic(e)
}
for _, item := range kvs {
var task model.Task
if err := json.Unmarshal(item.Value, &task); err == nil {
n.tasks[string(item.Key)] = &task
task := &model.Task{
Task: &avsproto.Task{},
}
err := protojson.Unmarshal(item.Value, task)
if err == nil {
n.tasks[string(item.Key)] = task
}
}
}
Expand Down Expand Up @@ -254,7 +286,7 @@ func (n *Engine) CreateTask(user *model.User, taskPayload *avsproto.CreateTaskRe
return task, nil
}

func (n *Engine) StreamCheckToOperator(payload *avsproto.SyncTasksReq, srv avsproto.Aggregator_SyncTasksServer) error {
func (n *Engine) StreamCheckToOperator(payload *avsproto.SyncMessagesReq, srv avsproto.Node_SyncMessagesServer) error {
ticker := time.NewTicker(5 * time.Second)
address := payload.Address

Expand Down Expand Up @@ -297,15 +329,23 @@ func (n *Engine) StreamCheckToOperator(payload *avsproto.SyncTasksReq, srv avspr
continue
}

n.logger.Info("stream check to operator", "taskID", task.Id, "operator", payload.Address)
resp := avsproto.SyncTasksResp{
Id: task.Id,
CheckType: "CheckTrigger",
Trigger: task.Trigger,
resp := avsproto.SyncMessagesResp{
Id: task.Id,
Op: avsproto.MessageOp_MonitorTaskTrigger,

TaskMetadata: &avsproto.SyncMessagesResp_TaskMetadata{
TaskId: task.Id,
Remain: task.MaxExecution,
ExpiredAt: task.ExpiredAt,
Trigger: task.Trigger,
},
}
n.logger.Info("stream check to operator", "taskID", task.Id, "operator", payload.Address, "resp", resp)

if err := srv.Send(&resp); err != nil {
return err
// return error to cause client to establish re-connect the connection
n.logger.Info("error sending check to operator", "taskID", task.Id, "operator", payload.Address)
return fmt.Errorf("cannot send data back to grpc channel")
}

n.lock.Lock()
Expand All @@ -317,35 +357,29 @@ func (n *Engine) StreamCheckToOperator(payload *avsproto.SyncTasksReq, srv avspr
}

// TODO: Merge and verify from multiple operators
func (n *Engine) AggregateChecksResult(address string, ids []string) error {
if len(ids) < 1 {
func (n *Engine) AggregateChecksResult(address string, payload *avsproto.NotifyTriggersReq) error {
if len(payload.TaskId) < 1 {
return nil
}

n.logger.Debug("process aggregator check hits", "operator", address, "task_ids", ids)
for _, id := range ids {
n.lock.Lock()
delete(n.tasks, id)
delete(n.trackSyncedTasks[address].TaskID, id)
n.logger.Info("processed aggregator check hit", "operator", address, "id", id)
n.lock.Unlock()
}
n.lock.Lock()
// delete(n.tasks, payload.TaskId)
// delete(n.trackSyncedTasks[address].TaskID, payload.TaskId)
// uncomment later

// Now we will queue the job
for _, id := range ids {
n.logger.Debug("mark task in executing status", "task_id", id)
n.logger.Info("processed aggregator check hit", "operator", address, "task_id", payload.TaskId)
n.lock.Unlock()

if err := n.db.Move(
[]byte(TaskStorageKey(id, avsproto.TaskStatus_Active)),
[]byte(TaskStorageKey(id, avsproto.TaskStatus_Executing)),
); err != nil {
n.logger.Error("error moving the task storage from active to executing", "task", id, "error", err)
}

n.queue.Enqueue("contract_run", id, []byte(id))
n.logger.Info("enqueue contract_run job", "taskid", id)
data, err := json.Marshal(payload.TriggerMarker)
if err != nil {
n.logger.Error("error serialize trigger to json", err)
return err
}

n.queue.Enqueue(ExecuteTask, payload.TaskId, data)
n.logger.Info("enqueue task into the queue system", "task_id", payload.TaskId)

// if the task can still run, add it back
return nil
}

Expand Down Expand Up @@ -455,7 +489,6 @@ func (n *Engine) CancelTaskByUser(user *model.User, taskID string) (bool, error)
updates := map[string][]byte{}
oldStatus := task.Status
task.SetCanceled()
fmt.Println("found task", task, string(TaskStorageKey(task.Id, oldStatus)), string(TaskUserKey(task)))
updates[string(TaskStorageKey(task.Id, oldStatus))], err = task.ToJSON()
updates[string(TaskUserKey(task))] = []byte(fmt.Sprintf("%d", task.Status))

Expand Down
Loading

0 comments on commit 049bc66

Please sign in to comment.