diff --git a/coordinator/internal/controller/api/get_task.go b/coordinator/internal/controller/api/get_task.go index c430f45ee3..0753f04548 100644 --- a/coordinator/internal/controller/api/get_task.go +++ b/coordinator/internal/controller/api/get_task.go @@ -23,7 +23,8 @@ import ( // GetTaskController the get prover task api controller type GetTaskController struct { - proverTasks map[message.ProofType]provertask.ProverTask + proverTasks map[message.ProofType]provertask.ProverTask + proverTaskManager *provertask.ProverTaskManager getTaskAccessCounter *prometheus.CounterVec @@ -32,12 +33,15 @@ type GetTaskController struct { // NewGetTaskController create a get prover task controller func NewGetTaskController(cfg *config.Config, chainCfg *params.ChainConfig, db *gorm.DB, verifier *verifier.Verifier, reg prometheus.Registerer) *GetTaskController { + proverTaskManager := provertask.NewProverTaskManager(db) + chunkProverTask := provertask.NewChunkProverTask(cfg, chainCfg, db, verifier.ChunkVk, reg) batchProverTask := provertask.NewBatchProverTask(cfg, chainCfg, db, verifier.BatchVk, reg) bundleProverTask := provertask.NewBundleProverTask(cfg, chainCfg, db, verifier.BundleVk, reg) ptc := &GetTaskController{ - proverTasks: make(map[message.ProofType]provertask.ProverTask), + proverTasks: make(map[message.ProofType]provertask.ProverTask), + proverTaskManager: proverTaskManager, getTaskAccessCounter: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ Name: "coordinator_get_task_access_count", Help: "Multi dimensions get task counter.", @@ -99,7 +103,19 @@ func (ptc *GetTaskController) GetTasks(ctx *gin.Context) { } } - proofType := ptc.proofType(&getTaskParameter) + assigned, err := ptc.proverTaskManager.CheckParameter(ctx) + if err != nil { + nerr := fmt.Errorf("check prover task parameter failed, error:%w", err) + types.RenderFailure(ctx, types.ErrCoordinatorGetTaskFailure, nerr) + return + } + + var proofType message.ProofType + if assigned != nil { + proofType = message.ProofType(assigned.TaskType) + } else { + proofType = ptc.proofType(&getTaskParameter) + } proverTask, isExist := ptc.proverTasks[proofType] if !isExist { nerr := fmt.Errorf("parameter wrong proof type:%v", proofType) diff --git a/coordinator/internal/logic/provertask/batch_prover_task.go b/coordinator/internal/logic/provertask/batch_prover_task.go index ff1bc25920..55cb362540 100644 --- a/coordinator/internal/logic/provertask/batch_prover_task.go +++ b/coordinator/internal/logic/provertask/batch_prover_task.go @@ -39,15 +39,14 @@ type BatchProverTask struct { func NewBatchProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *gorm.DB, expectedVk map[string][]byte, reg prometheus.Registerer) *BatchProverTask { bp := &BatchProverTask{ BaseProverTask: BaseProverTask{ - db: db, - cfg: cfg, - chainCfg: chainCfg, - expectedVk: expectedVk, - blockOrm: orm.NewL2Block(db), - chunkOrm: orm.NewChunk(db), - batchOrm: orm.NewBatch(db), - proverTaskOrm: orm.NewProverTask(db), - proverBlockListOrm: orm.NewProverBlockList(db), + db: db, + cfg: cfg, + chainCfg: chainCfg, + expectedVk: expectedVk, + blockOrm: orm.NewL2Block(db), + chunkOrm: orm.NewChunk(db), + batchOrm: orm.NewBatch(db), + proverTaskOrm: orm.NewProverTask(db), }, batchTaskGetTaskTotal: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ Name: "coordinator_batch_get_task_total", @@ -60,9 +59,9 @@ func NewBatchProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *go // Assign load and assign batch tasks func (bp *BatchProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) { - taskCtx, err := bp.checkParameter(ctx) - if err != nil || taskCtx == nil { - return nil, fmt.Errorf("check prover task parameter failed, error:%w", err) + taskCtx := bp.checkParameter(ctx) + if taskCtx == nil { + return nil, fmt.Errorf("check prover task parameter missed") } maxActiveAttempts := bp.cfg.ProverManager.ProversPerSession diff --git a/coordinator/internal/logic/provertask/bundle_prover_task.go b/coordinator/internal/logic/provertask/bundle_prover_task.go index 1d8377f554..d39d63a71a 100644 --- a/coordinator/internal/logic/provertask/bundle_prover_task.go +++ b/coordinator/internal/logic/provertask/bundle_prover_task.go @@ -36,16 +36,15 @@ type BundleProverTask struct { func NewBundleProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *gorm.DB, expectedVk map[string][]byte, reg prometheus.Registerer) *BundleProverTask { bp := &BundleProverTask{ BaseProverTask: BaseProverTask{ - db: db, - chainCfg: chainCfg, - cfg: cfg, - expectedVk: expectedVk, - blockOrm: orm.NewL2Block(db), - chunkOrm: orm.NewChunk(db), - batchOrm: orm.NewBatch(db), - bundleOrm: orm.NewBundle(db), - proverTaskOrm: orm.NewProverTask(db), - proverBlockListOrm: orm.NewProverBlockList(db), + db: db, + chainCfg: chainCfg, + cfg: cfg, + expectedVk: expectedVk, + blockOrm: orm.NewL2Block(db), + chunkOrm: orm.NewChunk(db), + batchOrm: orm.NewBatch(db), + bundleOrm: orm.NewBundle(db), + proverTaskOrm: orm.NewProverTask(db), }, bundleTaskGetTaskTotal: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ Name: "coordinator_bundle_get_task_total", @@ -58,9 +57,9 @@ func NewBundleProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *g // Assign load and assign batch tasks func (bp *BundleProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) { - taskCtx, err := bp.checkParameter(ctx) - if err != nil || taskCtx == nil { - return nil, fmt.Errorf("check prover task parameter failed, error:%w", err) + taskCtx := bp.checkParameter(ctx) + if taskCtx == nil { + return nil, fmt.Errorf("check prover task parameter missed") } maxActiveAttempts := bp.cfg.ProverManager.ProversPerSession diff --git a/coordinator/internal/logic/provertask/chunk_prover_task.go b/coordinator/internal/logic/provertask/chunk_prover_task.go index 9971d294e6..7bc8d756e9 100644 --- a/coordinator/internal/logic/provertask/chunk_prover_task.go +++ b/coordinator/internal/logic/provertask/chunk_prover_task.go @@ -36,14 +36,13 @@ type ChunkProverTask struct { func NewChunkProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *gorm.DB, expectedVk map[string][]byte, reg prometheus.Registerer) *ChunkProverTask { cp := &ChunkProverTask{ BaseProverTask: BaseProverTask{ - db: db, - cfg: cfg, - chainCfg: chainCfg, - expectedVk: expectedVk, - chunkOrm: orm.NewChunk(db), - blockOrm: orm.NewL2Block(db), - proverTaskOrm: orm.NewProverTask(db), - proverBlockListOrm: orm.NewProverBlockList(db), + db: db, + cfg: cfg, + chainCfg: chainCfg, + expectedVk: expectedVk, + chunkOrm: orm.NewChunk(db), + blockOrm: orm.NewL2Block(db), + proverTaskOrm: orm.NewProverTask(db), }, chunkTaskGetTaskTotal: promauto.With(reg).NewCounterVec(prometheus.CounterOpts{ Name: "coordinator_chunk_get_task_total", @@ -56,9 +55,9 @@ func NewChunkProverTask(cfg *config.Config, chainCfg *params.ChainConfig, db *go // Assign the chunk proof which need to prove func (cp *ChunkProverTask) Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) { - taskCtx, err := cp.checkParameter(ctx) - if err != nil || taskCtx == nil { - return nil, fmt.Errorf("check prover task parameter failed, error:%w", err) + taskCtx := cp.checkParameter(ctx) + if taskCtx == nil { + return nil, fmt.Errorf("check prover task parameter missed") } maxActiveAttempts := cp.cfg.ProverManager.ProversPerSession diff --git a/coordinator/internal/logic/provertask/prover_task.go b/coordinator/internal/logic/provertask/prover_task.go index e894ea7d14..4e5c0cd727 100644 --- a/coordinator/internal/logic/provertask/prover_task.go +++ b/coordinator/internal/logic/provertask/prover_task.go @@ -37,6 +37,81 @@ type ProverTask interface { Assign(ctx *gin.Context, getTaskParameter *coordinatorType.GetTaskParameter) (*coordinatorType.GetTaskSchema, error) } +// ProverTaskManager manage task which has been assigned +type ProverTaskManager struct { + proverTaskOrm *orm.ProverTask + proverBlockListOrm *orm.ProverBlockList +} + +const proverTaskCtxKey = "prover_task_context_key" + +// NewProverTaskManager new a prover task manager +func NewProverTaskManager(db *gorm.DB) *ProverTaskManager { + return &ProverTaskManager{ + proverTaskOrm: orm.NewProverTask(db), + proverBlockListOrm: orm.NewProverBlockList(db), + } +} + +// checkParameter check the prover task parameter illegal +func (b *ProverTaskManager) CheckParameter(ctx *gin.Context) (*orm.ProverTask, error) { + var ptc proverTaskContext + ptc.HardForkNames = make(map[string]struct{}) + + publicKey, publicKeyExist := ctx.Get(coordinatorType.PublicKey) + if !publicKeyExist { + return nil, errors.New("get public key from context failed") + } + ptc.PublicKey = publicKey.(string) + + proverName, proverNameExist := ctx.Get(coordinatorType.ProverName) + if !proverNameExist { + return nil, errors.New("get prover name from context failed") + } + ptc.ProverName = proverName.(string) + + proverVersion, proverVersionExist := ctx.Get(coordinatorType.ProverVersion) + if !proverVersionExist { + return nil, errors.New("get prover version from context failed") + } + ptc.ProverVersion = proverVersion.(string) + + ProverProviderType, ProverProviderTypeExist := ctx.Get(coordinatorType.ProverProviderTypeKey) + if !ProverProviderTypeExist { + // for backward compatibility, set ProverProviderType as internal + ProverProviderType = float64(coordinatorType.ProverProviderTypeInternal) + } + ptc.ProverProviderType = uint8(ProverProviderType.(float64)) + + hardForkNamesStr, hardForkNameExist := ctx.Get(coordinatorType.HardForkName) + if !hardForkNameExist { + return nil, errors.New("get hard fork name from context failed") + } + hardForkNames := strings.Split(hardForkNamesStr.(string), ",") + for _, hardForkName := range hardForkNames { + ptc.HardForkNames[hardForkName] = struct{}{} + } + + isBlocked, err := b.proverBlockListOrm.IsPublicKeyBlocked(ctx.Copy(), publicKey.(string)) + if err != nil { + return nil, fmt.Errorf("failed to check whether the public key %s is blocked before assigning a chunk task, err: %w, proverName: %s, proverVersion: %s", publicKey, err, proverName, proverVersion) + } + if isBlocked { + return nil, fmt.Errorf("public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", publicKey, proverName, proverVersion) + } + + assigned, err := b.proverTaskOrm.IsProverAssigned(ctx.Copy(), publicKey.(string)) + if err != nil { + return nil, fmt.Errorf("failed to check if prover %s is assigned a task, err: %w", publicKey.(string), err) + } + + ptc.hasAssignedTask = assigned + + ctx.Set(proverTaskCtxKey, &ptc) + + return assigned, nil +} + // BaseProverTask a base prover task which contain series functions type BaseProverTask struct { cfg *config.Config @@ -44,12 +119,12 @@ type BaseProverTask struct { db *gorm.DB expectedVk map[string][]byte - batchOrm *orm.Batch - chunkOrm *orm.Chunk - bundleOrm *orm.Bundle - blockOrm *orm.L2Block - proverTaskOrm *orm.ProverTask - proverBlockListOrm *orm.ProverBlockList + batchOrm *orm.Batch + chunkOrm *orm.Chunk + bundleOrm *orm.Bundle + blockOrm *orm.L2Block + + proverTaskOrm *orm.ProverTask } type proverTaskContext struct { @@ -132,59 +207,13 @@ func (b *BaseProverTask) hardForkSanityCheck(ctx *gin.Context, taskCtx *proverTa } // checkParameter check the prover task parameter illegal -func (b *BaseProverTask) checkParameter(ctx *gin.Context) (*proverTaskContext, error) { - var ptc proverTaskContext - ptc.HardForkNames = make(map[string]struct{}) - - publicKey, publicKeyExist := ctx.Get(coordinatorType.PublicKey) - if !publicKeyExist { - return nil, errors.New("get public key from context failed") - } - ptc.PublicKey = publicKey.(string) - - proverName, proverNameExist := ctx.Get(coordinatorType.ProverName) - if !proverNameExist { - return nil, errors.New("get prover name from context failed") - } - ptc.ProverName = proverName.(string) - - proverVersion, proverVersionExist := ctx.Get(coordinatorType.ProverVersion) - if !proverVersionExist { - return nil, errors.New("get prover version from context failed") - } - ptc.ProverVersion = proverVersion.(string) - - ProverProviderType, ProverProviderTypeExist := ctx.Get(coordinatorType.ProverProviderTypeKey) - if !ProverProviderTypeExist { - // for backward compatibility, set ProverProviderType as internal - ProverProviderType = float64(coordinatorType.ProverProviderTypeInternal) - } - ptc.ProverProviderType = uint8(ProverProviderType.(float64)) - - hardForkNamesStr, hardForkNameExist := ctx.Get(coordinatorType.HardForkName) - if !hardForkNameExist { - return nil, errors.New("get hard fork name from context failed") - } - hardForkNames := strings.Split(hardForkNamesStr.(string), ",") - for _, hardForkName := range hardForkNames { - ptc.HardForkNames[hardForkName] = struct{}{} +func (b *BaseProverTask) checkParameter(ctx *gin.Context) *proverTaskContext { + pctx, exist := ctx.Get(proverTaskCtxKey) + if !exist { + return nil } - isBlocked, err := b.proverBlockListOrm.IsPublicKeyBlocked(ctx.Copy(), publicKey.(string)) - if err != nil { - return nil, fmt.Errorf("failed to check whether the public key %s is blocked before assigning a chunk task, err: %w, proverName: %s, proverVersion: %s", publicKey, err, proverName, proverVersion) - } - if isBlocked { - return nil, fmt.Errorf("public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", publicKey, proverName, proverVersion) - } - - assigned, err := b.proverTaskOrm.IsProverAssigned(ctx.Copy(), publicKey.(string)) - if err != nil { - return nil, fmt.Errorf("failed to check if prover %s is assigned a task, err: %w", publicKey.(string), err) - } - - ptc.hasAssignedTask = assigned - return &ptc, nil + return pctx.(*proverTaskContext) } func (b *BaseProverTask) applyUniversal(schema *coordinatorType.GetTaskSchema) (*coordinatorType.GetTaskSchema, []byte, error) { diff --git a/coordinator/test/api_test.go b/coordinator/test/api_test.go index 5d119e49c0..d0a4c36f42 100644 --- a/coordinator/test/api_test.go +++ b/coordinator/test/api_test.go @@ -234,7 +234,7 @@ func testGetTaskBlocked(t *testing.T) { err := proverBlockListOrm.InsertProverPublicKey(context.Background(), chunkProver.proverName, chunkProver.publicKey()) assert.NoError(t, err) - expectedErr := fmt.Errorf("return prover task err:check prover task parameter failed, error:public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", chunkProver.publicKey(), chunkProver.proverName, chunkProver.proverVersion) + expectedErr := fmt.Errorf("check prover task parameter failed, error:public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", chunkProver.publicKey(), chunkProver.proverName, chunkProver.proverVersion) code, errMsg := chunkProver.tryGetProverTask(t, message.ProofTypeChunk) assert.Equal(t, types.ErrCoordinatorGetTaskFailure, code) assert.Equal(t, expectedErr, errors.New(errMsg)) @@ -255,7 +255,7 @@ func testGetTaskBlocked(t *testing.T) { assert.Equal(t, types.ErrCoordinatorEmptyProofData, code) assert.Equal(t, expectedErr, errors.New(errMsg)) - expectedErr = fmt.Errorf("return prover task err:check prover task parameter failed, error:public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", batchProver.publicKey(), batchProver.proverName, batchProver.proverVersion) + expectedErr = fmt.Errorf("check prover task parameter failed, error:public key %s is blocked from fetching tasks. ProverName: %s, ProverVersion: %s", batchProver.publicKey(), batchProver.proverName, batchProver.proverVersion) code, errMsg = batchProver.tryGetProverTask(t, message.ProofTypeBatch) assert.Equal(t, types.ErrCoordinatorGetTaskFailure, code) assert.Equal(t, expectedErr, errors.New(errMsg))