multiresolver_test.go 5.27 KB
Newer Older
1 2 3 4
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

5
package multiresolver_test
6 7

import (
8
	"errors"
9
	"fmt"
10
	"io/ioutil"
11 12 13
	"reflect"
	"testing"

14
	"github.com/ethersphere/bee/pkg/logging"
15 16
	"github.com/ethersphere/bee/pkg/resolver"
	"github.com/ethersphere/bee/pkg/resolver/mock"
17
	"github.com/ethersphere/bee/pkg/resolver/multiresolver"
18 19 20 21 22 23 24 25 26
	"github.com/ethersphere/bee/pkg/swarm"
)

type Address = swarm.Address

func newAddr(s string) Address {
	return swarm.NewAddress([]byte(s))
}

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
func TestMultiresolverOpts(t *testing.T) {
	wantLog := logging.New(ioutil.Discard, 1)
	wantCfgs := []multiresolver.ConnectionConfig{
		{
			Address:  "testadr1",
			Endpoint: "testEndpoint1",
			TLD:      "testtld1",
		},
		{
			Address:  "testadr2",
			Endpoint: "testEndpoint2",
			TLD:      "testtld2",
		},
	}

	mr := multiresolver.NewMultiResolver(
		multiresolver.WithLogger(wantLog),
		multiresolver.WithConnectionConfigs(wantCfgs),
		multiresolver.WithForceDefault(),
46 47
	)

48 49 50 51 52 53
	if got := multiresolver.GetLogger(mr); got != wantLog {
		t.Errorf("log: got: %v, want %v", got, wantLog)
	}
	if got := multiresolver.GetCfgs(mr); !reflect.DeepEqual(got, wantCfgs) {
		t.Errorf("cfg: got: %v, want %v", got, wantCfgs)
	}
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
	if !mr.ForceDefault {
		t.Error("did not set ForceDefault")
	}
}

func TestPushResolver(t *testing.T) {
	testCases := []struct {
		desc    string
		tld     string
		wantErr error
	}{
		{
			desc: "empty string, default",
			tld:  "",
		},
		{
			desc: "regular tld, named chain",
			tld:  ".tld",
		},
	}

	for _, tC := range testCases {
		t.Run(tC.desc, func(t *testing.T) {
77
			mr := multiresolver.NewMultiResolver()
78 79 80 81 82 83

			if mr.ChainCount(tC.tld) != 0 {
				t.Fatal("chain should start empty")
			}

			want := mock.NewResolver()
84
			mr.PushResolver(tC.tld, want)
85 86 87 88 89 90 91 92 93 94 95 96 97 98

			got := mr.GetChain(tC.tld)[0]
			if !reflect.DeepEqual(got, want) {
				t.Error("failed to push")
			}

			if err := mr.PopResolver(tC.tld); err != nil {
				t.Error(err)
			}
			if mr.ChainCount(tC.tld) > 0 {
				t.Error("failed to pop")
			}
		})
	}
99 100 101 102 103
	t.Run("pop empty chain", func(t *testing.T) {
		mr := multiresolver.NewMultiResolver()
		err := mr.PopResolver("")
		if !errors.Is(err, multiresolver.ErrResolverChainEmpty) {
			t.Errorf("got %v, want %v", err, multiresolver.ErrResolverChainEmpty)
104 105 106 107 108
		}
	})
}

func TestResolve(t *testing.T) {
109 110 111
	addr := newAddr("aaaabbbbccccdddd")
	addrAlt := newAddr("ddddccccbbbbaaaa")
	errUnregisteredName := fmt.Errorf("unregistered name")
112
	errResolutionFailed := fmt.Errorf("name resolution failed")
113

114
	newOKResolver := func(addr Address) resolver.Interface {
115 116
		return mock.NewResolver(
			mock.WithResolveFunc(func(_ string) (Address, error) {
117
				return addr, nil
118 119 120 121 122 123
			}),
		)
	}
	newErrResolver := func() resolver.Interface {
		return mock.NewResolver(
			mock.WithResolveFunc(func(name string) (Address, error) {
124
				return swarm.ZeroAddress, errResolutionFailed
125 126 127 128 129 130 131
			}),
		)
	}
	newUnregisteredNameResolver := func() resolver.Interface {
		return mock.NewResolver(
			mock.WithResolveFunc(func(name string) (Address, error) {
				return swarm.ZeroAddress, errUnregisteredName
132 133 134 135 136 137 138 139 140 141 142 143 144
			}),
		)
	}

	testFixture := []struct {
		tld       string
		res       []resolver.Interface
		expectAdr Address
	}{
		{
			// Default chain:
			tld: "",
			res: []resolver.Interface{
145
				newOKResolver(addr),
146
			},
147
			expectAdr: addr,
148 149 150 151 152 153
		},
		{
			tld: ".tld",
			res: []resolver.Interface{
				newErrResolver(),
				newErrResolver(),
154
				newOKResolver(addr),
155
			},
156
			expectAdr: addr,
157 158 159 160
		},
		{
			tld: ".good",
			res: []resolver.Interface{
161 162
				newOKResolver(addr),
				newOKResolver(addrAlt),
163
			},
164
			expectAdr: addr,
165 166 167 168 169
		},
		{
			tld: ".empty",
		},
		{
170
			tld: ".dies",
171 172 173 174 175
			res: []resolver.Interface{
				newErrResolver(),
				newErrResolver(),
			},
		},
176 177 178 179 180 181
		{
			tld: ".unregistered",
			res: []resolver.Interface{
				newUnregisteredNameResolver(),
			},
		},
182 183 184 185 186 187 188
	}

	testCases := []struct {
		name    string
		wantAdr Address
		wantErr error
	}{
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
		{
			name:    "",
			wantAdr: addr,
		},
		{
			name:    "hello",
			wantAdr: addr,
		},
		{
			name:    "example.tld",
			wantAdr: addr,
		},
		{
			name:    ".tld",
			wantAdr: addr,
		},
		{
			name:    "get.good",
			wantAdr: addr,
		},
		{
			// Switch to the default chain:
			name:    "this.empty",
			wantAdr: addr,
		},
		{
			name:    "this.dies",
			wantErr: errResolutionFailed,
		},
218
		{
219 220 221
			name:    "iam.unregistered",
			wantAdr: swarm.ZeroAddress,
			wantErr: errUnregisteredName,
222 223 224 225
		},
	}

	// Load the test fixture.
226
	mr := multiresolver.NewMultiResolver()
227 228
	for _, tE := range testFixture {
		for _, r := range tE.res {
229
			mr.PushResolver(tE.tld, r)
230 231 232 233 234
		}
	}

	for _, tC := range testCases {
		t.Run(tC.name, func(t *testing.T) {
235
			addr, err := mr.Resolve(tC.name)
236 237 238 239
			if err != nil {
				if tC.wantErr == nil {
					t.Fatalf("unexpected error: got %v", err)
				}
240
				if !errors.Is(err, tC.wantErr) {
241 242 243
					t.Fatalf("got %v, want %v", err, tC.wantErr)
				}
			}
244 245
			if !addr.Equal(tC.wantAdr) {
				t.Errorf("got %q, want %q", addr, tC.wantAdr)
246 247 248
			}
		})
	}
249 250 251 252 253 254 255 256 257 258 259 260 261

	t.Run("close all", func(t *testing.T) {
		if err := mr.Close(); err != nil {
			t.Fatal(err)
		}
		for _, tE := range testFixture {
			for _, r := range mr.GetChain(tE.tld) {
				if !r.(*mock.Resolver).IsClosed {
					t.Errorf("expected %q resolver closed", tE.tld)
				}
			}
		}
	})
262
}