Commit 99cf7b95 authored by Pavle Batuta's avatar Pavle Batuta Committed by GitHub

Add ENS contract address parameter to config (#1029)

parent aad68011
...@@ -5,11 +5,12 @@ ...@@ -5,11 +5,12 @@
package ens package ens
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/ethclient"
goens "github.com/wealdtech/go-ens/v3" goens "github.com/wealdtech/go-ens/v3"
...@@ -17,7 +18,10 @@ import ( ...@@ -17,7 +18,10 @@ import (
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
const swarmContentHashPrefix = "/swarm/" const (
defaultENSContractAddress = "00000000000C2E074eC69A0dFb2997BA6C7d2e1e"
swarmContentHashPrefix = "/swarm/"
)
// Address is the swarm bzz address. // Address is the swarm bzz address.
type Address = swarm.Address type Address = swarm.Address
...@@ -36,15 +40,19 @@ var ( ...@@ -36,15 +40,19 @@ var (
ErrInvalidContentHash = errors.New("invalid swarm content hash") ErrInvalidContentHash = errors.New("invalid swarm content hash")
// errNotImplemented denotes that the function has not been implemented. // errNotImplemented denotes that the function has not been implemented.
errNotImplemented = errors.New("function not implemented") errNotImplemented = errors.New("function not implemented")
// errNameNotRegistered denotes that the name is not registered.
errNameNotRegistered = errors.New("name is not registered")
) )
// Client is a name resolution client that can connect to ENS via an // Client is a name resolution client that can connect to ENS via an
// Ethereum endpoint. // Ethereum endpoint.
type Client struct { type Client struct {
endpoint string endpoint string
contractAddr string
ethCl *ethclient.Client ethCl *ethclient.Client
dialFn func(string) (*ethclient.Client, error) connectFn func(string, string) (*ethclient.Client, *goens.Registry, error)
resolveFn func(bind.ContractBackend, string) (string, error) resolveFn func(*goens.Registry, common.Address, string) (string, error)
registry *goens.Registry
} }
// Option is a function that applies an option to a Client. // Option is a function that applies an option to a Client.
...@@ -54,7 +62,7 @@ type Option func(*Client) ...@@ -54,7 +62,7 @@ type Option func(*Client)
func NewClient(endpoint string, opts ...Option) (client.Interface, error) { func NewClient(endpoint string, opts ...Option) (client.Interface, error) {
c := &Client{ c := &Client{
endpoint: endpoint, endpoint: endpoint,
dialFn: ethclient.Dial, connectFn: wrapDial,
resolveFn: wrapResolve, resolveFn: wrapResolve,
} }
...@@ -63,20 +71,32 @@ func NewClient(endpoint string, opts ...Option) (client.Interface, error) { ...@@ -63,20 +71,32 @@ func NewClient(endpoint string, opts ...Option) (client.Interface, error) {
o(c) o(c)
} }
// Connect to the name resolution service. // Set the default ENS contract address.
if c.dialFn == nil { if c.contractAddr == "" {
return nil, fmt.Errorf("dialFn: %w", errNotImplemented) c.contractAddr = defaultENSContractAddress
} }
ethCl, err := c.dialFn(c.endpoint) // Establish a connection to the ENS.
if c.connectFn == nil {
return nil, fmt.Errorf("connectFn: %w", errNotImplemented)
}
ethCl, registry, err := c.connectFn(c.endpoint, c.contractAddr)
if err != nil { if err != nil {
return nil, fmt.Errorf("%v: %w", err, ErrFailedToConnect) return nil, fmt.Errorf("%v: %w", err, ErrFailedToConnect)
} }
c.ethCl = ethCl c.ethCl = ethCl
c.registry = registry
return c, nil return c, nil
} }
// WithContractAddress will set the ENS contract address.
func WithContractAddress(addr string) Option {
return func(c *Client) {
c.contractAddr = addr
}
}
// IsConnected returns true if there is an active RPC connection with an // IsConnected returns true if there is an active RPC connection with an
// Ethereum node at the configured endpoint. // Ethereum node at the configured endpoint.
func (c *Client) IsConnected() bool { func (c *Client) IsConnected() bool {
...@@ -94,7 +114,7 @@ func (c *Client) Resolve(name string) (Address, error) { ...@@ -94,7 +114,7 @@ func (c *Client) Resolve(name string) (Address, error) {
return swarm.ZeroAddress, fmt.Errorf("resolveFn: %w", errNotImplemented) return swarm.ZeroAddress, fmt.Errorf("resolveFn: %w", errNotImplemented)
} }
hash, err := c.resolveFn(c.ethCl, name) hash, err := c.resolveFn(c.registry, common.HexToAddress(c.contractAddr), name)
if err != nil { if err != nil {
return swarm.ZeroAddress, fmt.Errorf("%v: %w", err, ErrResolveFailed) return swarm.ZeroAddress, fmt.Errorf("%v: %w", err, ErrResolveFailed)
} }
...@@ -121,18 +141,50 @@ func (c *Client) Close() error { ...@@ -121,18 +141,50 @@ func (c *Client) Close() error {
return nil return nil
} }
func wrapResolve(backend bind.ContractBackend, name string) (string, error) { func wrapDial(endpoint string, contractAddr string) (*ethclient.Client, *goens.Registry, error) {
// Dial the eth client.
ethCl, err := ethclient.Dial(endpoint)
if err != nil {
return nil, nil, fmt.Errorf("dial: %w", err)
}
// Obtain the ENS registry.
registry, err := goens.NewRegistryAt(ethCl, common.HexToAddress(contractAddr))
if err != nil {
return nil, nil, fmt.Errorf("new registry: %w", err)
}
// Ensure that the ENS registry client is deployed to the given contract address.
_, err = registry.Owner("")
if err != nil {
return nil, nil, fmt.Errorf("owner: %w", err)
}
return ethCl, registry, nil
}
func wrapResolve(registry *goens.Registry, addr common.Address, name string) (string, error) {
// Ensure the name is registered.
ownerAddress, err := registry.Owner(name)
if err != nil {
return "", fmt.Errorf("owner: %w", err)
}
// If the name is not registered, return an error.
if bytes.Equal(ownerAddress.Bytes(), goens.UnknownAddress.Bytes()) {
return "", errNameNotRegistered
}
// Connect to the ENS resolver for the provided name. // Obtain the resolver for this domain name.
ensR, err := goens.NewResolver(backend, name) ensR, err := registry.Resolver(name)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("resolver: %w", err)
} }
// Try and read out the content hash record. // Try and read out the content hash record.
ch, err := ensR.Contenthash() ch, err := ensR.Contenthash()
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("contenthash: %w", err)
} }
return goens.ContenthashToString(ch) return goens.ContenthashToString(ch)
......
...@@ -22,6 +22,7 @@ func TestENSntegration(t *testing.T) { ...@@ -22,6 +22,7 @@ func TestENSntegration(t *testing.T) {
testCases := []struct { testCases := []struct {
desc string desc string
endpoint string endpoint string
contractAddress string
name string name string
wantAdr swarm.Address wantAdr swarm.Address
wantErr error wantErr error
...@@ -53,6 +54,12 @@ func TestENSntegration(t *testing.T) { ...@@ -53,6 +54,12 @@ func TestENSntegration(t *testing.T) {
name: "nocontent.resolver.test.swarm.eth", name: "nocontent.resolver.test.swarm.eth",
wantErr: ens.ErrResolveFailed, wantErr: ens.ErrResolveFailed,
}, },
{
desc: "invalid contract address",
contractAddress: "0xFFFFFFFF",
name: "example.resolver.test.swarm.eth",
wantErr: ens.ErrFailedToConnect,
},
{ {
desc: "ok", desc: "ok",
name: "example.resolver.test.swarm.eth", name: "example.resolver.test.swarm.eth",
...@@ -65,9 +72,9 @@ func TestENSntegration(t *testing.T) { ...@@ -65,9 +72,9 @@ func TestENSntegration(t *testing.T) {
tC.endpoint = defaultEndpoint tC.endpoint = defaultEndpoint
} }
ensClient, err := ens.NewClient(tC.endpoint) ensClient, err := ens.NewClient(tC.endpoint, ens.WithContractAddress(tC.contractAddress))
if err != nil { if err != nil {
if !errors.Is(err, ens.ErrFailedToConnect) { if !errors.Is(err, tC.wantErr) {
t.Errorf("got %v, want %v", err, tC.wantErr) t.Errorf("got %v, want %v", err, tC.wantErr)
} }
return return
......
...@@ -8,9 +8,10 @@ import ( ...@@ -8,9 +8,10 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/ethclient"
"github.com/ethereum/go-ethereum/rpc" "github.com/ethereum/go-ethereum/rpc"
goens "github.com/wealdtech/go-ens/v3"
"github.com/ethersphere/bee/pkg/resolver/client/ens" "github.com/ethersphere/bee/pkg/resolver/client/ens"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
...@@ -20,29 +21,30 @@ func TestNewENSClient(t *testing.T) { ...@@ -20,29 +21,30 @@ func TestNewENSClient(t *testing.T) {
testCases := []struct { testCases := []struct {
desc string desc string
endpoint string endpoint string
dialFn func(string) (*ethclient.Client, error) address string
connectFn func(string, string) (*ethclient.Client, *goens.Registry, error)
wantErr error wantErr error
wantEndpoint string wantEndpoint string
}{ }{
{ {
desc: "nil dial function", desc: "nil dial function",
endpoint: "someaddress.net", endpoint: "someaddress.net",
dialFn: nil, connectFn: nil,
wantErr: ens.ErrNotImplemented, wantErr: ens.ErrNotImplemented,
}, },
{ {
desc: "error in dial function", desc: "error in dial function",
endpoint: "someaddress.com", endpoint: "someaddress.com",
dialFn: func(string) (*ethclient.Client, error) { connectFn: func(s1, s2 string) (*ethclient.Client, *goens.Registry, error) {
return nil, errors.New("dial error") return nil, nil, errors.New("dial error")
}, },
wantErr: ens.ErrFailedToConnect, wantErr: ens.ErrFailedToConnect,
}, },
{ {
desc: "regular endpoint", desc: "regular endpoint",
endpoint: "someaddress.org", endpoint: "someaddress.org",
dialFn: func(string) (*ethclient.Client, error) { connectFn: func(s1, s2 string) (*ethclient.Client, *goens.Registry, error) {
return &ethclient.Client{}, nil return &ethclient.Client{}, nil, nil
}, },
wantEndpoint: "someaddress.org", wantEndpoint: "someaddress.org",
}, },
...@@ -50,7 +52,8 @@ func TestNewENSClient(t *testing.T) { ...@@ -50,7 +52,8 @@ func TestNewENSClient(t *testing.T) {
for _, tC := range testCases { for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) { t.Run(tC.desc, func(t *testing.T) {
cl, err := ens.NewClient(tC.endpoint, cl, err := ens.NewClient(tC.endpoint,
ens.WithDialFunc(tC.dialFn), ens.WithConnectFunc(tC.connectFn),
ens.WithContractAddress(tC.address),
) )
if err != nil { if err != nil {
if !errors.Is(err, tC.wantErr) { if !errors.Is(err, tC.wantErr) {
...@@ -75,8 +78,8 @@ func TestClose(t *testing.T) { ...@@ -75,8 +78,8 @@ func TestClose(t *testing.T) {
ethCl := ethclient.NewClient(rpc.DialInProc(rpcServer)) ethCl := ethclient.NewClient(rpc.DialInProc(rpcServer))
cl, err := ens.NewClient("", cl, err := ens.NewClient("",
ens.WithDialFunc(func(string) (*ethclient.Client, error) { ens.WithConnectFunc(func(endpoint, contractAddr string) (*ethclient.Client, *goens.Registry, error) {
return ethCl, nil return ethCl, nil, nil
}), }),
) )
if err != nil { if err != nil {
...@@ -94,8 +97,8 @@ func TestClose(t *testing.T) { ...@@ -94,8 +97,8 @@ func TestClose(t *testing.T) {
}) })
t.Run("not connected", func(t *testing.T) { t.Run("not connected", func(t *testing.T) {
cl, err := ens.NewClient("", cl, err := ens.NewClient("",
ens.WithDialFunc(func(string) (*ethclient.Client, error) { ens.WithConnectFunc(func(endpoint, contractAddr string) (*ethclient.Client, *goens.Registry, error) {
return nil, nil return nil, nil, nil
}), }),
) )
if err != nil { if err != nil {
...@@ -114,12 +117,15 @@ func TestClose(t *testing.T) { ...@@ -114,12 +117,15 @@ func TestClose(t *testing.T) {
} }
func TestResolve(t *testing.T) { func TestResolve(t *testing.T) {
addr := swarm.MustParseHexAddress("aaabbbcc") testContractAddrString := "00000000000C2E074eC69A0dFb2997BA6C702e1B"
testContractAddr := common.HexToAddress(testContractAddrString)
testSwarmAddr := swarm.MustParseHexAddress("aaabbbcc")
testCases := []struct { testCases := []struct {
desc string desc string
name string name string
resolveFn func(bind.ContractBackend, string) (string, error) contractAddr string
resolveFn func(*goens.Registry, common.Address, string) (string, error)
wantErr error wantErr error
}{ }{
{ {
...@@ -129,38 +135,48 @@ func TestResolve(t *testing.T) { ...@@ -129,38 +135,48 @@ func TestResolve(t *testing.T) {
}, },
{ {
desc: "resolve function internal error", desc: "resolve function internal error",
resolveFn: func(bind.ContractBackend, string) (string, error) { resolveFn: func(*goens.Registry, common.Address, string) (string, error) {
return "", errors.New("internal error") return "", errors.New("internal error")
}, },
wantErr: ens.ErrResolveFailed, wantErr: ens.ErrResolveFailed,
}, },
{ {
desc: "resolver returns empty string", desc: "resolver returns empty string",
resolveFn: func(bind.ContractBackend, string) (string, error) { resolveFn: func(*goens.Registry, common.Address, string) (string, error) {
return "", nil return "", nil
}, },
wantErr: ens.ErrInvalidContentHash, wantErr: ens.ErrInvalidContentHash,
}, },
{ {
desc: "resolve does not prefix address with /swarm", desc: "resolve does not prefix address with /swarm",
resolveFn: func(bind.ContractBackend, string) (string, error) { resolveFn: func(*goens.Registry, common.Address, string) (string, error) {
return addr.String(), nil return testSwarmAddr.String(), nil
}, },
wantErr: ens.ErrInvalidContentHash, wantErr: ens.ErrInvalidContentHash,
}, },
{ {
desc: "resolve returns prefixed address", desc: "resolve returns prefixed address",
resolveFn: func(bind.ContractBackend, string) (string, error) { resolveFn: func(*goens.Registry, common.Address, string) (string, error) {
return ens.SwarmContentHashPrefix + addr.String(), nil return ens.SwarmContentHashPrefix + testSwarmAddr.String(), nil
}, },
wantErr: ens.ErrInvalidContentHash, wantErr: ens.ErrInvalidContentHash,
}, },
{
desc: "expect properly set contract address",
resolveFn: func(b *goens.Registry, c common.Address, s string) (string, error) {
if c != testContractAddr {
return "", errors.New("invalid contract address")
}
return ens.SwarmContentHashPrefix + testSwarmAddr.String(), nil
},
},
} }
for _, tC := range testCases { for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) { t.Run(tC.desc, func(t *testing.T) {
cl, err := ens.NewClient("example.com", cl, err := ens.NewClient("example.com",
ens.WithDialFunc(func(string) (*ethclient.Client, error) { ens.WithContractAddress(testContractAddrString),
return nil, nil ens.WithConnectFunc(func(endpoint, contractAddr string) (*ethclient.Client, *goens.Registry, error) {
return nil, nil, nil
}), }),
ens.WithResolveFunc(tC.resolveFn), ens.WithResolveFunc(tC.resolveFn),
) )
......
...@@ -5,23 +5,24 @@ ...@@ -5,23 +5,24 @@
package ens package ens
import ( import (
"github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/ethclient"
goens "github.com/wealdtech/go-ens/v3"
) )
const SwarmContentHashPrefix = swarmContentHashPrefix const SwarmContentHashPrefix = swarmContentHashPrefix
var ErrNotImplemented = errNotImplemented var ErrNotImplemented = errNotImplemented
// WithDialFunc will set the Dial function implementaton. // WithConnectFunc will set the Dial function implementaton.
func WithDialFunc(fn func(ep string) (*ethclient.Client, error)) Option { func WithConnectFunc(fn func(endpoint string, contractAddr string) (*ethclient.Client, *goens.Registry, error)) Option {
return func(c *Client) { return func(c *Client) {
c.dialFn = fn c.connectFn = fn
} }
} }
// WithResolveFunc will set the Resolve function implementation. // WithResolveFunc will set the Resolve function implementation.
func WithResolveFunc(fn func(backend bind.ContractBackend, input string) (string, error)) Option { func WithResolveFunc(fn func(registry *goens.Registry, addr common.Address, input string) (string, error)) Option {
return func(c *Client) { return func(c *Client) {
c.resolveFn = fn c.resolveFn = fn
} }
......
...@@ -74,14 +74,9 @@ func NewMultiResolver(opts ...Option) *MultiResolver { ...@@ -74,14 +74,9 @@ func NewMultiResolver(opts ...Option) *MultiResolver {
// Attempt to conect to each resolver using the connection string. // Attempt to conect to each resolver using the connection string.
for _, c := range mr.cfgs { for _, c := range mr.cfgs {
// Warn user that the resolver address field is not used.
if c.Address != "" {
log.Warningf("name resolver: connection string %q contains resolver address field, which is currently unused", c.Address)
}
// NOTE: if we want to create a specific client based on the TLD // NOTE: if we want to create a specific client based on the TLD
// we can do it here. // we can do it here.
mr.connectENSClient(c.TLD, c.Endpoint) mr.connectENSClient(c.TLD, c.Address, c.Endpoint)
} }
return mr return mr
...@@ -189,13 +184,18 @@ func getTLD(name string) string { ...@@ -189,13 +184,18 @@ func getTLD(name string) string {
return path.Ext(strings.ToLower(name)) return path.Ext(strings.ToLower(name))
} }
func (mr *MultiResolver) connectENSClient(tld string, endpoint string) { func (mr *MultiResolver) connectENSClient(tld string, address string, endpoint string) {
log := mr.logger log := mr.logger
if address == "" {
log.Debugf("name resolver: resolver for %q: connecting to endpoint %s", tld, endpoint) log.Debugf("name resolver: resolver for %q: connecting to endpoint %s", tld, endpoint)
ensCl, err := ens.NewClient(endpoint) } else {
log.Debugf("name resolver: resolver for %q: connecting to endpoint %s with contract address %s", tld, endpoint, address)
}
ensCl, err := ens.NewClient(endpoint, ens.WithContractAddress(address))
if err != nil { if err != nil {
log.Errorf("name resolver: resolver for %q domain: failed to connect to %q: %v", tld, endpoint, err) log.Errorf("name resolver: resolver for %q domain on endpoint %q: %v", tld, endpoint, err)
} else { } else {
log.Infof("name resolver: resolver for %q domain: connected to %s", tld, endpoint) log.Infof("name resolver: resolver for %q domain: connected to %s", tld, endpoint)
mr.PushResolver(tld, ensCl) mr.PushResolver(tld, ensCl)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment