Commit ec3f1675 authored by Pavle Batuta's avatar Pavle Batuta Committed by GitHub

Resolver refactor service (#670)

Broad-spectrum refactor of the resolver package.

- removed TLD constraints (all TLDs are now accepted)
- merged resolver service into multiresolver
- simplified client code
- refactored tests
- added MultiError
- other minor fixes
parent d60f4ab3
...@@ -22,7 +22,7 @@ import ( ...@@ -22,7 +22,7 @@ import (
memkeystore "github.com/ethersphere/bee/pkg/keystore/mem" memkeystore "github.com/ethersphere/bee/pkg/keystore/mem"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/node" "github.com/ethersphere/bee/pkg/node"
"github.com/ethersphere/bee/pkg/resolver" "github.com/ethersphere/bee/pkg/resolver/multiresolver"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
...@@ -57,10 +57,10 @@ func (c *command) initStartCmd() (err error) { ...@@ -57,10 +57,10 @@ func (c *command) initStartCmd() (err error) {
// If the resolver is specified, resolve all connection strings // If the resolver is specified, resolve all connection strings
// and fail on any errors. // and fail on any errors.
var resolverCfgs []*resolver.ConnectionConfig var resolverCfgs []multiresolver.ConnectionConfig
resolverEndpoints := c.config.GetStringSlice(optionNameResolverEndpoints) resolverEndpoints := c.config.GetStringSlice(optionNameResolverEndpoints)
if len(resolverEndpoints) > 0 { if len(resolverEndpoints) > 0 {
resolverCfgs, err = resolver.ParseConnectionStrings(resolverEndpoints) resolverCfgs, err = multiresolver.ParseConnectionStrings(resolverEndpoints)
if err != nil { if err != nil {
return err return err
} }
......
...@@ -37,8 +37,7 @@ import ( ...@@ -37,8 +37,7 @@ import (
"github.com/ethersphere/bee/pkg/pusher" "github.com/ethersphere/bee/pkg/pusher"
"github.com/ethersphere/bee/pkg/pushsync" "github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/recovery" "github.com/ethersphere/bee/pkg/recovery"
"github.com/ethersphere/bee/pkg/resolver" "github.com/ethersphere/bee/pkg/resolver/multiresolver"
resolverSvc "github.com/ethersphere/bee/pkg/resolver/service"
"github.com/ethersphere/bee/pkg/retrieval" "github.com/ethersphere/bee/pkg/retrieval"
"github.com/ethersphere/bee/pkg/settlement/pseudosettle" "github.com/ethersphere/bee/pkg/settlement/pseudosettle"
"github.com/ethersphere/bee/pkg/soc" "github.com/ethersphere/bee/pkg/soc"
...@@ -91,7 +90,7 @@ type Options struct { ...@@ -91,7 +90,7 @@ type Options struct {
GlobalPinningEnabled bool GlobalPinningEnabled bool
PaymentThreshold uint64 PaymentThreshold uint64
PaymentTolerance uint64 PaymentTolerance uint64
ResolverConnectionCfgs []*resolver.ConnectionConfig ResolverConnectionCfgs []multiresolver.ConnectionConfig
GatewayMode bool GatewayMode bool
} }
...@@ -293,7 +292,10 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -293,7 +292,10 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
b.pullerCloser = puller b.pullerCloser = puller
multiResolver := resolverSvc.InitMultiResolver(logger, o.ResolverConnectionCfgs) multiResolver := multiresolver.NewMultiResolver(
multiresolver.WithConnectionConfigs(o.ResolverConnectionCfgs),
multiresolver.WithLogger(o.Logger),
)
b.resolverCloser = multiResolver b.resolverCloser = multiResolver
var apiService api.Service var apiService api.Service
......
...@@ -9,9 +9,9 @@ import ( ...@@ -9,9 +9,9 @@ import (
) )
// Interface is a resolver client that can connect/disconnect to an external // Interface is a resolver client that can connect/disconnect to an external
// Name Resolution Service via an edpoint. // Name Resolution Service via an endpoint.
type Interface interface { type Interface interface {
resolver.Interface resolver.Interface
Connect(endpoint string) error Endpoint() string
IsConnected() bool IsConnected() bool
} }
...@@ -5,43 +5,56 @@ ...@@ -5,43 +5,56 @@
package ens package ens
import ( import (
"errors"
"fmt" "fmt"
"strings" "strings"
"sync"
"github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/ethclient"
goens "github.com/wealdtech/go-ens/v3"
"github.com/ethersphere/bee/pkg/resolver/client" "github.com/ethersphere/bee/pkg/resolver/client"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
const swarmContentHashPrefix = "/swarm/"
// Address is the swarm bzz address. // Address is the swarm bzz address.
type Address = swarm.Address type Address = swarm.Address
// Make sure Client implements the resolver.Client interface. // Make sure Client implements the resolver.Client interface.
var _ client.Interface = (*Client)(nil) var _ client.Interface = (*Client)(nil)
type dialType func(string) (*ethclient.Client, error) var (
type resolveType func(bind.ContractBackend, string) (string, error) // ErrFailedToConnect denotes that the resolver failed to connect to the
// provided endpoint.
ErrFailedToConnect = errors.New("failed to connect")
// ErrResolveFailed denotes that a name could not be resolved.
ErrResolveFailed = errors.New("resolve failed")
// ErrInvalidContentHash denotes that the value of the contenthash record is
// not valid.
ErrInvalidContentHash = errors.New("invalid swarm content hash")
// errNotImplemented denotes that the function has not been implemented.
errNotImplemented = errors.New("function not implemented")
)
// Client is a name resolution client that can connect to ENS via an // Client is a name resolution client that can connect to ENS via an
// Ethereum endpoint. // Ethereum endpoint.
type Client struct { type Client struct {
mu sync.Mutex endpoint string
Endpoint string
ethCl *ethclient.Client ethCl *ethclient.Client
dialFn dialType dialFn func(string) (*ethclient.Client, error)
resolveFn resolveType resolveFn func(bind.ContractBackend, string) (string, error)
} }
// Option is a function that applies an option to a Client. // Option is a function that applies an option to a Client.
type Option func(*Client) type Option func(*Client)
// NewClient will return a new Client. // NewClient will return a new Client.
func NewClient(opts ...Option) *Client { func NewClient(endpoint string, opts ...Option) (client.Interface, error) {
c := &Client{ c := &Client{
dialFn: wrapDial, endpoint: endpoint,
dialFn: ethclient.Dial,
resolveFn: wrapResolve, resolveFn: wrapResolve,
} }
...@@ -50,83 +63,77 @@ func NewClient(opts ...Option) *Client { ...@@ -50,83 +63,77 @@ func NewClient(opts ...Option) *Client {
o(c) o(c)
} }
return c // Connect to the name resolution service.
}
// Connect implements the resolver.Client interface.
func (c *Client) Connect(ep string) error {
if c.dialFn == nil { if c.dialFn == nil {
return fmt.Errorf("dialFn: %w", errNotImplemented) return nil, fmt.Errorf("dialFn: %w", errNotImplemented)
} }
ethCl, err := c.dialFn(ep) ethCl, err := c.dialFn(c.endpoint)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("%v: %w", err, ErrFailedToConnect)
} }
// Lock and set the parameters.
c.mu.Lock()
c.ethCl = ethCl c.ethCl = ethCl
c.Endpoint = ep
c.mu.Unlock()
return nil return c, nil
} }
// IsConnected returns true if there is an active RPC connection with an // IsConnected returns true if there is an active RPC connection with an
// Ethereum node at the configured endpoint. // Ethereum node at the configured endpoint.
// Function obtains a write lock while interacting with the Ethereum client.
func (c *Client) IsConnected() bool { func (c *Client) IsConnected() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.ethCl != nil return c.ethCl != nil
} }
// Endpoint returns the endpoint the client was connected to.
func (c *Client) Endpoint() string {
return c.endpoint
}
// Resolve implements the resolver.Client interface. // Resolve implements the resolver.Client interface.
// Function obtains a read lock while interacting with the Ethereum client.
func (c *Client) Resolve(name string) (Address, error) { func (c *Client) Resolve(name string) (Address, error) {
if c.resolveFn == nil { if c.resolveFn == nil {
return swarm.ZeroAddress, fmt.Errorf("resolveFn: %w", errNotImplemented) return swarm.ZeroAddress, fmt.Errorf("resolveFn: %w", errNotImplemented)
} }
// Obtain our copy of the client under lock. hash, err := c.resolveFn(c.ethCl, name)
c.mu.Lock()
ethCl := c.ethCl
c.mu.Unlock()
hash, err := c.resolveFn(ethCl, name)
if err != nil { if err != nil {
return swarm.ZeroAddress, fmt.Errorf("%v: %w", err, ErrResolveFailed) return swarm.ZeroAddress, fmt.Errorf("%v: %w", err, ErrResolveFailed)
} }
// In case the implementation returns a zero address return an NameNotFound
// error.
if hash == "" {
return swarm.ZeroAddress, fmt.Errorf("name %s: %w", name, ErrNameNotFound)
}
// Ensure that the content hash string is in a valid format, eg. // Ensure that the content hash string is in a valid format, eg.
// "/swarm/<address>". // "/swarm/<address>".
if !strings.HasPrefix(hash, "/swarm/") { if !strings.HasPrefix(hash, swarmContentHashPrefix) {
return swarm.ZeroAddress, fmt.Errorf("contenthash %s: %w", hash, ErrInvalidContentHash) return swarm.ZeroAddress, fmt.Errorf("contenthash %s: %w", hash, ErrInvalidContentHash)
} }
// Trim the prefix and try to parse the result as a bzz address. // Trim the prefix and try to parse the result as a bzz address.
return swarm.ParseHexAddress(strings.TrimPrefix(hash, "/swarm/")) return swarm.ParseHexAddress(strings.TrimPrefix(hash, swarmContentHashPrefix))
} }
// Close closes the RPC connection with the client, terminating all unfinished // Close closes the RPC connection with the client, terminating all unfinished
// requests. // requests. If the connection is already closed, this call is a noop.
// Function obtains a write lock while interacting with the Ethereum client.
func (c *Client) Close() error { func (c *Client) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.ethCl != nil { if c.ethCl != nil {
c.ethCl.Close() // TODO: consider mocking out the eth client. c.ethCl.Close()
} }
c.ethCl = nil c.ethCl = nil
return nil return nil
} }
func wrapResolve(backend bind.ContractBackend, name string) (string, error) {
// Connect to the ENS resolver for the provided name.
ensR, err := goens.NewResolver(backend, name)
if err != nil {
return "", err
}
// Try and read out the content hash record.
ch, err := ensR.Contenthash()
if err != nil {
return "", err
}
return goens.ContenthashToString(ch)
}
...@@ -7,53 +7,56 @@ ...@@ -7,53 +7,56 @@
package ens_test package ens_test
import ( import (
"strings" "errors"
"testing" "testing"
"github.com/ethersphere/bee/pkg/resolver/client/ens" "github.com/ethersphere/bee/pkg/resolver/client/ens"
"github.com/ethersphere/bee/pkg/swarm"
) )
func TestENSntegration(t *testing.T) { func TestENSntegration(t *testing.T) {
// TODO: consider using a stable gateway instead of INFURA. // TODO: consider using a stable gateway instead of INFURA.
defaultEndpoint := "https://goerli.infura.io/v3/59d83a5a4be74f86b9851190c802297b" defaultEndpoint := "https://goerli.infura.io/v3/59d83a5a4be74f86b9851190c802297b"
defaultAddr := swarm.MustParseHexAddress("00cb23598c2e520b6a6aae3ddc94fed4435a2909690bdd709bf9d9e7c2aadfad")
testCases := []struct { testCases := []struct {
desc string desc string
endpoint string endpoint string
name string name string
wantAdr string wantAdr swarm.Address
wantFailConnect bool wantErr error
wantFailResolve bool
}{ }{
// TODO: add a test targeting a resolver with an invalid contenthash
// record.
{ {
desc: "bad ethclient endpoint", desc: "invalid resolver endpoint",
endpoint: "fail", endpoint: "example.com",
wantFailConnect: true, wantErr: ens.ErrFailedToConnect,
}, },
{ {
desc: "no domain", desc: "no domain",
name: "idonthaveadomain", name: "idonthaveadomain",
wantFailResolve: true, wantErr: ens.ErrResolveFailed,
}, },
{ {
desc: "no eth domain", desc: "no eth domain",
name: "centralized.com", name: "centralized.com",
wantFailResolve: true, wantErr: ens.ErrResolveFailed,
}, },
{ {
desc: "not registered", desc: "not registered",
name: "unused.test.swarm.eth", name: "unused.test.swarm.eth",
wantFailResolve: true, wantErr: ens.ErrResolveFailed,
}, },
{ {
desc: "no content hash", desc: "no content hash",
name: "nocontent.resolver.test.swarm.eth", name: "nocontent.resolver.test.swarm.eth",
wantFailResolve: true, wantErr: ens.ErrResolveFailed,
}, },
{ {
desc: "ok", desc: "ok",
name: "example.resolver.test.swarm.eth", name: "example.resolver.test.swarm.eth",
wantAdr: "00cb23598c2e520b6a6aae3ddc94fed4435a2909690bdd709bf9d9e7c2aadfad", wantAdr: defaultAddr,
}, },
} }
for _, tC := range testCases { for _, tC := range testCases {
...@@ -62,34 +65,30 @@ func TestENSntegration(t *testing.T) { ...@@ -62,34 +65,30 @@ func TestENSntegration(t *testing.T) {
tC.endpoint = defaultEndpoint tC.endpoint = defaultEndpoint
} }
eC := ens.NewClient() ensClient, err := ens.NewClient(tC.endpoint)
defer eC.Close()
err := eC.Connect(tC.endpoint)
if err != nil { if err != nil {
if !tC.wantFailConnect { if !errors.Is(err, ens.ErrFailedToConnect) {
t.Fatalf("failed to connect: %v", err) t.Errorf("got %v, want %v", err, tC.wantErr)
} }
return return
} }
defer ensClient.Close()
addr, err := eC.Resolve(tC.name) addr, err := ensClient.Resolve(tC.name)
if err != nil { if err != nil {
if !tC.wantFailResolve { if !errors.Is(err, tC.wantErr) {
t.Fatalf("failed to resolve name: %v", err) t.Errorf("got %v, want %v", err, tC.wantErr)
} }
return return
} }
want := strings.ToLower(tC.wantAdr) if !addr.Equal(defaultAddr) {
got := strings.ToLower(addr.String()) t.Errorf("bad addr: got %s, want %s", addr, defaultAddr)
if got != want {
t.Errorf("bad addr: got %q, want %q", got, want)
} }
eC.Close() err = ensClient.Close()
if eC.IsConnected() { if err != nil {
t.Errorf("IsConnected: got true, want false") t.Fatal(err)
} }
}) })
} }
......
...@@ -8,159 +8,172 @@ import ( ...@@ -8,159 +8,172 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/ethclient"
"github.com/ethereum/go-ethereum/rpc"
"github.com/ethersphere/bee/pkg/resolver/client/ens" "github.com/ethersphere/bee/pkg/resolver/client/ens"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
func TestNewClient(t *testing.T) { func TestNewENSClient(t *testing.T) {
cl := ens.NewClient() testCases := []struct {
if cl.Endpoint != "" { desc string
t.Errorf("expected no endpoint set") endpoint string
dialFn func(string) (*ethclient.Client, error)
wantErr error
wantEndpoint string
}{
{
desc: "nil dial function",
endpoint: "someaddress.net",
dialFn: nil,
wantErr: ens.ErrNotImplemented,
},
{
desc: "error in dial function",
endpoint: "someaddress.com",
dialFn: func(string) (*ethclient.Client, error) {
return nil, errors.New("dial error")
},
wantErr: ens.ErrFailedToConnect,
},
{
desc: "regular endpoint",
endpoint: "someaddress.org",
dialFn: func(string) (*ethclient.Client, error) {
return &ethclient.Client{}, nil
},
wantEndpoint: "someaddress.org",
},
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
cl, err := ens.NewClient(tC.endpoint,
ens.WithDialFunc(tC.dialFn),
)
if err != nil {
if !errors.Is(err, tC.wantErr) {
t.Errorf("got %v, want %v", err, tC.wantErr)
}
return
}
if got := cl.Endpoint(); got != tC.wantEndpoint {
t.Errorf("endpoint: got %v, want %v", got, tC.wantEndpoint)
}
if got := cl.IsConnected(); got != true {
t.Errorf("connected: got %v, want true", got)
}
})
} }
} }
func TestConnect(t *testing.T) { func TestClose(t *testing.T) {
ep := "test" t.Run("connected", func(t *testing.T) {
rpcServer := rpc.NewServer()
t.Run("no dial func error", func(t *testing.T) { defer rpcServer.Stop()
c := ens.NewClient( ethCl := ethclient.NewClient(rpc.DialInProc(rpcServer))
ens.WithDialFunc(nil),
)
err := c.Connect(ep)
defer c.Close()
if !errors.Is(err, ens.ErrNotImplemented) {
t.Fatal("expected correct error")
}
})
t.Run("connect error", func(t *testing.T) {
c := ens.NewClient(
ens.WithErrorDialFunc(errors.New("failed to connect")),
)
if err := c.Connect("test"); err == nil {
t.Fatal("expected error")
}
c.Close()
})
t.Run("ok", func(t *testing.T) {
c := ens.NewClient(
ens.WithNoopDialFunc(),
)
if err := c.Connect(ep); err != nil {
t.Fatal(err)
}
// Override the eth client to test connection.
ens.SetEthClient(c, &ethclient.Client{})
if c.Endpoint != ep {
t.Errorf("bad endpoint: got %q, want %q", c.Endpoint, ep)
}
if !c.IsConnected() {
t.Error("IsConnected: got false, want true")
}
// We are not really connected, so clear the client to prevent panic.
ens.SetEthClient(c, nil)
c.Close()
if c.IsConnected() {
t.Error("IsConnected: got true, want false")
}
})
}
func TestResolve(t *testing.T) {
name := "hello"
bzzAddress := swarm.MustParseHexAddress(
"6f4eeb99d0a144d78ac33cf97091a59a6291aa78929938defcf967e74326e08b",
)
t.Run("no resolve func error", func(t *testing.T) {
c := ens.NewClient(
ens.WithResolveFunc(nil),
)
_, err := c.Resolve("test")
if !errors.Is(err, ens.ErrNotImplemented) {
t.Fatal("expected correct error")
}
})
t.Run("resolve error", func(t *testing.T) { cl, err := ens.NewClient("",
c := ens.NewClient( ens.WithDialFunc(func(string) (*ethclient.Client, error) {
ens.WithNoopDialFunc(), return ethCl, nil
ens.WithErrorResolveFunc(errors.New("resolve error")), }),
) )
if err != nil {
if err := c.Connect(name); err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer c.Close()
_, err := c.Resolve(name)
if !errors.Is(err, ens.ErrResolveFailed) {
t.Error("expected resolve error")
}
})
t.Run("zero address returned", func(t *testing.T) {
c := ens.NewClient(
ens.WithNoopDialFunc(),
ens.WithZeroAdrResolveFunc(),
)
if err := c.Connect(name); err != nil { err = cl.Close()
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer c.Close()
_, err := c.Resolve(name) if cl.IsConnected() {
if !errors.Is(err, ens.ErrNameNotFound) { t.Error("IsConnected == true")
t.Error("expected name not found error")
} }
}) })
t.Run("not connected", func(t *testing.T) {
t.Run("resolved without address prefix error", func(t *testing.T) { cl, err := ens.NewClient("",
c := ens.NewClient( ens.WithDialFunc(func(string) (*ethclient.Client, error) {
ens.WithNoopDialFunc(), return nil, nil
ens.WithNoprefixAdrResolveFunc(bzzAddress), }),
) )
if err != nil {
if err := c.Connect(name); err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer c.Close()
_, err := c.Resolve(name)
if err == nil {
t.Error("expected error")
}
})
t.Run("ok", func(t *testing.T) { err = cl.Close()
c := ens.NewClient( if err != nil {
ens.WithNoopDialFunc(),
ens.WithValidAdrResolveFunc(bzzAddress),
)
if err := c.Connect(name); err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer c.Close()
addr, err := c.Resolve(name) if cl.IsConnected() {
if err != nil { t.Error("IsConnected == true")
t.Error(err)
}
want := bzzAddress.String()
got := addr.String()
if got != want {
t.Errorf("got %q, want %q", got, want)
} }
}) })
}
func TestResolve(t *testing.T) {
addr := swarm.MustParseHexAddress("aaabbbcc")
testCases := []struct {
desc string
name string
resolveFn func(bind.ContractBackend, string) (string, error)
wantErr error
}{
{
desc: "nil resolve function",
resolveFn: nil,
wantErr: ens.ErrNotImplemented,
},
{
desc: "resolve function internal error",
resolveFn: func(bind.ContractBackend, string) (string, error) {
return "", errors.New("internal error")
},
wantErr: ens.ErrResolveFailed,
},
{
desc: "resolver returns empty string",
resolveFn: func(bind.ContractBackend, string) (string, error) {
return "", nil
},
wantErr: ens.ErrInvalidContentHash,
},
{
desc: "resolve does not prefix address with /swarm",
resolveFn: func(bind.ContractBackend, string) (string, error) {
return addr.String(), nil
},
wantErr: ens.ErrInvalidContentHash,
},
{
desc: "resolve returns prefixed address",
resolveFn: func(bind.ContractBackend, string) (string, error) {
return ens.SwarmContentHashPrefix + addr.String(), nil
},
wantErr: ens.ErrInvalidContentHash,
},
}
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
cl, err := ens.NewClient("example.com",
ens.WithDialFunc(func(string) (*ethclient.Client, error) {
return nil, nil
}),
ens.WithResolveFunc(tC.resolveFn),
)
if err != nil {
t.Fatal(err)
}
_, err = cl.Resolve(tC.name)
if err != nil {
if !errors.Is(err, tC.wantErr) {
t.Errorf("got %v, want %v", err, tC.wantErr)
}
return
}
})
}
} }
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ens
import (
"errors"
)
var (
// ErrInvalidContentHash denotes that the value of the contenthash record is
// not valid.
ErrInvalidContentHash = errors.New("invalid swarm content hash")
// ErrResolveFailed is returned when a name could not be resolved.
ErrResolveFailed = errors.New("resolve failed")
// ErrNameNotFound is returned when a name resolves to an empty contenthash
// record.
ErrNameNotFound = errors.New("name not found")
)
var (
// errNotImplemented denotes that the function has not been implemented.
errNotImplemented = errors.New("function not implemented")
)
...@@ -7,17 +7,11 @@ package ens ...@@ -7,17 +7,11 @@ package ens
import ( import (
"github.com/ethereum/go-ethereum/accounts/abi/bind" "github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/ethclient" "github.com/ethereum/go-ethereum/ethclient"
"github.com/ethersphere/bee/pkg/resolver"
"github.com/ethersphere/bee/pkg/swarm"
) )
var ( const SwarmContentHashPrefix = swarmContentHashPrefix
ErrNotImplemented = errNotImplemented
)
func SetEthClient(c *Client, ethCl *ethclient.Client) { var ErrNotImplemented = errNotImplemented
c.ethCl = ethCl
}
// WithDialFunc will set the Dial function implementaton. // WithDialFunc will set the Dial function implementaton.
func WithDialFunc(fn func(ep string) (*ethclient.Client, error)) Option { func WithDialFunc(fn func(ep string) (*ethclient.Client, error)) Option {
...@@ -26,45 +20,9 @@ func WithDialFunc(fn func(ep string) (*ethclient.Client, error)) Option { ...@@ -26,45 +20,9 @@ func WithDialFunc(fn func(ep string) (*ethclient.Client, error)) Option {
} }
} }
func WithErrorDialFunc(err error) Option {
return WithDialFunc(func(ep string) (*ethclient.Client, error) {
return nil, err
})
}
func WithNoopDialFunc() Option {
return WithDialFunc(func(ep string) (*ethclient.Client, error) {
return nil, nil
})
}
// WithResolveFunc will set the Resolve function implementation. // WithResolveFunc will set the Resolve function implementation.
func WithResolveFunc(fn func(backend bind.ContractBackend, input string) (string, error)) Option { func WithResolveFunc(fn func(backend bind.ContractBackend, input string) (string, error)) Option {
return func(c *Client) { return func(c *Client) {
c.resolveFn = fn c.resolveFn = fn
} }
} }
func WithErrorResolveFunc(err error) Option {
return WithResolveFunc(func(backend bind.ContractBackend, input string) (string, error) {
return "", err
})
}
func WithZeroAdrResolveFunc() Option {
return WithResolveFunc(func(backend bind.ContractBackend, input string) (string, error) {
return swarm.ZeroAddress.String(), nil
})
}
func WithNoprefixAdrResolveFunc(addr resolver.Address) Option {
return WithResolveFunc(func(backend bind.ContractBackend, input string) (string, error) {
return addr.String(), nil
})
}
func WithValidAdrResolveFunc(addr resolver.Address) Option {
return WithResolveFunc(func(backend bind.ContractBackend, input string) (string, error) {
return "/swarm/" + addr.String(), nil
})
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ens
import (
"github.com/ethereum/go-ethereum/accounts/abi/bind"
"github.com/ethereum/go-ethereum/ethclient"
goens "github.com/wealdtech/go-ens/v3"
)
func wrapDial(ep string) (*ethclient.Client, error) {
// Open a connection to the ethereum node through the endpoint.
cl, err := ethclient.Dial(ep)
if err != nil {
return nil, err
}
// Ensure the ENS resolver contract is deployed on the network we are now
// connected to.
if _, err := goens.PublicResolverAddress(cl); err != nil {
return nil, err
}
return cl, nil
}
func wrapResolve(backend bind.ContractBackend, name string) (string, error) {
// Connect to the ENS resolver for the provided name.
ensR, err := goens.NewResolver(backend, name)
if err != nil {
return "", err
}
// Try and read out the content hash record.
ch, err := ensR.Contenthash()
if err != nil {
return "", err
}
addr, err := goens.ContenthashToString(ch)
if err != nil {
return "", err
}
return addr, nil
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mock
import (
"github.com/ethersphere/bee/pkg/resolver/client"
"github.com/ethersphere/bee/pkg/swarm"
)
// Ensure mock Client implements the Client interface.
var _ client.Interface = (*Client)(nil)
// Client is the mock resolver client implementation.
type Client struct {
isConnected bool
endpoint string
defaultAddress swarm.Address
resolveFn func(string) (swarm.Address, error)
}
// Option is a function that applies an option to a Client.
type Option func(*Client)
// NewClient construct a new mock Client.
func NewClient(opts ...Option) *Client {
cl := &Client{}
for _, o := range opts {
o(cl)
}
cl.isConnected = true
return cl
}
// WithEndpoint will set the endpoint.
func WithEndpoint(endpoint string) Option {
return func(cl *Client) {
cl.endpoint = endpoint
}
}
// WitResolveAddress will set the address returned by Resolve.
func WitResolveAddress(addr swarm.Address) Option {
return func(cl *Client) {
cl.defaultAddress = addr
}
}
// WithResolveFunc will set the Resolve function implementation.
func WithResolveFunc(fn func(string) (swarm.Address, error)) Option {
return func(cl *Client) {
cl.resolveFn = fn
}
}
// IsConnected is the mock IsConnected implementation.
func (cl *Client) IsConnected() bool {
return cl.isConnected
}
// Endpoint is the mock Endpoint implementation.
func (cl *Client) Endpoint() string {
return cl.endpoint
}
// Resolve is the mock Resolve implementation
func (cl *Client) Resolve(name string) (swarm.Address, error) {
if cl.resolveFn == nil {
return cl.defaultAddress, nil
}
return cl.resolveFn(name)
}
// Close is the mock Close implementation.
func (cl *Client) Close() error {
cl.isConnected = false
return nil
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package resolver
import (
"errors"
"fmt"
"strings"
)
var (
// ErrTLDTooLong denotes when a TLD in a name exceeds maximum length.
ErrTLDTooLong = fmt.Errorf("TLD exceeds maximum length of %d characters", maxTLDLength)
// ErrInvalidTLD denotes passing an invalid TLD to the MultiResolver.
ErrInvalidTLD = errors.New("invalid TLD")
// ErrResolverChainEmpty denotes trying to pop an empty resolver chain.
ErrResolverChainEmpty = errors.New("resolver chain empty")
)
// CloseError denotes that at least one resolver in the MultiResolver has
// had an error when Close was called.
type CloseError struct {
errs []error
}
func (me CloseError) add(err error) {
if err != nil {
me.errs = append(me.errs, err)
}
}
func (me CloseError) errorOrNil() error {
if len(me.errs) > 0 {
return me
}
return nil
}
// Error returns a formatted multi close error.
func (me CloseError) Error() string {
if len(me.errs) == 0 {
return ""
}
var b strings.Builder
b.WriteString("multiresolver failed to close: ")
for _, e := range me.errs {
b.WriteString(e.Error())
b.WriteString("; ")
}
return b.String()
}
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package resolver package multiresolver
import ( import (
"fmt" "fmt"
...@@ -27,7 +27,7 @@ type ConnectionConfig struct { ...@@ -27,7 +27,7 @@ type ConnectionConfig struct {
// ParseConnectionString will try to parse a connection string used to connect // ParseConnectionString will try to parse a connection string used to connect
// the Resolver to a name resolution service. The resulting config can be // the Resolver to a name resolution service. The resulting config can be
// used to initialize a resovler Service. // used to initialize a resovler Service.
func parseConnectionString(cs string) (*ConnectionConfig, error) { func parseConnectionString(cs string) (ConnectionConfig, error) {
isAllUnicodeLetters := func(s string) bool { isAllUnicodeLetters := func(s string) bool {
for _, r := range s { for _, r := range s {
if !unicode.IsLetter(r) { if !unicode.IsLetter(r) {
...@@ -48,7 +48,7 @@ func parseConnectionString(cs string) (*ConnectionConfig, error) { ...@@ -48,7 +48,7 @@ func parseConnectionString(cs string) (*ConnectionConfig, error) {
if isAllUnicodeLetters(endpoint[:i]) && len(endpoint) > i+2 && endpoint[i+1:i+3] != "//" { if isAllUnicodeLetters(endpoint[:i]) && len(endpoint) > i+2 && endpoint[i+1:i+3] != "//" {
tld = endpoint[:i] tld = endpoint[:i]
if len(tld) > maxTLDLength { if len(tld) > maxTLDLength {
return nil, fmt.Errorf("%w: %s", ErrTLDTooLong, tld) return ConnectionConfig{}, fmt.Errorf("tld %s: %w", tld, ErrTLDTooLong)
} }
endpoint = endpoint[i+1:] endpoint = endpoint[i+1:]
...@@ -60,7 +60,7 @@ func parseConnectionString(cs string) (*ConnectionConfig, error) { ...@@ -60,7 +60,7 @@ func parseConnectionString(cs string) (*ConnectionConfig, error) {
endpoint = endpoint[i+1:] endpoint = endpoint[i+1:]
} }
return &ConnectionConfig{ return ConnectionConfig{
Endpoint: endpoint, Endpoint: endpoint,
Address: addr, Address: addr,
TLD: tld, TLD: tld,
...@@ -69,8 +69,8 @@ func parseConnectionString(cs string) (*ConnectionConfig, error) { ...@@ -69,8 +69,8 @@ func parseConnectionString(cs string) (*ConnectionConfig, error) {
// ParseConnectionStrings will apply ParseConnectionString to each connection // ParseConnectionStrings will apply ParseConnectionString to each connection
// string. Returns first error found. // string. Returns first error found.
func ParseConnectionStrings(cstrs []string) ([]*ConnectionConfig, error) { func ParseConnectionStrings(cstrs []string) ([]ConnectionConfig, error) {
var res []*ConnectionConfig var res []ConnectionConfig
for _, cs := range cstrs { for _, cs := range cstrs {
cfg, err := parseConnectionString(cs) cfg, err := parseConnectionString(cs)
......
...@@ -2,20 +2,20 @@ ...@@ -2,20 +2,20 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package resolver_test package multiresolver_test
import ( import (
"errors" "errors"
"testing" "testing"
"github.com/ethersphere/bee/pkg/resolver" "github.com/ethersphere/bee/pkg/resolver/multiresolver"
) )
func TestParseConnectionStrings(t *testing.T) { func TestParseConnectionStrings(t *testing.T) {
testCases := []struct { testCases := []struct {
desc string desc string
conStrings []string conStrings []string
wantCfg []resolver.ConnectionConfig wantCfg []multiresolver.ConnectionConfig
wantErr error wantErr error
}{ }{
{ {
...@@ -25,14 +25,14 @@ func TestParseConnectionStrings(t *testing.T) { ...@@ -25,14 +25,14 @@ func TestParseConnectionStrings(t *testing.T) {
conStrings: []string{ conStrings: []string{
"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff:example.com", "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff:example.com",
}, },
wantErr: resolver.ErrTLDTooLong, wantErr: multiresolver.ErrTLDTooLong,
}, },
{ {
desc: "single endpoint default tld", desc: "single endpoint default tld",
conStrings: []string{ conStrings: []string{
"https://example.com", "https://example.com",
}, },
wantCfg: []resolver.ConnectionConfig{ wantCfg: []multiresolver.ConnectionConfig{
{ {
TLD: "", TLD: "",
Endpoint: "https://example.com", Endpoint: "https://example.com",
...@@ -44,7 +44,7 @@ func TestParseConnectionStrings(t *testing.T) { ...@@ -44,7 +44,7 @@ func TestParseConnectionStrings(t *testing.T) {
conStrings: []string{ conStrings: []string{
"tld:https://example.com", "tld:https://example.com",
}, },
wantCfg: []resolver.ConnectionConfig{ wantCfg: []multiresolver.ConnectionConfig{
{ {
TLD: "tld", TLD: "tld",
Endpoint: "https://example.com", Endpoint: "https://example.com",
...@@ -56,7 +56,7 @@ func TestParseConnectionStrings(t *testing.T) { ...@@ -56,7 +56,7 @@ func TestParseConnectionStrings(t *testing.T) {
conStrings: []string{ conStrings: []string{
"0x314159265dD8dbb310642f98f50C066173C1259b@https://example.com", "0x314159265dD8dbb310642f98f50C066173C1259b@https://example.com",
}, },
wantCfg: []resolver.ConnectionConfig{ wantCfg: []multiresolver.ConnectionConfig{
{ {
TLD: "", TLD: "",
Address: "0x314159265dD8dbb310642f98f50C066173C1259b", Address: "0x314159265dD8dbb310642f98f50C066173C1259b",
...@@ -69,7 +69,7 @@ func TestParseConnectionStrings(t *testing.T) { ...@@ -69,7 +69,7 @@ func TestParseConnectionStrings(t *testing.T) {
conStrings: []string{ conStrings: []string{
"tld:0x314159265dD8dbb310642f98f50C066173C1259b@https://example.com", "tld:0x314159265dD8dbb310642f98f50C066173C1259b@https://example.com",
}, },
wantCfg: []resolver.ConnectionConfig{ wantCfg: []multiresolver.ConnectionConfig{
{ {
TLD: "tld", TLD: "tld",
Address: "0x314159265dD8dbb310642f98f50C066173C1259b", Address: "0x314159265dD8dbb310642f98f50C066173C1259b",
...@@ -85,7 +85,7 @@ func TestParseConnectionStrings(t *testing.T) { ...@@ -85,7 +85,7 @@ func TestParseConnectionStrings(t *testing.T) {
"yesyesyes:0x314159265dD8dbb310642f98f50C066173C1259b@2.2.2.2", "yesyesyes:0x314159265dD8dbb310642f98f50C066173C1259b@2.2.2.2",
"cloudflare-ethereum.org", "cloudflare-ethereum.org",
}, },
wantCfg: []resolver.ConnectionConfig{ wantCfg: []multiresolver.ConnectionConfig{
{ {
TLD: "tld", TLD: "tld",
Endpoint: "https://example.com", Endpoint: "https://example.com",
...@@ -112,12 +112,12 @@ func TestParseConnectionStrings(t *testing.T) { ...@@ -112,12 +112,12 @@ func TestParseConnectionStrings(t *testing.T) {
"testdomain:wowzers.map", "testdomain:wowzers.map",
"nonononononononononononononononononononononononononononononononononono:yes", "nonononononononononononononononononononononononononononononononononono:yes",
}, },
wantErr: resolver.ErrTLDTooLong, wantErr: multiresolver.ErrTLDTooLong,
}, },
} }
for _, tC := range testCases { for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) { t.Run(tC.desc, func(t *testing.T) {
got, err := resolver.ParseConnectionStrings(tC.conStrings) got, err := multiresolver.ParseConnectionStrings(tC.conStrings)
if err != nil { if err != nil {
if !errors.Is(err, tC.wantErr) { if !errors.Is(err, tC.wantErr) {
t.Errorf("got error %v", err) t.Errorf("got error %v", err)
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package multiresolver
import "github.com/ethersphere/bee/pkg/logging"
func GetLogger(mr *MultiResolver) logging.Logger {
return mr.logger
}
func GetCfgs(mr *MultiResolver) []ConnectionConfig {
return mr.cfgs
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package multierror
import (
"errors"
"fmt"
"strings"
)
// Ensure multierror implements Error interface.
var _ error = (*Error)(nil)
// Error is an error type to track multiple errors. This can be used to
// accumulate errors and return them as a single "error" type.
type Error struct {
Errors []error
}
// New will return a new multierror.
func New(errs ...error) *Error {
e := Error{}
e.Append(errs...)
return &e
}
func (e *Error) Error() string {
return format(e.Errors)
}
// Append will append errors to the multierror.
func (e *Error) Append(errs ...error) {
e.Errors = append(e.Errors, errs...)
}
// ErrorOrNil returns an error interface if the multierror represents a list of
// errors or nil if the list is empty.
func (e *Error) ErrorOrNil() error {
if e == nil || len(e.Errors) == 0 {
return nil
}
return e
}
// WrapErrorOrNil will wrap the given error if the multierror contains errors.
func (e *Error) WrapErrorOrNil(toWrap error) error {
if err := e.ErrorOrNil(); err != nil {
return fmt.Errorf("%v: %w", err, toWrap)
}
return nil
}
// Unwrap returns an error from Error (or nil if there are no errors).
// The error returned supports Unwrap, so that the entire chain of errors can
// be unwrapped. The order will match the order of Errors at the time of
// calling.
//
// This will perform a shallow copy of the errors slice. Any errors appended
// to this error after calling Unwrap will not be available until a new
// Unwrap is called on the multierror.Error.
func (e *Error) Unwrap() error {
if e == nil || len(e.Errors) == 0 {
return nil
}
if len(e.Errors) == 0 {
return e.Errors[0]
}
// Shallow copy the error slice.
errs := make([]error, len(e.Errors))
copy(errs, e.Errors)
return chain(errs)
}
type chain []error
// Error implements the error interface.
func (ec chain) Error() string {
return ec[0].Error()
}
// Unwrap implements errors.Unwrap by returning the next error in the chain or
// nil if there are no more errors.
func (ec chain) Unwrap() error {
if len(ec) == 1 {
return nil
}
// Return the rest of the chain.
return ec[1:]
}
// As implements errors.As by attempting to map the current value.
func (ec chain) As(target interface{}) bool {
return errors.As(ec[0], target)
}
// Is implements errors.Is by comparing the current value directly.
func (ec chain) Is(target error) bool {
return errors.Is(ec[0], target)
}
func format(errs []error) string {
if len(errs) == 1 {
return fmt.Sprintf("1 error occurred: %s", errs[0])
}
msgs := make([]string, len(errs))
for i, err := range errs {
msgs[i] = fmt.Sprintf("%s", err)
}
return fmt.Sprintf("%d errors occurred: %s",
len(errs), strings.Join(msgs, ", "))
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package multierror_test
import (
"errors"
"fmt"
"testing"
"github.com/ethersphere/bee/pkg/resolver/multiresolver/multierror"
)
// nestedError implements error and is used for tests.
type nestedError struct{}
func (*nestedError) Error() string { return "" }
func TestErrorError(t *testing.T) {
t.Run("one error", func(t *testing.T) {
want := "1 error occurred: foo"
multi := multierror.New(errors.New("foo"))
if multi.Error() != want {
t.Fatalf("got: %q, want %q", multi.Error(), want)
}
})
t.Run("multiple errors", func(t *testing.T) {
want := "2 errors occurred: foo, bar"
multi := multierror.New(
errors.New("foo"),
errors.New("bar"),
)
if multi.Error() != want {
t.Fatalf("got: %q, want %q", multi.Error(), want)
}
})
}
func TestErrorErrorOrNil(t *testing.T) {
err := multierror.New()
if err.ErrorOrNil() != nil {
t.Fatalf("bad: %#v", err.ErrorOrNil())
}
err.Errors = []error{errors.New("foo")}
if got := err.ErrorOrNil(); got == nil {
t.Fatal("should not be nil")
} else if got != err {
t.Fatalf("bad: %#v", got)
}
}
func TestErrorWrapErrorOrNil(t *testing.T) {
wrapErr := errors.New("wrapper error")
err := multierror.New()
if err.WrapErrorOrNil(wrapErr) != nil {
t.Fatalf("bad: %#v", err.ErrorOrNil())
}
multi := multierror.New(
errors.New("foo"),
errors.New("bar"),
)
if got := multi.WrapErrorOrNil(wrapErr); got == nil {
t.Fatal("should not be nil")
} else if !errors.Is(got, wrapErr) {
t.Fatalf("bad: %#v", got)
}
}
func TestErrorUnwrap(t *testing.T) {
t.Run("with errors", func(t *testing.T) {
err := multierror.New(
errors.New("foo"),
errors.New("bar"),
errors.New("baz"),
)
var current error = err
for i := 0; i < len(err.Errors); i++ {
current = errors.Unwrap(current)
if !errors.Is(current, err.Errors[i]) {
t.Fatal("should be next value")
}
}
if errors.Unwrap(current) != nil {
t.Fatal("should be nil at the end")
}
})
t.Run("with no errors", func(t *testing.T) {
err := multierror.New()
if errors.Unwrap(err) != nil {
t.Fatal("should be nil")
}
})
t.Run("with nil multierror", func(t *testing.T) {
var err *multierror.Error
if errors.Unwrap(err) != nil {
t.Fatal("should be nil")
}
})
}
func TestErrorIs(t *testing.T) {
errBar := errors.New("bar")
t.Run("with errBar", func(t *testing.T) {
err := multierror.New(
errors.New("foo"),
errBar,
errors.New("baz"),
)
if !errors.Is(err, errBar) {
t.Fatal("should be true")
}
})
t.Run("with errBar wrapped by fmt.Errorf", func(t *testing.T) {
err := multierror.New(
errors.New("foo"),
fmt.Errorf("errorf: %w", errBar),
errors.New("baz"),
)
if !errors.Is(err, errBar) {
t.Fatal("should be true")
}
})
t.Run("without errBar", func(t *testing.T) {
err := multierror.New(
errors.New("foo"),
errors.New("baz"),
)
if errors.Is(err, errBar) {
t.Fatal("should be false")
}
})
}
func TestErrorAs(t *testing.T) {
match := &nestedError{}
t.Run("with the value", func(t *testing.T) {
err := multierror.New(
errors.New("foo"),
match,
errors.New("baz"),
)
var target *nestedError
if !errors.As(err, &target) {
t.Fatal("should be true")
}
if target == nil {
t.Fatal("target should not be nil")
}
})
t.Run("with the value wrapped by fmt.Errorf", func(t *testing.T) {
err := multierror.New(
errors.New("foo"),
fmt.Errorf("errorf: %w", match),
errors.New("baz"),
)
var target *nestedError
if !errors.As(err, &target) {
t.Fatal("should be true")
}
if target == nil {
t.Fatal("target should not be nil")
}
})
t.Run("without the value", func(t *testing.T) {
err := multierror.New(
errors.New("foo"),
errors.New("baz"),
)
var target *nestedError
if errors.As(err, &target) {
t.Fatal("should be false")
}
if target != nil {
t.Fatal("target should be nil")
}
})
}
...@@ -2,24 +2,45 @@ ...@@ -2,24 +2,45 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package resolver package multiresolver
import ( import (
"errors"
"fmt" "fmt"
"io/ioutil"
"path" "path"
"strings" "strings"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/resolver"
"github.com/ethersphere/bee/pkg/resolver/client/ens"
"github.com/ethersphere/bee/pkg/resolver/multiresolver/multierror"
) )
// Ensure MultiResolver implements Resolver interface. // Ensure MultiResolver implements Resolver interface.
var _ Interface = (*MultiResolver)(nil) var _ resolver.Interface = (*MultiResolver)(nil)
var (
// ErrTLDTooLong denotes when a TLD in a name exceeds maximum length.
ErrTLDTooLong = fmt.Errorf("TLD exceeds maximum length of %d characters", maxTLDLength)
// ErrInvalidTLD denotes passing an invalid TLD to the MultiResolver.
ErrInvalidTLD = errors.New("invalid TLD")
// ErrResolverChainEmpty denotes trying to pop an empty resolver chain.
ErrResolverChainEmpty = errors.New("resolver chain empty")
// ErrResolverChainFailed denotes that an entire name resolution chain
// for a given TLD failed.
ErrResolverChainFailed = errors.New("resolver chain failed")
// ErrCloseFailed denotes that closing the multiresolver failed.
ErrCloseFailed = errors.New("close failed")
)
type resolverMap map[string][]Interface type resolverMap map[string][]resolver.Interface
// MultiResolver performs name resolutions based on the TLD label in the name. // MultiResolver performs name resolutions based on the TLD label in the name.
type MultiResolver struct { type MultiResolver struct {
resolvers resolverMap resolvers resolverMap
logger logging.Logger
cfgs []ConnectionConfig
// ForceDefault will force all names to be resolved by the default // ForceDefault will force all names to be resolved by the default
// resolution chain, regadless of their TLD. // resolution chain, regadless of their TLD.
ForceDefault bool ForceDefault bool
...@@ -39,9 +60,47 @@ func NewMultiResolver(opts ...Option) *MultiResolver { ...@@ -39,9 +60,47 @@ func NewMultiResolver(opts ...Option) *MultiResolver {
o(mr) o(mr)
} }
// Discard log output by default.
if mr.logger == nil {
mr.logger = logging.New(ioutil.Discard, 0)
}
log := mr.logger
if len(mr.cfgs) == 0 {
log.Info("name resolver: no name resolution service provided")
return mr
}
// Attempt to conect to each resolver using the connection string.
for _, c := range mr.cfgs {
// Warn user that the resolver address field is not used.
if c.Address != "" {
log.Warningf("name resolver: connection string %q contains resolver address field, which is currently unused", c.Address)
}
// NOTE: if we want to create a specific client based on the TLD
// we can do it here.
mr.connectENSClient(c.TLD, c.Endpoint)
}
return mr return mr
} }
// WithConnectionConfigs will set the initial connection configuration.
func WithConnectionConfigs(cfgs []ConnectionConfig) Option {
return func(mr *MultiResolver) {
mr.cfgs = cfgs
}
}
// WithLogger will set the logger used by the MultiResolver.
func WithLogger(logger logging.Logger) Option {
return func(mr *MultiResolver) {
mr.logger = logger
}
}
// WithForceDefault will force resolution using the default resolver chain. // WithForceDefault will force resolution using the default resolver chain.
func WithForceDefault() Option { func WithForceDefault() Option {
return func(mr *MultiResolver) { return func(mr *MultiResolver) {
...@@ -50,27 +109,14 @@ func WithForceDefault() Option { ...@@ -50,27 +109,14 @@ func WithForceDefault() Option {
} }
// PushResolver will push a new Resolver to the name resolution chain for the // PushResolver will push a new Resolver to the name resolution chain for the
// given TLD. // given TLD. An empty TLD will push to the default resolver chain.
// TLD names should be prepended with a dot (eg ".tld"). An empty TLD will push func (mr *MultiResolver) PushResolver(tld string, r resolver.Interface) {
// to the default resolver chain.
func (mr *MultiResolver) PushResolver(tld string, r Interface) error {
if tld != "" && !isTLD(tld) {
return fmt.Errorf("tld %s: %w", tld, ErrInvalidTLD)
}
mr.resolvers[tld] = append(mr.resolvers[tld], r) mr.resolvers[tld] = append(mr.resolvers[tld], r)
return nil
} }
// PopResolver will pop the last reslover from the name resolution chain for the // PopResolver will pop the last reslover from the name resolution chain for the
// given TLD. // given TLD. An empty TLD will pop from the default resolver chain.
// TLD names should be prepended with a dot (eg ".tld"). An empty TLD will pop
// from the default resolver chain.
func (mr *MultiResolver) PopResolver(tld string) error { func (mr *MultiResolver) PopResolver(tld string) error {
if tld != "" && !isTLD(tld) {
return fmt.Errorf("tld %s: %w", tld, ErrInvalidTLD)
}
l := len(mr.resolvers[tld]) l := len(mr.resolvers[tld])
if l == 0 { if l == 0 {
return fmt.Errorf("tld %s: %w", tld, ErrResolverChainEmpty) return fmt.Errorf("tld %s: %w", tld, ErrResolverChainEmpty)
...@@ -90,7 +136,7 @@ func (mr *MultiResolver) ChainCount(tld string) int { ...@@ -90,7 +136,7 @@ func (mr *MultiResolver) ChainCount(tld string) int {
// GetChain will return the resolution chain for a given TLD. // GetChain will return the resolution chain for a given TLD.
// TLD names should be prepended with a dot (eg ".tld"). An empty TLD will // TLD names should be prepended with a dot (eg ".tld"). An empty TLD will
// return all resolvers in the default resolver chain. // return all resolvers in the default resolver chain.
func (mr *MultiResolver) GetChain(tld string) []Interface { func (mr *MultiResolver) GetChain(tld string) []resolver.Interface {
return mr.resolvers[tld] return mr.resolvers[tld]
} }
...@@ -100,8 +146,7 @@ func (mr *MultiResolver) GetChain(tld string) []Interface { ...@@ -100,8 +146,7 @@ func (mr *MultiResolver) GetChain(tld string) []Interface {
// The resolution will be performed iteratively on the resolution chain, // The resolution will be performed iteratively on the resolution chain,
// returning the result of the first Resolver that succeeds. If all resolvers // returning the result of the first Resolver that succeeds. If all resolvers
// in the chain return an error, the function will return an ErrResolveFailed. // in the chain return an error, the function will return an ErrResolveFailed.
func (mr *MultiResolver) Resolve(name string) (Address, error) { func (mr *MultiResolver) Resolve(name string) (addr resolver.Address, err error) {
tld := "" tld := ""
if !mr.ForceDefault { if !mr.ForceDefault {
tld = getTLD(name) tld = getTLD(name)
...@@ -113,35 +158,46 @@ func (mr *MultiResolver) Resolve(name string) (Address, error) { ...@@ -113,35 +158,46 @@ func (mr *MultiResolver) Resolve(name string) (Address, error) {
chain = mr.resolvers[""] chain = mr.resolvers[""]
} }
addr := swarm.ZeroAddress errs := multierror.New()
var err error
for _, res := range chain { for _, res := range chain {
addr, err = res.Resolve(name) addr, err = res.Resolve(name)
if err == nil { if err == nil {
return addr, nil return addr, nil
} }
errs.Append(err)
} }
return addr, err return addr, errs.ErrorOrNil()
} }
// Close all will call Close on all resolvers in all resolver chains. // Close all will call Close on all resolvers in all resolver chains.
func (mr *MultiResolver) Close() error { func (mr *MultiResolver) Close() error {
errs := new(CloseError) errs := multierror.New()
for _, chain := range mr.resolvers { for _, chain := range mr.resolvers {
for _, r := range chain { for _, r := range chain {
errs.add(r.Close()) if err := r.Close(); err != nil {
errs.Append(err)
}
} }
} }
return errs.errorOrNil() return errs.ErrorOrNil()
}
func isTLD(tld string) bool {
return len(tld) > 1 && tld[0] == '.'
} }
func getTLD(name string) string { func getTLD(name string) string {
return path.Ext(strings.ToLower(name)) return path.Ext(strings.ToLower(name))
} }
func (mr *MultiResolver) connectENSClient(tld string, endpoint string) {
log := mr.logger
log.Debugf("name resolver: resolver for %q: connecting to endpoint %s", tld, endpoint)
ensCl, err := ens.NewClient(endpoint)
if err != nil {
log.Errorf("name resolver: resolver for %q domain: failed to connect to %q: %v", tld, endpoint, err)
} else {
log.Infof("name resolver: resolver for %q domain: connected to %s", tld, endpoint)
mr.PushResolver(tld, ensCl)
}
}
...@@ -2,16 +2,19 @@ ...@@ -2,16 +2,19 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package resolver_test package multiresolver_test
import ( import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"reflect" "reflect"
"testing" "testing"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/resolver" "github.com/ethersphere/bee/pkg/resolver"
"github.com/ethersphere/bee/pkg/resolver/mock" "github.com/ethersphere/bee/pkg/resolver/mock"
"github.com/ethersphere/bee/pkg/resolver/multiresolver"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -21,11 +24,33 @@ func newAddr(s string) Address { ...@@ -21,11 +24,33 @@ func newAddr(s string) Address {
return swarm.NewAddress([]byte(s)) return swarm.NewAddress([]byte(s))
} }
func TestWithForceDefault(t *testing.T) { func TestMultiresolverOpts(t *testing.T) {
mr := resolver.NewMultiResolver( wantLog := logging.New(ioutil.Discard, 1)
resolver.WithForceDefault(), wantCfgs := []multiresolver.ConnectionConfig{
{
Address: "testadr1",
Endpoint: "testEndpoint1",
TLD: "testtld1",
},
{
Address: "testadr2",
Endpoint: "testEndpoint2",
TLD: "testtld2",
},
}
mr := multiresolver.NewMultiResolver(
multiresolver.WithLogger(wantLog),
multiresolver.WithConnectionConfigs(wantCfgs),
multiresolver.WithForceDefault(),
) )
if got := multiresolver.GetLogger(mr); got != wantLog {
t.Errorf("log: got: %v, want %v", got, wantLog)
}
if got := multiresolver.GetCfgs(mr); !reflect.DeepEqual(got, wantCfgs) {
t.Errorf("cfg: got: %v, want %v", got, wantCfgs)
}
if !mr.ForceDefault { if !mr.ForceDefault {
t.Error("did not set ForceDefault") t.Error("did not set ForceDefault")
} }
...@@ -45,29 +70,18 @@ func TestPushResolver(t *testing.T) { ...@@ -45,29 +70,18 @@ func TestPushResolver(t *testing.T) {
desc: "regular tld, named chain", desc: "regular tld, named chain",
tld: ".tld", tld: ".tld",
}, },
{
desc: "invalid tld",
tld: "invalid",
wantErr: resolver.ErrInvalidTLD,
},
} }
for _, tC := range testCases { for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) { t.Run(tC.desc, func(t *testing.T) {
mr := resolver.NewMultiResolver() mr := multiresolver.NewMultiResolver()
if mr.ChainCount(tC.tld) != 0 { if mr.ChainCount(tC.tld) != 0 {
t.Fatal("chain should start empty") t.Fatal("chain should start empty")
} }
want := mock.NewResolver() want := mock.NewResolver()
err := mr.PushResolver(tC.tld, want) mr.PushResolver(tC.tld, want)
if err != nil {
if !errors.Is(err, tC.wantErr) {
t.Fatal(err)
}
return
}
got := mr.GetChain(tC.tld)[0] got := mr.GetChain(tC.tld)[0]
if !reflect.DeepEqual(got, want) { if !reflect.DeepEqual(got, want) {
...@@ -82,20 +96,11 @@ func TestPushResolver(t *testing.T) { ...@@ -82,20 +96,11 @@ func TestPushResolver(t *testing.T) {
} }
}) })
} }
} t.Run("pop empty chain", func(t *testing.T) {
mr := multiresolver.NewMultiResolver()
func TestPopResolver(t *testing.T) { err := mr.PopResolver("")
mr := resolver.NewMultiResolver() if !errors.Is(err, multiresolver.ErrResolverChainEmpty) {
t.Errorf("got %v, want %v", err, multiresolver.ErrResolverChainEmpty)
t.Run("error on bad tld", func(t *testing.T) {
if err := mr.PopResolver("invalid"); !errors.Is(err, resolver.ErrInvalidTLD) {
t.Fatal("invalid error type")
}
})
t.Run("error on empty", func(t *testing.T) {
if err := mr.PopResolver(".tld"); !errors.Is(err, resolver.ErrResolverChainEmpty) {
t.Fatal("invalid error type")
} }
}) })
} }
...@@ -104,6 +109,7 @@ func TestResolve(t *testing.T) { ...@@ -104,6 +109,7 @@ func TestResolve(t *testing.T) {
addr := newAddr("aaaabbbbccccdddd") addr := newAddr("aaaabbbbccccdddd")
addrAlt := newAddr("ddddccccbbbbaaaa") addrAlt := newAddr("ddddccccbbbbaaaa")
errUnregisteredName := fmt.Errorf("unregistered name") errUnregisteredName := fmt.Errorf("unregistered name")
errResolutionFailed := fmt.Errorf("name resolution failed")
newOKResolver := func(addr Address) resolver.Interface { newOKResolver := func(addr Address) resolver.Interface {
return mock.NewResolver( return mock.NewResolver(
...@@ -115,8 +121,7 @@ func TestResolve(t *testing.T) { ...@@ -115,8 +121,7 @@ func TestResolve(t *testing.T) {
newErrResolver := func() resolver.Interface { newErrResolver := func() resolver.Interface {
return mock.NewResolver( return mock.NewResolver(
mock.WithResolveFunc(func(name string) (Address, error) { mock.WithResolveFunc(func(name string) (Address, error) {
err := fmt.Errorf("name resolution failed for %q", name) return swarm.ZeroAddress, errResolutionFailed
return swarm.ZeroAddress, err
}), }),
) )
} }
...@@ -181,35 +186,35 @@ func TestResolve(t *testing.T) { ...@@ -181,35 +186,35 @@ func TestResolve(t *testing.T) {
wantAdr Address wantAdr Address
wantErr error wantErr error
}{ }{
// { {
// name: "", name: "",
// wantAdr: testAdr, wantAdr: addr,
// }, },
// { {
// name: "hello", name: "hello",
// wantAdr: testAdr, wantAdr: addr,
// }, },
// { {
// name: "example.tld", name: "example.tld",
// wantAdr: testAdr, wantAdr: addr,
// }, },
// { {
// name: ".tld", name: ".tld",
// wantAdr: testAdr, wantAdr: addr,
// }, },
// { {
// name: "get.good", name: "get.good",
// wantAdr: testAdr, wantAdr: addr,
// }, },
// { {
// // Switch to the default chain: // Switch to the default chain:
// name: "this.empty", name: "this.empty",
// wantAdr: testAdr, wantAdr: addr,
// }, },
// { {
// name: "this.dies", name: "this.dies",
// wantErr: fmt.Errorf("Failed to resolve name %q", "this.dies"), wantErr: errResolutionFailed,
// }, },
{ {
name: "iam.unregistered", name: "iam.unregistered",
wantAdr: swarm.ZeroAddress, wantAdr: swarm.ZeroAddress,
...@@ -218,12 +223,10 @@ func TestResolve(t *testing.T) { ...@@ -218,12 +223,10 @@ func TestResolve(t *testing.T) {
} }
// Load the test fixture. // Load the test fixture.
mr := resolver.NewMultiResolver() mr := multiresolver.NewMultiResolver()
for _, tE := range testFixture { for _, tE := range testFixture {
for _, r := range tE.res { for _, r := range tE.res {
if err := mr.PushResolver(tE.tld, r); err != nil { mr.PushResolver(tE.tld, r)
t.Fatal(err)
}
} }
} }
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package service
import (
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/resolver"
"github.com/ethersphere/bee/pkg/resolver/client/ens"
)
// InitMultiResolver will create a new MultiResolver, create the appropriate
// resolvers, push them to the resolver chains and attempt to connect.
func InitMultiResolver(logger logging.Logger, cfgs []*resolver.ConnectionConfig) resolver.Interface {
if len(cfgs) == 0 {
logger.Info("name resolver: no name resolution service provided")
return nil
}
// Create a new MultiResolver.
mr := resolver.NewMultiResolver()
connectENS := func(tld string, ep string) {
ensCl := ens.NewClient()
logger.Debugf("name resolver: resolver for %q: connecting to endpoint %s", tld, ep)
if err := ensCl.Connect(ep); err != nil {
logger.Errorf("name resolver: resolver for %q domain: failed to connect to %q: %v", tld, ep, err)
} else {
logger.Infof("name resolver: resolver for %q domain: connected to %s", tld, ep)
if err := mr.PushResolver(tld, ensCl); err != nil {
logger.Errorf("name resolver: failed to push resolver to %q resolver chain: %v", tld, err)
}
}
}
// Attempt to conect to each resolver using the connection string.
for _, c := range cfgs {
// Warn user that the resolver address field is not used.
if c.Address != "" {
logger.Warningf("name resolver: connection string %q contains resolver address field, which is currently unused", c.Address)
}
// Select the appropriate resolver.
switch c.TLD {
case "eth":
// FIXME: MultiResolver expects "." in front of the TLD label.
connectENS("."+c.TLD, c.Endpoint)
case "":
connectENS("", c.Endpoint)
default:
logger.Errorf("default domain resolution not supported")
}
}
return mr
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment