Commit 6ed63739 authored by acud's avatar acud Committed by GitHub

api: add pss client facing apis (#686)

* add pss websocket api for opening a persistent connection for incoming messages on a given topic
* add pss message send api to send outgoing messages
parent cf783b24
...@@ -156,7 +156,7 @@ func newTestServer(t *testing.T, storer storage.Storer) *url.URL { ...@@ -156,7 +156,7 @@ func newTestServer(t *testing.T, storer storage.Storer) *url.URL {
t.Helper() t.Helper()
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
store := statestore.NewStateStore() store := statestore.NewStateStore()
s := api.New(tags.NewTags(store, logger), storer, nil, logger, nil, api.Options{}) s := api.New(tags.NewTags(store, logger), storer, nil, nil, logger, nil, api.Options{})
ts := httptest.NewServer(s) ts := httptest.NewServer(s)
srvUrl, err := url.Parse(ts.URL) srvUrl, err := url.Parse(ts.URL)
if err != nil { if err != nil {
......
...@@ -15,6 +15,7 @@ require ( ...@@ -15,6 +15,7 @@ require (
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 // indirect github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 // indirect
github.com/gorilla/handlers v1.4.2 github.com/gorilla/handlers v1.4.2
github.com/gorilla/mux v1.7.4 github.com/gorilla/mux v1.7.4
github.com/gorilla/websocket v1.4.2
github.com/ipfs/go-log/v2 v2.1.1 // indirect github.com/ipfs/go-log/v2 v2.1.1 // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/libp2p/go-libp2p v0.10.0 github.com/libp2p/go-libp2p v0.10.0
......
...@@ -7,13 +7,16 @@ package api ...@@ -7,13 +7,16 @@ package api
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
m "github.com/ethersphere/bee/pkg/metrics" m "github.com/ethersphere/bee/pkg/metrics"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/resolver" "github.com/ethersphere/bee/pkg/resolver"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
...@@ -37,22 +40,28 @@ var ( ...@@ -37,22 +40,28 @@ var (
type Service interface { type Service interface {
http.Handler http.Handler
m.Collector m.Collector
io.Closer
} }
type server struct { type server struct {
Tags *tags.Tags Tags *tags.Tags
Storer storage.Storer Storer storage.Storer
Resolver resolver.Interface Resolver resolver.Interface
Pss pss.Interface
Logger logging.Logger Logger logging.Logger
Tracer *tracing.Tracer Tracer *tracing.Tracer
Options Options
http.Handler http.Handler
metrics metrics metrics metrics
wsWg sync.WaitGroup // wait for all websockets to close on exit
quit chan struct{}
} }
type Options struct { type Options struct {
CORSAllowedOrigins []string CORSAllowedOrigins []string
GatewayMode bool GatewayMode bool
WsPingPeriod time.Duration
} }
const ( const (
...@@ -61,15 +70,17 @@ const ( ...@@ -61,15 +70,17 @@ const (
) )
// New will create a and initialize a new API service. // New will create a and initialize a new API service.
func New(tags *tags.Tags, storer storage.Storer, resolver resolver.Interface, logger logging.Logger, tracer *tracing.Tracer, o Options) Service { func New(tags *tags.Tags, storer storage.Storer, resolver resolver.Interface, pss pss.Interface, logger logging.Logger, tracer *tracing.Tracer, o Options) Service {
s := &server{ s := &server{
Tags: tags, Tags: tags,
Storer: storer, Storer: storer,
Resolver: resolver, Resolver: resolver,
Pss: pss,
Options: o, Options: o,
Logger: logger, Logger: logger,
Tracer: tracer, Tracer: tracer,
metrics: newMetrics(), metrics: newMetrics(),
quit: make(chan struct{}),
} }
s.setupRouting() s.setupRouting()
...@@ -77,6 +88,26 @@ func New(tags *tags.Tags, storer storage.Storer, resolver resolver.Interface, lo ...@@ -77,6 +88,26 @@ func New(tags *tags.Tags, storer storage.Storer, resolver resolver.Interface, lo
return s return s
} }
// Close hangs up running websockets on shutdown.
func (s *server) Close() error {
s.Logger.Info("api shutting down")
close(s.quit)
done := make(chan struct{})
go func() {
defer close(done)
s.wsWg.Wait()
}()
select {
case <-done:
case <-time.After(5 * time.Second):
return errors.New("api shutting down with open websockets")
}
return nil
}
// getOrCreateTag attempts to get the tag if an id is supplied, and returns an error if it does not exist. // getOrCreateTag attempts to get the tag if an id is supplied, and returns an error if it does not exist.
// If no id is supplied, it will attempt to create a new tag with a generated name and return it. // If no id is supplied, it will attempt to create a new tag with a generated name and return it.
func (s *server) getOrCreateTag(tagUid string) (*tags.Tag, bool, error) { func (s *server) getOrCreateTag(tagUid string) (*tags.Tag, bool, error) {
......
...@@ -11,48 +11,72 @@ import ( ...@@ -11,48 +11,72 @@ import (
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
"time"
"github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/resolver" "github.com/ethersphere/bee/pkg/resolver"
resolverMock "github.com/ethersphere/bee/pkg/resolver/mock" resolverMock "github.com/ethersphere/bee/pkg/resolver/mock"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
"github.com/gorilla/websocket"
"resenje.org/web" "resenje.org/web"
) )
type testServerOptions struct { type testServerOptions struct {
Storer storage.Storer Storer storage.Storer
Resolver resolver.Interface Resolver resolver.Interface
Tags *tags.Tags Pss pss.Interface
GatewayMode bool WsPath string
Logger logging.Logger Tags *tags.Tags
GatewayMode bool
WsPingPeriod time.Duration
Logger logging.Logger
} }
func newTestServer(t *testing.T, o testServerOptions) *http.Client { func newTestServer(t *testing.T, o testServerOptions) (*http.Client, *websocket.Conn, string) {
if o.Logger == nil { if o.Logger == nil {
o.Logger = logging.New(ioutil.Discard, 0) o.Logger = logging.New(ioutil.Discard, 0)
} }
if o.Resolver == nil { if o.Resolver == nil {
o.Resolver = resolverMock.NewResolver() o.Resolver = resolverMock.NewResolver()
} }
s := api.New(o.Tags, o.Storer, o.Resolver, o.Logger, nil, api.Options{ if o.WsPingPeriod == 0 {
GatewayMode: o.GatewayMode, o.WsPingPeriod = 60 * time.Second
}
s := api.New(o.Tags, o.Storer, o.Resolver, o.Pss, o.Logger, nil, api.Options{
GatewayMode: o.GatewayMode,
WsPingPeriod: o.WsPingPeriod,
}) })
ts := httptest.NewServer(s) ts := httptest.NewServer(s)
t.Cleanup(ts.Close) t.Cleanup(ts.Close)
return &http.Client{ var (
Transport: web.RoundTripperFunc(func(r *http.Request) (*http.Response, error) { httpClient = &http.Client{
u, err := url.Parse(ts.URL + r.URL.String()) Transport: web.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
if err != nil { u, err := url.Parse(ts.URL + r.URL.String())
return nil, err if err != nil {
} return nil, err
r.URL = u }
return ts.Client().Transport.RoundTrip(r) r.URL = u
}), return ts.Client().Transport.RoundTrip(r)
}),
}
conn *websocket.Conn
err error
)
if o.WsPath != "" {
u := url.URL{Scheme: "ws", Host: ts.Listener.Addr().String(), Path: o.WsPath}
conn, _, err = websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
t.Fatalf("dial: %v. url %v", err, u.String())
}
} }
return httpClient, conn, ts.Listener.Addr().String()
} }
func TestParseName(t *testing.T) { func TestParseName(t *testing.T) {
...@@ -116,7 +140,7 @@ func TestParseName(t *testing.T) { ...@@ -116,7 +140,7 @@ func TestParseName(t *testing.T) {
})) }))
} }
s := api.New(nil, nil, tC.res, tC.log, nil, api.Options{}).(*api.Server) s := api.New(nil, nil, tC.res, nil, tC.log, nil, api.Options{}).(*api.Server)
t.Run(tC.desc, func(t *testing.T) { t.Run(tC.desc, func(t *testing.T) {
got, err := s.ResolveNameOrAddress(tC.name) got, err := s.ResolveNameOrAddress(tC.name)
......
...@@ -6,11 +6,12 @@ package api_test ...@@ -6,11 +6,12 @@ package api_test
import ( import (
"bytes" "bytes"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"testing" "testing"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
...@@ -31,7 +32,7 @@ func TestBytes(t *testing.T) { ...@@ -31,7 +32,7 @@ func TestBytes(t *testing.T) {
mockStorer = mock.NewStorer() mockStorer = mock.NewStorer()
mockStatestore = statestore.NewStateStore() mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{ client, _, _ = newTestServer(t, testServerOptions{
Storer: mockStorer, Storer: mockStorer,
Tags: tags.NewTags(mockStatestore, logger), Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5), Logger: logging.New(ioutil.Discard, 5),
......
...@@ -9,7 +9,6 @@ import ( ...@@ -9,7 +9,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io" "io"
"io/ioutil" "io/ioutil"
"mime" "mime"
...@@ -17,6 +16,8 @@ import ( ...@@ -17,6 +16,8 @@ import (
"strings" "strings"
"testing" "testing"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/collection/entry" "github.com/ethersphere/bee/pkg/collection/entry"
"github.com/ethersphere/bee/pkg/file/pipeline" "github.com/ethersphere/bee/pkg/file/pipeline"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
...@@ -36,7 +37,7 @@ func TestBzz(t *testing.T) { ...@@ -36,7 +37,7 @@ func TestBzz(t *testing.T) {
ctx = context.Background() ctx = context.Background()
mockStatestore = statestore.NewStateStore() mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{ client, _, _ = newTestServer(t, testServerOptions{
Storer: storer, Storer: storer,
Tags: tags.NewTags(mockStatestore, logger), Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5), Logger: logging.New(ioutil.Discard, 5),
......
...@@ -6,13 +6,14 @@ package api_test ...@@ -6,13 +6,14 @@ package api_test
import ( import (
"bytes" "bytes"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"testing" "testing"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/api"
...@@ -41,7 +42,7 @@ func TestChunkUploadDownload(t *testing.T) { ...@@ -41,7 +42,7 @@ func TestChunkUploadDownload(t *testing.T) {
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger) tag = tags.NewTags(mockStatestore, logger)
mockValidatingStorer = mock.NewStorer(mock.WithValidator(mockValidator)) mockValidatingStorer = mock.NewStorer(mock.WithValidator(mockValidator))
client = newTestServer(t, testServerOptions{ client, _, _ = newTestServer(t, testServerOptions{
Storer: mockValidatingStorer, Storer: mockValidatingStorer,
Tags: tag, Tags: tag,
}) })
......
...@@ -35,7 +35,7 @@ func TestDirs(t *testing.T) { ...@@ -35,7 +35,7 @@ func TestDirs(t *testing.T) {
storer = mock.NewStorer() storer = mock.NewStorer()
mockStatestore = statestore.NewStateStore() mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{ client, _, _ = newTestServer(t, testServerOptions{
Storer: storer, Storer: storer,
Tags: tags.NewTags(mockStatestore, logger), Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5), Logger: logging.New(ioutil.Discard, 5),
......
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io" "io"
"io/ioutil" "io/ioutil"
"mime" "mime"
...@@ -18,6 +17,8 @@ import ( ...@@ -18,6 +17,8 @@ import (
"strings" "strings"
"testing" "testing"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
...@@ -35,7 +36,7 @@ func TestFiles(t *testing.T) { ...@@ -35,7 +36,7 @@ func TestFiles(t *testing.T) {
simpleData = []byte("this is a simple text") simpleData = []byte("this is a simple text")
mockStatestore = statestore.NewStateStore() mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{ client, _, _ = newTestServer(t, testServerOptions{
Storer: mock.NewStorer(), Storer: mock.NewStorer(),
Tags: tags.NewTags(mockStatestore, logger), Tags: tags.NewTags(mockStatestore, logger),
}) })
...@@ -338,7 +339,7 @@ func TestRangeRequests(t *testing.T) { ...@@ -338,7 +339,7 @@ func TestRangeRequests(t *testing.T) {
t.Run(upload.name, func(t *testing.T) { t.Run(upload.name, func(t *testing.T) {
mockStatestore := statestore.NewStateStore() mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
client := newTestServer(t, testServerOptions{ client, _, _ := newTestServer(t, testServerOptions{
Storer: mock.NewStorer(), Storer: mock.NewStorer(),
Tags: tags.NewTags(mockStatestore, logger), Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5), Logger: logging.New(ioutil.Discard, 5),
......
...@@ -20,7 +20,7 @@ import ( ...@@ -20,7 +20,7 @@ import (
func TestGatewayMode(t *testing.T) { func TestGatewayMode(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
client := newTestServer(t, testServerOptions{ client, _, _ := newTestServer(t, testServerOptions{
Storer: mock.NewStorer(), Storer: mock.NewStorer(),
Tags: tags.NewTags(statestore.NewStateStore(), logger), Tags: tags.NewTags(statestore.NewStateStore(), logger),
Logger: logger, Logger: logger,
......
...@@ -37,7 +37,7 @@ func TestPinChunkHandler(t *testing.T) { ...@@ -37,7 +37,7 @@ func TestPinChunkHandler(t *testing.T) {
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger) tag = tags.NewTags(mockStatestore, logger)
client = newTestServer(t, testServerOptions{ client, _, _ = newTestServer(t, testServerOptions{
Storer: mockValidatingStorer, Storer: mockValidatingStorer,
Tags: tag, Tags: tag,
Logger: logger, Logger: logger,
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package api
import (
"context"
"encoding/hex"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
)
var (
upgrader = websocket.Upgrader{
ReadBufferSize: swarm.ChunkSize,
WriteBufferSize: swarm.ChunkSize,
}
writeDeadline = 4 * time.Second // write deadline. should be smaller than the shutdown timeout on api close
targetMaxLength = 2 // max target length in bytes, in order to prevent grieving by excess computation
)
type PssMessage struct {
Topic string
Message string
}
func (s *server) pssPostHandler(w http.ResponseWriter, r *http.Request) {
t := mux.Vars(r)["topic"]
topic := trojan.NewTopic(t)
tg := mux.Vars(r)["targets"]
var targets trojan.Targets
tgts := strings.Split(tg, ",")
for _, v := range tgts {
target, err := hex.DecodeString(v)
if err != nil || len(target) > targetMaxLength {
s.Logger.Debugf("pss send: bad targets: %v", err)
s.Logger.Error("pss send: bad targets")
jsonhttp.BadRequest(w, nil)
return
}
targets = append(targets, target)
}
payload, err := ioutil.ReadAll(r.Body)
if err != nil {
s.Logger.Debugf("pss read payload: %v", err)
s.Logger.Error("pss read payload")
jsonhttp.InternalServerError(w, nil)
return
}
err = s.Pss.Send(r.Context(), targets, topic, payload)
if err != nil {
s.Logger.Debugf("pss send payload: %v. topic: %s", err, t)
s.Logger.Error("pss send payload")
jsonhttp.InternalServerError(w, nil)
return
}
jsonhttp.OK(w, nil)
}
func (s *server) pssWsHandler(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
s.Logger.Debugf("pss ws: upgrade: %v", err)
s.Logger.Error("pss ws: cannot upgrade")
jsonhttp.InternalServerError(w, nil)
return
}
t := mux.Vars(r)["topic"]
s.wsWg.Add(1)
go s.pumpWs(conn, t)
}
func (s *server) pumpWs(conn *websocket.Conn, t string) {
defer s.wsWg.Done()
var (
dataC = make(chan []byte)
gone = make(chan struct{})
topic = trojan.NewTopic(t)
ticker = time.NewTicker(s.WsPingPeriod)
err error
)
defer func() {
ticker.Stop()
_ = conn.Close()
}()
cleanup := s.Pss.Register(topic, func(_ context.Context, m *trojan.Message) {
dataC <- m.Payload
})
defer cleanup()
conn.SetCloseHandler(func(code int, text string) error {
s.Logger.Debugf("pss handler: client gone. code %d message %s", code, text)
close(gone)
return nil
})
for {
select {
case b := <-dataC:
err = conn.SetWriteDeadline(time.Now().Add(writeDeadline))
if err != nil {
s.Logger.Debugf("pss set write deadline: %v", err)
return
}
err = conn.WriteMessage(websocket.BinaryMessage, b)
if err != nil {
s.Logger.Debugf("pss write to websocket: %v", err)
return
}
case <-s.quit:
// shutdown
err = conn.SetWriteDeadline(time.Now().Add(writeDeadline))
if err != nil {
s.Logger.Debugf("pss set write deadline: %v", err)
return
}
err = conn.WriteMessage(websocket.CloseMessage, []byte{})
if err != nil {
s.Logger.Debugf("pss write close message: %v", err)
}
return
case <-gone:
// client gone
return
case <-ticker.C:
err = conn.SetWriteDeadline(time.Now().Add(writeDeadline))
if err != nil {
s.Logger.Debugf("pss set write deadline: %v", err)
return
}
if err = conn.WriteMessage(websocket.PingMessage, nil); err != nil {
// error encountered while pinging client. client probably gone
return
}
}
}
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package api_test
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"sync"
"testing"
"time"
"github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
"github.com/gorilla/websocket"
)
var (
target = trojan.Target([]byte{1})
targets = trojan.Targets([]trojan.Target{target})
payload = []byte("testdata")
topic = trojan.NewTopic("testtopic")
timeout = 10 * time.Second
)
// creates a single websocket handler for an arbitrary topic, and receives a message
func TestPssWebsocketSingleHandler(t *testing.T) {
var (
pss, cl, _ = newPssTest(t, opts{})
msgContent = make([]byte, len(payload))
tc swarm.Chunk
mtx sync.Mutex
done = make(chan struct{})
)
err := cl.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
t.Fatal(err)
}
cl.SetReadLimit(swarm.ChunkSize)
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
m, err := trojan.NewMessage(topic, payload)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, payload, &mtx)
}
func TestPssWebsocketSingleHandlerDeregister(t *testing.T) {
// create a new pss instance, register a handle through ws, call
// pss.TryUnwrap with a chunk designated for this handler and expect
// the handler to be notified
var (
pss, cl, _ = newPssTest(t, opts{})
msgContent = make([]byte, len(payload))
tc swarm.Chunk
mtx sync.Mutex
done = make(chan struct{})
)
err := cl.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
t.Fatal(err)
}
cl.SetReadLimit(swarm.ChunkSize)
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
m, err := trojan.NewMessage(topic, payload)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
// close the websocket before calling pss with the message
err = cl.WriteMessage(websocket.CloseMessage, []byte{})
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, nil, &mtx)
}
func TestPssWebsocketMultiHandler(t *testing.T) {
var (
pss, cl, listener = newPssTest(t, opts{})
u = url.URL{Scheme: "ws", Host: listener, Path: "/pss/subscribe/testtopic"}
cl2, _, err = websocket.DefaultDialer.Dial(u.String(), nil)
msgContent = make([]byte, len(payload))
msgContent2 = make([]byte, len(payload))
tc swarm.Chunk
mtx sync.Mutex
done = make(chan struct{})
)
if err != nil {
t.Fatalf("dial: %v. url %v", err, u.String())
}
err = cl.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
t.Fatal(err)
}
cl.SetReadLimit(swarm.ChunkSize)
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
go waitReadMessage(t, &mtx, cl2, msgContent2, done)
m, err := trojan.NewMessage(topic, payload)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
// close the websocket before calling pss with the message
err = cl.WriteMessage(websocket.CloseMessage, []byte{})
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, nil, &mtx)
waitMessage(t, msgContent2, nil, &mtx)
}
// TestPssSend tests that the pss message sending over http works correctly.
func TestPssSend(t *testing.T) {
var (
logger = logging.New(ioutil.Discard, 0)
mtx sync.Mutex
recievedTargets trojan.Targets
recievedTopic trojan.Topic
recievedBytes []byte
done bool
sendFn = func(_ context.Context, targets trojan.Targets, topic trojan.Topic, bytes []byte) error {
mtx.Lock()
recievedTargets = targets
recievedTopic = topic
recievedBytes = bytes
done = true
mtx.Unlock()
return nil
}
pss = newMockPss(sendFn)
client, _, _ = newTestServer(t, testServerOptions{
Pss: pss,
Storer: mock.NewStorer(),
Logger: logger,
})
targets = fmt.Sprintf("[[%d]]", 0x12)
topic = "testtopic"
hasher = swarm.NewHasher()
_, err = hasher.Write([]byte(topic))
topicHash = hasher.Sum(nil)
)
if err != nil {
t.Fatal(err)
}
t.Run("err - bad targets", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/to/badtarget", http.StatusBadRequest,
jsonhttptest.WithRequestBody(bytes.NewReader(payload)),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: "Bad Request",
Code: http.StatusBadRequest,
}),
)
})
t.Run("ok", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/testtopic/12", http.StatusOK,
jsonhttptest.WithRequestBody(bytes.NewReader(payload)),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: "OK",
Code: http.StatusOK,
}),
)
waitDone(t, &mtx, &done)
if !bytes.Equal(recievedBytes, payload) {
t.Fatalf("payload mismatch. want %v got %v", payload, recievedBytes)
}
if targets != fmt.Sprint(recievedTargets) {
t.Fatalf("targets mismatch. want %v got %v", targets, recievedTargets)
}
if string(topicHash) != string(recievedTopic[:]) {
t.Fatalf("topic mismatch. want %v got %v", topic, string(recievedTopic[:]))
}
})
}
// TestPssPingPong tests that the websocket api adheres to the websocket standard
// and sends ping-pong messages to keep the connection alive.
// The test opens a websocket, keeps it alive for 500ms, then receives a pss message.
func TestPssPingPong(t *testing.T) {
var (
pss, cl, _ = newPssTest(t, opts{pingPeriod: 90 * time.Millisecond})
msgContent = make([]byte, len(payload))
tc swarm.Chunk
mtx sync.Mutex
pongWait = 1 * time.Millisecond
done = make(chan struct{})
)
cl.SetReadLimit(swarm.ChunkSize)
err := cl.SetReadDeadline(time.Now().Add(pongWait))
if err != nil {
t.Fatal(err)
}
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
m, err := trojan.NewMessage(topic, payload)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
time.Sleep(500 * time.Millisecond) // wait to see that the websocket is kept alive
err = pss.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, nil, &mtx)
}
func waitReadMessage(t *testing.T, mtx *sync.Mutex, cl *websocket.Conn, targetContent []byte, done <-chan struct{}) {
t.Helper()
timeout := time.After(timeout)
for {
select {
case <-done:
return
case <-timeout:
t.Errorf("timed out waiting for message")
return
default:
}
msgType, message, err := cl.ReadMessage()
if err != nil {
return
}
if msgType == websocket.PongMessage {
// ignore pings
continue
}
if message != nil {
mtx.Lock()
copy(targetContent, message)
mtx.Unlock()
}
time.Sleep(50 * time.Millisecond)
}
}
func waitDone(t *testing.T, mtx *sync.Mutex, done *bool) {
for i := 0; i < 10; i++ {
mtx.Lock()
if *done {
mtx.Unlock()
return
}
mtx.Unlock()
time.Sleep(50 * time.Millisecond)
}
t.Fatal("timed out waiting for send")
}
func waitMessage(t *testing.T, data, expData []byte, mtx *sync.Mutex) {
ttl := time.After(timeout)
for {
select {
case <-ttl:
if expData == nil {
return
}
t.Fatal("timed out waiting for pss message")
default:
}
mtx.Lock()
if bytes.Equal(data, expData) {
mtx.Unlock()
return
}
mtx.Unlock()
time.Sleep(100 * time.Millisecond)
}
}
type opts struct {
pingPeriod time.Duration
}
func newPssTest(t *testing.T, o opts) (pss.Interface, *websocket.Conn, string) {
var (
logger = logging.New(ioutil.Discard, 0)
pss = pss.New(logger)
)
if o.pingPeriod == 0 {
o.pingPeriod = 10 * time.Second
}
_, cl, listener := newTestServer(t, testServerOptions{
Pss: pss,
WsPath: "/pss/subscribe/testtopic",
Storer: mock.NewStorer(),
Logger: logger,
WsPingPeriod: o.pingPeriod,
})
return pss, cl, listener
}
type pssSendFn func(context.Context, trojan.Targets, trojan.Topic, []byte) error
type mpss struct {
f pssSendFn
}
func newMockPss(f pssSendFn) *mpss {
return &mpss{f}
}
// Send arbitrary byte slice with the given topic to Targets.
func (m *mpss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, bytes []byte) error {
return m.f(ctx, targets, topic, bytes)
}
// Register a Handler for a given Topic.
func (m *mpss) Register(_ trojan.Topic, _ pss.Handler) func() {
panic("not implemented") // TODO: Implement
}
// TryUnwrap tries to unwrap a wrapped trojan message.
func (m *mpss) TryUnwrap(_ context.Context, _ swarm.Chunk) error {
panic("not implemented") // TODO: Implement
}
func (m *mpss) SetPushSyncer(pushSyncer pushsync.PushSyncer) {
panic("not implemented") // TODO: Implement
}
func (m *mpss) Close() error {
panic("not implemented") // TODO: Implement
}
...@@ -90,6 +90,15 @@ func (s *server) setupRouting() { ...@@ -90,6 +90,15 @@ func (s *server) setupRouting() {
), ),
}) })
handle(router, "/pss/send/{topic}/{targets}", jsonhttp.MethodHandler{
"POST": web.ChainHandlers(
jsonhttp.NewMaxBodyBytesHandler(swarm.ChunkSize),
web.FinalHandlerFunc(s.pssPostHandler),
),
})
handle(router, "/pss/subscribe/{topic}", http.HandlerFunc(s.pssWsHandler))
handle(router, "/tags", web.ChainHandlers( handle(router, "/tags", web.ChainHandlers(
s.gatewayModeForbidEndpointHandler, s.gatewayModeForbidEndpointHandler,
web.FinalHandler(jsonhttp.MethodHandler{ web.FinalHandler(jsonhttp.MethodHandler{
......
...@@ -8,14 +8,15 @@ import ( ...@@ -8,14 +8,15 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
...@@ -46,7 +47,7 @@ func TestTags(t *testing.T) { ...@@ -46,7 +47,7 @@ func TestTags(t *testing.T) {
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger) tag = tags.NewTags(mockStatestore, logger)
mockPusher = mp.NewMockPusher(tag) mockPusher = mp.NewMockPusher(tag)
client = newTestServer(t, testServerOptions{ client, _, _ = newTestServer(t, testServerOptions{
Storer: mock.NewStorer(), Storer: mock.NewStorer(),
Tags: tag, Tags: tag,
}) })
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package logging package logging
import ( import (
"bufio"
"net" "net"
"net/http" "net/http"
"time" "time"
...@@ -89,6 +90,10 @@ func (l *responseLogger) Flush() { ...@@ -89,6 +90,10 @@ func (l *responseLogger) Flush() {
l.w.(http.Flusher).Flush() l.w.(http.Flusher).Flush()
} }
func (l *responseLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return l.w.(http.Hijacker).Hijack()
}
func (l *responseLogger) CloseNotify() <-chan bool { func (l *responseLogger) CloseNotify() <-chan bool {
// staticcheck SA1019 CloseNotifier interface is required by gorilla compress handler // staticcheck SA1019 CloseNotifier interface is required by gorilla compress handler
// nolint:staticcheck // nolint:staticcheck
......
...@@ -52,20 +52,23 @@ import ( ...@@ -52,20 +52,23 @@ import (
) )
type Bee struct { type Bee struct {
p2pService io.Closer p2pService io.Closer
p2pCancel context.CancelFunc p2pCancel context.CancelFunc
apiServer *http.Server apiCloser io.Closer
debugAPIServer *http.Server apiServer *http.Server
resolverCloser io.Closer debugAPIServer *http.Server
errorLogWriter *io.PipeWriter resolverCloser io.Closer
tracerCloser io.Closer errorLogWriter *io.PipeWriter
tagsCloser io.Closer tracerCloser io.Closer
stateStoreCloser io.Closer tagsCloser io.Closer
localstoreCloser io.Closer stateStoreCloser io.Closer
topologyCloser io.Closer localstoreCloser io.Closer
pusherCloser io.Closer topologyCloser io.Closer
pullerCloser io.Closer pusherCloser io.Closer
pullSyncCloser io.Closer pullerCloser io.Closer
pullSyncCloser io.Closer
pssCloser io.Closer
recoveryHandleCleanup func()
} }
type Options struct { type Options struct {
...@@ -247,7 +250,8 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -247,7 +250,8 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
} }
// instantiate the pss object // instantiate the pss object
psss := pss.New(logger, nil) psss := pss.New(logger)
b.pssCloser = psss
var ns storage.Storer var ns storage.Storer
if o.GlobalPinningEnabled { if o.GlobalPinningEnabled {
...@@ -262,7 +266,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -262,7 +266,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagg, psss.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10)) pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagg, psss.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10))
// set the pushSyncer in the PSS // set the pushSyncer in the PSS
psss.WithPushSyncer(pushSyncProtocol) psss.SetPushSyncer(pushSyncProtocol)
if err = p2ps.AddProtocol(pushSyncProtocol.Protocol()); err != nil { if err = p2ps.AddProtocol(pushSyncProtocol.Protocol()); err != nil {
return nil, fmt.Errorf("pushsync service: %w", err) return nil, fmt.Errorf("pushsync service: %w", err)
...@@ -271,7 +275,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -271,7 +275,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
if o.GlobalPinningEnabled { if o.GlobalPinningEnabled {
// register function for chunk repair upon receiving a trojan message // register function for chunk repair upon receiving a trojan message
chunkRepairHandler := recovery.NewRepairHandler(ns, logger, pushSyncProtocol) chunkRepairHandler := recovery.NewRepairHandler(ns, logger, pushSyncProtocol)
psss.Register(recovery.RecoveryTopic, chunkRepairHandler) b.recoveryHandleCleanup = psss.Register(recovery.RecoveryTopic, chunkRepairHandler)
} }
pushSyncPusher := pusher.New(storer, kad, pushSyncProtocol, tagg, logger) pushSyncPusher := pusher.New(storer, kad, pushSyncProtocol, tagg, logger)
...@@ -299,9 +303,10 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -299,9 +303,10 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
var apiService api.Service var apiService api.Service
if o.APIAddr != "" { if o.APIAddr != "" {
// API server // API server
apiService = api.New(tagg, ns, multiResolver, logger, tracer, api.Options{ apiService = api.New(tagg, ns, multiResolver, psss, logger, tracer, api.Options{
CORSAllowedOrigins: o.CORSAllowedOrigins, CORSAllowedOrigins: o.CORSAllowedOrigins,
GatewayMode: o.GatewayMode, GatewayMode: o.GatewayMode,
WsPingPeriod: 60 * time.Second,
}) })
apiListener, err := net.Listen("tcp", o.APIAddr) apiListener, err := net.Listen("tcp", o.APIAddr)
if err != nil { if err != nil {
...@@ -323,6 +328,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -323,6 +328,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
}() }()
b.apiServer = apiServer b.apiServer = apiServer
b.apiCloser = apiService
} }
if o.DebugAPIAddr != "" { if o.DebugAPIAddr != "" {
...@@ -373,6 +379,12 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -373,6 +379,12 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
func (b *Bee) Shutdown(ctx context.Context) error { func (b *Bee) Shutdown(ctx context.Context) error {
errs := new(multiError) errs := new(multiError)
if b.apiCloser != nil {
if err := b.apiCloser.Close(); err != nil {
errs.add(fmt.Errorf("api: %w", err))
}
}
var eg errgroup.Group var eg errgroup.Group
if b.apiServer != nil { if b.apiServer != nil {
eg.Go(func() error { eg.Go(func() error {
...@@ -395,6 +407,10 @@ func (b *Bee) Shutdown(ctx context.Context) error { ...@@ -395,6 +407,10 @@ func (b *Bee) Shutdown(ctx context.Context) error {
errs.add(err) errs.add(err)
} }
if b.recoveryHandleCleanup != nil {
b.recoveryHandleCleanup()
}
if err := b.pusherCloser.Close(); err != nil { if err := b.pusherCloser.Close(); err != nil {
errs.add(fmt.Errorf("pusher: %w", err)) errs.add(fmt.Errorf("pusher: %w", err))
} }
...@@ -407,6 +423,10 @@ func (b *Bee) Shutdown(ctx context.Context) error { ...@@ -407,6 +423,10 @@ func (b *Bee) Shutdown(ctx context.Context) error {
errs.add(fmt.Errorf("pull sync: %w", err)) errs.add(fmt.Errorf("pull sync: %w", err))
} }
if err := b.pssCloser.Close(); err != nil {
errs.add(fmt.Errorf("pss: %w", err))
}
b.p2pCancel() b.p2pCancel()
if err := b.p2pService.Close(); err != nil { if err := b.p2pService.Close(); err != nil {
errs.add(fmt.Errorf("p2p server: %w", err)) errs.add(fmt.Errorf("p2p server: %w", err))
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"sync" "sync"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
...@@ -22,54 +23,67 @@ var ( ...@@ -22,54 +23,67 @@ var (
) )
type Interface interface { type Interface interface {
Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, payload []byte) error // Send arbitrary byte slice with the given topic to Targets.
Register(topic trojan.Topic, hndlr Handler) Send(context.Context, trojan.Targets, trojan.Topic, []byte) error
GetHandler(topic trojan.Topic) Handler // Register a Handler for a given Topic.
TryUnwrap(ctx context.Context, c swarm.Chunk) error Register(trojan.Topic, Handler) func()
WithPushSyncer(pushSyncer pushsync.PushSyncer) // TryUnwrap tries to unwrap a wrapped trojan message.
TryUnwrap(context.Context, swarm.Chunk) error
SetPushSyncer(pushSyncer pushsync.PushSyncer)
io.Closer
} }
// pss is the top-level struct, which takes care of message sending
type pss struct { type pss struct {
pusher pushsync.PushSyncer pusher pushsync.PushSyncer
handlers map[trojan.Topic]Handler handlers map[trojan.Topic][]*Handler
handlersMu sync.RWMutex handlersMu sync.Mutex
metrics metrics metrics metrics
logger logging.Logger logger logging.Logger
quit chan struct{}
} }
// New inits the pss struct with the storer // New returns a new pss service.
func New(logger logging.Logger, pusher pushsync.PushSyncer) Interface { func New(logger logging.Logger) Interface {
return &pss{ return &pss{
logger: logger, logger: logger,
pusher: pusher, handlers: make(map[trojan.Topic][]*Handler),
handlers: make(map[trojan.Topic]Handler),
metrics: newMetrics(), metrics: newMetrics(),
quit: make(chan struct{}),
} }
} }
func (ps *pss) WithPushSyncer(pushSyncer pushsync.PushSyncer) { func (ps *pss) Close() error {
close(ps.quit)
ps.handlersMu.Lock()
defer ps.handlersMu.Unlock()
ps.handlers = make(map[trojan.Topic][]*Handler) //unset handlers on shutdown
return nil
}
func (ps *pss) SetPushSyncer(pushSyncer pushsync.PushSyncer) {
ps.pusher = pushSyncer ps.pusher = pushSyncer
} }
// Handler defines code to be executed upon reception of a trojan message // Handler defines code to be executed upon reception of a trojan message.
type Handler func(context.Context, *trojan.Message) error type Handler func(context.Context, *trojan.Message)
// Send constructs a padded message with topic and payload, // Send constructs a padded message with topic and payload,
// wraps it in a trojan chunk such that one of the targets is a prefix of the chunk address // wraps it in a trojan chunk such that one of the targets is a prefix of the chunk address.
// uses push-sync to deliver message // Uses push-sync to deliver message.
func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, payload []byte) error { func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, payload []byte) error {
p.metrics.TotalMessagesSentCounter.Inc() p.metrics.TotalMessagesSentCounter.Inc()
//construct Trojan Chunk
m, err := trojan.NewMessage(topic, payload) m, err := trojan.NewMessage(topic, payload)
if err != nil { if err != nil {
return err return err
} }
var tc swarm.Chunk var tc swarm.Chunk
tc, err = m.Wrap(ctx, targets) tc, err = m.Wrap(ctx, targets)
if err != nil { if err != nil {
return err return err
} }
...@@ -81,33 +95,69 @@ func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Top ...@@ -81,33 +95,69 @@ func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Top
return nil return nil
} }
// Register allows the definition of a Handler func for a specific topic on the pss struct // Register allows the definition of a Handler func for a specific topic on the pss struct.
func (p *pss) Register(topic trojan.Topic, hndlr Handler) { func (p *pss) Register(topic trojan.Topic, handler Handler) (cleanup func()) {
p.handlersMu.Lock() p.handlersMu.Lock()
defer p.handlersMu.Unlock() defer p.handlersMu.Unlock()
p.handlers[topic] = hndlr
p.handlers[topic] = append(p.handlers[topic], &handler)
return func() {
p.handlersMu.Lock()
defer p.handlersMu.Unlock()
h := p.handlers[topic]
for i := 0; i < len(h); i++ {
if h[i] == &handler {
p.handlers[topic] = append(h[:i], h[i+1:]...)
return
}
}
}
} }
// TryUnwrap allows unwrapping a chunk as a trojan message and calling its handler func based on its topic // TryUnwrap allows unwrapping a chunk as a trojan message and calling its handlers based on the topic.
func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error { func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error {
if !trojan.IsPotential(c) { if !trojan.IsPotential(c) {
return nil return nil
} }
m, err := trojan.Unwrap(c) // if err occurs unwrapping, there will be no handler m, err := trojan.Unwrap(c)
if err != nil { if err != nil {
return err return err
} }
h := p.GetHandler(m.Topic) h := p.getHandlers(m.Topic)
if h == nil { if h == nil {
return fmt.Errorf("topic %v, %w", m.Topic, ErrNoHandler) return fmt.Errorf("topic %v, %w", m.Topic, ErrNoHandler)
} }
return h(ctx, m)
ctx, cancel := context.WithCancel(ctx)
done := make(chan struct{})
var wg sync.WaitGroup
go func() {
defer cancel()
select {
case <-p.quit:
case <-done:
}
}()
for _, hh := range h {
wg.Add(1)
go func(hh Handler) {
defer wg.Done()
hh(ctx, m)
}(*hh)
}
go func() {
wg.Wait()
close(done)
}()
return nil
} }
// GetHandler returns the Handler func registered in pss for the given topic func (p *pss) getHandlers(topic trojan.Topic) []*Handler {
func (p *pss) GetHandler(topic trojan.Topic) Handler { p.handlersMu.Lock()
p.handlersMu.RLock() defer p.handlersMu.Unlock()
defer p.handlersMu.RUnlock()
return p.handlers[topic] return p.handlers[topic]
} }
...@@ -8,7 +8,10 @@ import ( ...@@ -8,7 +8,10 @@ import (
"bytes" "bytes"
"context" "context"
"io/ioutil" "io/ioutil"
"runtime"
"sync"
"testing" "testing"
"time"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss" "github.com/ethersphere/bee/pkg/pss"
...@@ -21,7 +24,7 @@ import ( ...@@ -21,7 +24,7 @@ import (
// TestSend creates a trojan chunk and sends it using push sync // TestSend creates a trojan chunk and sends it using push sync
func TestSend(t *testing.T) { func TestSend(t *testing.T) {
var err error var err error
ctx := context.TODO() ctx := context.Background()
// create a mock pushsync service to push the chunk to its destination // create a mock pushsync service to push the chunk to its destination
var receipt *pushsync.Receipt var receipt *pushsync.Receipt
...@@ -35,7 +38,8 @@ func TestSend(t *testing.T) { ...@@ -35,7 +38,8 @@ func TestSend(t *testing.T) {
return rcpt, nil return rcpt, nil
}) })
pss := pss.New(logging.New(ioutil.Discard, 0), pushSyncService) pss := pss.New(logging.New(ioutil.Discard, 0))
pss.SetPushSyncer(pushSyncService)
target := trojan.Target([]byte{1}) // arbitrary test target target := trojan.Target([]byte{1}) // arbitrary test target
targets := trojan.Targets([]trojan.Target{target}) targets := trojan.Targets([]trojan.Target{target})
...@@ -64,54 +68,12 @@ func TestSend(t *testing.T) { ...@@ -64,54 +68,12 @@ func TestSend(t *testing.T) {
} }
} }
// TestRegister verifies that handler funcs are able to be registered correctly in pss
func TestRegister(t *testing.T) {
pss := pss.New(logging.New(ioutil.Discard, 0), nil)
handlerVerifier := 0 // test variable to check handler funcs are correctly retrieved
// register first handler
testHandler := func(ctx context.Context, m *trojan.Message) error {
handlerVerifier = 1
return nil
}
testTopic := trojan.NewTopic("FIRST_HANDLER")
pss.Register(testTopic, testHandler)
registeredHandler := pss.GetHandler(testTopic)
err := registeredHandler(context.Background(), &trojan.Message{}) // call handler to verify the retrieved func is correct
if err != nil {
t.Fatal(err)
}
if handlerVerifier != 1 {
t.Fatalf("unexpected handler retrieved, verifier variable should be 1 but is %d instead", handlerVerifier)
}
// register second handler
testHandler = func(ctx context.Context, m *trojan.Message) error {
handlerVerifier = 2
return nil
}
testTopic = trojan.NewTopic("SECOND_HANDLER")
pss.Register(testTopic, testHandler)
registeredHandler = pss.GetHandler(testTopic)
err = registeredHandler(context.Background(), &trojan.Message{}) // call handler to verify the retrieved func is correct
if err != nil {
t.Fatal(err)
}
if handlerVerifier != 2 {
t.Fatalf("unexpected handler retrieved, verifier variable should be 2 but is %d instead", handlerVerifier)
}
}
// TestDeliver verifies that registering a handler on pss for a given topic and then submitting a trojan chunk with said topic to it // TestDeliver verifies that registering a handler on pss for a given topic and then submitting a trojan chunk with said topic to it
// results in the execution of the expected handler func // results in the execution of the expected handler func
func TestDeliver(t *testing.T) { func TestDeliver(t *testing.T) {
pss := pss.New(logging.New(ioutil.Discard, 0), nil) pss := pss.New(logging.New(ioutil.Discard, 0))
ctx := context.TODO() ctx := context.Background()
var mtx sync.Mutex
// test message // test message
topic := trojan.NewTopic("footopic") topic := trojan.NewTopic("footopic")
...@@ -130,9 +92,10 @@ func TestDeliver(t *testing.T) { ...@@ -130,9 +92,10 @@ func TestDeliver(t *testing.T) {
// create and register handler // create and register handler
var tt trojan.Topic // test variable to check handler func was correctly called var tt trojan.Topic // test variable to check handler func was correctly called
hndlr := func(ctx context.Context, m *trojan.Message) error { hndlr := func(ctx context.Context, m *trojan.Message) {
tt = m.Topic // copy the message topic to the test variable mtx.Lock()
return nil copy(tt[:], m.Topic[:]) // copy the message topic to the test variable
mtx.Unlock()
} }
pss.Register(topic, hndlr) pss.Register(topic, hndlr)
...@@ -141,28 +104,125 @@ func TestDeliver(t *testing.T) { ...@@ -141,28 +104,125 @@ func TestDeliver(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if tt != msg.Topic { runtime.Gosched() // schedule the handler goroutine
t.Fatalf("unexpected result for pss Deliver func, expected test variable to have a value of %v but is %v instead", msg.Topic, tt) for i := 0; i < 10; i++ {
mtx.Lock()
eq := bytes.Equal(tt[:], msg.Topic[:])
mtx.Unlock()
if eq {
return
}
<-time.After(50 * time.Millisecond)
} }
t.Fatalf("unexpected result for pss Deliver func, expected test variable to have a value of %v but is %v instead", msg.Topic, tt)
} }
func TestHandler(t *testing.T) { // TestRegister verifies that handler funcs are able to be registered correctly in pss
pss := pss.New(logging.New(ioutil.Discard, 0), nil) func TestRegister(t *testing.T) {
testTopic := trojan.NewTopic("TEST_TOPIC") var (
pss = pss.New(logging.New(ioutil.Discard, 0))
h1Calls = 0
h2Calls = 0
h3Calls = 0
mtx sync.Mutex
topic1 = trojan.NewTopic("one")
topic2 = trojan.NewTopic("two")
payload = []byte("payload")
target = trojan.Target([]byte{1})
targets = trojan.Targets([]trojan.Target{target})
h1 = func(_ context.Context, m *trojan.Message) {
mtx.Lock()
defer mtx.Unlock()
h1Calls++
}
h2 = func(_ context.Context, m *trojan.Message) {
mtx.Lock()
defer mtx.Unlock()
h2Calls++
}
h3 = func(_ context.Context, m *trojan.Message) {
mtx.Lock()
defer mtx.Unlock()
h3Calls++
}
)
_ = pss.Register(topic1, h1)
_ = pss.Register(topic2, h2)
// send a message on topic1, check that only h1 is called
msg, err := trojan.NewMessage(topic1, payload)
if err != nil {
t.Fatal(err)
}
c, err := msg.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), c)
if err != nil {
t.Fatal(err)
}
ensureCalls(t, &mtx, &h1Calls, 1)
ensureCalls(t, &mtx, &h2Calls, 0)
// verify handler is null // register another topic handler on the same topic
if pss.GetHandler(testTopic) != nil { cleanup := pss.Register(topic1, h3)
t.Errorf("handler should be null") err = pss.TryUnwrap(context.Background(), c)
if err != nil {
t.Fatal(err)
} }
// register first handler ensureCalls(t, &mtx, &h1Calls, 2)
testHandler := func(ctx context.Context, m *trojan.Message) error { return nil } ensureCalls(t, &mtx, &h2Calls, 0)
ensureCalls(t, &mtx, &h3Calls, 1)
// set handler for test topic cleanup() // remove the last handler
pss.Register(testTopic, testHandler)
if pss.GetHandler(testTopic) == nil { err = pss.TryUnwrap(context.Background(), c)
t.Errorf("handler should be registered") if err != nil {
t.Fatal(err)
} }
ensureCalls(t, &mtx, &h1Calls, 3)
ensureCalls(t, &mtx, &h2Calls, 0)
ensureCalls(t, &mtx, &h3Calls, 1)
msg, err = trojan.NewMessage(topic2, payload)
if err != nil {
t.Fatal(err)
}
c, err = msg.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), c)
if err != nil {
t.Fatal(err)
}
ensureCalls(t, &mtx, &h1Calls, 3)
ensureCalls(t, &mtx, &h2Calls, 1)
ensureCalls(t, &mtx, &h3Calls, 1)
}
func ensureCalls(t *testing.T, mtx *sync.Mutex, calls *int, exp int) {
t.Helper()
for i := 0; i < 10; i++ {
mtx.Lock()
if *calls == exp {
mtx.Unlock()
return
}
mtx.Unlock()
<-time.After(100 * time.Millisecond)
}
t.Fatal("timed out waiting for value")
} }
...@@ -11,14 +11,18 @@ import ( ...@@ -11,14 +11,18 @@ import (
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
type PushSync struct { type mock struct {
sendChunk func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error) sendChunk func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error)
} }
func New(sendChunk func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error)) *PushSync { func New(sendChunk func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error)) pushsync.PushSyncer {
return &PushSync{sendChunk: sendChunk} return &mock{sendChunk: sendChunk}
} }
func (s *PushSync) PushChunkToClosest(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error) { func (s *mock) PushChunkToClosest(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error) {
return s.sendChunk(ctx, chunk) return s.sendChunk(ctx, chunk)
} }
func (s *mock) Close() error {
return nil
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package recovery
var (
ErrChunkNotPresent = errChunkNotPresent
)
...@@ -6,7 +6,6 @@ package recovery ...@@ -6,7 +6,6 @@ package recovery
import ( import (
"context" "context"
"errors"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss" "github.com/ethersphere/bee/pkg/pss"
...@@ -26,10 +25,6 @@ var ( ...@@ -26,10 +25,6 @@ var (
RecoveryTopic = trojan.NewTopic(RecoveryTopicText) RecoveryTopic = trojan.NewTopic(RecoveryTopicText)
) )
var (
errChunkNotPresent = errors.New("chunk repair: chunk not present in local store for repairing")
)
// RecoveryHook defines code to be executed upon failing to retrieve chunks. // RecoveryHook defines code to be executed upon failing to retrieve chunks.
type RecoveryHook func(chunkAddress swarm.Address, targets trojan.Targets) error type RecoveryHook func(chunkAddress swarm.Address, targets trojan.Targets) error
...@@ -50,32 +45,31 @@ func NewRecoveryHook(pss PssSender) RecoveryHook { ...@@ -50,32 +45,31 @@ func NewRecoveryHook(pss PssSender) RecoveryHook {
// NewRepairHandler creates a repair function to re-upload globally pinned chunks to the network with the given store. // NewRepairHandler creates a repair function to re-upload globally pinned chunks to the network with the given store.
func NewRepairHandler(s storage.Storer, logger logging.Logger, pushSyncer pushsync.PushSyncer) pss.Handler { func NewRepairHandler(s storage.Storer, logger logging.Logger, pushSyncer pushsync.PushSyncer) pss.Handler {
return func(ctx context.Context, m *trojan.Message) error { return func(ctx context.Context, m *trojan.Message) {
chAddr := m.Payload chAddr := m.Payload
// check if the chunk exists in the local store and proceed. // check if the chunk exists in the local store and proceed.
// otherwise the Get will trigger a unnecessary network retrieve // otherwise the Get will trigger a unnecessary network retrieve
exists, err := s.Has(ctx, swarm.NewAddress(chAddr)) exists, err := s.Has(ctx, swarm.NewAddress(chAddr))
if err != nil { if err != nil {
return err return
} }
if !exists { if !exists {
return errChunkNotPresent return
} }
// retrieve the chunk from the local store // retrieve the chunk from the local store
ch, err := s.Get(ctx, storage.ModeGetRequest, swarm.NewAddress(chAddr)) ch, err := s.Get(ctx, storage.ModeGetRequest, swarm.NewAddress(chAddr))
if err != nil { if err != nil {
logger.Tracef("chunk repair: error while getting chunk for repairing: %v", err) logger.Tracef("chunk repair: error while getting chunk for repairing: %v", err)
return err return
} }
// push the chunk using push sync so that it reaches it destination in network // push the chunk using push sync so that it reaches it destination in network
_, err = pushSyncer.PushChunkToClosest(ctx, ch) _, err = pushSyncer.PushChunkToClosest(ctx, ch)
if err != nil { if err != nil {
logger.Tracef("chunk repair: error while sending chunk or receiving receipt: %v", err) logger.Tracef("chunk repair: error while sending chunk or receiving receipt: %v", err)
return err return
} }
return nil
} }
} }
...@@ -155,10 +155,7 @@ func TestNewRepairHandler(t *testing.T) { ...@@ -155,10 +155,7 @@ func TestNewRepairHandler(t *testing.T) {
} }
// invoke the chunk repair handler // invoke the chunk repair handler
err = repairHandler(context.Background(), &msg) repairHandler(context.Background(), &msg)
if err != nil {
t.Fatal(err)
}
// check if receipt is received // check if receipt is received
if receipt == nil { if receipt == nil {
...@@ -200,10 +197,7 @@ func TestNewRepairHandler(t *testing.T) { ...@@ -200,10 +197,7 @@ func TestNewRepairHandler(t *testing.T) {
} }
// invoke the chunk repair handler // invoke the chunk repair handler
err = repairHandler(context.Background(), &msg) repairHandler(context.Background(), &msg)
if !errors.Is(err, recovery.ErrChunkNotPresent) {
t.Fatal(err)
}
if pushServiceCalled { if pushServiceCalled {
t.Fatal("push service called even if the chunk is not present") t.Fatal("push service called even if the chunk is not present")
...@@ -243,10 +237,7 @@ func TestNewRepairHandler(t *testing.T) { ...@@ -243,10 +237,7 @@ func TestNewRepairHandler(t *testing.T) {
} }
// invoke the chunk repair handler // invoke the chunk repair handler
err = repairHandler(context.Background(), &msg) repairHandler(context.Background(), &msg)
if err != nil && err != receiptError {
t.Fatal(err)
}
if receiptError == nil { if receiptError == nil {
t.Fatal("pushsync did not generate a receipt error") t.Fatal("pushsync did not generate a receipt error")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment