Commit 6ed63739 authored by acud's avatar acud Committed by GitHub

api: add pss client facing apis (#686)

* add pss websocket api for opening a persistent connection for incoming messages on a given topic
* add pss message send api to send outgoing messages
parent cf783b24
......@@ -156,7 +156,7 @@ func newTestServer(t *testing.T, storer storage.Storer) *url.URL {
t.Helper()
logger := logging.New(ioutil.Discard, 0)
store := statestore.NewStateStore()
s := api.New(tags.NewTags(store, logger), storer, nil, logger, nil, api.Options{})
s := api.New(tags.NewTags(store, logger), storer, nil, nil, logger, nil, api.Options{})
ts := httptest.NewServer(s)
srvUrl, err := url.Parse(ts.URL)
if err != nil {
......
......@@ -15,6 +15,7 @@ require (
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 // indirect
github.com/gorilla/handlers v1.4.2
github.com/gorilla/mux v1.7.4
github.com/gorilla/websocket v1.4.2
github.com/ipfs/go-log/v2 v2.1.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/libp2p/go-libp2p v0.10.0
......
......@@ -7,13 +7,16 @@ package api
import (
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/ethersphere/bee/pkg/logging"
m "github.com/ethersphere/bee/pkg/metrics"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/resolver"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
......@@ -37,22 +40,28 @@ var (
type Service interface {
http.Handler
m.Collector
io.Closer
}
type server struct {
Tags *tags.Tags
Storer storage.Storer
Resolver resolver.Interface
Pss pss.Interface
Logger logging.Logger
Tracer *tracing.Tracer
Options
http.Handler
metrics metrics
wsWg sync.WaitGroup // wait for all websockets to close on exit
quit chan struct{}
}
type Options struct {
CORSAllowedOrigins []string
GatewayMode bool
WsPingPeriod time.Duration
}
const (
......@@ -61,15 +70,17 @@ const (
)
// New will create a and initialize a new API service.
func New(tags *tags.Tags, storer storage.Storer, resolver resolver.Interface, logger logging.Logger, tracer *tracing.Tracer, o Options) Service {
func New(tags *tags.Tags, storer storage.Storer, resolver resolver.Interface, pss pss.Interface, logger logging.Logger, tracer *tracing.Tracer, o Options) Service {
s := &server{
Tags: tags,
Storer: storer,
Resolver: resolver,
Pss: pss,
Options: o,
Logger: logger,
Tracer: tracer,
metrics: newMetrics(),
quit: make(chan struct{}),
}
s.setupRouting()
......@@ -77,6 +88,26 @@ func New(tags *tags.Tags, storer storage.Storer, resolver resolver.Interface, lo
return s
}
// Close hangs up running websockets on shutdown.
func (s *server) Close() error {
s.Logger.Info("api shutting down")
close(s.quit)
done := make(chan struct{})
go func() {
defer close(done)
s.wsWg.Wait()
}()
select {
case <-done:
case <-time.After(5 * time.Second):
return errors.New("api shutting down with open websockets")
}
return nil
}
// getOrCreateTag attempts to get the tag if an id is supplied, and returns an error if it does not exist.
// If no id is supplied, it will attempt to create a new tag with a generated name and return it.
func (s *server) getOrCreateTag(tagUid string) (*tags.Tag, bool, error) {
......
......@@ -11,48 +11,72 @@ import (
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/resolver"
resolverMock "github.com/ethersphere/bee/pkg/resolver/mock"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags"
"github.com/gorilla/websocket"
"resenje.org/web"
)
type testServerOptions struct {
Storer storage.Storer
Resolver resolver.Interface
Tags *tags.Tags
GatewayMode bool
Logger logging.Logger
Storer storage.Storer
Resolver resolver.Interface
Pss pss.Interface
WsPath string
Tags *tags.Tags
GatewayMode bool
WsPingPeriod time.Duration
Logger logging.Logger
}
func newTestServer(t *testing.T, o testServerOptions) *http.Client {
func newTestServer(t *testing.T, o testServerOptions) (*http.Client, *websocket.Conn, string) {
if o.Logger == nil {
o.Logger = logging.New(ioutil.Discard, 0)
}
if o.Resolver == nil {
o.Resolver = resolverMock.NewResolver()
}
s := api.New(o.Tags, o.Storer, o.Resolver, o.Logger, nil, api.Options{
GatewayMode: o.GatewayMode,
if o.WsPingPeriod == 0 {
o.WsPingPeriod = 60 * time.Second
}
s := api.New(o.Tags, o.Storer, o.Resolver, o.Pss, o.Logger, nil, api.Options{
GatewayMode: o.GatewayMode,
WsPingPeriod: o.WsPingPeriod,
})
ts := httptest.NewServer(s)
t.Cleanup(ts.Close)
return &http.Client{
Transport: web.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
u, err := url.Parse(ts.URL + r.URL.String())
if err != nil {
return nil, err
}
r.URL = u
return ts.Client().Transport.RoundTrip(r)
}),
var (
httpClient = &http.Client{
Transport: web.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
u, err := url.Parse(ts.URL + r.URL.String())
if err != nil {
return nil, err
}
r.URL = u
return ts.Client().Transport.RoundTrip(r)
}),
}
conn *websocket.Conn
err error
)
if o.WsPath != "" {
u := url.URL{Scheme: "ws", Host: ts.Listener.Addr().String(), Path: o.WsPath}
conn, _, err = websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
t.Fatalf("dial: %v. url %v", err, u.String())
}
}
return httpClient, conn, ts.Listener.Addr().String()
}
func TestParseName(t *testing.T) {
......@@ -116,7 +140,7 @@ func TestParseName(t *testing.T) {
}))
}
s := api.New(nil, nil, tC.res, tC.log, nil, api.Options{}).(*api.Server)
s := api.New(nil, nil, tC.res, nil, tC.log, nil, api.Options{}).(*api.Server)
t.Run(tC.desc, func(t *testing.T) {
got, err := s.ResolveNameOrAddress(tC.name)
......
......@@ -6,11 +6,12 @@ package api_test
import (
"bytes"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil"
"net/http"
"testing"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
......@@ -31,7 +32,7 @@ func TestBytes(t *testing.T) {
mockStorer = mock.NewStorer()
mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{
client, _, _ = newTestServer(t, testServerOptions{
Storer: mockStorer,
Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5),
......
......@@ -9,7 +9,6 @@ import (
"context"
"encoding/json"
"fmt"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io"
"io/ioutil"
"mime"
......@@ -17,6 +16,8 @@ import (
"strings"
"testing"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/collection/entry"
"github.com/ethersphere/bee/pkg/file/pipeline"
"github.com/ethersphere/bee/pkg/jsonhttp"
......@@ -36,7 +37,7 @@ func TestBzz(t *testing.T) {
ctx = context.Background()
mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{
client, _, _ = newTestServer(t, testServerOptions{
Storer: storer,
Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5),
......
......@@ -6,13 +6,14 @@ package api_test
import (
"bytes"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io"
"io/ioutil"
"net/http"
"testing"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/api"
......@@ -41,7 +42,7 @@ func TestChunkUploadDownload(t *testing.T) {
logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger)
mockValidatingStorer = mock.NewStorer(mock.WithValidator(mockValidator))
client = newTestServer(t, testServerOptions{
client, _, _ = newTestServer(t, testServerOptions{
Storer: mockValidatingStorer,
Tags: tag,
})
......
......@@ -35,7 +35,7 @@ func TestDirs(t *testing.T) {
storer = mock.NewStorer()
mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{
client, _, _ = newTestServer(t, testServerOptions{
Storer: storer,
Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5),
......
......@@ -8,7 +8,6 @@ import (
"bytes"
"encoding/json"
"fmt"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io"
"io/ioutil"
"mime"
......@@ -18,6 +17,8 @@ import (
"strings"
"testing"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
......@@ -35,7 +36,7 @@ func TestFiles(t *testing.T) {
simpleData = []byte("this is a simple text")
mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{
client, _, _ = newTestServer(t, testServerOptions{
Storer: mock.NewStorer(),
Tags: tags.NewTags(mockStatestore, logger),
})
......@@ -338,7 +339,7 @@ func TestRangeRequests(t *testing.T) {
t.Run(upload.name, func(t *testing.T) {
mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
client := newTestServer(t, testServerOptions{
client, _, _ := newTestServer(t, testServerOptions{
Storer: mock.NewStorer(),
Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5),
......
......@@ -20,7 +20,7 @@ import (
func TestGatewayMode(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
client := newTestServer(t, testServerOptions{
client, _, _ := newTestServer(t, testServerOptions{
Storer: mock.NewStorer(),
Tags: tags.NewTags(statestore.NewStateStore(), logger),
Logger: logger,
......
......@@ -37,7 +37,7 @@ func TestPinChunkHandler(t *testing.T) {
logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger)
client = newTestServer(t, testServerOptions{
client, _, _ = newTestServer(t, testServerOptions{
Storer: mockValidatingStorer,
Tags: tag,
Logger: logger,
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package api
import (
"context"
"encoding/hex"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
)
var (
upgrader = websocket.Upgrader{
ReadBufferSize: swarm.ChunkSize,
WriteBufferSize: swarm.ChunkSize,
}
writeDeadline = 4 * time.Second // write deadline. should be smaller than the shutdown timeout on api close
targetMaxLength = 2 // max target length in bytes, in order to prevent grieving by excess computation
)
type PssMessage struct {
Topic string
Message string
}
func (s *server) pssPostHandler(w http.ResponseWriter, r *http.Request) {
t := mux.Vars(r)["topic"]
topic := trojan.NewTopic(t)
tg := mux.Vars(r)["targets"]
var targets trojan.Targets
tgts := strings.Split(tg, ",")
for _, v := range tgts {
target, err := hex.DecodeString(v)
if err != nil || len(target) > targetMaxLength {
s.Logger.Debugf("pss send: bad targets: %v", err)
s.Logger.Error("pss send: bad targets")
jsonhttp.BadRequest(w, nil)
return
}
targets = append(targets, target)
}
payload, err := ioutil.ReadAll(r.Body)
if err != nil {
s.Logger.Debugf("pss read payload: %v", err)
s.Logger.Error("pss read payload")
jsonhttp.InternalServerError(w, nil)
return
}
err = s.Pss.Send(r.Context(), targets, topic, payload)
if err != nil {
s.Logger.Debugf("pss send payload: %v. topic: %s", err, t)
s.Logger.Error("pss send payload")
jsonhttp.InternalServerError(w, nil)
return
}
jsonhttp.OK(w, nil)
}
func (s *server) pssWsHandler(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
s.Logger.Debugf("pss ws: upgrade: %v", err)
s.Logger.Error("pss ws: cannot upgrade")
jsonhttp.InternalServerError(w, nil)
return
}
t := mux.Vars(r)["topic"]
s.wsWg.Add(1)
go s.pumpWs(conn, t)
}
func (s *server) pumpWs(conn *websocket.Conn, t string) {
defer s.wsWg.Done()
var (
dataC = make(chan []byte)
gone = make(chan struct{})
topic = trojan.NewTopic(t)
ticker = time.NewTicker(s.WsPingPeriod)
err error
)
defer func() {
ticker.Stop()
_ = conn.Close()
}()
cleanup := s.Pss.Register(topic, func(_ context.Context, m *trojan.Message) {
dataC <- m.Payload
})
defer cleanup()
conn.SetCloseHandler(func(code int, text string) error {
s.Logger.Debugf("pss handler: client gone. code %d message %s", code, text)
close(gone)
return nil
})
for {
select {
case b := <-dataC:
err = conn.SetWriteDeadline(time.Now().Add(writeDeadline))
if err != nil {
s.Logger.Debugf("pss set write deadline: %v", err)
return
}
err = conn.WriteMessage(websocket.BinaryMessage, b)
if err != nil {
s.Logger.Debugf("pss write to websocket: %v", err)
return
}
case <-s.quit:
// shutdown
err = conn.SetWriteDeadline(time.Now().Add(writeDeadline))
if err != nil {
s.Logger.Debugf("pss set write deadline: %v", err)
return
}
err = conn.WriteMessage(websocket.CloseMessage, []byte{})
if err != nil {
s.Logger.Debugf("pss write close message: %v", err)
}
return
case <-gone:
// client gone
return
case <-ticker.C:
err = conn.SetWriteDeadline(time.Now().Add(writeDeadline))
if err != nil {
s.Logger.Debugf("pss set write deadline: %v", err)
return
}
if err = conn.WriteMessage(websocket.PingMessage, nil); err != nil {
// error encountered while pinging client. client probably gone
return
}
}
}
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package api_test
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"sync"
"testing"
"time"
"github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
"github.com/gorilla/websocket"
)
var (
target = trojan.Target([]byte{1})
targets = trojan.Targets([]trojan.Target{target})
payload = []byte("testdata")
topic = trojan.NewTopic("testtopic")
timeout = 10 * time.Second
)
// creates a single websocket handler for an arbitrary topic, and receives a message
func TestPssWebsocketSingleHandler(t *testing.T) {
var (
pss, cl, _ = newPssTest(t, opts{})
msgContent = make([]byte, len(payload))
tc swarm.Chunk
mtx sync.Mutex
done = make(chan struct{})
)
err := cl.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
t.Fatal(err)
}
cl.SetReadLimit(swarm.ChunkSize)
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
m, err := trojan.NewMessage(topic, payload)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, payload, &mtx)
}
func TestPssWebsocketSingleHandlerDeregister(t *testing.T) {
// create a new pss instance, register a handle through ws, call
// pss.TryUnwrap with a chunk designated for this handler and expect
// the handler to be notified
var (
pss, cl, _ = newPssTest(t, opts{})
msgContent = make([]byte, len(payload))
tc swarm.Chunk
mtx sync.Mutex
done = make(chan struct{})
)
err := cl.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
t.Fatal(err)
}
cl.SetReadLimit(swarm.ChunkSize)
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
m, err := trojan.NewMessage(topic, payload)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
// close the websocket before calling pss with the message
err = cl.WriteMessage(websocket.CloseMessage, []byte{})
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, nil, &mtx)
}
func TestPssWebsocketMultiHandler(t *testing.T) {
var (
pss, cl, listener = newPssTest(t, opts{})
u = url.URL{Scheme: "ws", Host: listener, Path: "/pss/subscribe/testtopic"}
cl2, _, err = websocket.DefaultDialer.Dial(u.String(), nil)
msgContent = make([]byte, len(payload))
msgContent2 = make([]byte, len(payload))
tc swarm.Chunk
mtx sync.Mutex
done = make(chan struct{})
)
if err != nil {
t.Fatalf("dial: %v. url %v", err, u.String())
}
err = cl.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
t.Fatal(err)
}
cl.SetReadLimit(swarm.ChunkSize)
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
go waitReadMessage(t, &mtx, cl2, msgContent2, done)
m, err := trojan.NewMessage(topic, payload)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
// close the websocket before calling pss with the message
err = cl.WriteMessage(websocket.CloseMessage, []byte{})
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, nil, &mtx)
waitMessage(t, msgContent2, nil, &mtx)
}
// TestPssSend tests that the pss message sending over http works correctly.
func TestPssSend(t *testing.T) {
var (
logger = logging.New(ioutil.Discard, 0)
mtx sync.Mutex
recievedTargets trojan.Targets
recievedTopic trojan.Topic
recievedBytes []byte
done bool
sendFn = func(_ context.Context, targets trojan.Targets, topic trojan.Topic, bytes []byte) error {
mtx.Lock()
recievedTargets = targets
recievedTopic = topic
recievedBytes = bytes
done = true
mtx.Unlock()
return nil
}
pss = newMockPss(sendFn)
client, _, _ = newTestServer(t, testServerOptions{
Pss: pss,
Storer: mock.NewStorer(),
Logger: logger,
})
targets = fmt.Sprintf("[[%d]]", 0x12)
topic = "testtopic"
hasher = swarm.NewHasher()
_, err = hasher.Write([]byte(topic))
topicHash = hasher.Sum(nil)
)
if err != nil {
t.Fatal(err)
}
t.Run("err - bad targets", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/to/badtarget", http.StatusBadRequest,
jsonhttptest.WithRequestBody(bytes.NewReader(payload)),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: "Bad Request",
Code: http.StatusBadRequest,
}),
)
})
t.Run("ok", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/testtopic/12", http.StatusOK,
jsonhttptest.WithRequestBody(bytes.NewReader(payload)),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: "OK",
Code: http.StatusOK,
}),
)
waitDone(t, &mtx, &done)
if !bytes.Equal(recievedBytes, payload) {
t.Fatalf("payload mismatch. want %v got %v", payload, recievedBytes)
}
if targets != fmt.Sprint(recievedTargets) {
t.Fatalf("targets mismatch. want %v got %v", targets, recievedTargets)
}
if string(topicHash) != string(recievedTopic[:]) {
t.Fatalf("topic mismatch. want %v got %v", topic, string(recievedTopic[:]))
}
})
}
// TestPssPingPong tests that the websocket api adheres to the websocket standard
// and sends ping-pong messages to keep the connection alive.
// The test opens a websocket, keeps it alive for 500ms, then receives a pss message.
func TestPssPingPong(t *testing.T) {
var (
pss, cl, _ = newPssTest(t, opts{pingPeriod: 90 * time.Millisecond})
msgContent = make([]byte, len(payload))
tc swarm.Chunk
mtx sync.Mutex
pongWait = 1 * time.Millisecond
done = make(chan struct{})
)
cl.SetReadLimit(swarm.ChunkSize)
err := cl.SetReadDeadline(time.Now().Add(pongWait))
if err != nil {
t.Fatal(err)
}
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
m, err := trojan.NewMessage(topic, payload)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
time.Sleep(500 * time.Millisecond) // wait to see that the websocket is kept alive
err = pss.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, nil, &mtx)
}
func waitReadMessage(t *testing.T, mtx *sync.Mutex, cl *websocket.Conn, targetContent []byte, done <-chan struct{}) {
t.Helper()
timeout := time.After(timeout)
for {
select {
case <-done:
return
case <-timeout:
t.Errorf("timed out waiting for message")
return
default:
}
msgType, message, err := cl.ReadMessage()
if err != nil {
return
}
if msgType == websocket.PongMessage {
// ignore pings
continue
}
if message != nil {
mtx.Lock()
copy(targetContent, message)
mtx.Unlock()
}
time.Sleep(50 * time.Millisecond)
}
}
func waitDone(t *testing.T, mtx *sync.Mutex, done *bool) {
for i := 0; i < 10; i++ {
mtx.Lock()
if *done {
mtx.Unlock()
return
}
mtx.Unlock()
time.Sleep(50 * time.Millisecond)
}
t.Fatal("timed out waiting for send")
}
func waitMessage(t *testing.T, data, expData []byte, mtx *sync.Mutex) {
ttl := time.After(timeout)
for {
select {
case <-ttl:
if expData == nil {
return
}
t.Fatal("timed out waiting for pss message")
default:
}
mtx.Lock()
if bytes.Equal(data, expData) {
mtx.Unlock()
return
}
mtx.Unlock()
time.Sleep(100 * time.Millisecond)
}
}
type opts struct {
pingPeriod time.Duration
}
func newPssTest(t *testing.T, o opts) (pss.Interface, *websocket.Conn, string) {
var (
logger = logging.New(ioutil.Discard, 0)
pss = pss.New(logger)
)
if o.pingPeriod == 0 {
o.pingPeriod = 10 * time.Second
}
_, cl, listener := newTestServer(t, testServerOptions{
Pss: pss,
WsPath: "/pss/subscribe/testtopic",
Storer: mock.NewStorer(),
Logger: logger,
WsPingPeriod: o.pingPeriod,
})
return pss, cl, listener
}
type pssSendFn func(context.Context, trojan.Targets, trojan.Topic, []byte) error
type mpss struct {
f pssSendFn
}
func newMockPss(f pssSendFn) *mpss {
return &mpss{f}
}
// Send arbitrary byte slice with the given topic to Targets.
func (m *mpss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, bytes []byte) error {
return m.f(ctx, targets, topic, bytes)
}
// Register a Handler for a given Topic.
func (m *mpss) Register(_ trojan.Topic, _ pss.Handler) func() {
panic("not implemented") // TODO: Implement
}
// TryUnwrap tries to unwrap a wrapped trojan message.
func (m *mpss) TryUnwrap(_ context.Context, _ swarm.Chunk) error {
panic("not implemented") // TODO: Implement
}
func (m *mpss) SetPushSyncer(pushSyncer pushsync.PushSyncer) {
panic("not implemented") // TODO: Implement
}
func (m *mpss) Close() error {
panic("not implemented") // TODO: Implement
}
......@@ -90,6 +90,15 @@ func (s *server) setupRouting() {
),
})
handle(router, "/pss/send/{topic}/{targets}", jsonhttp.MethodHandler{
"POST": web.ChainHandlers(
jsonhttp.NewMaxBodyBytesHandler(swarm.ChunkSize),
web.FinalHandlerFunc(s.pssPostHandler),
),
})
handle(router, "/pss/subscribe/{topic}", http.HandlerFunc(s.pssWsHandler))
handle(router, "/tags", web.ChainHandlers(
s.gatewayModeForbidEndpointHandler,
web.FinalHandler(jsonhttp.MethodHandler{
......
......@@ -8,14 +8,15 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil"
"net/http"
"strconv"
"strings"
"testing"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
......@@ -46,7 +47,7 @@ func TestTags(t *testing.T) {
logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger)
mockPusher = mp.NewMockPusher(tag)
client = newTestServer(t, testServerOptions{
client, _, _ = newTestServer(t, testServerOptions{
Storer: mock.NewStorer(),
Tags: tag,
})
......
......@@ -5,6 +5,7 @@
package logging
import (
"bufio"
"net"
"net/http"
"time"
......@@ -89,6 +90,10 @@ func (l *responseLogger) Flush() {
l.w.(http.Flusher).Flush()
}
func (l *responseLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return l.w.(http.Hijacker).Hijack()
}
func (l *responseLogger) CloseNotify() <-chan bool {
// staticcheck SA1019 CloseNotifier interface is required by gorilla compress handler
// nolint:staticcheck
......
......@@ -52,20 +52,23 @@ import (
)
type Bee struct {
p2pService io.Closer
p2pCancel context.CancelFunc
apiServer *http.Server
debugAPIServer *http.Server
resolverCloser io.Closer
errorLogWriter *io.PipeWriter
tracerCloser io.Closer
tagsCloser io.Closer
stateStoreCloser io.Closer
localstoreCloser io.Closer
topologyCloser io.Closer
pusherCloser io.Closer
pullerCloser io.Closer
pullSyncCloser io.Closer
p2pService io.Closer
p2pCancel context.CancelFunc
apiCloser io.Closer
apiServer *http.Server
debugAPIServer *http.Server
resolverCloser io.Closer
errorLogWriter *io.PipeWriter
tracerCloser io.Closer
tagsCloser io.Closer
stateStoreCloser io.Closer
localstoreCloser io.Closer
topologyCloser io.Closer
pusherCloser io.Closer
pullerCloser io.Closer
pullSyncCloser io.Closer
pssCloser io.Closer
recoveryHandleCleanup func()
}
type Options struct {
......@@ -247,7 +250,8 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
}
// instantiate the pss object
psss := pss.New(logger, nil)
psss := pss.New(logger)
b.pssCloser = psss
var ns storage.Storer
if o.GlobalPinningEnabled {
......@@ -262,7 +266,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagg, psss.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10))
// set the pushSyncer in the PSS
psss.WithPushSyncer(pushSyncProtocol)
psss.SetPushSyncer(pushSyncProtocol)
if err = p2ps.AddProtocol(pushSyncProtocol.Protocol()); err != nil {
return nil, fmt.Errorf("pushsync service: %w", err)
......@@ -271,7 +275,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
if o.GlobalPinningEnabled {
// register function for chunk repair upon receiving a trojan message
chunkRepairHandler := recovery.NewRepairHandler(ns, logger, pushSyncProtocol)
psss.Register(recovery.RecoveryTopic, chunkRepairHandler)
b.recoveryHandleCleanup = psss.Register(recovery.RecoveryTopic, chunkRepairHandler)
}
pushSyncPusher := pusher.New(storer, kad, pushSyncProtocol, tagg, logger)
......@@ -299,9 +303,10 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
var apiService api.Service
if o.APIAddr != "" {
// API server
apiService = api.New(tagg, ns, multiResolver, logger, tracer, api.Options{
apiService = api.New(tagg, ns, multiResolver, psss, logger, tracer, api.Options{
CORSAllowedOrigins: o.CORSAllowedOrigins,
GatewayMode: o.GatewayMode,
WsPingPeriod: 60 * time.Second,
})
apiListener, err := net.Listen("tcp", o.APIAddr)
if err != nil {
......@@ -323,6 +328,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
}()
b.apiServer = apiServer
b.apiCloser = apiService
}
if o.DebugAPIAddr != "" {
......@@ -373,6 +379,12 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
func (b *Bee) Shutdown(ctx context.Context) error {
errs := new(multiError)
if b.apiCloser != nil {
if err := b.apiCloser.Close(); err != nil {
errs.add(fmt.Errorf("api: %w", err))
}
}
var eg errgroup.Group
if b.apiServer != nil {
eg.Go(func() error {
......@@ -395,6 +407,10 @@ func (b *Bee) Shutdown(ctx context.Context) error {
errs.add(err)
}
if b.recoveryHandleCleanup != nil {
b.recoveryHandleCleanup()
}
if err := b.pusherCloser.Close(); err != nil {
errs.add(fmt.Errorf("pusher: %w", err))
}
......@@ -407,6 +423,10 @@ func (b *Bee) Shutdown(ctx context.Context) error {
errs.add(fmt.Errorf("pull sync: %w", err))
}
if err := b.pssCloser.Close(); err != nil {
errs.add(fmt.Errorf("pss: %w", err))
}
b.p2pCancel()
if err := b.p2pService.Close(); err != nil {
errs.add(fmt.Errorf("p2p server: %w", err))
......
......@@ -8,6 +8,7 @@ import (
"context"
"errors"
"fmt"
"io"
"sync"
"github.com/ethersphere/bee/pkg/logging"
......@@ -22,54 +23,67 @@ var (
)
type Interface interface {
Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, payload []byte) error
Register(topic trojan.Topic, hndlr Handler)
GetHandler(topic trojan.Topic) Handler
TryUnwrap(ctx context.Context, c swarm.Chunk) error
WithPushSyncer(pushSyncer pushsync.PushSyncer)
// Send arbitrary byte slice with the given topic to Targets.
Send(context.Context, trojan.Targets, trojan.Topic, []byte) error
// Register a Handler for a given Topic.
Register(trojan.Topic, Handler) func()
// TryUnwrap tries to unwrap a wrapped trojan message.
TryUnwrap(context.Context, swarm.Chunk) error
SetPushSyncer(pushSyncer pushsync.PushSyncer)
io.Closer
}
// pss is the top-level struct, which takes care of message sending
type pss struct {
pusher pushsync.PushSyncer
handlers map[trojan.Topic]Handler
handlersMu sync.RWMutex
handlers map[trojan.Topic][]*Handler
handlersMu sync.Mutex
metrics metrics
logger logging.Logger
quit chan struct{}
}
// New inits the pss struct with the storer
func New(logger logging.Logger, pusher pushsync.PushSyncer) Interface {
// New returns a new pss service.
func New(logger logging.Logger) Interface {
return &pss{
logger: logger,
pusher: pusher,
handlers: make(map[trojan.Topic]Handler),
handlers: make(map[trojan.Topic][]*Handler),
metrics: newMetrics(),
quit: make(chan struct{}),
}
}
func (ps *pss) WithPushSyncer(pushSyncer pushsync.PushSyncer) {
func (ps *pss) Close() error {
close(ps.quit)
ps.handlersMu.Lock()
defer ps.handlersMu.Unlock()
ps.handlers = make(map[trojan.Topic][]*Handler) //unset handlers on shutdown
return nil
}
func (ps *pss) SetPushSyncer(pushSyncer pushsync.PushSyncer) {
ps.pusher = pushSyncer
}
// Handler defines code to be executed upon reception of a trojan message
type Handler func(context.Context, *trojan.Message) error
// Handler defines code to be executed upon reception of a trojan message.
type Handler func(context.Context, *trojan.Message)
// Send constructs a padded message with topic and payload,
// wraps it in a trojan chunk such that one of the targets is a prefix of the chunk address
// uses push-sync to deliver message
// wraps it in a trojan chunk such that one of the targets is a prefix of the chunk address.
// Uses push-sync to deliver message.
func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, payload []byte) error {
p.metrics.TotalMessagesSentCounter.Inc()
//construct Trojan Chunk
m, err := trojan.NewMessage(topic, payload)
if err != nil {
return err
}
var tc swarm.Chunk
tc, err = m.Wrap(ctx, targets)
if err != nil {
return err
}
......@@ -81,33 +95,69 @@ func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Top
return nil
}
// Register allows the definition of a Handler func for a specific topic on the pss struct
func (p *pss) Register(topic trojan.Topic, hndlr Handler) {
// Register allows the definition of a Handler func for a specific topic on the pss struct.
func (p *pss) Register(topic trojan.Topic, handler Handler) (cleanup func()) {
p.handlersMu.Lock()
defer p.handlersMu.Unlock()
p.handlers[topic] = hndlr
p.handlers[topic] = append(p.handlers[topic], &handler)
return func() {
p.handlersMu.Lock()
defer p.handlersMu.Unlock()
h := p.handlers[topic]
for i := 0; i < len(h); i++ {
if h[i] == &handler {
p.handlers[topic] = append(h[:i], h[i+1:]...)
return
}
}
}
}
// TryUnwrap allows unwrapping a chunk as a trojan message and calling its handler func based on its topic
// TryUnwrap allows unwrapping a chunk as a trojan message and calling its handlers based on the topic.
func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error {
if !trojan.IsPotential(c) {
return nil
}
m, err := trojan.Unwrap(c) // if err occurs unwrapping, there will be no handler
m, err := trojan.Unwrap(c)
if err != nil {
return err
}
h := p.GetHandler(m.Topic)
h := p.getHandlers(m.Topic)
if h == nil {
return fmt.Errorf("topic %v, %w", m.Topic, ErrNoHandler)
}
return h(ctx, m)
ctx, cancel := context.WithCancel(ctx)
done := make(chan struct{})
var wg sync.WaitGroup
go func() {
defer cancel()
select {
case <-p.quit:
case <-done:
}
}()
for _, hh := range h {
wg.Add(1)
go func(hh Handler) {
defer wg.Done()
hh(ctx, m)
}(*hh)
}
go func() {
wg.Wait()
close(done)
}()
return nil
}
// GetHandler returns the Handler func registered in pss for the given topic
func (p *pss) GetHandler(topic trojan.Topic) Handler {
p.handlersMu.RLock()
defer p.handlersMu.RUnlock()
func (p *pss) getHandlers(topic trojan.Topic) []*Handler {
p.handlersMu.Lock()
defer p.handlersMu.Unlock()
return p.handlers[topic]
}
......@@ -8,7 +8,10 @@ import (
"bytes"
"context"
"io/ioutil"
"runtime"
"sync"
"testing"
"time"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss"
......@@ -21,7 +24,7 @@ import (
// TestSend creates a trojan chunk and sends it using push sync
func TestSend(t *testing.T) {
var err error
ctx := context.TODO()
ctx := context.Background()
// create a mock pushsync service to push the chunk to its destination
var receipt *pushsync.Receipt
......@@ -35,7 +38,8 @@ func TestSend(t *testing.T) {
return rcpt, nil
})
pss := pss.New(logging.New(ioutil.Discard, 0), pushSyncService)
pss := pss.New(logging.New(ioutil.Discard, 0))
pss.SetPushSyncer(pushSyncService)
target := trojan.Target([]byte{1}) // arbitrary test target
targets := trojan.Targets([]trojan.Target{target})
......@@ -64,54 +68,12 @@ func TestSend(t *testing.T) {
}
}
// TestRegister verifies that handler funcs are able to be registered correctly in pss
func TestRegister(t *testing.T) {
pss := pss.New(logging.New(ioutil.Discard, 0), nil)
handlerVerifier := 0 // test variable to check handler funcs are correctly retrieved
// register first handler
testHandler := func(ctx context.Context, m *trojan.Message) error {
handlerVerifier = 1
return nil
}
testTopic := trojan.NewTopic("FIRST_HANDLER")
pss.Register(testTopic, testHandler)
registeredHandler := pss.GetHandler(testTopic)
err := registeredHandler(context.Background(), &trojan.Message{}) // call handler to verify the retrieved func is correct
if err != nil {
t.Fatal(err)
}
if handlerVerifier != 1 {
t.Fatalf("unexpected handler retrieved, verifier variable should be 1 but is %d instead", handlerVerifier)
}
// register second handler
testHandler = func(ctx context.Context, m *trojan.Message) error {
handlerVerifier = 2
return nil
}
testTopic = trojan.NewTopic("SECOND_HANDLER")
pss.Register(testTopic, testHandler)
registeredHandler = pss.GetHandler(testTopic)
err = registeredHandler(context.Background(), &trojan.Message{}) // call handler to verify the retrieved func is correct
if err != nil {
t.Fatal(err)
}
if handlerVerifier != 2 {
t.Fatalf("unexpected handler retrieved, verifier variable should be 2 but is %d instead", handlerVerifier)
}
}
// TestDeliver verifies that registering a handler on pss for a given topic and then submitting a trojan chunk with said topic to it
// results in the execution of the expected handler func
func TestDeliver(t *testing.T) {
pss := pss.New(logging.New(ioutil.Discard, 0), nil)
ctx := context.TODO()
pss := pss.New(logging.New(ioutil.Discard, 0))
ctx := context.Background()
var mtx sync.Mutex
// test message
topic := trojan.NewTopic("footopic")
......@@ -130,9 +92,10 @@ func TestDeliver(t *testing.T) {
// create and register handler
var tt trojan.Topic // test variable to check handler func was correctly called
hndlr := func(ctx context.Context, m *trojan.Message) error {
tt = m.Topic // copy the message topic to the test variable
return nil
hndlr := func(ctx context.Context, m *trojan.Message) {
mtx.Lock()
copy(tt[:], m.Topic[:]) // copy the message topic to the test variable
mtx.Unlock()
}
pss.Register(topic, hndlr)
......@@ -141,28 +104,125 @@ func TestDeliver(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if tt != msg.Topic {
t.Fatalf("unexpected result for pss Deliver func, expected test variable to have a value of %v but is %v instead", msg.Topic, tt)
runtime.Gosched() // schedule the handler goroutine
for i := 0; i < 10; i++ {
mtx.Lock()
eq := bytes.Equal(tt[:], msg.Topic[:])
mtx.Unlock()
if eq {
return
}
<-time.After(50 * time.Millisecond)
}
t.Fatalf("unexpected result for pss Deliver func, expected test variable to have a value of %v but is %v instead", msg.Topic, tt)
}
func TestHandler(t *testing.T) {
pss := pss.New(logging.New(ioutil.Discard, 0), nil)
testTopic := trojan.NewTopic("TEST_TOPIC")
// TestRegister verifies that handler funcs are able to be registered correctly in pss
func TestRegister(t *testing.T) {
var (
pss = pss.New(logging.New(ioutil.Discard, 0))
h1Calls = 0
h2Calls = 0
h3Calls = 0
mtx sync.Mutex
topic1 = trojan.NewTopic("one")
topic2 = trojan.NewTopic("two")
payload = []byte("payload")
target = trojan.Target([]byte{1})
targets = trojan.Targets([]trojan.Target{target})
h1 = func(_ context.Context, m *trojan.Message) {
mtx.Lock()
defer mtx.Unlock()
h1Calls++
}
h2 = func(_ context.Context, m *trojan.Message) {
mtx.Lock()
defer mtx.Unlock()
h2Calls++
}
h3 = func(_ context.Context, m *trojan.Message) {
mtx.Lock()
defer mtx.Unlock()
h3Calls++
}
)
_ = pss.Register(topic1, h1)
_ = pss.Register(topic2, h2)
// send a message on topic1, check that only h1 is called
msg, err := trojan.NewMessage(topic1, payload)
if err != nil {
t.Fatal(err)
}
c, err := msg.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), c)
if err != nil {
t.Fatal(err)
}
ensureCalls(t, &mtx, &h1Calls, 1)
ensureCalls(t, &mtx, &h2Calls, 0)
// verify handler is null
if pss.GetHandler(testTopic) != nil {
t.Errorf("handler should be null")
// register another topic handler on the same topic
cleanup := pss.Register(topic1, h3)
err = pss.TryUnwrap(context.Background(), c)
if err != nil {
t.Fatal(err)
}
// register first handler
testHandler := func(ctx context.Context, m *trojan.Message) error { return nil }
ensureCalls(t, &mtx, &h1Calls, 2)
ensureCalls(t, &mtx, &h2Calls, 0)
ensureCalls(t, &mtx, &h3Calls, 1)
// set handler for test topic
pss.Register(testTopic, testHandler)
cleanup() // remove the last handler
if pss.GetHandler(testTopic) == nil {
t.Errorf("handler should be registered")
err = pss.TryUnwrap(context.Background(), c)
if err != nil {
t.Fatal(err)
}
ensureCalls(t, &mtx, &h1Calls, 3)
ensureCalls(t, &mtx, &h2Calls, 0)
ensureCalls(t, &mtx, &h3Calls, 1)
msg, err = trojan.NewMessage(topic2, payload)
if err != nil {
t.Fatal(err)
}
c, err = msg.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), c)
if err != nil {
t.Fatal(err)
}
ensureCalls(t, &mtx, &h1Calls, 3)
ensureCalls(t, &mtx, &h2Calls, 1)
ensureCalls(t, &mtx, &h3Calls, 1)
}
func ensureCalls(t *testing.T, mtx *sync.Mutex, calls *int, exp int) {
t.Helper()
for i := 0; i < 10; i++ {
mtx.Lock()
if *calls == exp {
mtx.Unlock()
return
}
mtx.Unlock()
<-time.After(100 * time.Millisecond)
}
t.Fatal("timed out waiting for value")
}
......@@ -11,14 +11,18 @@ import (
"github.com/ethersphere/bee/pkg/swarm"
)
type PushSync struct {
type mock struct {
sendChunk func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error)
}
func New(sendChunk func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error)) *PushSync {
return &PushSync{sendChunk: sendChunk}
func New(sendChunk func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error)) pushsync.PushSyncer {
return &mock{sendChunk: sendChunk}
}
func (s *PushSync) PushChunkToClosest(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error) {
func (s *mock) PushChunkToClosest(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error) {
return s.sendChunk(ctx, chunk)
}
func (s *mock) Close() error {
return nil
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package recovery
var (
ErrChunkNotPresent = errChunkNotPresent
)
......@@ -6,7 +6,6 @@ package recovery
import (
"context"
"errors"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss"
......@@ -26,10 +25,6 @@ var (
RecoveryTopic = trojan.NewTopic(RecoveryTopicText)
)
var (
errChunkNotPresent = errors.New("chunk repair: chunk not present in local store for repairing")
)
// RecoveryHook defines code to be executed upon failing to retrieve chunks.
type RecoveryHook func(chunkAddress swarm.Address, targets trojan.Targets) error
......@@ -50,32 +45,31 @@ func NewRecoveryHook(pss PssSender) RecoveryHook {
// NewRepairHandler creates a repair function to re-upload globally pinned chunks to the network with the given store.
func NewRepairHandler(s storage.Storer, logger logging.Logger, pushSyncer pushsync.PushSyncer) pss.Handler {
return func(ctx context.Context, m *trojan.Message) error {
return func(ctx context.Context, m *trojan.Message) {
chAddr := m.Payload
// check if the chunk exists in the local store and proceed.
// otherwise the Get will trigger a unnecessary network retrieve
exists, err := s.Has(ctx, swarm.NewAddress(chAddr))
if err != nil {
return err
return
}
if !exists {
return errChunkNotPresent
return
}
// retrieve the chunk from the local store
ch, err := s.Get(ctx, storage.ModeGetRequest, swarm.NewAddress(chAddr))
if err != nil {
logger.Tracef("chunk repair: error while getting chunk for repairing: %v", err)
return err
return
}
// push the chunk using push sync so that it reaches it destination in network
_, err = pushSyncer.PushChunkToClosest(ctx, ch)
if err != nil {
logger.Tracef("chunk repair: error while sending chunk or receiving receipt: %v", err)
return err
return
}
return nil
}
}
......@@ -155,10 +155,7 @@ func TestNewRepairHandler(t *testing.T) {
}
// invoke the chunk repair handler
err = repairHandler(context.Background(), &msg)
if err != nil {
t.Fatal(err)
}
repairHandler(context.Background(), &msg)
// check if receipt is received
if receipt == nil {
......@@ -200,10 +197,7 @@ func TestNewRepairHandler(t *testing.T) {
}
// invoke the chunk repair handler
err = repairHandler(context.Background(), &msg)
if !errors.Is(err, recovery.ErrChunkNotPresent) {
t.Fatal(err)
}
repairHandler(context.Background(), &msg)
if pushServiceCalled {
t.Fatal("push service called even if the chunk is not present")
......@@ -243,10 +237,7 @@ func TestNewRepairHandler(t *testing.T) {
}
// invoke the chunk repair handler
err = repairHandler(context.Background(), &msg)
if err != nil && err != receiptError {
t.Fatal(err)
}
repairHandler(context.Background(), &msg)
if receiptError == nil {
t.Fatal("pushsync did not generate a receipt error")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment