package vm

import (
	"bytes"
	"errors"
	"go-ethereum-advance/common"
	"go-ethereum-advance/core/vm/protoc"
	"go-ethereum-advance/graph"
	"go-ethereum-advance/log"
	"go-ethereum-advance/params"
	"github.com/golang/protobuf/proto"
)

type contractRelationship struct {}

func (c *contractRelationship) RequiredGas(input []byte) uint64 {
	return uint64(len(input)+31)/32*params.Sha256PerWordGas + params.Sha256BaseGas
}

func (c *contractRelationship) Run(ctx *PrecompiledContractContext,input []byte) ([]byte, error) {
	info := &protoc.RelationInfo{}
	err := proto.Unmarshal(input,info)
	if err != nil {
		log.Info("contract relationship Run err","err info:",err.Error())
		return nil, err
	}

	contractAddr := common.HexToAddress(info.ContractAddress)
	keyHash := contractAddr.Hash()
	senderAddr := common.HexToAddress(info.OwnerAddress)
	contractOwnerAddr := common.BytesToAddress([]byte("contract_owner_record"))

	value := ctx.Evm.StateDB.GetState(contractOwnerAddr,keyHash)
	if !bytes.Equal(value.Bytes(),senderAddr.Bytes()) {
		//更新关系的发送者与记录不相等
		return nil,errors.New("relationship tx sender not match with record")
	}

	if len(info.Relations) == 0 {
		return nil,errors.New("registration relationship must not be empty")
	}

	graphRela := graph.NewSparseGraphOnce()
	contactsAddrs := make([]common.Address,len(info.Relations))
	for i,addr := range info.Relations{
		contactsAddrs[i] = common.HexToAddress(addr)
	}

	switch info.Type {
	case protoc.RelationType_AddRelation:
		//添加关系
		index,ok := graphRela.Retrieve[contractAddr]
		if ok {
			//有注册关系 需要去重以后添加
			oldContracts := graphRela.Graph[index]
			newContracts := append(oldContracts, contactsAddrs...)
			newContracts = graph.RemoveRepeatedElement(newContracts)
			graphRela.Graph[index] = newContracts
		}else {
			//没有注册关系 将关系放入列表最后
			index = len(graphRela.Graph)
			newConntracts := graph.RemoveRepeatedElement(contactsAddrs)
			graphRela.Graph[index] = newConntracts
			graphRela.Retrieve[contractAddr] = index
		}

	case protoc.RelationType_DelRelation:
		//删除关系
		index,ok := graphRela.Retrieve[contractAddr]
		if ok {
			//找到了关系 进行剔除
			oldContracts := graphRela.Graph[index]
			for _,addr := range contactsAddrs {
				ind,ok := graph.Contains(oldContracts,addr)
				if ok {
					if ind == 0 {
						oldContracts = oldContracts[0:]
					}else {
						oldContracts = append(oldContracts[:ind],oldContracts[ind+1:]...)
					}
				}
			}
		}else {
			//没有找到对应的关系需要报错
			return nil,errors.New(graph.ErrorAddressNotExit)
		}
	}

	return []byte{},nil
}

