Commit 7c4ee096 authored by 贾浩@五瓣科技's avatar 贾浩@五瓣科技

update

parent c01567df
package api
import (
"context"
"net"
"witness/core"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
witnessv1 "github.com/odysseus/odysseus-protocol/gen/proto/go/witness/v1"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
)
type Server struct {
witnessv1.UnimplementedWitnessServiceServer
w *core.Witness
}
func (s *Server) WitnessStatus(ctx context.Context, req *witnessv1.WitnessStatusRequest) (resp *witnessv1.WitnessStatusResponse, err error) {
return nil, nil
}
func (s *Server) PushProof(ctx context.Context, req *witnessv1.PushProofRequest) (resp *witnessv1.PushProofResponse, err error) {
return resp, nil
}
func (s *Server) GetMinerProof(ctx context.Context, req *witnessv1.GetMinerProofRequest) (*witnessv1.GetMinerProofResponse, error) {
return nil, nil
}
func (s *Server) GetProof(ctx context.Context, req *witnessv1.GetProofRequest) (*witnessv1.GetProofResponse, error) {
return nil, nil
}
func (s *Server) GetWithdrawProof(ctx context.Context, req *witnessv1.GetWithdrawProofRequest) (*witnessv1.GetWithdrawProofResponse, error) {
return nil, nil
}
func StartGRPC(listenAddress string, w *core.Witness) {
ln, err := net.Listen("tcp", listenAddress)
if err != nil {
log.WithError(err).Errorf("failed to listen on %s", listenAddress)
return
}
log.WithField("listen", listenAddress).Info("start gRPC server")
server := grpc.NewServer(grpc.MaxRecvMsgSize(1024*1024*1024), grpc.MaxSendMsgSize(1024*1024*1024))
witnessv1.RegisterWitnessServiceServer(server, &Server{w: w})
grpc_prometheus.Register(server)
err = server.Serve(ln)
if err != nil {
log.WithError(err).Error("failed to serve")
return
}
server.Stop()
}
package api
import (
"context"
"testing"
witnessv1 "github.com/odysseus/odysseus-protocol/gen/proto/go/witness/v1"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
func TestClient(t *testing.T) {
client, err := grpc.Dial("127.0.0.1:9431", grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(
grpc.MaxCallRecvMsgSize(1024*1024*1024),
grpc.MaxCallSendMsgSize(1024*1024*1024)),
)
if err != nil {
panic(err)
}
gc := witnessv1.NewWitnessServiceClient(client)
req := &witnessv1.PushProofRequest{
Proofs: []*witnessv1.Proof{
{
Workload: 10,
TaskId: "1",
ReqHash: []byte("req"),
RespHash: []byte("resp"),
ManagerSignature: nil,
ContainerSignature: nil,
MinerSignature: nil,
},
},
RewardAddress: "",
MinerAddress: "",
}
resp, err := gc.PushProof(context.Background(), req)
if err != nil {
log.Fatal(err)
}
log.Info(resp.Workload)
}
......@@ -69,6 +69,9 @@ func rpcHandle(w http.ResponseWriter, r *http.Request) {
case "getDailyMerkleNodes":
getDailyMerkleNodes(req.Params, resp)
_ = json.NewEncoder(w).Encode(resp)
case "getDailyMerkleSumNodes":
getDailyMerkleSumNodes(req.Params, resp)
_ = json.NewEncoder(w).Encode(resp)
default:
resp.Error = &jsonError{
Code: -32601,
......@@ -198,6 +201,55 @@ func getDailyMerkleNodes(params []byte, resp *jsonrpcMessage) {
resp.Result, _ = json.Marshal(nodes)
}
func getDailyMerkleSumNodes(params []byte, resp *jsonrpcMessage) {
// date string, depth int, rootHash common.Hash
paramList := make([]interface{}, 0)
err := json.Unmarshal(params, &paramList)
if err != nil || len(paramList) < 1 || len(paramList) > 3 {
resp.Error = &jsonError{
Code: -32602,
Message: "invalid params",
}
return
}
var date string
var depth = 1
var rootHash common.Hash
_, err = time.Parse("2006-01-02", paramList[1].(string))
if err != nil {
resp.Error = &jsonError{
Code: -32602,
Message: "invalid params",
}
return
}
date = paramList[1].(string)
if len(paramList) >= 2 {
_depth, ok := paramList[2].(float64)
if !ok {
resp.Error = &jsonError{
Code: -32602,
Message: "invalid params",
}
return
}
depth = int(uint(_depth))
}
if len(paramList) >= 3 {
rootHash = common.HexToHash(paramList[3].(string))
}
nodes, vals := witness.GetDailyMerkleSumNodes(date, depth, rootHash)
resp.Result, _ = json.Marshal(map[string]interface{}{
"nodes": nodes,
"values": vals,
})
}
func StartJSONRPC(listenAddress string, w *core.Witness) {
witness = w
http.HandleFunc("/", rpcHandle)
......
......@@ -105,10 +105,6 @@ func runMetrics(listen string) {
}()
}
func runGrpcServer(listen string, w *core.Witness) {
go api.StartGRPC(listen, w)
}
func runJSONRPCServer(listen string, w *core.Witness) {
go api.StartJSONRPC(listen, w)
}
package core
import (
"fmt"
"math/big"
"witness/tree"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
log "github.com/sirupsen/logrus"
)
func (w *Witness) GetPendingWorkload(address common.Address) (workload, globalWorkload uint64) {
wl, err := w.q.GetPendingWorkload(w.todayTimestamp(), address.Hex())
if err != nil {
log.WithError(err).Error("failed to get pending workload")
return
}
log.WithField("workload", wl).Debug("quest get pending workload")
return wl, w.pendingWorkload
}
func (w *Witness) GetMerkleProof(address common.Address, date string) (balance string, proofs []common.Hash) {
if date == "" {
date = w.date
}
w.Lock()
cacheTree, ok := w.mtTreeCache[date]
w.Unlock()
if !ok {
if ok = w.LoadMerkleTree(date); !ok {
log.WithFields(log.Fields{
"date": date,
}).Error("load merkle proof empty")
return "0", nil
}
}
w.Lock()
cacheTree = w.mtTreeCache[date]
w.Unlock()
dateStateRootKey := fmt.Sprintf("sroot:%s", date)
dateStateRoot, err := w.lvdb.Get([]byte(dateStateRootKey))
if err != nil {
log.WithFields(log.Fields{
"key": dateStateRootKey,
"err": err.Error(),
}).Error("failed to get state root")
return "0", nil
}
var sdb *StateDB
if date == w.date {
sdb = w.state
} else {
sdb, err = NewStateDB(w.lvdb, common.BytesToHash(dateStateRoot))
if err != nil {
log.WithError(err).Error("failed to create state db")
return "0", nil
}
}
object := sdb.GetMinerObject(address)
if object == nil {
return "0", nil
}
bigBalance, _ := new(big.Int).SetString(object.Balance, 10)
payload := append(common.HexToAddress(object.Miner).Bytes(), common.LeftPadBytes(bigBalance.Bytes(), 32)...)
leaf := crypto.Keccak256Hash(payload)
proofs, err = cacheTree.GetProof(leaf)
if err != nil {
log.WithError(err).Error("failed to get merkle proof")
return "0", make([]common.Hash, 0)
}
return object.Balance, proofs
}
func (w *Witness) GetDailyMerkleNodes(date string, depth int, rootHash common.Hash) (nodes [][]common.Hash) {
if date == "" {
date = w.date
}
w.Lock()
cacheTree, ok := w.mtTreeCache[date]
w.Unlock()
if !ok {
if ok = w.LoadMerkleTree(date); !ok {
log.WithFields(log.Fields{
"date": date,
}).Error("load merkle proof empty")
return nil
}
}
w.Lock()
cacheTree = w.mtTreeCache[date]
w.Unlock()
rootNode := cacheTree.GetRootNode()
if rootHash.Hex() == (common.Hash{}).Hex() {
rootNode = cacheTree.GetRootNode()
} else if rootHash.Hex() != rootNode.Hash.Hex() {
rootNode = cacheTree.FindMerkleNode(rootHash)
}
return tree.MerkleTreeTraversal(rootNode, depth)
}
func (w *Witness) GetDailyMerkleSumNodes(date string, depth int, rootHash common.Hash) (nodesHash [][]common.Hash, nodesVal [][]string) {
if date == "" {
date = w.date
}
w.Lock()
cacheTree, ok := w.mstTreeCache[date]
w.Unlock()
if !ok {
if ok = w.LoadMerkleSumTree(date); !ok {
log.WithFields(log.Fields{
"date": date,
}).Error("load merkle sum proof empty")
return nil, nil
}
}
w.Lock()
cacheTree = w.mstTreeCache[date]
w.Unlock()
rootNode := cacheTree.GetRoot()
if rootHash.Hex() == (common.Hash{}).Hex() {
rootNode = cacheTree.GetRoot()
} else if rootHash.Hex() != rootNode.Value.Hash.Hex() {
rootNode = cacheTree.FindMerkleSumNode(rootHash)
}
vlss := tree.MerkleSumTreeTraversal(rootNode, depth)
for i := 0; i < len(vlss); i++ {
_nodesHash := make([]common.Hash, 0)
_nodesVal := make([]string, 0)
for j := 0; j < len(vlss[i]); j++ {
_nodesHash = append(_nodesHash, vlss[i][j].Hash)
_nodesVal = append(_nodesVal, vlss[i][j].BigValue.String())
}
nodesHash = append(nodesHash, _nodesHash)
nodesVal = append(nodesVal, _nodesVal)
}
return nodesHash, nodesVal
}
package core
import (
"time"
log "github.com/sirupsen/logrus"
)
// UpdateContractAddressJob 定时更新合约内的地址
func (w *Witness) UpdateContractAddressJob() {
ticker := time.NewTicker(time.Minute * 10)
defer ticker.Stop()
log.Info("start update address task")
for {
addrs, err := w.rpc.GetContainerAddresses()
if err != nil {
log.WithError(err).Error("failed to get container addresses")
} else {
w.containerAddresses = addrs
}
addrs, err = w.rpc.GetNMAddresses()
if err != nil {
log.WithError(err).Error("failed to get NM addresses")
} else {
w.nmAddresses = addrs
}
log.WithFields(log.Fields{
"container_count": len(w.containerAddresses),
"nm_count": len(w.nmAddresses),
}).Info("store contract update address")
<-ticker.C
}
}
// UpdateGlobalWorkloadJob 定时从quest更新全局workload
func (w *Witness) UpdateGlobalWorkloadJob() {
ticker := time.NewTicker(time.Second * 30)
defer ticker.Stop()
log.Info("start update global workload task")
for {
wl := w.GetGlobalWorkload()
if wl > 0 {
w.pendingWorkload = wl
}
<-ticker.C
}
}
......@@ -27,17 +27,16 @@ func (w *Witness) LoadPendingProofs(startTimestamp, endTimestamp int64) {
}
for _, dbProof := range dbProofs {
miner, proof := w.VerifyProof(dbProof)
miner, proof := w.verifyProof(dbProof)
if proof != nil {
// w.AddPendingProof(miner, proof)
w.AddPendingProof(miner, proof)
}
_ = miner
lastTaskID = dbProof.TaskId
}
}
}
func (w *Witness) VerifyProof(dbProof *quest.ProofModel) (miner common.Address, proof *witnessv1.ValidatedProof) {
func (w *Witness) verifyProof(dbProof *quest.ProofModel) (miner common.Address, proof *witnessv1.ValidatedProof) {
if dbProof.TaskWorkload == 0 {
return
}
......@@ -83,3 +82,13 @@ func (w *Witness) VerifyProof(dbProof *quest.ProofModel) (miner common.Address,
}
return common.HexToAddress(dbProof.TaskProfitAccount), proof
}
func (w *Witness) GetGlobalWorkload() uint64 {
workload, err := w.q.GetGlobalWorkload(w.todayTimestamp())
if err != nil {
log.WithError(err).Error("failed to get global workload")
return 0
}
log.WithField("workload", workload).Debug("quest get global workload")
return workload
}
package core
import (
"bytes"
"fmt"
"math/big"
"sort"
"time"
"witness/tree"
"witness/util"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
witnessv1 "github.com/odysseus/odysseus-protocol/gen/proto/go/witness/v1"
log "github.com/sirupsen/logrus"
"github.com/syndtr/goleveldb/leveldb/errors"
)
// CommitMST commit workload of per day
func (w *Witness) CommitMST(proofMap map[common.Address]*witnessv1.ValidatedProof) (root common.Hash, sum *big.Int, err error) {
if len(proofMap) == 0 {
return common.Hash{}, big.NewInt(0), nil
}
rawKeys := make([]common.Address, 0)
keys := make([][]byte, len(proofMap))
dbKey := make([]byte, 0)
dbVal := make([]byte, 0)
vals := make([]*big.Int, len(proofMap))
for key := range proofMap {
rawKeys = append(rawKeys, key)
}
sort.Sort(util.AddressSlice(rawKeys))
for i, key := range rawKeys {
dbKey = append(dbKey, rawKeys[i][:]...)
keys[i] = rawKeys[i].Bytes()
vals[i] = new(big.Int).SetUint64(proofMap[key].Workload)
dbVal = append(append(dbVal, vals[i].Bytes()...), []byte(":")...)
}
st := time.Now()
mstTree := tree.NewMerkleSumTree(keys, vals)
rootNode := mstTree.GetRoot()
err = w.lvdb.Put([]byte(fmt.Sprintf("mstroot:%s", w.date)), root.Bytes())
if err != nil {
log.Error(err)
return
}
err = w.lvdb.Put([]byte(fmt.Sprintf("mstsum:%s", w.date)), sum.Bytes())
if err != nil {
log.Error(err)
return
}
err = w.lvdb.Put([]byte(fmt.Sprintf("mstk:%s", w.date)), dbKey)
if err != nil {
return
}
err = w.lvdb.Put([]byte(fmt.Sprintf("mstv:%s", w.date)), dbVal)
if err != nil {
return
}
log.WithFields(log.Fields{
"root": root.Hex(),
"sum": sum.String(),
"cost": time.Since(st).String(),
}).Info("commit MST root")
return rootNode.Value.Hash, rootNode.Value.BigValue, nil
}
// CommitMT commit all workload
func (w *Witness) CommitMT(objects []*witnessv1.MinerObject) (root common.Hash, err error) {
if len(objects) == 0 {
return common.Hash{}, nil
}
merkleProofs := make(tree.Proofs, 0)
dbProofs := make([]byte, 0)
for _, object := range objects {
bigBalance, _ := new(big.Int).SetString(object.Balance, 10)
payload := append(common.HexToAddress(object.Miner).Bytes(), common.LeftPadBytes(bigBalance.Bytes(), 32)...)
_proof := crypto.Keccak256Hash(payload)
merkleProofs = append(merkleProofs, _proof)
dbProofs = append(dbProofs, _proof[:]...)
}
mtTree, err := tree.NewMerkleTree(merkleProofs)
if err != nil {
return
}
root = mtTree.GetRoot()
st := time.Now()
err = w.lvdb.Put([]byte(fmt.Sprintf("mtroot:%s", w.date)), root.Bytes())
if err != nil {
return
}
log.WithFields(log.Fields{
"k": fmt.Sprintf("mtroot:%s", w.date),
"v": root.String(),
}).Debug()
err = w.lvdb.Put([]byte(fmt.Sprintf("mtk:%s", w.date)), dbProofs)
if err != nil {
return
}
log.WithFields(log.Fields{
"k": fmt.Sprintf("mtk:%s", w.date),
"v_length": len(dbProofs),
}).Debug()
w.mtTreeCache[w.date] = mtTree
log.WithFields(log.Fields{
"root": root.Hex(),
"cost": time.Since(st).String(),
}).Info("commit MT root")
return
}
func (w *Witness) LoadMerkleTree(date string) (ok bool) {
if date == "" {
return false
}
merkleTreeKey := fmt.Sprintf("mtk:%s", date)
data, err := w.lvdb.Get([]byte(merkleTreeKey))
if err != nil {
if err == errors.ErrNotFound {
return
}
log.WithError(err).Error("failed to load merkle proof")
return
}
log.WithFields(log.Fields{
"key": merkleTreeKey,
"length": len(data),
}).Info("diskdb load merkle proof")
proofs := make([]common.Hash, len(data)/32)
for i := 0; i < len(data)/32; i++ {
copy(proofs[i][:], data[i*32:(i+1)*32])
}
mTree, err := tree.NewMerkleTree(proofs)
if err != nil {
log.WithError(err).Error("failed to load merkle proof")
return
}
w.Lock()
w.mtTreeCache[date] = mTree
w.Unlock()
log.WithFields(log.Fields{"date": date, "root": mTree.GetRoot().Hex()}).Info("load merkle tree")
return true
}
func (w *Witness) LoadMerkleSumTree(date string) (ok bool) {
if date == "" {
return false
}
merkleSumTreeKey := fmt.Sprintf("mstk:%s", date)
keyData, err := w.lvdb.Get([]byte(merkleSumTreeKey))
if err != nil {
if err == errors.ErrNotFound {
return
}
log.WithError(err).Error("failed to load merkle sum proof k")
return
}
log.WithFields(log.Fields{
"key": merkleSumTreeKey,
"length": len(keyData),
}).Info("diskdb load merkle sum proof key")
// data = addr1:addr2
datas := make([][]byte, len(keyData)/20)
for i := 0; i < len(keyData)/20; i++ {
copy(datas[i], keyData[i*20:(i+1)*20])
}
merkleSumTreeVal := fmt.Sprintf("mstv:%s", date)
valData, err := w.lvdb.Get([]byte(merkleSumTreeVal))
if err != nil {
if err == errors.ErrNotFound {
return
}
log.WithError(err).Error("failed to load merkle sum proof v")
return
}
log.WithFields(log.Fields{
"key": merkleSumTreeVal,
"length": len(valData),
}).Info("diskdb load merkle sum proof val")
vals := bytes.Split(valData, []byte(":"))
bigVals := make([]*big.Int, len(vals))
for i := range vals {
bigVals[i] = big.NewInt(0).SetBytes(vals[i])
}
mstTree := tree.NewMerkleSumTree(datas, bigVals)
w.Lock()
w.mstTreeCache[date] = mstTree
w.Unlock()
log.WithFields(log.Fields{"date": date, "root": mstTree.GetRoot()}).Info("load merkle sum tree")
return true
}
This diff is collapsed.
......@@ -40,3 +40,30 @@ func (q *Quest) GetProofs(startTimestamp, endTimestamp int64, lastTaskID string,
err = q.db.Debug().Raw(querySQL, startTimestamp, endTimestamp, lastTaskID, limit).Scan(&proofs).Error
return
}
func (q *Quest) GetPendingWorkload(startTimestamp int64, address string) (workload uint64, err error) {
proof := new(ProofModel)
querySQL := "SELECT " +
"`TaskWorkload`, `TaskReqHash`, `TaskRespHash`, `TaskManagerSignature`, `TaskContainerSignature`, `TaskMinerSignature`, `TaskProfitAccount`, `TaskWorkerAccount` " +
"FROM `proof` " +
"WHERE `TaskFinishTimestamp` >= ?" +
"AND `TaskProfitAccount` = ? ;"
err = q.db.Debug().Raw(querySQL, startTimestamp, address).First(&proof).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return 0, nil
}
return 0, err
}
return proof.TaskWorkload, nil
}
func (q *Quest) GetGlobalWorkload(startTimestamp int64) (workload uint64, err error) {
querySQL := "SELECT " +
"SUM(`TaskWorkload`) " +
"FROM `proof` " +
"WHERE `TaskFinishTimestamp` >= ? ;"
err = q.db.Debug().Raw(querySQL, startTimestamp).Row().Scan(&workload)
return workload, err
}
......@@ -10,7 +10,7 @@ merkle sum tree:
mstsum:2020-01-01 -> sum
mstk:2020-01-01 -> key1(bytes32)key2(bytes32)
mstk:2020-01-01 -> key1(address bytes20)key2(address bytes20)
mstv:2020-01-01 -> val1:val2
......@@ -30,7 +30,3 @@ state:
txid:2020-0101 -> txhash(32bytes)
unconfirmed proof:
proof:2020-01-01 -> json(map[common.Address]*witnessv1.ValidatedProof)
\ No newline at end of file
......@@ -13,7 +13,7 @@ func (x Proofs) Len() int { return len(x) }
func (x Proofs) Less(i, j int) bool { return bytes.Compare(x[i][:], x[j][:]) == -1 }
func (x Proofs) Swap(i, j int) { x[i], x[j] = x[j], x[i] }
func (x Proofs) Dedup() Proofs {
func (x Proofs) Distinct() Proofs {
r := make(Proofs, 0, len(x))
for idx, el := range x {
if idx == 0 || !bytes.Equal(x[idx-1][:], el[:]) {
......
......@@ -13,7 +13,7 @@ import (
func TestMT(t *testing.T) {
proofs := make([]common.Hash, 0)
for i := 100; i < 200000; i++ {
for i := 1; i < 6; i++ {
_proof := crypto.Keccak256Hash(big.NewInt(0).SetInt64(int64(i)).Bytes())
proofs = append(proofs, _proof)
}
......@@ -50,11 +50,11 @@ func TestMT(t *testing.T) {
t.Log("---")
t.Log(time.Since(st))
tnode := tree.FindNode(common.HexToHash("0xa05a2b09dda15efbc0e7a2f9cd779d902d765c5ecb1abb3ef06a450f55a9fce7"))
t.Log("tnode", tnode)
t.Log(time.Since(st))
// tnode := tree.FindMerkleNode(common.HexToHash("0xa05a2b09dda15efbc0e7a2f9cd779d902d765c5ecb1abb3ef06a450f55a9fce7"))
// t.Log("tnode", tnode)
// t.Log(time.Since(st))
layers := Traversal(tnode, 5)
layers := MerkleTreeTraversal(tree.treeNode, 5)
t.Log(time.Since(st))
for _, layer := range layers {
......
package tree
import (
"bytes"
"math/big"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
)
var (
LeafPrefix = []byte{0x00}
NodePrefix = []byte{0x01}
)
type MerkleSumTree struct {
sideNodes [][32]byte
nodeSums []*big.Int
Levels [][]MerkleSumNode
RootNode *MerkleSumNode
}
func leafDigest(data [32]byte, val *big.Int) (digest [32]byte) {
valBytes := make([]byte, 32)
copy(valBytes[32-len(val.Bytes()):], val.Bytes())
buf := bytes.NewBuffer(LeafPrefix)
buf.Write(valBytes)
buf.Write(data[:])
copy(digest[:], crypto.Keccak256Hash(buf.Bytes()).Bytes())
return
type MerkleSumNode struct {
Left *MerkleSumNode
Right *MerkleSumNode
Value Value
}
func nodeDigest(leftDigest, rightDigest [32]byte, leftVal, rightVal *big.Int) (digest [32]byte) {
leftValBytes := make([]byte, 32)
rightValBytes := make([]byte, 32)
type Value struct {
BigValue *big.Int
Hash common.Hash
Raw []byte
}
copy(leftValBytes[32-len(leftVal.Bytes()):], leftVal.Bytes())
copy(rightValBytes[32-len(rightVal.Bytes()):], rightVal.Bytes())
func NewMerkleSumTree(datas [][]byte, vals []*big.Int) *MerkleSumTree {
var nodes []MerkleSumNode
var levels [][]MerkleSumNode
buf := bytes.NewBuffer(NodePrefix)
buf.Write(leftValBytes)
buf.Write(leftDigest[:])
buf.Write(rightValBytes)
buf.Write(rightDigest[:])
copy(digest[:], crypto.Keccak256Hash(buf.Bytes()).Bytes())
return
if len(datas)%2 != 0 {
datas = append(datas, []byte{})
vals = append(vals, big.NewInt(0))
}
for i := range datas {
nodes = append(nodes, *buildToNode(datas[i], vals[i]))
}
countOfDataNodes := len(nodes)
counterOfLevels := 0
for countOfDataNodes > 1 {
if countOfDataNodes%2 == 0 {
countOfDataNodes = countOfDataNodes / 2
counterOfLevels++
} else {
countOfDataNodes = (countOfDataNodes + 1) / 2
counterOfLevels++
}
}
levels = [][]MerkleSumNode{nodes}
for i := 0; i < counterOfLevels; i++ {
var level []MerkleSumNode
lastNodeIndex := len(nodes) - 1
for j := 0; j <= lastNodeIndex; j += 2 {
if j == lastNodeIndex && lastNodeIndex%2 == 0 {
node := newMerkleSumNode(&nodes[j], nil)
level = append(level, *node)
} else {
node := newMerkleSumNode(&nodes[j], &nodes[j+1])
level = append(level, *node)
}
}
nodes = level
levels = append(levels, level)
}
tree := MerkleSumTree{levels, &nodes[0]}
return &tree
}
func ComputeMST(dataList [][32]byte, values []*big.Int) (nodes [][32]byte, sums []*big.Int) {
nodeDigests := make([][32]byte, len(dataList))
nodeSums := make([]*big.Int, len(dataList))
func newMerkleSumNode(left, right *MerkleSumNode) *MerkleSumNode {
var node MerkleSumNode
if right == nil {
concatLeftNodeData := append(left.Value.BigValue.Bytes(), left.Value.Hash[:]...)
concatRightNodeData := append(big.NewInt(0).Bytes(), (common.Hash{}).Bytes()...)
prevHashes := append(concatLeftNodeData, concatRightNodeData...)
node.Value.Hash = crypto.Keccak256Hash(prevHashes)
node.Value.BigValue = left.Value.BigValue
} else {
concatLeftNodeData := append(left.Value.BigValue.Bytes(), left.Value.Hash[:]...)
concatRightNodeData := append(right.Value.BigValue.Bytes(), right.Value.Hash[:]...)
prevHashes := append(concatLeftNodeData, concatRightNodeData...)
node.Value.Hash = crypto.Keccak256Hash(prevHashes)
node.Value.BigValue = big.NewInt(0).Add(left.Value.BigValue, right.Value.BigValue)
}
for i := 0; i < len(dataList); i++ {
nodeDigests[i] = leafDigest(dataList[i], values[i])
nodeSums[i] = values[i]
node.Left = left
node.Right = right
return &node
}
func buildToNode(data []byte, val *big.Int) *MerkleSumNode {
node := MerkleSumNode{
Value: Value{
BigValue: val,
Hash: crypto.Keccak256Hash(append(val.Bytes(), data...)),
Raw: data,
},
}
return &node
}
func (m *MerkleSumTree) GetRoot() *MerkleSumNode {
return m.RootNode
}
func (m *MerkleSumTree) FindMerkleSumNode(hash common.Hash) *MerkleSumNode {
if m.RootNode == nil {
return nil
}
odd := len(dataList) & 1
size := (len(dataList) + 1) >> 1
queue := []*MerkleSumNode{m.RootNode}
pNodeDigests := make([][32]byte, len(dataList))
pNodeSums := make([]*big.Int, len(dataList))
for len(queue) > 0 {
size := len(queue)
copy(pNodeDigests, nodeDigests)
copy(pNodeSums, nodeSums)
for i := 0; i < size; i++ {
currentNode := queue[0]
queue = queue[1:]
for {
i := 0
for i < size-odd {
j := i << 1
nodeDigests[i] = nodeDigest(pNodeDigests[j], pNodeDigests[j+1], pNodeSums[j], pNodeSums[j+1])
nodeSums[i] = new(big.Int).Add(pNodeSums[j], pNodeSums[j+1])
i++
if currentNode.Value.Hash.Hex() == hash.Hex() {
return currentNode
}
if odd == 1 {
nodeDigests[i] = pNodeDigests[i<<1]
nodeSums[i] = pNodeSums[i<<1]
if currentNode.Left != nil {
queue = append(queue, currentNode.Left)
}
if currentNode.Right != nil {
queue = append(queue, currentNode.Right)
}
}
}
return nil
}
if size == 1 {
break
func MerkleSumTreeTraversal(root *MerkleSumNode, depth int) (retNodes [][]Value) {
if root == nil {
return nil
}
odd = size & 1
size = (size + 1) >> 1
copy(pNodeDigests, nodeDigests)
copy(pNodeSums, nodeSums)
queue := []*MerkleSumNode{root}
currentDepth := 1
for len(queue) > 0 {
size := len(queue)
layerNodes := make([]Value, 0)
for i := 0; i < size; i++ {
currentNode := queue[0]
queue = queue[1:]
layerNodes = append(layerNodes, currentNode.Value)
if currentNode.Left != nil {
queue = append(queue, currentNode.Left)
}
if currentNode.Right != nil {
queue = append(queue, currentNode.Right)
}
return nodeDigests, nodeSums
}
retNodes = append(retNodes, layerNodes)
currentDepth++
if currentDepth > depth {
break
}
}
return
}
package tree
import (
"crypto/sha256"
"fmt"
"math/big"
"testing"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
)
func Test(t *testing.T) {
dataList := [][]byte{[]byte{0x01, 0x01}, []byte{0x02, 0x02}, {0x03, 0x03}, {0x04, 0x04}}
func TestMst(t *testing.T) {
digestList := make([][32]byte, len(dataList))
datas := [][]byte{}
vals := []*big.Int{}
for i := 1; i <= 7; i++ {
datas = append(datas, crypto.Keccak256(big.NewInt(int64(i)).Bytes()))
vals = append(vals, big.NewInt(int64(i)))
}
for i := 0; i < len(dataList); i++ {
digestList[i] = sha256.Sum256(dataList[i])
tree := NewMerkleSumTree(datas, vals)
t.Log(tree.GetRoot())
t.Log(tree.RootNode.Value.Hash, tree.RootNode.Value.BigValue)
t.Log("---")
for i := len(tree.Levels) - 1; i >= 0; i-- {
t.Log(tree.Levels[i])
}
t.Log("---")
subNode := tree.FindMerkleSumNode(common.HexToHash("0x8a28d124351a95140ed4fad3468a9833523faf12e07c24eac13e1c1d468adb95"))
fmt.Printf("%x\n", digestList)
layers := MerkleSumTreeTraversal(subNode, 3)
valList := []*big.Int{big.NewInt(1), big.NewInt(2), big.NewInt(3), big.NewInt(4)}
for _, layer := range layers {
t.Log(layer)
}
roots, sums := ComputeMST(digestList, valList)
fmt.Printf("%x\n", roots[0])
fmt.Println(sums[0].String())
}
......@@ -3,7 +3,6 @@ package tree
import (
"bytes"
"errors"
"fmt"
"math"
"sort"
......@@ -29,18 +28,18 @@ type MerkleTree struct {
layers [][]common.Hash
proofs Proofs
bufferElementPositionIndex map[common.Hash]int
treeNode *TreeNode
treeNode *MerkleTreeNode
}
type TreeNode struct {
type MerkleTreeNode struct {
Hash common.Hash
Left *TreeNode
Right *TreeNode
Left *MerkleTreeNode
Right *MerkleTreeNode
}
func NewMerkleTree(proofs Proofs) (*MerkleTree, error) {
sort.Sort(proofs)
proofs = proofs.Dedup()
proofs = proofs.Distinct()
bufferElementPositionIndex := make(map[common.Hash]int)
for idx, el := range proofs {
......@@ -127,7 +126,7 @@ func (m *MerkleTree) getPairElement(idx int, layer Proofs) (common.Hash, bool) {
func (m *MerkleTree) buildTree() {
nodes := make([][]common.Hash, 0)
for i := len(m.layers) - 1; i >= 0; i-- {
if len(m.layers[i])%2 == 1 && i != len(m.layers)-1 {
if len(m.layers[i])%2 == 1 && i == 0 {
nodes = append(nodes, append(m.layers[i], common.Hash{}))
} else {
nodes = append(nodes, m.layers[i])
......@@ -136,17 +135,16 @@ func (m *MerkleTree) buildTree() {
m.treeNode = buildTree(nodes)
}
func (m *MerkleTree) GetRootNode() *TreeNode {
func (m *MerkleTree) GetRootNode() *MerkleTreeNode {
return m.treeNode
}
func (m *MerkleTree) FindNode(hash common.Hash) *TreeNode {
func (m *MerkleTree) FindMerkleNode(hash common.Hash) *MerkleTreeNode {
if m.treeNode == nil {
fmt.Println("tnn", m.treeNode)
return nil
}
queue := []*TreeNode{m.treeNode}
queue := []*MerkleTreeNode{m.treeNode}
for len(queue) > 0 {
size := len(queue)
......@@ -156,7 +154,6 @@ func (m *MerkleTree) FindNode(hash common.Hash) *TreeNode {
queue = queue[1:]
if currentNode.Hash.Hex() == hash.Hex() {
fmt.Println("nodev", currentNode.Hash.Hex())
return currentNode
}
......@@ -171,12 +168,12 @@ func (m *MerkleTree) FindNode(hash common.Hash) *TreeNode {
return nil
}
func Traversal(root *TreeNode, depth int) (retNodes [][]common.Hash) {
func MerkleTreeTraversal(root *MerkleTreeNode, depth int) (retNodes [][]common.Hash) {
if root == nil {
return nil
}
queue := []*TreeNode{root}
queue := []*MerkleTreeNode{root}
currentDepth := 1
......@@ -206,24 +203,24 @@ func Traversal(root *TreeNode, depth int) (retNodes [][]common.Hash) {
return
}
func buildTree(nodes [][]common.Hash) *TreeNode {
func buildTree(nodes [][]common.Hash) *MerkleTreeNode {
if len(nodes) == 0 {
return nil
}
root := &TreeNode{Hash: nodes[0][0]}
queue := []*TreeNode{root}
root := &MerkleTreeNode{Hash: nodes[0][0]}
queue := []*MerkleTreeNode{root}
for i := 1; i < len(nodes); i++ {
var levelNodes []*TreeNode
var levelNodes []*MerkleTreeNode
for j := 0; j < len(nodes[i]); j += 2 {
current := queue[0]
queue = queue[1:]
current.Left = &TreeNode{Hash: nodes[i][j]}
current.Left = &MerkleTreeNode{Hash: nodes[i][j]}
levelNodes = append(levelNodes, current.Left)
if j+1 < len(nodes[i]) {
current.Right = &TreeNode{Hash: nodes[i][j+1]}
current.Right = &MerkleTreeNode{Hash: nodes[i][j+1]}
levelNodes = append(levelNodes, current.Right)
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment