package graph import ( "crypto/sha256" "encoding/json" "fmt" "go-ethereum-advance/common" "go-ethereum-advance/core/state" "go-ethereum-advance/crypto" "go-ethereum-advance/log" "go-ethereum-advance/rlp" "runtime" "sync" ) const ( ErrorAlreadyInclude = "regist contract already exit" ErrorAddressNotExit = "regist contract not exit" ErrorAddressRegistAgain = "regist contract is the same with registed" ) var ( //AccountStateHash = common.Hash(Hash(params.Account)) //AccountAddr = params.Account SparseGraphKey = common.BytesToAddress(crypto.Keccak256([]byte("contract_relations"))) SparseGraph = NewSparseGraph() ContractCache = sync.Map{} ) //邻接表 type sparseGraph struct { Graph [][]common.Address `json:"graph"` Retrieve map[common.Address]int `json:"retrieve"` //节点检索序号 lock sync.RWMutex } func NewSparseGraph() *sparseGraph { buf := make([][]common.Address, 0) cache := make(map[common.Address]int, 0) return &sparseGraph{ Graph: buf, Retrieve: cache, } } func NewSparseGraphOnce() *sparseGraph { var once sync.Once var spGraph *sparseGraph once.Do(func() { buf := make([][]common.Address, 0) cache := make(map[common.Address]int, 0) spGraph = &sparseGraph{ Graph: buf, Retrieve: cache, } }) return spGraph } func Contains(addresses []common.Address, address common.Address) (int, bool) { for i, item := range addresses { if item == address { return i, true } } return 0, false } func (s *sparseGraph) GetContact(addr common.Address) ([]common.Address, bool) { s.lock.RLock() defer s.lock.RUnlock() if addr == [20]byte{} { return nil, false } index, ok := s.Retrieve[addr] if ok == false { return []common.Address{}, false } list := s.Graph[index] return list, true } func (s *sparseGraph) String() { s.lock.RLock() defer s.lock.RUnlock() str := "" for key, index := range s.Retrieve { str += key.String() + " -> " nexts := s.Graph[index] for _, next := range nexts { str += next.String() + " " } str += "\n" } fmt.Println(str) } func (s *sparseGraph) Encode() ([]byte, error) { return json.Marshal(s) } func (s *sparseGraph) Decode(data []byte) error { return json.Unmarshal(data, s) } type NodeQueue struct { nodes []common.Address lock sync.RWMutex } //生成节点队列 func NewNodeQueue() *NodeQueue { q := NodeQueue{} q.lock.Lock() defer q.lock.Unlock() q.nodes = []common.Address{} return &q } //入队列 func (q *NodeQueue) Enqueue(address common.Address) { q.lock.Lock() defer q.lock.Unlock() q.nodes = append(q.nodes, address) } //出队列 func (q *NodeQueue) Dequeue() common.Address { q.lock.Lock() defer q.lock.Unlock() node := q.nodes[0] q.nodes = q.nodes[1:] return node } //判空 func (q *NodeQueue) IsEmpty() bool { q.lock.RLock() defer q.lock.RUnlock() return len(q.nodes) == 0 } func GetContractRelations(addr common.Address, stateDB *state.StateDB) []common.Address { data := stateDB.GetState(addr, SparseGraphKey.Hash()) if len(data) == 0 { return []common.Address{} } list := make([]common.Address, 0) err := rlp.DecodeBytes(data.Bytes(), &list) if err != nil { return []common.Address{} } return list } func PrintCallerName() string { pc, _, _, _ := runtime.Caller(3) return runtime.FuncForPC(pc).Name() } func BFSGetAllContracts(addr common.Address, stateDB *state.StateDB) []common.Address { list := GetContractRelations(addr, stateDB) if len(list) == 0 { return []common.Address{} } result := make([]common.Address, 0) q := NewNodeQueue() for _, item := range list { q.Enqueue(item) result = append(result, item) } //标记地址已经访问过 visited := make(map[common.Address]bool) visited[addr] = true //遍历所有节点直到队列为空 for { if q.IsEmpty() { break } item := q.Dequeue() visited[item] = true res := GetContractRelations(item, stateDB) if len(res) == 0 { continue } else { for _, count := range res { if visited[count] { continue } q.Enqueue(count) visited[count] = true result = append(result, count) } } } return result } func RemoveRepeatedElement(arr []common.Address) (newArr []common.Address) { newArr = make([]common.Address, 0) for i := 0; i < len(arr); i++ { repeat := false for j := i + 1; j < len(arr); j++ { if arr[i] == arr[j] { repeat = true break } } if !repeat { newArr = append(newArr, arr[i]) } } return } // //func SetAllContracts(contractAddr common.Address, registerList []common.Address, stateDB *state.StateDB) { // relyAddr := registerList // //取出所有关联地址 // var account bool // for _, addr := range registerList { // if addr == AccountAddr { // account = true // } // contracts := BFSGetAllContracts(addr, stateDB) // //判断是否是账户state的交易 // for _, contractsAddr := range contracts { // if contractsAddr == AccountAddr { // account = true // } // } // // //需要去重 // relyAddr = append(relyAddr, contracts...) // } // //把自己放进去 // relyAddr = append(relyAddr, contractAddr) // // element := RemoveRepeatedElement(relyAddr) // registerData, err := rlp.EncodeToBytes(element) // if err != nil { // return // } // //为了保证不会清空obj // stateDB.SetNonce(contractAddr, 1) // //更新其他关系地址的列表情况 // sort.Sort(common.SortAddress(relyAddr)) // // for _, addr := range relyAddr { // stateDB.SetState(addr, SparseGraphKey.Hash(), common.BytesToHash(registerData)) // //缓存合约簇ID // if account { // //这个合约关联了 账户state 把所有的关联账户全部写成账户state的id // log.Info("这个账户关联了账户state", "addr", addr) // ContractCache.Store(addr, AccountStateHash) // } else { // ContractCache.Store(addr, common.Hash(Hash(relyAddr))) // } // // } // //} ////通过获取关系列表,进行hash后来得到合约ID //func GetContractId(addr common.Address, stateDB *state.StateDB) (common.Hash, []common.Address) { // if load, ok := ContractCache.Load(addr); ok { // return load.(common.Hash), []common.Address{} // } // if addr == params.Account { // return AccountStateHash, []common.Address{} // } // if stateDB==nil{ // //log.Warn("GetContractId,stateDB==nil") // return common.Hash{},[]common.Address{} // } // contracts := BFSGetAllContracts(addr, stateDB) // if len(contracts) == 0 { // //如果没有找到地址,直接用当前地址hash后做为合约簇id // relyAddr := make([]common.Address, 0) // relyAddr = append(relyAddr, addr) // return Hash(relyAddr), contracts // // } // for _, addr := range contracts { // //如果关系列表中关联了账户state // //log.Debug("查看账户stata", "addr", addr.Hex()) // if addr == AccountAddr { // return AccountStateHash, contracts // } // } // //排序是为了计算hash // sort.Sort(common.SortAddress(contracts)) // return Hash(contracts), contracts //} func Hash(v interface{}) [32]byte { //bytes, err := json.Marshal(v) bytes, err := rlp.EncodeToBytes(v) if err != nil { log.Error("rlp encode error :") return common.Hash{} } return sha256.Sum256(bytes) }