Commit ec3f1675 authored by Pavle Batuta's avatar Pavle Batuta Committed by GitHub

Resolver refactor service (#670)

Broad-spectrum refactor of the resolver package.

- removed TLD constraints (all TLDs are now accepted)
- merged resolver service into multiresolver
- simplified client code
- refactored tests
- added MultiError
- other minor fixes
parent d60f4ab3
......@@ -22,7 +22,7 @@ import (
memkeystore "github.com/ethersphere/bee/pkg/keystore/mem"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/node"
"github.com/ethersphere/bee/pkg/resolver"
"github.com/ethersphere/bee/pkg/resolver/multiresolver"
"github.com/sirupsen/logrus"
"github.com/spf13/cobra"
)
......@@ -57,10 +57,10 @@ func (c *command) initStartCmd() (err error) {
// If the resolver is specified, resolve all connection strings
// and fail on any errors.
var resolverCfgs []*resolver.ConnectionConfig
var resolverCfgs []multiresolver.ConnectionConfig
resolverEndpoints := c.config.GetStringSlice(optionNameResolverEndpoints)
if len(resolverEndpoints) > 0 {
resolverCfgs, err = resolver.ParseConnectionStrings(resolverEndpoints)
resolverCfgs, err = multiresolver.ParseConnectionStrings(resolverEndpoints)
if err != nil {
return err
}
......
......@@ -37,8 +37,7 @@ import (
"github.com/ethersphere/bee/pkg/pusher"
"github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/recovery"
"github.com/ethersphere/bee/pkg/resolver"
resolverSvc "github.com/ethersphere/bee/pkg/resolver/service"
"github.com/ethersphere/bee/pkg/resolver/multiresolver"
"github.com/ethersphere/bee/pkg/retrieval"
"github.com/ethersphere/bee/pkg/settlement/pseudosettle"
"github.com/ethersphere/bee/pkg/soc"
......@@ -91,7 +90,7 @@ type Options struct {
GlobalPinningEnabled bool
PaymentThreshold uint64
PaymentTolerance uint64
ResolverConnectionCfgs []*resolver.ConnectionConfig
ResolverConnectionCfgs []multiresolver.ConnectionConfig
GatewayMode bool
}
......@@ -293,7 +292,10 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
b.pullerCloser = puller
multiResolver := resolverSvc.InitMultiResolver(logger, o.ResolverConnectionCfgs)
multiResolver := multiresolver.NewMultiResolver(
multiresolver.WithConnectionConfigs(o.ResolverConnectionCfgs),
multiresolver.WithLogger(o.Logger),
)
b.resolverCloser = multiResolver
var apiService api.Service
......
......@@ -9,9 +9,9 @@ import (
)
// Interface is a resolver client that can connect/disconnect to an external
// Name Resolution Service via an edpoint.
// Name Resolution Service via an endpoint.
type Interface interface {
resolver.Interface
Connect(endpoint string) error
Endpoint() string
IsConnected() bool
}
......@@ -5,43 +5,56 @@
package ens
import (
"errors"
"fmt"
"strings"
"sync"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/ethclient"
goens "github.com/wealdtech/go-ens/v3"
"github.com/ethersphere/bee/pkg/resolver/client"
"github.com/ethersphere/bee/pkg/swarm"
)
const swarmContentHashPrefix = "/swarm/"
// Address is the swarm bzz address.
type Address = swarm.Address
// Make sure Client implements the resolver.Client interface.
var _ client.Interface = (*Client)(nil)
type dialType func(string) (*ethclient.Client, error)
type resolveType func(bind.ContractBackend, string) (string, error)
var (
// ErrFailedToConnect denotes that the resolver failed to connect to the
// provided endpoint.
ErrFailedToConnect = errors.New("failed to connect")
// ErrResolveFailed denotes that a name could not be resolved.
ErrResolveFailed = errors.New("resolve failed")
// ErrInvalidContentHash denotes that the value of the contenthash record is
// not valid.
ErrInvalidContentHash = errors.New("invalid swarm content hash")
// errNotImplemented denotes that the function has not been implemented.
errNotImplemented = errors.New("function not implemented")
)
// Client is a name resolution client that can connect to ENS via an
// Ethereum endpoint.
type Client struct {
mu sync.Mutex
Endpoint string
endpoint string
ethCl *ethclient.Client
dialFn dialType
resolveFn resolveType
dialFn func(string) (*ethclient.Client, error)
resolveFn func(bind.ContractBackend, string) (string, error)
}
// Option is a function that applies an option to a Client.
type Option func(*Client)
// NewClient will return a new Client.
func NewClient(opts ...Option) *Client {
func NewClient(endpoint string, opts ...Option) (client.Interface, error) {
c := &Client{
dialFn: wrapDial,
endpoint: endpoint,
dialFn: ethclient.Dial,
resolveFn: wrapResolve,
}
......@@ -50,83 +63,77 @@ func NewClient(opts ...Option) *Client {
o(c)
}
return c
}
// Connect implements the resolver.Client interface.
func (c *Client) Connect(ep string) error {
// Connect to the name resolution service.
if c.dialFn == nil {
return fmt.Errorf("dialFn: %w", errNotImplemented)
return nil, fmt.Errorf("dialFn: %w", errNotImplemented)
}
ethCl, err := c.dialFn(ep)
ethCl, err := c.dialFn(c.endpoint)
if err != nil {
return err
return nil, fmt.Errorf("%v: %w", err, ErrFailedToConnect)
}
// Lock and set the parameters.
c.mu.Lock()
c.ethCl = ethCl
c.Endpoint = ep
c.mu.Unlock()
return nil
return c, nil
}
// IsConnected returns true if there is an active RPC connection with an
// Ethereum node at the configured endpoint.
// Function obtains a write lock while interacting with the Ethereum client.
func (c *Client) IsConnected() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.ethCl != nil
}
// Endpoint returns the endpoint the client was connected to.
func (c *Client) Endpoint() string {
return c.endpoint
}
// Resolve implements the resolver.Client interface.
// Function obtains a read lock while interacting with the Ethereum client.
func (c *Client) Resolve(name string) (Address, error) {
if c.resolveFn == nil {
return swarm.ZeroAddress, fmt.Errorf("resolveFn: %w", errNotImplemented)
}
// Obtain our copy of the client under lock.
c.mu.Lock()
ethCl := c.ethCl
c.mu.Unlock()
hash, err := c.resolveFn(ethCl, name)
hash, err := c.resolveFn(c.ethCl, name)
if err != nil {
return swarm.ZeroAddress, fmt.Errorf("%v: %w", err, ErrResolveFailed)
}
// In case the implementation returns a zero address return an NameNotFound
// error.
if hash == "" {
return swarm.ZeroAddress, fmt.Errorf("name %s: %w", name, ErrNameNotFound)
}
// Ensure that the content hash string is in a valid format, eg.
// "/swarm/<address>".
if !strings.HasPrefix(hash, "/swarm/") {
if !strings.HasPrefix(hash, swarmContentHashPrefix) {
return swarm.ZeroAddress, fmt.Errorf("contenthash %s: %w", hash, ErrInvalidContentHash)
}
// Trim the prefix and try to parse the result as a bzz address.
return swarm.ParseHexAddress(strings.TrimPrefix(hash, "/swarm/"))
return swarm.ParseHexAddress(strings.TrimPrefix(hash, swarmContentHashPrefix))
}
// Close closes the RPC connection with the client, terminating all unfinished
// requests.
// Function obtains a write lock while interacting with the Ethereum client.
// requests. If the connection is already closed, this call is a noop.
func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.ethCl != nil {
c.ethCl.Close() // TODO: consider mocking out the eth client.
c.ethCl.Close()
}
c.ethCl = nil
return nil
}
func wrapResolve(backend bind.ContractBackend, name string) (string, error) {
// Connect to the ENS resolver for the provided name.
ensR, err := goens.NewResolver(backend, name)
if err != nil {
return "", err
}
// Try and read out the content hash record.
ch, err := ensR.Contenthash()
if err != nil {
return "", err
}
return goens.ContenthashToString(ch)
}
......@@ -7,53 +7,56 @@
package ens_test
import (
"strings"
"errors"
"testing"
"github.com/ethersphere/bee/pkg/resolver/client/ens"
"github.com/ethersphere/bee/pkg/swarm"
)
func TestENSntegration(t *testing.T) {
// TODO: consider using a stable gateway instead of INFURA.
defaultEndpoint := "https://goerli.infura.io/v3/59d83a5a4be74f86b9851190c802297b"
defaultAddr := swarm.MustParseHexAddress("00cb23598c2e520b6a6aae3ddc94fed4435a2909690bdd709bf9d9e7c2aadfad")
testCases := []struct {
desc string
endpoint string
name string
wantAdr string
wantFailConnect bool
wantFailResolve bool
desc string
endpoint string
name string
wantAdr swarm.Address
wantErr error
}{
// TODO: add a test targeting a resolver with an invalid contenthash
// record.
{
desc: "bad ethclient endpoint",
endpoint: "fail",
wantFailConnect: true,
desc: "invalid resolver endpoint",
endpoint: "example.com",
wantErr: ens.ErrFailedToConnect,
},
{
desc: "no domain",
name: "idonthaveadomain",
wantFailResolve: true,
desc: "no domain",
name: "idonthaveadomain",
wantErr: ens.ErrResolveFailed,
},
{
desc: "no eth domain",
name: "centralized.com",
wantFailResolve: true,
desc: "no eth domain",
name: "centralized.com",
wantErr: ens.ErrResolveFailed,
},
{
desc: "not registered",
name: "unused.test.swarm.eth",
wantFailResolve: true,
desc: "not registered",
name: "unused.test.swarm.eth",
wantErr: ens.ErrResolveFailed,
},
{
desc: "no content hash",
name: "nocontent.resolver.test.swarm.eth",
wantFailResolve: true,
desc: "no content hash",
name: "nocontent.resolver.test.swarm.eth",
wantErr: ens.ErrResolveFailed,
},
{
desc: "ok",
name: "example.resolver.test.swarm.eth",
wantAdr: "00cb23598c2e520b6a6aae3ddc94fed4435a2909690bdd709bf9d9e7c2aadfad",
wantAdr: defaultAddr,
},
}
for _, tC := range testCases {
......@@ -62,34 +65,30 @@ func TestENSntegration(t *testing.T) {
tC.endpoint = defaultEndpoint
}
eC := ens.NewClient()
defer eC.Close()
err := eC.Connect(tC.endpoint)
ensClient, err := ens.NewClient(tC.endpoint)
if err != nil {
if !tC.wantFailConnect {
t.Fatalf("failed to connect: %v", err)
if !errors.Is(err, ens.ErrFailedToConnect) {
t.Errorf("got %v, want %v", err, tC.wantErr)
}
return
}
defer ensClient.Close()
addr, err := eC.Resolve(tC.name)
addr, err := ensClient.Resolve(tC.name)
if err != nil {
if !tC.wantFailResolve {
t.Fatalf("failed to resolve name: %v", err)
if !errors.Is(err, tC.wantErr) {
t.Errorf("got %v, want %v", err, tC.wantErr)
}
return
}
want := strings.ToLower(tC.wantAdr)
got := strings.ToLower(addr.String())
if got != want {
t.Errorf("bad addr: got %q, want %q", got, want)
if !addr.Equal(defaultAddr) {
t.Errorf("bad addr: got %s, want %s", addr, defaultAddr)
}
eC.Close()
if eC.IsConnected() {
t.Errorf("IsConnected: got true, want false")
err = ensClient.Close()
if err != nil {
t.Fatal(err)
}
})
}
......
......@@ -8,159 +8,172 @@ import (
"errors"
"testing"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/ethereum/go-ethereum/rpc"
"github.com/ethersphere/bee/pkg/resolver/client/ens"
"github.com/ethersphere/bee/pkg/swarm"
)
func TestNewClient(t *testing.T) {
cl := ens.NewClient()
if cl.Endpoint != "" {
t.Errorf("expected no endpoint set")
func TestNewENSClient(t *testing.T) {
testCases := []struct {
desc string
endpoint string
dialFn func(string) (*ethclient.Client, error)
wantErr error
wantEndpoint string
}{
{
desc: "nil dial function",
endpoint: "someaddress.net",
dialFn: nil,
wantErr: ens.ErrNotImplemented,
},
{
desc: "error in dial function",
endpoint: "someaddress.com",
dialFn: func(string) (*ethclient.Client, error) {
return nil, errors.New("dial error")
},
wantErr: ens.ErrFailedToConnect,
},
{
desc: "regular endpoint",
endpoint: "someaddress.org",
dialFn: func(string) (*ethclient.Client, error) {
return &ethclient.Client{}, nil
},
wantEndpoint: "someaddress.org",
},
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
cl, err := ens.NewClient(tC.endpoint,
ens.WithDialFunc(tC.dialFn),
)
if err != nil {
if !errors.Is(err, tC.wantErr) {
t.Errorf("got %v, want %v", err, tC.wantErr)
}
return
}
if got := cl.Endpoint(); got != tC.wantEndpoint {
t.Errorf("endpoint: got %v, want %v", got, tC.wantEndpoint)
}
if got := cl.IsConnected(); got != true {
t.Errorf("connected: got %v, want true", got)
}
})
}
}
func TestConnect(t *testing.T) {
ep := "test"
t.Run("no dial func error", func(t *testing.T) {
c := ens.NewClient(
ens.WithDialFunc(nil),
)
err := c.Connect(ep)
defer c.Close()
if !errors.Is(err, ens.ErrNotImplemented) {
t.Fatal("expected correct error")
}
})
t.Run("connect error", func(t *testing.T) {
c := ens.NewClient(
ens.WithErrorDialFunc(errors.New("failed to connect")),
)
if err := c.Connect("test"); err == nil {
t.Fatal("expected error")
}
c.Close()
})
t.Run("ok", func(t *testing.T) {
c := ens.NewClient(
ens.WithNoopDialFunc(),
)
if err := c.Connect(ep); err != nil {
t.Fatal(err)
}
// Override the eth client to test connection.
ens.SetEthClient(c, &ethclient.Client{})
if c.Endpoint != ep {
t.Errorf("bad endpoint: got %q, want %q", c.Endpoint, ep)
}
if !c.IsConnected() {
t.Error("IsConnected: got false, want true")
}
// We are not really connected, so clear the client to prevent panic.
ens.SetEthClient(c, nil)
c.Close()
if c.IsConnected() {
t.Error("IsConnected: got true, want false")
}
})
}
func TestResolve(t *testing.T) {
name := "hello"
bzzAddress := swarm.MustParseHexAddress(
"6f4eeb99d0a144d78ac33cf97091a59a6291aa78929938defcf967e74326e08b",
)
t.Run("no resolve func error", func(t *testing.T) {
c := ens.NewClient(
ens.WithResolveFunc(nil),
)
_, err := c.Resolve("test")
if !errors.Is(err, ens.ErrNotImplemented) {
t.Fatal("expected correct error")
}
})
func TestClose(t *testing.T) {
t.Run("connected", func(t *testing.T) {
rpcServer := rpc.NewServer()
defer rpcServer.Stop()
ethCl := ethclient.NewClient(rpc.DialInProc(rpcServer))
t.Run("resolve error", func(t *testing.T) {
c := ens.NewClient(
ens.WithNoopDialFunc(),
ens.WithErrorResolveFunc(errors.New("resolve error")),
cl, err := ens.NewClient("",
ens.WithDialFunc(func(string) (*ethclient.Client, error) {
return ethCl, nil
}),
)
if err := c.Connect(name); err != nil {
if err != nil {
t.Fatal(err)
}
defer c.Close()
_, err := c.Resolve(name)
if !errors.Is(err, ens.ErrResolveFailed) {
t.Error("expected resolve error")
}
})
t.Run("zero address returned", func(t *testing.T) {
c := ens.NewClient(
ens.WithNoopDialFunc(),
ens.WithZeroAdrResolveFunc(),
)
if err := c.Connect(name); err != nil {
err = cl.Close()
if err != nil {
t.Fatal(err)
}
defer c.Close()
_, err := c.Resolve(name)
if !errors.Is(err, ens.ErrNameNotFound) {
t.Error("expected name not found error")
if cl.IsConnected() {
t.Error("IsConnected == true")
}
})
t.Run("resolved without address prefix error", func(t *testing.T) {
c := ens.NewClient(
ens.WithNoopDialFunc(),
ens.WithNoprefixAdrResolveFunc(bzzAddress),
t.Run("not connected", func(t *testing.T) {
cl, err := ens.NewClient("",
ens.WithDialFunc(func(string) (*ethclient.Client, error) {
return nil, nil
}),
)
if err := c.Connect(name); err != nil {
if err != nil {
t.Fatal(err)
}
defer c.Close()
_, err := c.Resolve(name)
if err == nil {
t.Error("expected error")
}
})
t.Run("ok", func(t *testing.T) {
c := ens.NewClient(
ens.WithNoopDialFunc(),
ens.WithValidAdrResolveFunc(bzzAddress),
)
if err := c.Connect(name); err != nil {
err = cl.Close()
if err != nil {
t.Fatal(err)
}
defer c.Close()
addr, err := c.Resolve(name)
if err != nil {
t.Error(err)
}
want := bzzAddress.String()
got := addr.String()
if got != want {
t.Errorf("got %q, want %q", got, want)
if cl.IsConnected() {
t.Error("IsConnected == true")
}
})
}
func TestResolve(t *testing.T) {
addr := swarm.MustParseHexAddress("aaabbbcc")
testCases := []struct {
desc string
name string
resolveFn func(bind.ContractBackend, string) (string, error)
wantErr error
}{
{
desc: "nil resolve function",
resolveFn: nil,
wantErr: ens.ErrNotImplemented,
},
{
desc: "resolve function internal error",
resolveFn: func(bind.ContractBackend, string) (string, error) {
return "", errors.New("internal error")
},
wantErr: ens.ErrResolveFailed,
},
{
desc: "resolver returns empty string",
resolveFn: func(bind.ContractBackend, string) (string, error) {
return "", nil
},
wantErr: ens.ErrInvalidContentHash,
},
{
desc: "resolve does not prefix address with /swarm",
resolveFn: func(bind.ContractBackend, string) (string, error) {
return addr.String(), nil
},
wantErr: ens.ErrInvalidContentHash,
},
{
desc: "resolve returns prefixed address",
resolveFn: func(bind.ContractBackend, string) (string, error) {
return ens.SwarmContentHashPrefix + addr.String(), nil
},
wantErr: ens.ErrInvalidContentHash,
},
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
cl, err := ens.NewClient("example.com",
ens.WithDialFunc(func(string) (*ethclient.Client, error) {
return nil, nil
}),
ens.WithResolveFunc(tC.resolveFn),
)
if err != nil {
t.Fatal(err)
}
_, err = cl.Resolve(tC.name)
if err != nil {
if !errors.Is(err, tC.wantErr) {
t.Errorf("got %v, want %v", err, tC.wantErr)
}
return
}
})
}
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ens
import (
"errors"
)
var (
// ErrInvalidContentHash denotes that the value of the contenthash record is
// not valid.
ErrInvalidContentHash = errors.New("invalid swarm content hash")
// ErrResolveFailed is returned when a name could not be resolved.
ErrResolveFailed = errors.New("resolve failed")
// ErrNameNotFound is returned when a name resolves to an empty contenthash
// record.
ErrNameNotFound = errors.New("name not found")
)
var (
// errNotImplemented denotes that the function has not been implemented.
errNotImplemented = errors.New("function not implemented")
)
......@@ -7,17 +7,11 @@ package ens
import (
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/ethersphere/bee/pkg/resolver"
"github.com/ethersphere/bee/pkg/swarm"
)
var (
ErrNotImplemented = errNotImplemented
)
const SwarmContentHashPrefix = swarmContentHashPrefix
func SetEthClient(c *Client, ethCl *ethclient.Client) {
c.ethCl = ethCl
}
var ErrNotImplemented = errNotImplemented
// WithDialFunc will set the Dial function implementaton.
func WithDialFunc(fn func(ep string) (*ethclient.Client, error)) Option {
......@@ -26,45 +20,9 @@ func WithDialFunc(fn func(ep string) (*ethclient.Client, error)) Option {
}
}
func WithErrorDialFunc(err error) Option {
return WithDialFunc(func(ep string) (*ethclient.Client, error) {
return nil, err
})
}
func WithNoopDialFunc() Option {
return WithDialFunc(func(ep string) (*ethclient.Client, error) {
return nil, nil
})
}
// WithResolveFunc will set the Resolve function implementation.
func WithResolveFunc(fn func(backend bind.ContractBackend, input string) (string, error)) Option {
return func(c *Client) {
c.resolveFn = fn
}
}
func WithErrorResolveFunc(err error) Option {
return WithResolveFunc(func(backend bind.ContractBackend, input string) (string, error) {
return "", err
})
}
func WithZeroAdrResolveFunc() Option {
return WithResolveFunc(func(backend bind.ContractBackend, input string) (string, error) {
return swarm.ZeroAddress.String(), nil
})
}
func WithNoprefixAdrResolveFunc(addr resolver.Address) Option {
return WithResolveFunc(func(backend bind.ContractBackend, input string) (string, error) {
return addr.String(), nil
})
}
func WithValidAdrResolveFunc(addr resolver.Address) Option {
return WithResolveFunc(func(backend bind.ContractBackend, input string) (string, error) {
return "/swarm/" + addr.String(), nil
})
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ens
import (
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/ethclient"
goens "github.com/wealdtech/go-ens/v3"
)
func wrapDial(ep string) (*ethclient.Client, error) {
// Open a connection to the ethereum node through the endpoint.
cl, err := ethclient.Dial(ep)
if err != nil {
return nil, err
}
// Ensure the ENS resolver contract is deployed on the network we are now
// connected to.
if _, err := goens.PublicResolverAddress(cl); err != nil {
return nil, err
}
return cl, nil
}
func wrapResolve(backend bind.ContractBackend, name string) (string, error) {
// Connect to the ENS resolver for the provided name.
ensR, err := goens.NewResolver(backend, name)
if err != nil {
return "", err
}
// Try and read out the content hash record.
ch, err := ensR.Contenthash()
if err != nil {
return "", err
}
addr, err := goens.ContenthashToString(ch)
if err != nil {
return "", err
}
return addr, nil
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mock
import (
"github.com/ethersphere/bee/pkg/resolver/client"
"github.com/ethersphere/bee/pkg/swarm"
)
// Ensure mock Client implements the Client interface.
var _ client.Interface = (*Client)(nil)
// Client is the mock resolver client implementation.
type Client struct {
isConnected bool
endpoint string
defaultAddress swarm.Address
resolveFn func(string) (swarm.Address, error)
}
// Option is a function that applies an option to a Client.
type Option func(*Client)
// NewClient construct a new mock Client.
func NewClient(opts ...Option) *Client {
cl := &Client{}
for _, o := range opts {
o(cl)
}
cl.isConnected = true
return cl
}
// WithEndpoint will set the endpoint.
func WithEndpoint(endpoint string) Option {
return func(cl *Client) {
cl.endpoint = endpoint
}
}
// WitResolveAddress will set the address returned by Resolve.
func WitResolveAddress(addr swarm.Address) Option {
return func(cl *Client) {
cl.defaultAddress = addr
}
}
// WithResolveFunc will set the Resolve function implementation.
func WithResolveFunc(fn func(string) (swarm.Address, error)) Option {
return func(cl *Client) {
cl.resolveFn = fn
}
}
// IsConnected is the mock IsConnected implementation.
func (cl *Client) IsConnected() bool {
return cl.isConnected
}
// Endpoint is the mock Endpoint implementation.
func (cl *Client) Endpoint() string {
return cl.endpoint
}
// Resolve is the mock Resolve implementation
func (cl *Client) Resolve(name string) (swarm.Address, error) {
if cl.resolveFn == nil {
return cl.defaultAddress, nil
}
return cl.resolveFn(name)
}
// Close is the mock Close implementation.
func (cl *Client) Close() error {
cl.isConnected = false
return nil
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package resolver
import (
"errors"
"fmt"
"strings"
)
var (
// ErrTLDTooLong denotes when a TLD in a name exceeds maximum length.
ErrTLDTooLong = fmt.Errorf("TLD exceeds maximum length of %d characters", maxTLDLength)
// ErrInvalidTLD denotes passing an invalid TLD to the MultiResolver.
ErrInvalidTLD = errors.New("invalid TLD")
// ErrResolverChainEmpty denotes trying to pop an empty resolver chain.
ErrResolverChainEmpty = errors.New("resolver chain empty")
)
// CloseError denotes that at least one resolver in the MultiResolver has
// had an error when Close was called.
type CloseError struct {
errs []error
}
func (me CloseError) add(err error) {
if err != nil {
me.errs = append(me.errs, err)
}
}
func (me CloseError) errorOrNil() error {
if len(me.errs) > 0 {
return me
}
return nil
}
// Error returns a formatted multi close error.
func (me CloseError) Error() string {
if len(me.errs) == 0 {
return ""
}
var b strings.Builder
b.WriteString("multiresolver failed to close: ")
for _, e := range me.errs {
b.WriteString(e.Error())
b.WriteString("; ")
}
return b.String()
}
......@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package resolver
package multiresolver
import (
"fmt"
......@@ -27,7 +27,7 @@ type ConnectionConfig struct {
// ParseConnectionString will try to parse a connection string used to connect
// the Resolver to a name resolution service. The resulting config can be
// used to initialize a resovler Service.
func parseConnectionString(cs string) (*ConnectionConfig, error) {
func parseConnectionString(cs string) (ConnectionConfig, error) {
isAllUnicodeLetters := func(s string) bool {
for _, r := range s {
if !unicode.IsLetter(r) {
......@@ -48,7 +48,7 @@ func parseConnectionString(cs string) (*ConnectionConfig, error) {
if isAllUnicodeLetters(endpoint[:i]) && len(endpoint) > i+2 && endpoint[i+1:i+3] != "//" {
tld = endpoint[:i]
if len(tld) > maxTLDLength {
return nil, fmt.Errorf("%w: %s", ErrTLDTooLong, tld)
return ConnectionConfig{}, fmt.Errorf("tld %s: %w", tld, ErrTLDTooLong)
}
endpoint = endpoint[i+1:]
......@@ -60,7 +60,7 @@ func parseConnectionString(cs string) (*ConnectionConfig, error) {
endpoint = endpoint[i+1:]
}
return &ConnectionConfig{
return ConnectionConfig{
Endpoint: endpoint,
Address: addr,
TLD: tld,
......@@ -69,8 +69,8 @@ func parseConnectionString(cs string) (*ConnectionConfig, error) {
// ParseConnectionStrings will apply ParseConnectionString to each connection
// string. Returns first error found.
func ParseConnectionStrings(cstrs []string) ([]*ConnectionConfig, error) {
var res []*ConnectionConfig
func ParseConnectionStrings(cstrs []string) ([]ConnectionConfig, error) {
var res []ConnectionConfig
for _, cs := range cstrs {
cfg, err := parseConnectionString(cs)
......
......@@ -2,20 +2,20 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package resolver_test
package multiresolver_test
import (
"errors"
"testing"
"github.com/ethersphere/bee/pkg/resolver"
"github.com/ethersphere/bee/pkg/resolver/multiresolver"
)
func TestParseConnectionStrings(t *testing.T) {
testCases := []struct {
desc string
conStrings []string
wantCfg []resolver.ConnectionConfig
wantCfg []multiresolver.ConnectionConfig
wantErr error
}{
{
......@@ -25,14 +25,14 @@ func TestParseConnectionStrings(t *testing.T) {
conStrings: []string{
"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff:example.com",
},
wantErr: resolver.ErrTLDTooLong,
wantErr: multiresolver.ErrTLDTooLong,
},
{
desc: "single endpoint default tld",
conStrings: []string{
"https://example.com",
},
wantCfg: []resolver.ConnectionConfig{
wantCfg: []multiresolver.ConnectionConfig{
{
TLD: "",
Endpoint: "https://example.com",
......@@ -44,7 +44,7 @@ func TestParseConnectionStrings(t *testing.T) {
conStrings: []string{
"tld:https://example.com",
},
wantCfg: []resolver.ConnectionConfig{
wantCfg: []multiresolver.ConnectionConfig{
{
TLD: "tld",
Endpoint: "https://example.com",
......@@ -56,7 +56,7 @@ func TestParseConnectionStrings(t *testing.T) {
conStrings: []string{
"0x314159265dD8dbb310642f98f50C066173C1259b@https://example.com",
},
wantCfg: []resolver.ConnectionConfig{
wantCfg: []multiresolver.ConnectionConfig{
{
TLD: "",
Address: "0x314159265dD8dbb310642f98f50C066173C1259b",
......@@ -69,7 +69,7 @@ func TestParseConnectionStrings(t *testing.T) {
conStrings: []string{
"tld:0x314159265dD8dbb310642f98f50C066173C1259b@https://example.com",
},
wantCfg: []resolver.ConnectionConfig{
wantCfg: []multiresolver.ConnectionConfig{
{
TLD: "tld",
Address: "0x314159265dD8dbb310642f98f50C066173C1259b",
......@@ -85,7 +85,7 @@ func TestParseConnectionStrings(t *testing.T) {
"yesyesyes:0x314159265dD8dbb310642f98f50C066173C1259b@2.2.2.2",
"cloudflare-ethereum.org",
},
wantCfg: []resolver.ConnectionConfig{
wantCfg: []multiresolver.ConnectionConfig{
{
TLD: "tld",
Endpoint: "https://example.com",
......@@ -112,12 +112,12 @@ func TestParseConnectionStrings(t *testing.T) {
"testdomain:wowzers.map",
"nonononononononononononononononononononononononononononononononononono:yes",
},
wantErr: resolver.ErrTLDTooLong,
wantErr: multiresolver.ErrTLDTooLong,
},
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
got, err := resolver.ParseConnectionStrings(tC.conStrings)
got, err := multiresolver.ParseConnectionStrings(tC.conStrings)
if err != nil {
if !errors.Is(err, tC.wantErr) {
t.Errorf("got error %v", err)
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package multiresolver
import "github.com/ethersphere/bee/pkg/logging"
func GetLogger(mr *MultiResolver) logging.Logger {
return mr.logger
}
func GetCfgs(mr *MultiResolver) []ConnectionConfig {
return mr.cfgs
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package multierror
import (
"errors"
"fmt"
"strings"
)
// Ensure multierror implements Error interface.
var _ error = (*Error)(nil)
// Error is an error type to track multiple errors. This can be used to
// accumulate errors and return them as a single "error" type.
type Error struct {
Errors []error
}
// New will return a new multierror.
func New(errs ...error) *Error {
e := Error{}
e.Append(errs...)
return &e
}
func (e *Error) Error() string {
return format(e.Errors)
}
// Append will append errors to the multierror.
func (e *Error) Append(errs ...error) {
e.Errors = append(e.Errors, errs...)
}
// ErrorOrNil returns an error interface if the multierror represents a list of
// errors or nil if the list is empty.
func (e *Error) ErrorOrNil() error {
if e == nil || len(e.Errors) == 0 {
return nil
}
return e
}
// WrapErrorOrNil will wrap the given error if the multierror contains errors.
func (e *Error) WrapErrorOrNil(toWrap error) error {
if err := e.ErrorOrNil(); err != nil {
return fmt.Errorf("%v: %w", err, toWrap)
}
return nil
}
// Unwrap returns an error from Error (or nil if there are no errors).
// The error returned supports Unwrap, so that the entire chain of errors can
// be unwrapped. The order will match the order of Errors at the time of
// calling.
//
// This will perform a shallow copy of the errors slice. Any errors appended
// to this error after calling Unwrap will not be available until a new
// Unwrap is called on the multierror.Error.
func (e *Error) Unwrap() error {
if e == nil || len(e.Errors) == 0 {
return nil
}
if len(e.Errors) == 0 {
return e.Errors[0]
}
// Shallow copy the error slice.
errs := make([]error, len(e.Errors))
copy(errs, e.Errors)
return chain(errs)
}
type chain []error
// Error implements the error interface.
func (ec chain) Error() string {
return ec[0].Error()
}
// Unwrap implements errors.Unwrap by returning the next error in the chain or
// nil if there are no more errors.
func (ec chain) Unwrap() error {
if len(ec) == 1 {
return nil
}
// Return the rest of the chain.
return ec[1:]
}
// As implements errors.As by attempting to map the current value.
func (ec chain) As(target interface{}) bool {
return errors.As(ec[0], target)
}
// Is implements errors.Is by comparing the current value directly.
func (ec chain) Is(target error) bool {
return errors.Is(ec[0], target)
}
func format(errs []error) string {
if len(errs) == 1 {
return fmt.Sprintf("1 error occurred: %s", errs[0])
}
msgs := make([]string, len(errs))
for i, err := range errs {
msgs[i] = fmt.Sprintf("%s", err)
}
return fmt.Sprintf("%d errors occurred: %s",
len(errs), strings.Join(msgs, ", "))
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package multierror_test
import (
"errors"
"fmt"
"testing"
"github.com/ethersphere/bee/pkg/resolver/multiresolver/multierror"
)
// nestedError implements error and is used for tests.
type nestedError struct{}
func (*nestedError) Error() string { return "" }
func TestErrorError(t *testing.T) {
t.Run("one error", func(t *testing.T) {
want := "1 error occurred: foo"
multi := multierror.New(errors.New("foo"))
if multi.Error() != want {
t.Fatalf("got: %q, want %q", multi.Error(), want)
}
})
t.Run("multiple errors", func(t *testing.T) {
want := "2 errors occurred: foo, bar"
multi := multierror.New(
errors.New("foo"),
errors.New("bar"),
)
if multi.Error() != want {
t.Fatalf("got: %q, want %q", multi.Error(), want)
}
})
}
func TestErrorErrorOrNil(t *testing.T) {
err := multierror.New()
if err.ErrorOrNil() != nil {
t.Fatalf("bad: %#v", err.ErrorOrNil())
}
err.Errors = []error{errors.New("foo")}
if got := err.ErrorOrNil(); got == nil {
t.Fatal("should not be nil")
} else if got != err {
t.Fatalf("bad: %#v", got)
}
}
func TestErrorWrapErrorOrNil(t *testing.T) {
wrapErr := errors.New("wrapper error")
err := multierror.New()
if err.WrapErrorOrNil(wrapErr) != nil {
t.Fatalf("bad: %#v", err.ErrorOrNil())
}
multi := multierror.New(
errors.New("foo"),
errors.New("bar"),
)
if got := multi.WrapErrorOrNil(wrapErr); got == nil {
t.Fatal("should not be nil")
} else if !errors.Is(got, wrapErr) {
t.Fatalf("bad: %#v", got)
}
}
func TestErrorUnwrap(t *testing.T) {
t.Run("with errors", func(t *testing.T) {
err := multierror.New(
errors.New("foo"),
errors.New("bar"),
errors.New("baz"),
)
var current error = err
for i := 0; i < len(err.Errors); i++ {
current = errors.Unwrap(current)
if !errors.Is(current, err.Errors[i]) {
t.Fatal("should be next value")
}
}
if errors.Unwrap(current) != nil {
t.Fatal("should be nil at the end")
}
})
t.Run("with no errors", func(t *testing.T) {
err := multierror.New()
if errors.Unwrap(err) != nil {
t.Fatal("should be nil")
}
})
t.Run("with nil multierror", func(t *testing.T) {
var err *multierror.Error
if errors.Unwrap(err) != nil {
t.Fatal("should be nil")
}
})
}
func TestErrorIs(t *testing.T) {
errBar := errors.New("bar")
t.Run("with errBar", func(t *testing.T) {
err := multierror.New(
errors.New("foo"),
errBar,
errors.New("baz"),
)
if !errors.Is(err, errBar) {
t.Fatal("should be true")
}
})
t.Run("with errBar wrapped by fmt.Errorf", func(t *testing.T) {
err := multierror.New(
errors.New("foo"),
fmt.Errorf("errorf: %w", errBar),
errors.New("baz"),
)
if !errors.Is(err, errBar) {
t.Fatal("should be true")
}
})
t.Run("without errBar", func(t *testing.T) {
err := multierror.New(
errors.New("foo"),
errors.New("baz"),
)
if errors.Is(err, errBar) {
t.Fatal("should be false")
}
})
}
func TestErrorAs(t *testing.T) {
match := &nestedError{}
t.Run("with the value", func(t *testing.T) {
err := multierror.New(
errors.New("foo"),
match,
errors.New("baz"),
)
var target *nestedError
if !errors.As(err, &target) {
t.Fatal("should be true")
}
if target == nil {
t.Fatal("target should not be nil")
}
})
t.Run("with the value wrapped by fmt.Errorf", func(t *testing.T) {
err := multierror.New(
errors.New("foo"),
fmt.Errorf("errorf: %w", match),
errors.New("baz"),
)
var target *nestedError
if !errors.As(err, &target) {
t.Fatal("should be true")
}
if target == nil {
t.Fatal("target should not be nil")
}
})
t.Run("without the value", func(t *testing.T) {
err := multierror.New(
errors.New("foo"),
errors.New("baz"),
)
var target *nestedError
if errors.As(err, &target) {
t.Fatal("should be false")
}
if target != nil {
t.Fatal("target should be nil")
}
})
}
......@@ -2,24 +2,45 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package resolver
package multiresolver
import (
"errors"
"fmt"
"io/ioutil"
"path"
"strings"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/resolver"
"github.com/ethersphere/bee/pkg/resolver/client/ens"
"github.com/ethersphere/bee/pkg/resolver/multiresolver/multierror"
)
// Ensure MultiResolver implements Resolver interface.
var _ Interface = (*MultiResolver)(nil)
var _ resolver.Interface = (*MultiResolver)(nil)
var (
// ErrTLDTooLong denotes when a TLD in a name exceeds maximum length.
ErrTLDTooLong = fmt.Errorf("TLD exceeds maximum length of %d characters", maxTLDLength)
// ErrInvalidTLD denotes passing an invalid TLD to the MultiResolver.
ErrInvalidTLD = errors.New("invalid TLD")
// ErrResolverChainEmpty denotes trying to pop an empty resolver chain.
ErrResolverChainEmpty = errors.New("resolver chain empty")
// ErrResolverChainFailed denotes that an entire name resolution chain
// for a given TLD failed.
ErrResolverChainFailed = errors.New("resolver chain failed")
// ErrCloseFailed denotes that closing the multiresolver failed.
ErrCloseFailed = errors.New("close failed")
)
type resolverMap map[string][]Interface
type resolverMap map[string][]resolver.Interface
// MultiResolver performs name resolutions based on the TLD label in the name.
type MultiResolver struct {
resolvers resolverMap
logger logging.Logger
cfgs []ConnectionConfig
// ForceDefault will force all names to be resolved by the default
// resolution chain, regadless of their TLD.
ForceDefault bool
......@@ -39,9 +60,47 @@ func NewMultiResolver(opts ...Option) *MultiResolver {
o(mr)
}
// Discard log output by default.
if mr.logger == nil {
mr.logger = logging.New(ioutil.Discard, 0)
}
log := mr.logger
if len(mr.cfgs) == 0 {
log.Info("name resolver: no name resolution service provided")
return mr
}
// Attempt to conect to each resolver using the connection string.
for _, c := range mr.cfgs {
// Warn user that the resolver address field is not used.
if c.Address != "" {
log.Warningf("name resolver: connection string %q contains resolver address field, which is currently unused", c.Address)
}
// NOTE: if we want to create a specific client based on the TLD
// we can do it here.
mr.connectENSClient(c.TLD, c.Endpoint)
}
return mr
}
// WithConnectionConfigs will set the initial connection configuration.
func WithConnectionConfigs(cfgs []ConnectionConfig) Option {
return func(mr *MultiResolver) {
mr.cfgs = cfgs
}
}
// WithLogger will set the logger used by the MultiResolver.
func WithLogger(logger logging.Logger) Option {
return func(mr *MultiResolver) {
mr.logger = logger
}
}
// WithForceDefault will force resolution using the default resolver chain.
func WithForceDefault() Option {
return func(mr *MultiResolver) {
......@@ -50,27 +109,14 @@ func WithForceDefault() Option {
}
// PushResolver will push a new Resolver to the name resolution chain for the
// given TLD.
// TLD names should be prepended with a dot (eg ".tld"). An empty TLD will push
// to the default resolver chain.
func (mr *MultiResolver) PushResolver(tld string, r Interface) error {
if tld != "" && !isTLD(tld) {
return fmt.Errorf("tld %s: %w", tld, ErrInvalidTLD)
}
// given TLD. An empty TLD will push to the default resolver chain.
func (mr *MultiResolver) PushResolver(tld string, r resolver.Interface) {
mr.resolvers[tld] = append(mr.resolvers[tld], r)
return nil
}
// PopResolver will pop the last reslover from the name resolution chain for the
// given TLD.
// TLD names should be prepended with a dot (eg ".tld"). An empty TLD will pop
// from the default resolver chain.
// given TLD. An empty TLD will pop from the default resolver chain.
func (mr *MultiResolver) PopResolver(tld string) error {
if tld != "" && !isTLD(tld) {
return fmt.Errorf("tld %s: %w", tld, ErrInvalidTLD)
}
l := len(mr.resolvers[tld])
if l == 0 {
return fmt.Errorf("tld %s: %w", tld, ErrResolverChainEmpty)
......@@ -90,7 +136,7 @@ func (mr *MultiResolver) ChainCount(tld string) int {
// GetChain will return the resolution chain for a given TLD.
// TLD names should be prepended with a dot (eg ".tld"). An empty TLD will
// return all resolvers in the default resolver chain.
func (mr *MultiResolver) GetChain(tld string) []Interface {
func (mr *MultiResolver) GetChain(tld string) []resolver.Interface {
return mr.resolvers[tld]
}
......@@ -100,8 +146,7 @@ func (mr *MultiResolver) GetChain(tld string) []Interface {
// The resolution will be performed iteratively on the resolution chain,
// returning the result of the first Resolver that succeeds. If all resolvers
// in the chain return an error, the function will return an ErrResolveFailed.
func (mr *MultiResolver) Resolve(name string) (Address, error) {
func (mr *MultiResolver) Resolve(name string) (addr resolver.Address, err error) {
tld := ""
if !mr.ForceDefault {
tld = getTLD(name)
......@@ -113,35 +158,46 @@ func (mr *MultiResolver) Resolve(name string) (Address, error) {
chain = mr.resolvers[""]
}
addr := swarm.ZeroAddress
var err error
errs := multierror.New()
for _, res := range chain {
addr, err = res.Resolve(name)
if err == nil {
return addr, nil
}
errs.Append(err)
}
return addr, err
return addr, errs.ErrorOrNil()
}
// Close all will call Close on all resolvers in all resolver chains.
func (mr *MultiResolver) Close() error {
errs := new(CloseError)
errs := multierror.New()
for _, chain := range mr.resolvers {
for _, r := range chain {
errs.add(r.Close())
if err := r.Close(); err != nil {
errs.Append(err)
}
}
}
return errs.errorOrNil()
}
func isTLD(tld string) bool {
return len(tld) > 1 && tld[0] == '.'
return errs.ErrorOrNil()
}
func getTLD(name string) string {
return path.Ext(strings.ToLower(name))
}
func (mr *MultiResolver) connectENSClient(tld string, endpoint string) {
log := mr.logger
log.Debugf("name resolver: resolver for %q: connecting to endpoint %s", tld, endpoint)
ensCl, err := ens.NewClient(endpoint)
if err != nil {
log.Errorf("name resolver: resolver for %q domain: failed to connect to %q: %v", tld, endpoint, err)
} else {
log.Infof("name resolver: resolver for %q domain: connected to %s", tld, endpoint)
mr.PushResolver(tld, ensCl)
}
}
......@@ -2,16 +2,19 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package resolver_test
package multiresolver_test
import (
"errors"
"fmt"
"io/ioutil"
"reflect"
"testing"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/resolver"
"github.com/ethersphere/bee/pkg/resolver/mock"
"github.com/ethersphere/bee/pkg/resolver/multiresolver"
"github.com/ethersphere/bee/pkg/swarm"
)
......@@ -21,11 +24,33 @@ func newAddr(s string) Address {
return swarm.NewAddress([]byte(s))
}
func TestWithForceDefault(t *testing.T) {
mr := resolver.NewMultiResolver(
resolver.WithForceDefault(),
func TestMultiresolverOpts(t *testing.T) {
wantLog := logging.New(ioutil.Discard, 1)
wantCfgs := []multiresolver.ConnectionConfig{
{
Address: "testadr1",
Endpoint: "testEndpoint1",
TLD: "testtld1",
},
{
Address: "testadr2",
Endpoint: "testEndpoint2",
TLD: "testtld2",
},
}
mr := multiresolver.NewMultiResolver(
multiresolver.WithLogger(wantLog),
multiresolver.WithConnectionConfigs(wantCfgs),
multiresolver.WithForceDefault(),
)
if got := multiresolver.GetLogger(mr); got != wantLog {
t.Errorf("log: got: %v, want %v", got, wantLog)
}
if got := multiresolver.GetCfgs(mr); !reflect.DeepEqual(got, wantCfgs) {
t.Errorf("cfg: got: %v, want %v", got, wantCfgs)
}
if !mr.ForceDefault {
t.Error("did not set ForceDefault")
}
......@@ -45,29 +70,18 @@ func TestPushResolver(t *testing.T) {
desc: "regular tld, named chain",
tld: ".tld",
},
{
desc: "invalid tld",
tld: "invalid",
wantErr: resolver.ErrInvalidTLD,
},
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
mr := resolver.NewMultiResolver()
mr := multiresolver.NewMultiResolver()
if mr.ChainCount(tC.tld) != 0 {
t.Fatal("chain should start empty")
}
want := mock.NewResolver()
err := mr.PushResolver(tC.tld, want)
if err != nil {
if !errors.Is(err, tC.wantErr) {
t.Fatal(err)
}
return
}
mr.PushResolver(tC.tld, want)
got := mr.GetChain(tC.tld)[0]
if !reflect.DeepEqual(got, want) {
......@@ -82,20 +96,11 @@ func TestPushResolver(t *testing.T) {
}
})
}
}
func TestPopResolver(t *testing.T) {
mr := resolver.NewMultiResolver()
t.Run("error on bad tld", func(t *testing.T) {
if err := mr.PopResolver("invalid"); !errors.Is(err, resolver.ErrInvalidTLD) {
t.Fatal("invalid error type")
}
})
t.Run("error on empty", func(t *testing.T) {
if err := mr.PopResolver(".tld"); !errors.Is(err, resolver.ErrResolverChainEmpty) {
t.Fatal("invalid error type")
t.Run("pop empty chain", func(t *testing.T) {
mr := multiresolver.NewMultiResolver()
err := mr.PopResolver("")
if !errors.Is(err, multiresolver.ErrResolverChainEmpty) {
t.Errorf("got %v, want %v", err, multiresolver.ErrResolverChainEmpty)
}
})
}
......@@ -104,6 +109,7 @@ func TestResolve(t *testing.T) {
addr := newAddr("aaaabbbbccccdddd")
addrAlt := newAddr("ddddccccbbbbaaaa")
errUnregisteredName := fmt.Errorf("unregistered name")
errResolutionFailed := fmt.Errorf("name resolution failed")
newOKResolver := func(addr Address) resolver.Interface {
return mock.NewResolver(
......@@ -115,8 +121,7 @@ func TestResolve(t *testing.T) {
newErrResolver := func() resolver.Interface {
return mock.NewResolver(
mock.WithResolveFunc(func(name string) (Address, error) {
err := fmt.Errorf("name resolution failed for %q", name)
return swarm.ZeroAddress, err
return swarm.ZeroAddress, errResolutionFailed
}),
)
}
......@@ -181,35 +186,35 @@ func TestResolve(t *testing.T) {
wantAdr Address
wantErr error
}{
// {
// name: "",
// wantAdr: testAdr,
// },
// {
// name: "hello",
// wantAdr: testAdr,
// },
// {
// name: "example.tld",
// wantAdr: testAdr,
// },
// {
// name: ".tld",
// wantAdr: testAdr,
// },
// {
// name: "get.good",
// wantAdr: testAdr,
// },
// {
// // Switch to the default chain:
// name: "this.empty",
// wantAdr: testAdr,
// },
// {
// name: "this.dies",
// wantErr: fmt.Errorf("Failed to resolve name %q", "this.dies"),
// },
{
name: "",
wantAdr: addr,
},
{
name: "hello",
wantAdr: addr,
},
{
name: "example.tld",
wantAdr: addr,
},
{
name: ".tld",
wantAdr: addr,
},
{
name: "get.good",
wantAdr: addr,
},
{
// Switch to the default chain:
name: "this.empty",
wantAdr: addr,
},
{
name: "this.dies",
wantErr: errResolutionFailed,
},
{
name: "iam.unregistered",
wantAdr: swarm.ZeroAddress,
......@@ -218,12 +223,10 @@ func TestResolve(t *testing.T) {
}
// Load the test fixture.
mr := resolver.NewMultiResolver()
mr := multiresolver.NewMultiResolver()
for _, tE := range testFixture {
for _, r := range tE.res {
if err := mr.PushResolver(tE.tld, r); err != nil {
t.Fatal(err)
}
mr.PushResolver(tE.tld, r)
}
}
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package service
import (
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/resolver"
"github.com/ethersphere/bee/pkg/resolver/client/ens"
)
// InitMultiResolver will create a new MultiResolver, create the appropriate
// resolvers, push them to the resolver chains and attempt to connect.
func InitMultiResolver(logger logging.Logger, cfgs []*resolver.ConnectionConfig) resolver.Interface {
if len(cfgs) == 0 {
logger.Info("name resolver: no name resolution service provided")
return nil
}
// Create a new MultiResolver.
mr := resolver.NewMultiResolver()
connectENS := func(tld string, ep string) {
ensCl := ens.NewClient()
logger.Debugf("name resolver: resolver for %q: connecting to endpoint %s", tld, ep)
if err := ensCl.Connect(ep); err != nil {
logger.Errorf("name resolver: resolver for %q domain: failed to connect to %q: %v", tld, ep, err)
} else {
logger.Infof("name resolver: resolver for %q domain: connected to %s", tld, ep)
if err := mr.PushResolver(tld, ensCl); err != nil {
logger.Errorf("name resolver: failed to push resolver to %q resolver chain: %v", tld, err)
}
}
}
// Attempt to conect to each resolver using the connection string.
for _, c := range cfgs {
// Warn user that the resolver address field is not used.
if c.Address != "" {
logger.Warningf("name resolver: connection string %q contains resolver address field, which is currently unused", c.Address)
}
// Select the appropriate resolver.
switch c.TLD {
case "eth":
// FIXME: MultiResolver expects "." in front of the TLD label.
connectENS("."+c.TLD, c.Endpoint)
case "":
connectENS("", c.Endpoint)
default:
logger.Errorf("default domain resolution not supported")
}
}
return mr
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment