Commit 9c84e3ff authored by Janoš Guljaš's avatar Janoš Guljaš Committed by GitHub

use testing.Cleanup in test instead deferred explicit cleanup function (#204)

parent af01e019
...@@ -14,20 +14,19 @@ import ( ...@@ -14,20 +14,19 @@ import (
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
) )
type bookFunc func(t *testing.T) (book addressbook.GetPutter, cleanup func()) type bookFunc func(t *testing.T) (book addressbook.GetPutter)
func TestInMem(t *testing.T) { func TestInMem(t *testing.T) {
run(t, func(t *testing.T) (addressbook.GetPutter, func()) { run(t, func(t *testing.T) addressbook.GetPutter {
store := mock.NewStateStore() store := mock.NewStateStore()
book := addressbook.New(store) book := addressbook.New(store)
return book, func() {} return book
}) })
} }
func run(t *testing.T, f bookFunc) { func run(t *testing.T, f bookFunc) {
store, cleanup := f(t) store := f(t)
defer cleanup()
addr1 := swarm.NewAddress([]byte{0, 1, 2, 3}) addr1 := swarm.NewAddress([]byte{0, 1, 2, 3})
addr2 := swarm.NewAddress([]byte{0, 1, 2, 4}) addr2 := swarm.NewAddress([]byte{0, 1, 2, 4})
......
...@@ -23,16 +23,16 @@ type testServerOptions struct { ...@@ -23,16 +23,16 @@ type testServerOptions struct {
Storer storage.Storer Storer storage.Storer
} }
func newTestServer(t *testing.T, o testServerOptions) (client *http.Client, cleanup func()) { func newTestServer(t *testing.T, o testServerOptions) *http.Client {
s := api.New(api.Options{ s := api.New(api.Options{
Pingpong: o.Pingpong, Pingpong: o.Pingpong,
Storer: o.Storer, Storer: o.Storer,
Logger: logging.New(ioutil.Discard, 0), Logger: logging.New(ioutil.Discard, 0),
}) })
ts := httptest.NewServer(s) ts := httptest.NewServer(s)
cleanup = ts.Close t.Cleanup(ts.Close)
client = &http.Client{ return &http.Client{
Transport: web.RoundTripperFunc(func(r *http.Request) (*http.Response, error) { Transport: web.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
u, err := url.Parse(ts.URL + r.URL.String()) u, err := url.Parse(ts.URL + r.URL.String())
if err != nil { if err != nil {
...@@ -42,5 +42,4 @@ func newTestServer(t *testing.T, o testServerOptions) (client *http.Client, clea ...@@ -42,5 +42,4 @@ func newTestServer(t *testing.T, o testServerOptions) (client *http.Client, clea
return ts.Client().Transport.RoundTrip(r) return ts.Client().Transport.RoundTrip(r)
}), }),
} }
return client, cleanup
} }
...@@ -21,15 +21,14 @@ import ( ...@@ -21,15 +21,14 @@ import (
// downloading and requesting a resource that cannot be found. // downloading and requesting a resource that cannot be found.
func TestBzz(t *testing.T) { func TestBzz(t *testing.T) {
var ( var (
resource = "/bzz" resource = "/bzz"
content = []byte("foo") content = []byte("foo")
expHash = "2387e8e7d8a48c2a9339c97c1dc3461a9a7aa07e994c5cb8b38fd7c1b3e6ea48" expHash = "2387e8e7d8a48c2a9339c97c1dc3461a9a7aa07e994c5cb8b38fd7c1b3e6ea48"
mockStorer = mock.NewStorer() mockStorer = mock.NewStorer()
client, cleanup = newTestServer(t, testServerOptions{ client = newTestServer(t, testServerOptions{
Storer: mockStorer, Storer: mockStorer,
}) })
) )
defer cleanup()
t.Run("upload", func(t *testing.T) { t.Run("upload", func(t *testing.T) {
jsonhttptest.ResponseDirect(t, client, http.MethodPost, resource, bytes.NewReader(content), http.StatusOK, api.BzzPostResponse{ jsonhttptest.ResponseDirect(t, client, http.MethodPost, resource, bytes.NewReader(content), http.StatusOK, api.BzzPostResponse{
......
...@@ -32,11 +32,10 @@ func TestChunkUploadDownload(t *testing.T) { ...@@ -32,11 +32,10 @@ func TestChunkUploadDownload(t *testing.T) {
invalidContent = []byte("bbaattss") invalidContent = []byte("bbaattss")
mockValidator = validator.NewMockValidator(validHash, validContent) mockValidator = validator.NewMockValidator(validHash, validContent)
mockValidatingStorer = mock.NewValidatingStorer(mockValidator) mockValidatingStorer = mock.NewValidatingStorer(mockValidator)
client, cleanup = newTestServer(t, testServerOptions{ client = newTestServer(t, testServerOptions{
Storer: mockValidatingStorer, Storer: mockValidatingStorer,
}) })
) )
defer cleanup()
t.Run("invalid hash", func(t *testing.T) { t.Run("invalid hash", func(t *testing.T) {
jsonhttptest.ResponseDirect(t, client, http.MethodPost, resource(invalidHash), bytes.NewReader(validContent), http.StatusBadRequest, jsonhttp.StatusResponse{ jsonhttptest.ResponseDirect(t, client, http.MethodPost, resource(invalidHash), bytes.NewReader(validContent), http.StatusBadRequest, jsonhttp.StatusResponse{
......
...@@ -36,10 +36,9 @@ func TestPingpong(t *testing.T) { ...@@ -36,10 +36,9 @@ func TestPingpong(t *testing.T) {
return rtt, nil return rtt, nil
}) })
client, cleanup := newTestServer(t, testServerOptions{ client := newTestServer(t, testServerOptions{
Pingpong: pingpongService, Pingpong: pingpongService,
}) })
defer cleanup()
t.Run("ok", func(t *testing.T) { t.Run("ok", func(t *testing.T) {
jsonhttptest.ResponseDirect(t, client, http.MethodPost, "/pingpong/"+peerID.String(), nil, http.StatusOK, api.PingpongResponse{ jsonhttptest.ResponseDirect(t, client, http.MethodPost, "/pingpong/"+peerID.String(), nil, http.StatusOK, api.PingpongResponse{
......
...@@ -21,7 +21,6 @@ func TestHasChunkHandler(t *testing.T) { ...@@ -21,7 +21,6 @@ func TestHasChunkHandler(t *testing.T) {
testServer := newTestServer(t, testServerOptions{ testServer := newTestServer(t, testServerOptions{
Storer: mockStorer, Storer: mockStorer,
}) })
defer testServer.Cleanup()
key := swarm.MustParseHexAddress("aabbcc") key := swarm.MustParseHexAddress("aabbcc")
value := []byte("data data data") value := []byte("data data data")
......
...@@ -36,7 +36,6 @@ type testServer struct { ...@@ -36,7 +36,6 @@ type testServer struct {
Client *http.Client Client *http.Client
Addressbook addressbook.GetPutter Addressbook addressbook.GetPutter
TopologyDriver topology.Driver TopologyDriver topology.Driver
Cleanup func()
} }
func newTestServer(t *testing.T, o testServerOptions) *testServer { func newTestServer(t *testing.T, o testServerOptions) *testServer {
...@@ -53,7 +52,7 @@ func newTestServer(t *testing.T, o testServerOptions) *testServer { ...@@ -53,7 +52,7 @@ func newTestServer(t *testing.T, o testServerOptions) *testServer {
Storer: o.Storer, Storer: o.Storer,
}) })
ts := httptest.NewServer(s) ts := httptest.NewServer(s)
cleanup := ts.Close t.Cleanup(ts.Close)
client := &http.Client{ client := &http.Client{
Transport: web.RoundTripperFunc(func(r *http.Request) (*http.Response, error) { Transport: web.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
...@@ -69,19 +68,18 @@ func newTestServer(t *testing.T, o testServerOptions) *testServer { ...@@ -69,19 +68,18 @@ func newTestServer(t *testing.T, o testServerOptions) *testServer {
Client: client, Client: client,
Addressbook: addrbook, Addressbook: addrbook,
TopologyDriver: topologyDriver, TopologyDriver: topologyDriver,
Cleanup: cleanup,
} }
} }
func newBZZTestServer(t *testing.T, o testServerOptions) (client *http.Client, cleanup func()) { func newBZZTestServer(t *testing.T, o testServerOptions) *http.Client {
s := api.New(api.Options{ s := api.New(api.Options{
Storer: o.Storer, Storer: o.Storer,
Logger: logging.New(ioutil.Discard, 0), Logger: logging.New(ioutil.Discard, 0),
}) })
ts := httptest.NewServer(s) ts := httptest.NewServer(s)
cleanup = ts.Close t.Cleanup(ts.Close)
client = &http.Client{ return &http.Client{
Transport: web.RoundTripperFunc(func(r *http.Request) (*http.Response, error) { Transport: web.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
u, err := url.Parse(ts.URL + r.URL.String()) u, err := url.Parse(ts.URL + r.URL.String())
if err != nil { if err != nil {
...@@ -91,7 +89,6 @@ func newBZZTestServer(t *testing.T, o testServerOptions) (client *http.Client, c ...@@ -91,7 +89,6 @@ func newBZZTestServer(t *testing.T, o testServerOptions) (client *http.Client, c
return ts.Client().Transport.RoundTrip(r) return ts.Client().Transport.RoundTrip(r)
}), }),
} }
return client, cleanup
} }
func mustMultiaddr(t *testing.T, s string) multiaddr.Multiaddr { func mustMultiaddr(t *testing.T, s string) multiaddr.Multiaddr {
......
...@@ -31,7 +31,6 @@ func TestAddresses(t *testing.T) { ...@@ -31,7 +31,6 @@ func TestAddresses(t *testing.T) {
return addresses, nil return addresses, nil
})), })),
}) })
defer testServer.Cleanup()
t.Run("ok", func(t *testing.T) { t.Run("ok", func(t *testing.T) {
jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodGet, "/addresses", nil, http.StatusOK, debugapi.AddressesResponse{ jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodGet, "/addresses", nil, http.StatusOK, debugapi.AddressesResponse{
...@@ -56,7 +55,6 @@ func TestAddresses_error(t *testing.T) { ...@@ -56,7 +55,6 @@ func TestAddresses_error(t *testing.T) {
return nil, testErr return nil, testErr
})), })),
}) })
defer testServer.Cleanup()
jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodGet, "/addresses", nil, http.StatusInternalServerError, jsonhttp.StatusResponse{ jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodGet, "/addresses", nil, http.StatusInternalServerError, jsonhttp.StatusResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
......
...@@ -35,7 +35,6 @@ func TestConnect(t *testing.T) { ...@@ -35,7 +35,6 @@ func TestConnect(t *testing.T) {
return overlay, nil return overlay, nil
})), })),
}) })
defer testServer.Cleanup()
t.Run("ok", func(t *testing.T) { t.Run("ok", func(t *testing.T) {
jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodPost, "/connect"+underlay, nil, http.StatusOK, debugapi.PeerConnectResponse{ jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodPost, "/connect"+underlay, nil, http.StatusOK, debugapi.PeerConnectResponse{
...@@ -76,7 +75,6 @@ func TestConnect(t *testing.T) { ...@@ -76,7 +75,6 @@ func TestConnect(t *testing.T) {
})), })),
TopologyOpts: []topmock.Option{topmock.WithAddPeerErr(testErr)}, TopologyOpts: []topmock.Option{topmock.WithAddPeerErr(testErr)},
}) })
defer testServer.Cleanup()
jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodPost, "/connect"+underlay, nil, http.StatusInternalServerError, jsonhttp.StatusResponse{ jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodPost, "/connect"+underlay, nil, http.StatusInternalServerError, jsonhttp.StatusResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
...@@ -113,7 +111,6 @@ func TestDisconnect(t *testing.T) { ...@@ -113,7 +111,6 @@ func TestDisconnect(t *testing.T) {
return p2p.ErrPeerNotFound return p2p.ErrPeerNotFound
})), })),
}) })
defer testServer.Cleanup()
t.Run("ok", func(t *testing.T) { t.Run("ok", func(t *testing.T) {
jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodDelete, "/peers/"+address.String(), nil, http.StatusOK, jsonhttp.StatusResponse{ jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodDelete, "/peers/"+address.String(), nil, http.StatusOK, jsonhttp.StatusResponse{
...@@ -152,7 +149,6 @@ func TestPeer(t *testing.T) { ...@@ -152,7 +149,6 @@ func TestPeer(t *testing.T) {
return []p2p.Peer{{Address: overlay}} return []p2p.Peer{{Address: overlay}}
})), })),
}) })
defer testServer.Cleanup()
t.Run("ok", func(t *testing.T) { t.Run("ok", func(t *testing.T) {
jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodGet, "/peers", nil, http.StatusOK, debugapi.PeersResponse{ jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodGet, "/peers", nil, http.StatusOK, debugapi.PeersResponse{
......
...@@ -30,13 +30,10 @@ func TestPinChunkHandler(t *testing.T) { ...@@ -30,13 +30,10 @@ func TestPinChunkHandler(t *testing.T) {
Storer: mockValidatingStorer, Storer: mockValidatingStorer,
}) })
// This server is used to store chunks // This server is used to store chunks
bzzTestServer, cleanup := newBZZTestServer(t, testServerOptions{ bzzTestServer := newBZZTestServer(t, testServerOptions{
Storer: mockValidatingStorer, Storer: mockValidatingStorer,
}) })
defer debugTestServer.Cleanup()
defer cleanup()
// bad chunk address // bad chunk address
t.Run("pin-bad-address", func(t *testing.T) { t.Run("pin-bad-address", func(t *testing.T) {
jsonhttptest.ResponseDirect(t, debugTestServer.Client, http.MethodPost, "/chunks-pin/abcd1100zz", nil, http.StatusBadRequest, jsonhttp.StatusResponse{ jsonhttptest.ResponseDirect(t, debugTestServer.Client, http.MethodPost, "/chunks-pin/abcd1100zz", nil, http.StatusBadRequest, jsonhttp.StatusResponse{
......
...@@ -14,7 +14,6 @@ import ( ...@@ -14,7 +14,6 @@ import (
func TestHealth(t *testing.T) { func TestHealth(t *testing.T) {
testServer := newTestServer(t, testServerOptions{}) testServer := newTestServer(t, testServerOptions{})
defer testServer.Cleanup()
jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodGet, "/health", nil, http.StatusOK, debugapi.StatusResponse{ jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodGet, "/health", nil, http.StatusOK, debugapi.StatusResponse{
Status: "ok", Status: "ok",
...@@ -23,7 +22,6 @@ func TestHealth(t *testing.T) { ...@@ -23,7 +22,6 @@ func TestHealth(t *testing.T) {
func TestReadiness(t *testing.T) { func TestReadiness(t *testing.T) {
testServer := newTestServer(t, testServerOptions{}) testServer := newTestServer(t, testServerOptions{})
defer testServer.Cleanup()
jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodGet, "/readiness", nil, http.StatusOK, debugapi.StatusResponse{ jsonhttptest.ResponseDirect(t, testServer.Client, http.MethodGet, "/readiness", nil, http.StatusOK, debugapi.StatusResponse{
Status: "ok", Status: "ok",
......
...@@ -16,8 +16,7 @@ import ( ...@@ -16,8 +16,7 @@ import (
) )
func TestAddresses(t *testing.T) { func TestAddresses(t *testing.T) {
s, _, cleanup := newService(t, 1, libp2p.Options{}) s, _ := newService(t, 1, libp2p.Options{})
defer cleanup()
addrs, err := s.Addresses() addrs, err := s.Addresses()
if err != nil { if err != nil {
...@@ -32,11 +31,9 @@ func TestConnectDisconnect(t *testing.T) { ...@@ -32,11 +31,9 @@ func TestConnectDisconnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{}) s1, overlay1 := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{}) s2, overlay2 := newService(t, 1, libp2p.Options{})
defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -60,11 +57,9 @@ func TestDoubleConnect(t *testing.T) { ...@@ -60,11 +57,9 @@ func TestDoubleConnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{}) s1, overlay1 := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{}) s2, overlay2 := newService(t, 1, libp2p.Options{})
defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -87,11 +82,9 @@ func TestDoubleDisconnect(t *testing.T) { ...@@ -87,11 +82,9 @@ func TestDoubleDisconnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{}) s1, overlay1 := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{}) s2, overlay2 := newService(t, 1, libp2p.Options{})
defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -122,11 +115,9 @@ func TestMultipleConnectDisconnect(t *testing.T) { ...@@ -122,11 +115,9 @@ func TestMultipleConnectDisconnect(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{}) s1, overlay1 := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{}) s2, overlay2 := newService(t, 1, libp2p.Options{})
defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -165,11 +156,9 @@ func TestConnectDisconnectOnAllAddresses(t *testing.T) { ...@@ -165,11 +156,9 @@ func TestConnectDisconnectOnAllAddresses(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{}) s1, overlay1 := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{}) s2, overlay2 := newService(t, 1, libp2p.Options{})
defer cleanup2()
addrs, err := s1.Addresses() addrs, err := s1.Addresses()
if err != nil { if err != nil {
...@@ -197,11 +186,9 @@ func TestDoubleConnectOnAllAddresses(t *testing.T) { ...@@ -197,11 +186,9 @@ func TestDoubleConnectOnAllAddresses(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{}) s1, overlay1 := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{}) s2, overlay2 := newService(t, 1, libp2p.Options{})
defer cleanup2()
addrs, err := s1.Addresses() addrs, err := s1.Addresses()
if err != nil { if err != nil {
...@@ -235,11 +222,9 @@ func TestDifferentNetworkIDs(t *testing.T) { ...@@ -235,11 +222,9 @@ func TestDifferentNetworkIDs(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, _, cleanup1 := newService(t, 1, libp2p.Options{}) s1, _ := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, _, cleanup2 := newService(t, 2, libp2p.Options{}) s2, _ := newService(t, 2, libp2p.Options{})
defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -255,17 +240,15 @@ func TestConnectWithDisabledQUICAndWSTransports(t *testing.T) { ...@@ -255,17 +240,15 @@ func TestConnectWithDisabledQUICAndWSTransports(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{ s1, overlay1 := newService(t, 1, libp2p.Options{
DisableQUIC: true, DisableQUIC: true,
DisableWS: true, DisableWS: true,
}) })
defer cleanup1()
s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{ s2, overlay2 := newService(t, 1, libp2p.Options{
DisableQUIC: true, DisableQUIC: true,
DisableWS: true, DisableWS: true,
}) })
defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -282,11 +265,9 @@ func TestConnectRepeatHandshake(t *testing.T) { ...@@ -282,11 +265,9 @@ func TestConnectRepeatHandshake(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{}) s1, overlay1 := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{}) s2, overlay2 := newService(t, 1, libp2p.Options{})
defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
......
...@@ -23,11 +23,9 @@ func TestHeaders(t *testing.T) { ...@@ -23,11 +23,9 @@ func TestHeaders(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{}) s1, overlay1 := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{}) s2, overlay2 := newService(t, 1, libp2p.Options{})
defer cleanup2()
var gotHeaders p2p.Headers var gotHeaders p2p.Headers
handled := make(chan struct{}) handled := make(chan struct{})
...@@ -72,11 +70,9 @@ func TestHeaders_empty(t *testing.T) { ...@@ -72,11 +70,9 @@ func TestHeaders_empty(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{}) s1, overlay1 := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{}) s2, overlay2 := newService(t, 1, libp2p.Options{})
defer cleanup2()
var gotHeaders p2p.Headers var gotHeaders p2p.Headers
handled := make(chan struct{}) handled := make(chan struct{})
...@@ -130,11 +126,9 @@ func TestHeadler(t *testing.T) { ...@@ -130,11 +126,9 @@ func TestHeadler(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{}) s1, overlay1 := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, _, cleanup2 := newService(t, 1, libp2p.Options{}) s2, _ := newService(t, 1, libp2p.Options{})
defer cleanup2()
var gotReceivedHeaders p2p.Headers var gotReceivedHeaders p2p.Headers
handled := make(chan struct{}) handled := make(chan struct{})
......
...@@ -23,7 +23,7 @@ import ( ...@@ -23,7 +23,7 @@ import (
) )
// newService constructs a new libp2p service. // newService constructs a new libp2p service.
func newService(t *testing.T, networkID uint64, o libp2p.Options) (s *libp2p.Service, overlay swarm.Address, cleanup func()) { func newService(t *testing.T, networkID uint64, o libp2p.Options) (s *libp2p.Service, overlay swarm.Address) {
t.Helper() t.Helper()
privateKey, err := crypto.GenerateSecp256k1Key() privateKey, err := crypto.GenerateSecp256k1Key()
...@@ -49,10 +49,11 @@ func newService(t *testing.T, networkID uint64, o libp2p.Options) (s *libp2p.Ser ...@@ -49,10 +49,11 @@ func newService(t *testing.T, networkID uint64, o libp2p.Options) (s *libp2p.Ser
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
return s, overlay, func() { t.Cleanup(func() {
cancel() cancel()
s.Close() s.Close()
} })
return s, overlay
} }
// expectPeers validates that peers with addresses are connected. // expectPeers validates that peers with addresses are connected.
......
...@@ -18,11 +18,9 @@ func TestNewStream(t *testing.T) { ...@@ -18,11 +18,9 @@ func TestNewStream(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{}) s1, overlay1 := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, _, cleanup2 := newService(t, 1, libp2p.Options{}) s2, _ := newService(t, 1, libp2p.Options{})
defer cleanup2()
if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error { if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error {
return nil return nil
...@@ -49,11 +47,9 @@ func TestNewStream_errNotSupported(t *testing.T) { ...@@ -49,11 +47,9 @@ func TestNewStream_errNotSupported(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{}) s1, overlay1 := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, _, cleanup2 := newService(t, 1, libp2p.Options{}) s2, _ := newService(t, 1, libp2p.Options{})
defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -86,11 +82,9 @@ func TestNewStream_semanticVersioning(t *testing.T) { ...@@ -86,11 +82,9 @@ func TestNewStream_semanticVersioning(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{}) s1, overlay1 := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, _, cleanup2 := newService(t, 1, libp2p.Options{}) s2, _ := newService(t, 1, libp2p.Options{})
defer cleanup2()
addr := serviceUnderlayAddress(t, s1) addr := serviceUnderlayAddress(t, s1)
...@@ -147,11 +141,9 @@ func TestDisconnectError(t *testing.T) { ...@@ -147,11 +141,9 @@ func TestDisconnectError(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{}) s1, overlay1 := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, overlay2, cleanup2 := newService(t, 1, libp2p.Options{}) s2, overlay2 := newService(t, 1, libp2p.Options{})
defer cleanup2()
if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error { if err := s1.AddProtocol(newTestProtocol(func(_ context.Context, _ p2p.Peer, _ p2p.Stream) error {
return p2p.NewDisconnectError(errors.New("test error")) return p2p.NewDisconnectError(errors.New("test error"))
......
...@@ -34,11 +34,9 @@ func TestTracing(t *testing.T) { ...@@ -34,11 +34,9 @@ func TestTracing(t *testing.T) {
} }
defer closer2.Close() defer closer2.Close()
s1, overlay1, cleanup1 := newService(t, 1, libp2p.Options{}) s1, overlay1 := newService(t, 1, libp2p.Options{})
defer cleanup1()
s2, _, cleanup2 := newService(t, 1, libp2p.Options{}) s2, _ := newService(t, 1, libp2p.Options{})
defer cleanup2()
var handledTracingSpan string var handledTracingSpan string
handled := make(chan struct{}) handled := make(chan struct{})
......
...@@ -25,8 +25,7 @@ import ( ...@@ -25,8 +25,7 @@ import (
// TestNewDB constructs a new DB // TestNewDB constructs a new DB
// and validates if the schema is initialized properly. // and validates if the schema is initialized properly.
func TestNewDB(t *testing.T) { func TestNewDB(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
s, err := db.getSchema() s, err := db.getSchema()
if err != nil { if err != nil {
...@@ -93,13 +92,16 @@ func TestDB_persistence(t *testing.T) { ...@@ -93,13 +92,16 @@ func TestDB_persistence(t *testing.T) {
// newTestDB is a helper function that constructs a // newTestDB is a helper function that constructs a
// temporary database and returns a cleanup function that must // temporary database and returns a cleanup function that must
// be called to remove the data. // be called to remove the data.
func newTestDB(t *testing.T) (db *DB, cleanupFunc func()) { func newTestDB(t *testing.T) *DB {
t.Helper() t.Helper()
db, err := NewDB("") db, err := NewDB("")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
return db, func() { t.Cleanup(func() {
db.Close() if err := db.Close(); err != nil {
} t.Fatal(err)
}
})
return db
} }
...@@ -25,8 +25,7 @@ import ( ...@@ -25,8 +25,7 @@ import (
// TestStringField validates put and get operations // TestStringField validates put and get operations
// of the StringField. // of the StringField.
func TestStringField(t *testing.T) { func TestStringField(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
simpleString, err := db.NewStringField("simple-string") simpleString, err := db.NewStringField("simple-string")
if err != nil { if err != nil {
......
...@@ -25,8 +25,7 @@ import ( ...@@ -25,8 +25,7 @@ import (
// TestStructField validates put and get operations // TestStructField validates put and get operations
// of the StructField. // of the StructField.
func TestStructField(t *testing.T) { func TestStructField(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
complexField, err := db.NewStructField("complex-field") complexField, err := db.NewStructField("complex-field")
if err != nil { if err != nil {
......
...@@ -25,8 +25,7 @@ import ( ...@@ -25,8 +25,7 @@ import (
// TestUint64Field validates put and get operations // TestUint64Field validates put and get operations
// of the Uint64Field. // of the Uint64Field.
func TestUint64Field(t *testing.T) { func TestUint64Field(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
counter, err := db.NewUint64Field("counter") counter, err := db.NewUint64Field("counter")
if err != nil { if err != nil {
...@@ -112,8 +111,7 @@ func TestUint64Field(t *testing.T) { ...@@ -112,8 +111,7 @@ func TestUint64Field(t *testing.T) {
// TestUint64Field_Inc validates Inc operation // TestUint64Field_Inc validates Inc operation
// of the Uint64Field. // of the Uint64Field.
func TestUint64Field_Inc(t *testing.T) { func TestUint64Field_Inc(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
counter, err := db.NewUint64Field("counter") counter, err := db.NewUint64Field("counter")
if err != nil { if err != nil {
...@@ -142,8 +140,7 @@ func TestUint64Field_Inc(t *testing.T) { ...@@ -142,8 +140,7 @@ func TestUint64Field_Inc(t *testing.T) {
// TestUint64Field_IncInBatch validates IncInBatch operation // TestUint64Field_IncInBatch validates IncInBatch operation
// of the Uint64Field. // of the Uint64Field.
func TestUint64Field_IncInBatch(t *testing.T) { func TestUint64Field_IncInBatch(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
counter, err := db.NewUint64Field("counter") counter, err := db.NewUint64Field("counter")
if err != nil { if err != nil {
...@@ -196,8 +193,7 @@ func TestUint64Field_IncInBatch(t *testing.T) { ...@@ -196,8 +193,7 @@ func TestUint64Field_IncInBatch(t *testing.T) {
// TestUint64Field_Dec validates Dec operation // TestUint64Field_Dec validates Dec operation
// of the Uint64Field. // of the Uint64Field.
func TestUint64Field_Dec(t *testing.T) { func TestUint64Field_Dec(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
counter, err := db.NewUint64Field("counter") counter, err := db.NewUint64Field("counter")
if err != nil { if err != nil {
...@@ -233,8 +229,7 @@ func TestUint64Field_Dec(t *testing.T) { ...@@ -233,8 +229,7 @@ func TestUint64Field_Dec(t *testing.T) {
// TestUint64Field_DecInBatch validates DecInBatch operation // TestUint64Field_DecInBatch validates DecInBatch operation
// of the Uint64Field. // of the Uint64Field.
func TestUint64Field_DecInBatch(t *testing.T) { func TestUint64Field_DecInBatch(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
counter, err := db.NewUint64Field("counter") counter, err := db.NewUint64Field("counter")
if err != nil { if err != nil {
......
...@@ -52,8 +52,7 @@ var retrievalIndexFuncs = IndexFuncs{ ...@@ -52,8 +52,7 @@ var retrievalIndexFuncs = IndexFuncs{
// TestIndex validates put, get, fill, has and delete functions of the Index implementation. // TestIndex validates put, get, fill, has and delete functions of the Index implementation.
func TestIndex(t *testing.T) { func TestIndex(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
index, err := db.NewIndex("retrieval", retrievalIndexFuncs) index, err := db.NewIndex("retrieval", retrievalIndexFuncs)
if err != nil { if err != nil {
...@@ -366,8 +365,7 @@ func TestIndex(t *testing.T) { ...@@ -366,8 +365,7 @@ func TestIndex(t *testing.T) {
// TestIndex_Iterate validates index Iterate // TestIndex_Iterate validates index Iterate
// functions for correctness. // functions for correctness.
func TestIndex_Iterate(t *testing.T) { func TestIndex_Iterate(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
index, err := db.NewIndex("retrieval", retrievalIndexFuncs) index, err := db.NewIndex("retrieval", retrievalIndexFuncs)
if err != nil { if err != nil {
...@@ -549,8 +547,7 @@ func TestIndex_Iterate(t *testing.T) { ...@@ -549,8 +547,7 @@ func TestIndex_Iterate(t *testing.T) {
// TestIndex_Iterate_withPrefix validates index Iterate // TestIndex_Iterate_withPrefix validates index Iterate
// function for correctness. // function for correctness.
func TestIndex_Iterate_withPrefix(t *testing.T) { func TestIndex_Iterate_withPrefix(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
index, err := db.NewIndex("retrieval", retrievalIndexFuncs) index, err := db.NewIndex("retrieval", retrievalIndexFuncs)
if err != nil { if err != nil {
...@@ -736,8 +733,7 @@ func TestIndex_Iterate_withPrefix(t *testing.T) { ...@@ -736,8 +733,7 @@ func TestIndex_Iterate_withPrefix(t *testing.T) {
// TestIndex_count tests if Index.Count and Index.CountFrom // TestIndex_count tests if Index.Count and Index.CountFrom
// returns the correct number of items. // returns the correct number of items.
func TestIndex_count(t *testing.T) { func TestIndex_count(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
index, err := db.NewIndex("retrieval", retrievalIndexFuncs) index, err := db.NewIndex("retrieval", retrievalIndexFuncs)
if err != nil { if err != nil {
...@@ -906,8 +902,7 @@ func checkItem(t *testing.T, got, want Item) { ...@@ -906,8 +902,7 @@ func checkItem(t *testing.T, got, want Item) {
// TestIndex_firstAndLast validates that index First and Last methods // TestIndex_firstAndLast validates that index First and Last methods
// are returning expected results based on the provided prefix. // are returning expected results based on the provided prefix.
func TestIndex_firstAndLast(t *testing.T) { func TestIndex_firstAndLast(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
index, err := db.NewIndex("retrieval", retrievalIndexFuncs) index, err := db.NewIndex("retrieval", retrievalIndexFuncs)
if err != nil { if err != nil {
...@@ -1051,8 +1046,7 @@ func TestIncByteSlice(t *testing.T) { ...@@ -1051,8 +1046,7 @@ func TestIncByteSlice(t *testing.T) {
// TestIndex_HasMulti validates that HasMulti returns a correct // TestIndex_HasMulti validates that HasMulti returns a correct
// slice of booleans for provided Items. // slice of booleans for provided Items.
func TestIndex_HasMulti(t *testing.T) { func TestIndex_HasMulti(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
index, err := db.NewIndex("retrieval", retrievalIndexFuncs) index, err := db.NewIndex("retrieval", retrievalIndexFuncs)
if err != nil { if err != nil {
......
...@@ -23,8 +23,7 @@ import ( ...@@ -23,8 +23,7 @@ import (
// TestDB_schemaFieldKey validates correctness of schemaFieldKey. // TestDB_schemaFieldKey validates correctness of schemaFieldKey.
func TestDB_schemaFieldKey(t *testing.T) { func TestDB_schemaFieldKey(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
t.Run("empty name or type", func(t *testing.T) { t.Run("empty name or type", func(t *testing.T) {
_, err := db.schemaFieldKey("", "") _, err := db.schemaFieldKey("", "")
...@@ -89,8 +88,7 @@ func TestDB_schemaFieldKey(t *testing.T) { ...@@ -89,8 +88,7 @@ func TestDB_schemaFieldKey(t *testing.T) {
// TestDB_schemaIndexPrefix validates correctness of schemaIndexPrefix. // TestDB_schemaIndexPrefix validates correctness of schemaIndexPrefix.
func TestDB_schemaIndexPrefix(t *testing.T) { func TestDB_schemaIndexPrefix(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
t.Run("same name", func(t *testing.T) { t.Run("same name", func(t *testing.T) {
id1, err := db.schemaIndexPrefix("test") id1, err := db.schemaIndexPrefix("test")
...@@ -129,8 +127,7 @@ func TestDB_schemaIndexPrefix(t *testing.T) { ...@@ -129,8 +127,7 @@ func TestDB_schemaIndexPrefix(t *testing.T) {
func TestDB_RenameIndex(t *testing.T) { func TestDB_RenameIndex(t *testing.T) {
t.Run("empty names", func(t *testing.T) { t.Run("empty names", func(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
// empty names // empty names
renamed, err := db.RenameIndex("", "") renamed, err := db.RenameIndex("", "")
...@@ -161,8 +158,7 @@ func TestDB_RenameIndex(t *testing.T) { ...@@ -161,8 +158,7 @@ func TestDB_RenameIndex(t *testing.T) {
}) })
t.Run("same names", func(t *testing.T) { t.Run("same names", func(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
renamed, err := db.RenameIndex("index1", "index1") renamed, err := db.RenameIndex("index1", "index1")
if err != nil { if err != nil {
...@@ -174,8 +170,7 @@ func TestDB_RenameIndex(t *testing.T) { ...@@ -174,8 +170,7 @@ func TestDB_RenameIndex(t *testing.T) {
}) })
t.Run("unknown name", func(t *testing.T) { t.Run("unknown name", func(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
renamed, err := db.RenameIndex("index1", "index1new") renamed, err := db.RenameIndex("index1", "index1new")
if err != nil { if err != nil {
...@@ -187,8 +182,7 @@ func TestDB_RenameIndex(t *testing.T) { ...@@ -187,8 +182,7 @@ func TestDB_RenameIndex(t *testing.T) {
}) })
t.Run("valid names", func(t *testing.T) { t.Run("valid names", func(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
// initial indexes // initial indexes
key1, err := db.schemaIndexPrefix("index1") key1, err := db.schemaIndexPrefix("index1")
......
...@@ -25,8 +25,7 @@ import ( ...@@ -25,8 +25,7 @@ import (
// TestUint64Vector validates put and get operations // TestUint64Vector validates put and get operations
// of the Uint64Vector. // of the Uint64Vector.
func TestUint64Vector(t *testing.T) { func TestUint64Vector(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
bins, err := db.NewUint64Vector("bins") bins, err := db.NewUint64Vector("bins")
if err != nil { if err != nil {
...@@ -116,8 +115,7 @@ func TestUint64Vector(t *testing.T) { ...@@ -116,8 +115,7 @@ func TestUint64Vector(t *testing.T) {
// TestUint64Vector_Inc validates Inc operation // TestUint64Vector_Inc validates Inc operation
// of the Uint64Vector. // of the Uint64Vector.
func TestUint64Vector_Inc(t *testing.T) { func TestUint64Vector_Inc(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
bins, err := db.NewUint64Vector("bins") bins, err := db.NewUint64Vector("bins")
if err != nil { if err != nil {
...@@ -148,8 +146,7 @@ func TestUint64Vector_Inc(t *testing.T) { ...@@ -148,8 +146,7 @@ func TestUint64Vector_Inc(t *testing.T) {
// TestUint64Vector_IncInBatch validates IncInBatch operation // TestUint64Vector_IncInBatch validates IncInBatch operation
// of the Uint64Vector. // of the Uint64Vector.
func TestUint64Vector_IncInBatch(t *testing.T) { func TestUint64Vector_IncInBatch(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
bins, err := db.NewUint64Vector("bins") bins, err := db.NewUint64Vector("bins")
if err != nil { if err != nil {
...@@ -204,8 +201,7 @@ func TestUint64Vector_IncInBatch(t *testing.T) { ...@@ -204,8 +201,7 @@ func TestUint64Vector_IncInBatch(t *testing.T) {
// TestUint64Vector_Dec validates Dec operation // TestUint64Vector_Dec validates Dec operation
// of the Uint64Vector. // of the Uint64Vector.
func TestUint64Vector_Dec(t *testing.T) { func TestUint64Vector_Dec(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
bins, err := db.NewUint64Vector("bins") bins, err := db.NewUint64Vector("bins")
if err != nil { if err != nil {
...@@ -243,8 +239,7 @@ func TestUint64Vector_Dec(t *testing.T) { ...@@ -243,8 +239,7 @@ func TestUint64Vector_Dec(t *testing.T) {
// TestUint64Vector_DecInBatch validates DecInBatch operation // TestUint64Vector_DecInBatch validates DecInBatch operation
// of the Uint64Vector. // of the Uint64Vector.
func TestUint64Vector_DecInBatch(t *testing.T) { func TestUint64Vector_DecInBatch(t *testing.T) {
db, cleanupFunc := newTestDB(t) db := newTestDB(t)
defer cleanupFunc()
bins, err := db.NewUint64Vector("bins") bins, err := db.NewUint64Vector("bins")
if err != nil { if err != nil {
......
...@@ -15,18 +15,28 @@ import ( ...@@ -15,18 +15,28 @@ import (
) )
func TestPersistentStateStore(t *testing.T) { func TestPersistentStateStore(t *testing.T) {
test.Run(t, func(t *testing.T) (storage.StateStorer, func()) { test.Run(t, func(t *testing.T) storage.StateStorer {
dir, err := ioutil.TempDir("", "statestore_test") dir, err := ioutil.TempDir("", "statestore_test")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(func() {
if err := os.RemoveAll(dir); err != nil {
t.Fatal(err)
}
})
store, err := leveldb.NewStateStore(dir) store, err := leveldb.NewStateStore(dir)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Cleanup(func() {
if err := store.Close(); err != nil {
t.Fatal(err)
}
})
return store, func() { os.RemoveAll(dir) } return store
}) })
test.RunPersist(t, func(t *testing.T, dir string) storage.StateStorer { test.RunPersist(t, func(t *testing.T, dir string) storage.StateStorer {
......
...@@ -13,7 +13,7 @@ import ( ...@@ -13,7 +13,7 @@ import (
) )
func TestMockStateStore(t *testing.T) { func TestMockStateStore(t *testing.T) {
test.Run(t, func(t *testing.T) (storage.StateStorer, func()) { test.Run(t, func(t *testing.T) storage.StateStorer {
return mock.NewStateStore(), func() {} return mock.NewStateStore()
}) })
} }
...@@ -79,7 +79,7 @@ func RunPersist(t *testing.T, f func(t *testing.T, dir string) storage.StateStor ...@@ -79,7 +79,7 @@ func RunPersist(t *testing.T, f func(t *testing.T, dir string) storage.StateStor
testStoreIterator(t, persistedStore, "some_other_prefix", 1000) testStoreIterator(t, persistedStore, "some_other_prefix", 1000)
} }
func Run(t *testing.T, f func(t *testing.T) (storage.StateStorer, func())) { func Run(t *testing.T, f func(t *testing.T) storage.StateStorer) {
t.Helper() t.Helper()
t.Run("test_put_get", func(t *testing.T) { testPutGet(t, f) }) t.Run("test_put_get", func(t *testing.T) { testPutGet(t, f) })
...@@ -87,13 +87,11 @@ func Run(t *testing.T, f func(t *testing.T) (storage.StateStorer, func())) { ...@@ -87,13 +87,11 @@ func Run(t *testing.T, f func(t *testing.T) (storage.StateStorer, func())) {
t.Run("test_iterator", func(t *testing.T) { testIterator(t, f) }) t.Run("test_iterator", func(t *testing.T) { testIterator(t, f) })
} }
func testDelete(t *testing.T, f func(t *testing.T) (storage.StateStorer, func())) { func testDelete(t *testing.T, f func(t *testing.T) storage.StateStorer) {
t.Helper() t.Helper()
// create a store // create a store
store, cleanup := f(t) store := f(t)
defer store.Close()
defer cleanup()
// insert some values // insert some values
insertValues(t, store, key1, key2, value1, value2) insertValues(t, store, key1, key2, value1, value2)
...@@ -114,13 +112,11 @@ func testDelete(t *testing.T, f func(t *testing.T) (storage.StateStorer, func()) ...@@ -114,13 +112,11 @@ func testDelete(t *testing.T, f func(t *testing.T) (storage.StateStorer, func())
testEmpty(t, store) testEmpty(t, store)
} }
func testPutGet(t *testing.T, f func(t *testing.T) (storage.StateStorer, func())) { func testPutGet(t *testing.T, f func(t *testing.T) storage.StateStorer) {
t.Helper() t.Helper()
// create a store // create a store
store, cleanup := f(t) store := f(t)
defer store.Close()
defer cleanup()
// insert some values // insert some values
insertValues(t, store, key1, key2, value1, value2) insertValues(t, store, key1, key2, value1, value2)
...@@ -129,13 +125,11 @@ func testPutGet(t *testing.T, f func(t *testing.T) (storage.StateStorer, func()) ...@@ -129,13 +125,11 @@ func testPutGet(t *testing.T, f func(t *testing.T) (storage.StateStorer, func())
testPersistedValues(t, store, key1, key2, value1, value2) testPersistedValues(t, store, key1, key2, value1, value2)
} }
func testIterator(t *testing.T, f func(t *testing.T) (storage.StateStorer, func())) { func testIterator(t *testing.T, f func(t *testing.T) storage.StateStorer) {
t.Helper() t.Helper()
// create a store // create a store
store, cleanup := f(t) store := f(t)
defer store.Close()
defer cleanup()
// insert some values // insert some values
insert(t, store, "some_prefix", 1000) insert(t, store, "some_prefix", 1000)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment