Commit 127ae72f authored by Mark Tyneway's avatar Mark Tyneway Committed by GitHub

state-surgery: small fixes (#3246)

- Add support for `&common.Address`
- Take locks in initialization of `hardhat.Hardhat`
parent b1fa4d8c
...@@ -53,6 +53,11 @@ func New(network string, artifacts, deployments []string) (*Hardhat, error) { ...@@ -53,6 +53,11 @@ func New(network string, artifacts, deployments []string) (*Hardhat, error) {
// init is called in the constructor and will cache required files to disk. // init is called in the constructor and will cache required files to disk.
func (h *Hardhat) init() error { func (h *Hardhat) init() error {
h.amu.Lock()
defer h.amu.Unlock()
h.dmu.Lock()
defer h.dmu.Unlock()
if err := h.initArtifacts(); err != nil { if err := h.initArtifacts(); err != nil {
return err return err
} }
...@@ -130,6 +135,7 @@ func (h *Hardhat) initArtifacts() error { ...@@ -130,6 +135,7 @@ func (h *Hardhat) initArtifacts() error {
if err := json.Unmarshal(file, &artifact); err != nil { if err := json.Unmarshal(file, &artifact); err != nil {
return err return err
} }
h.artifacts = append(h.artifacts, &artifact) h.artifacts = append(h.artifacts, &artifact)
return nil return nil
}) })
......
...@@ -31,13 +31,13 @@ func EncodeStorageKeyValue(value any, entry solc.StorageLayoutEntry, storageType ...@@ -31,13 +31,13 @@ func EncodeStorageKeyValue(value any, entry solc.StorageLayoutEntry, storageType
case "bool": case "bool":
val, err := EncodeBoolValue(value, entry.Offset) val, err := EncodeBoolValue(value, entry.Offset)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("%s bool invalid: %w", entry.Label, err)
} }
encoded = append(encoded, &EncodedStorage{key, val}) encoded = append(encoded, &EncodedStorage{key, val})
case "address": case "address":
val, err := EncodeAddressValue(value, entry.Offset) val, err := EncodeAddressValue(value, entry.Offset)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("%s address invalid: %w", entry.Label, err)
} }
encoded = append(encoded, &EncodedStorage{key, val}) encoded = append(encoded, &EncodedStorage{key, val})
case "bytes": case "bytes":
...@@ -53,7 +53,7 @@ func EncodeStorageKeyValue(value any, entry solc.StorageLayoutEntry, storageType ...@@ -53,7 +53,7 @@ func EncodeStorageKeyValue(value any, entry solc.StorageLayoutEntry, storageType
case strings.HasPrefix(label, "contract"): case strings.HasPrefix(label, "contract"):
val, err := EncodeAddressValue(value, entry.Offset) val, err := EncodeAddressValue(value, entry.Offset)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("%s address invalid: %w", entry.Label, err)
} }
encoded = append(encoded, &EncodedStorage{key, val}) encoded = append(encoded, &EncodedStorage{key, val})
case strings.HasPrefix(label, "uint"): case strings.HasPrefix(label, "uint"):
...@@ -281,14 +281,27 @@ func EncodeAddressValue(value any, offset uint) (common.Hash, error) { ...@@ -281,14 +281,27 @@ func EncodeAddressValue(value any, offset uint) (common.Hash, error) {
// encodeAddressValue will encode an address value into // encodeAddressValue will encode an address value into
// a type suitable for solidity storage. // a type suitable for solidity storage.
func encodeAddressValue(value any) (common.Hash, error) { func encodeAddressValue(value any) (common.Hash, error) {
name := reflect.TypeOf(value).Name() typ := reflect.TypeOf(value)
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}
name := typ.Name()
switch name { switch name {
case "Address": case "Address":
address, ok := value.(common.Address) if reflect.TypeOf(value).Kind() == reflect.Ptr {
if !ok { address, ok := value.(*common.Address)
return common.Hash{}, errInvalidType if !ok {
return common.Hash{}, errInvalidType
}
return address.Hash(), nil
} else {
address, ok := value.(common.Address)
if !ok {
return common.Hash{}, errInvalidType
}
return address.Hash(), nil
} }
return address.Hash(), nil
case "string": case "string":
address, ok := value.(string) address, ok := value.(string)
if !ok { if !ok {
......
...@@ -411,6 +411,11 @@ func TestEncodeAddressValue(t *testing.T) { ...@@ -411,6 +411,11 @@ func TestEncodeAddressValue(t *testing.T) {
offset: 1, offset: 1,
expect: common.Hash{30: 0x01}, expect: common.Hash{30: 0x01},
}, },
{
addr: &common.Address{},
offset: 0,
expect: common.Hash{},
},
} }
for _, test := range cases { for _, test := range cases {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment