Commit 9e91cc80 authored by vicotor's avatar vicotor

implement worker msg handler

parent 6ea89bcb
...@@ -2,6 +2,7 @@ package nmregistry ...@@ -2,6 +2,7 @@ package nmregistry
import ( import (
"context" "context"
"crypto/ecdsa"
"fmt" "fmt"
"github.com/odysseus/nodemanager/config" "github.com/odysseus/nodemanager/config"
"github.com/odysseus/nodemanager/utils" "github.com/odysseus/nodemanager/utils"
...@@ -31,14 +32,16 @@ type RegistryService struct { ...@@ -31,14 +32,16 @@ type RegistryService struct {
rdb *redis.Client rdb *redis.Client
conf *config.Config conf *config.Config
rw sync.RWMutex rw sync.RWMutex
public ecdsa.PublicKey
quit chan struct{} quit chan struct{}
} }
func NewRegistryService(conf *config.Config, rdb *redis.Client) *RegistryService { func NewRegistryService(conf *config.Config, rdb *redis.Client, public ecdsa.PublicKey) *RegistryService {
return &RegistryService{ return &RegistryService{
rdb: rdb, rdb: rdb,
conf: conf, conf: conf,
quit: make(chan struct{}), public: public,
quit: make(chan struct{}),
} }
} }
...@@ -86,7 +89,7 @@ func (s *RegistryService) registry(rdb *redis.Client) error { ...@@ -86,7 +89,7 @@ func (s *RegistryService) registry(rdb *redis.Client) error {
err = rdb.HSet(context.Background(), config.NODE_MANAGER_SET+addr, RegistryInfo{ err = rdb.HSet(context.Background(), config.NODE_MANAGER_SET+addr, RegistryInfo{
Pubkey: pubHex, Pubkey: pubHex,
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
Endpoint: config.GetConfig().Endpoint, Endpoint: s.conf.Endpoint,
Addr: addr, Addr: addr,
}).Err() }).Err()
return err return err
......
package server package server
import ( import (
"crypto/ecdsa"
"crypto/rand"
"github.com/odysseus/nodemanager/config" "github.com/odysseus/nodemanager/config"
"github.com/odysseus/nodemanager/nmregistry" "github.com/odysseus/nodemanager/nmregistry"
"github.com/odysseus/nodemanager/utils" "github.com/odysseus/nodemanager/utils"
...@@ -16,6 +18,7 @@ type Node struct { ...@@ -16,6 +18,7 @@ type Node struct {
apiServer *grpc.Server apiServer *grpc.Server
rdb *redis.Client rdb *redis.Client
wm *WorkerManager wm *WorkerManager
privk *ecdsa.PrivateKey
} }
func NewNode() *Node { func NewNode() *Node {
...@@ -25,16 +28,27 @@ func NewNode() *Node { ...@@ -25,16 +28,27 @@ func NewNode() *Node {
Password: redisConfig.Password, Password: redisConfig.Password,
DbIndex: redisConfig.DbIndex, DbIndex: redisConfig.DbIndex,
}) })
privk, err := utils.HexToPrivatekey(config.GetConfig().PrivateKey)
if err != nil {
log.WithError(err).Error("failed to parse node manager private key")
return nil
}
node := &Node{ node := &Node{
rdb: rdb, rdb: rdb,
wm: NewWorkerManager(rdb), privk: privk,
apiServer: grpc.NewServer(grpc.MaxSendMsgSize(1024*1024*20), grpc.MaxRecvMsgSize(1024*1024*20)), apiServer: grpc.NewServer(grpc.MaxSendMsgSize(1024*1024*20), grpc.MaxRecvMsgSize(1024*1024*20)),
registry: nmregistry.NewRegistryService(config.GetConfig(), rdb), registry: nmregistry.NewRegistryService(config.GetConfig(), rdb, privk.PublicKey),
} }
node.wm = NewWorkerManager(rdb, node)
return node return node
} }
func (n *Node) Sign(hash []byte) ([]byte, error) {
return n.privk.Sign(rand.Reader, hash, nil)
}
func (n *Node) Start() error { func (n *Node) Start() error {
n.registry.Start() n.registry.Start()
if err := n.apiStart(); err != nil { if err := n.apiStart(); err != nil {
......
...@@ -4,10 +4,13 @@ import ( ...@@ -4,10 +4,13 @@ import (
"errors" "errors"
"fmt" "fmt"
lru "github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
"github.com/odysseus/nodemanager/utils"
odysseus "github.com/odysseus/odysseus-protocol/gen/proto/go/base/v1" odysseus "github.com/odysseus/odysseus-protocol/gen/proto/go/base/v1"
omanager "github.com/odysseus/odysseus-protocol/gen/proto/go/nodemanager/v1" omanager "github.com/odysseus/odysseus-protocol/gen/proto/go/nodemanager/v1"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/crypto/sha3"
"math/big"
"sync" "sync"
"time" "time"
) )
...@@ -27,6 +30,7 @@ type Worker struct { ...@@ -27,6 +30,7 @@ type Worker struct {
resultCh chan *omanager.SubmitTaskResult resultCh chan *omanager.SubmitTaskResult
uuid int64 uuid int64
publicKey string publicKey string
deviceInfo []*omanager.DeviceInfo
recentTask *lru.Cache recentTask *lru.Cache
stream omanager.NodeManagerService_RegisterWorkerServer stream omanager.NodeManagerService_RegisterWorkerServer
} }
...@@ -39,17 +43,24 @@ type WorkerManager struct { ...@@ -39,17 +43,24 @@ type WorkerManager struct {
workers map[int64]*Worker workers map[int64]*Worker
wkRwLock sync.RWMutex wkRwLock sync.RWMutex
quit chan struct{} quit chan struct{}
node *Node
} }
func NewWorkerManager(rdb *redis.Client) *WorkerManager { func NewWorkerManager(rdb *redis.Client, node *Node) *WorkerManager {
return &WorkerManager{ return &WorkerManager{
heartBeat: make(map[int64]int64), heartBeat: make(map[int64]int64),
workers: make(map[int64]*Worker), workers: make(map[int64]*Worker),
quit: make(chan struct{}), quit: make(chan struct{}),
rdb: rdb, rdb: rdb,
node: node,
} }
} }
func (wm *WorkerManager) Stop() {
close(wm.quit)
}
func (wm *WorkerManager) UpdateHeartBeat(uuid int64) { func (wm *WorkerManager) UpdateHeartBeat(uuid int64) {
wm.hbRwLock.Lock() wm.hbRwLock.Lock()
defer wm.hbRwLock.Unlock() defer wm.hbRwLock.Unlock()
...@@ -192,6 +203,7 @@ func (wm *WorkerManager) manageWorker(worker *Worker) error { ...@@ -192,6 +203,7 @@ func (wm *WorkerManager) manageWorker(worker *Worker) error {
msg.Message = taskMsg msg.Message = taskMsg
callback = func(err error) bool { callback = func(err error) bool {
if err == nil { if err == nil {
// add task to cache.
worker.recentTask.Add(task.TaskId, task) worker.recentTask.Add(task.TaskId, task)
} }
...@@ -202,6 +214,45 @@ func (wm *WorkerManager) manageWorker(worker *Worker) error { ...@@ -202,6 +214,45 @@ func (wm *WorkerManager) manageWorker(worker *Worker) error {
} }
return true return true
} }
case result := <-worker.resultCh:
// verify result and make a new signature.
data, exist := worker.recentTask.Get(result.TaskId)
if !exist {
log.WithField("worker", worker.uuid).Error("task not found for verify result")
continue
}
task := data.(*odysseus.TaskContent)
if result.TaskId != task.TaskId {
log.WithField("worker", worker.uuid).Error("task id not match")
continue
}
// todo: verify container_signature and miner_signature
//manager_signature = sign(hash((task_id+hash(task_param)+hash(task_result)+container_signature+miner_signature+workload))
paramHash := sha3.Sum256(task.TaskParam)
resultHash := sha3.Sum256(result.TaskResult)
dataHash := sha3.Sum256(utils.CombineBytes([]byte(result.TaskId), paramHash[:], resultHash[:],
result.ContainerSignature, result.MinerSignature, big.NewInt(0).Bytes()))
//result.ContainerSignature, result.MinerSignature, big.NewInt(int64(task.Workload)).Bytes()))
signature, err := wm.node.Sign(dataHash[:])
if err != nil {
log.WithError(err).Error("sign result failed")
continue
}
proof := new(omanager.ManagerMessage_ProofTaskResult)
proof.ProofTaskResult = &omanager.ProofTaskResult{
TaskId: result.TaskId,
ManagerSignature: signature,
}
callback = func(err error) bool {
if err == nil {
// remove task from cache.
worker.recentTask.Remove(result.TaskId)
}
// todo: post event for task succeed or failed
return true
}
} }
if msg != nil { if msg != nil {
...@@ -235,10 +286,35 @@ func (wm *WorkerManager) handleWorkerMsg(worker *Worker) { ...@@ -235,10 +286,35 @@ func (wm *WorkerManager) handleWorkerMsg(worker *Worker) {
close(worker.taskCh) close(worker.taskCh)
return return
case *omanager.WorkerMessage_SubmitTaskResult: case *omanager.WorkerMessage_SubmitTaskResult:
worker.resultCh <- msg.SubmitTaskResult
case *omanager.WorkerMessage_HeartbeatResponse: case *omanager.WorkerMessage_HeartbeatResponse:
wm.UpdateHeartBeat(worker.uuid)
log.WithFields(log.Fields{
"worker": worker.uuid,
"hearBeat": time.Now().Unix() - int64(msg.HeartbeatResponse.Timestamp),
}).Debug("receive worker heartbeat")
case *omanager.WorkerMessage_Status: case *omanager.WorkerMessage_Status:
// todo: store worker status
log.WithFields(log.Fields{
"worker": worker.uuid,
}).Debugf("receive worker status:0x%x", msg.Status.DeviceStatus)
case *omanager.WorkerMessage_DeviceInfo: case *omanager.WorkerMessage_DeviceInfo:
// todo: handler worker device info
log.WithFields(log.Fields{
"worker": worker.uuid,
}).Debugf("receive worker device info:%v", msg.DeviceInfo.Devices)
if worker.deviceInfo == nil {
// first time receive device info
worker.publicKey = msg.DeviceInfo.MinerPubkey
worker.deviceInfo = msg.DeviceInfo.Devices
}
case *omanager.WorkerMessage_DeviceUsage: case *omanager.WorkerMessage_DeviceUsage:
// todo: handler worker device usage
log.WithFields(log.Fields{
"worker": worker.uuid,
}).Debugf("receive worker device usage:%v", msg.DeviceUsage.Usage)
default: default:
log.WithField("worker", worker.uuid).Error(fmt.Sprintf("unsupport msg type %T", msg)) log.WithField("worker", worker.uuid).Error(fmt.Sprintf("unsupport msg type %T", msg))
} }
......
package utils
func CombineBytes(b ...[]byte) []byte {
var result []byte
for _, v := range b {
result = append(result, v...)
}
return result
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment