Commit 6ed63739 authored by acud's avatar acud Committed by GitHub

api: add pss client facing apis (#686)

* add pss websocket api for opening a persistent connection for incoming messages on a given topic
* add pss message send api to send outgoing messages
parent cf783b24
...@@ -156,7 +156,7 @@ func newTestServer(t *testing.T, storer storage.Storer) *url.URL { ...@@ -156,7 +156,7 @@ func newTestServer(t *testing.T, storer storage.Storer) *url.URL {
t.Helper() t.Helper()
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
store := statestore.NewStateStore() store := statestore.NewStateStore()
s := api.New(tags.NewTags(store, logger), storer, nil, logger, nil, api.Options{}) s := api.New(tags.NewTags(store, logger), storer, nil, nil, logger, nil, api.Options{})
ts := httptest.NewServer(s) ts := httptest.NewServer(s)
srvUrl, err := url.Parse(ts.URL) srvUrl, err := url.Parse(ts.URL)
if err != nil { if err != nil {
......
...@@ -15,6 +15,7 @@ require ( ...@@ -15,6 +15,7 @@ require (
github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 // indirect github.com/gopherjs/gopherjs v0.0.0-20200217142428-fce0ec30dd00 // indirect
github.com/gorilla/handlers v1.4.2 github.com/gorilla/handlers v1.4.2
github.com/gorilla/mux v1.7.4 github.com/gorilla/mux v1.7.4
github.com/gorilla/websocket v1.4.2
github.com/ipfs/go-log/v2 v2.1.1 // indirect github.com/ipfs/go-log/v2 v2.1.1 // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/libp2p/go-libp2p v0.10.0 github.com/libp2p/go-libp2p v0.10.0
......
...@@ -7,13 +7,16 @@ package api ...@@ -7,13 +7,16 @@ package api
import ( import (
"errors" "errors"
"fmt" "fmt"
"io"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
m "github.com/ethersphere/bee/pkg/metrics" m "github.com/ethersphere/bee/pkg/metrics"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/resolver" "github.com/ethersphere/bee/pkg/resolver"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
...@@ -37,22 +40,28 @@ var ( ...@@ -37,22 +40,28 @@ var (
type Service interface { type Service interface {
http.Handler http.Handler
m.Collector m.Collector
io.Closer
} }
type server struct { type server struct {
Tags *tags.Tags Tags *tags.Tags
Storer storage.Storer Storer storage.Storer
Resolver resolver.Interface Resolver resolver.Interface
Pss pss.Interface
Logger logging.Logger Logger logging.Logger
Tracer *tracing.Tracer Tracer *tracing.Tracer
Options Options
http.Handler http.Handler
metrics metrics metrics metrics
wsWg sync.WaitGroup // wait for all websockets to close on exit
quit chan struct{}
} }
type Options struct { type Options struct {
CORSAllowedOrigins []string CORSAllowedOrigins []string
GatewayMode bool GatewayMode bool
WsPingPeriod time.Duration
} }
const ( const (
...@@ -61,15 +70,17 @@ const ( ...@@ -61,15 +70,17 @@ const (
) )
// New will create a and initialize a new API service. // New will create a and initialize a new API service.
func New(tags *tags.Tags, storer storage.Storer, resolver resolver.Interface, logger logging.Logger, tracer *tracing.Tracer, o Options) Service { func New(tags *tags.Tags, storer storage.Storer, resolver resolver.Interface, pss pss.Interface, logger logging.Logger, tracer *tracing.Tracer, o Options) Service {
s := &server{ s := &server{
Tags: tags, Tags: tags,
Storer: storer, Storer: storer,
Resolver: resolver, Resolver: resolver,
Pss: pss,
Options: o, Options: o,
Logger: logger, Logger: logger,
Tracer: tracer, Tracer: tracer,
metrics: newMetrics(), metrics: newMetrics(),
quit: make(chan struct{}),
} }
s.setupRouting() s.setupRouting()
...@@ -77,6 +88,26 @@ func New(tags *tags.Tags, storer storage.Storer, resolver resolver.Interface, lo ...@@ -77,6 +88,26 @@ func New(tags *tags.Tags, storer storage.Storer, resolver resolver.Interface, lo
return s return s
} }
// Close hangs up running websockets on shutdown.
func (s *server) Close() error {
s.Logger.Info("api shutting down")
close(s.quit)
done := make(chan struct{})
go func() {
defer close(done)
s.wsWg.Wait()
}()
select {
case <-done:
case <-time.After(5 * time.Second):
return errors.New("api shutting down with open websockets")
}
return nil
}
// getOrCreateTag attempts to get the tag if an id is supplied, and returns an error if it does not exist. // getOrCreateTag attempts to get the tag if an id is supplied, and returns an error if it does not exist.
// If no id is supplied, it will attempt to create a new tag with a generated name and return it. // If no id is supplied, it will attempt to create a new tag with a generated name and return it.
func (s *server) getOrCreateTag(tagUid string) (*tags.Tag, bool, error) { func (s *server) getOrCreateTag(tagUid string) (*tags.Tag, bool, error) {
......
...@@ -11,39 +11,50 @@ import ( ...@@ -11,39 +11,50 @@ import (
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"testing" "testing"
"time"
"github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/resolver" "github.com/ethersphere/bee/pkg/resolver"
resolverMock "github.com/ethersphere/bee/pkg/resolver/mock" resolverMock "github.com/ethersphere/bee/pkg/resolver/mock"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
"github.com/gorilla/websocket"
"resenje.org/web" "resenje.org/web"
) )
type testServerOptions struct { type testServerOptions struct {
Storer storage.Storer Storer storage.Storer
Resolver resolver.Interface Resolver resolver.Interface
Pss pss.Interface
WsPath string
Tags *tags.Tags Tags *tags.Tags
GatewayMode bool GatewayMode bool
WsPingPeriod time.Duration
Logger logging.Logger Logger logging.Logger
} }
func newTestServer(t *testing.T, o testServerOptions) *http.Client { func newTestServer(t *testing.T, o testServerOptions) (*http.Client, *websocket.Conn, string) {
if o.Logger == nil { if o.Logger == nil {
o.Logger = logging.New(ioutil.Discard, 0) o.Logger = logging.New(ioutil.Discard, 0)
} }
if o.Resolver == nil { if o.Resolver == nil {
o.Resolver = resolverMock.NewResolver() o.Resolver = resolverMock.NewResolver()
} }
s := api.New(o.Tags, o.Storer, o.Resolver, o.Logger, nil, api.Options{ if o.WsPingPeriod == 0 {
o.WsPingPeriod = 60 * time.Second
}
s := api.New(o.Tags, o.Storer, o.Resolver, o.Pss, o.Logger, nil, api.Options{
GatewayMode: o.GatewayMode, GatewayMode: o.GatewayMode,
WsPingPeriod: o.WsPingPeriod,
}) })
ts := httptest.NewServer(s) ts := httptest.NewServer(s)
t.Cleanup(ts.Close) t.Cleanup(ts.Close)
return &http.Client{ var (
httpClient = &http.Client{
Transport: web.RoundTripperFunc(func(r *http.Request) (*http.Response, error) { Transport: web.RoundTripperFunc(func(r *http.Request) (*http.Response, error) {
u, err := url.Parse(ts.URL + r.URL.String()) u, err := url.Parse(ts.URL + r.URL.String())
if err != nil { if err != nil {
...@@ -53,6 +64,19 @@ func newTestServer(t *testing.T, o testServerOptions) *http.Client { ...@@ -53,6 +64,19 @@ func newTestServer(t *testing.T, o testServerOptions) *http.Client {
return ts.Client().Transport.RoundTrip(r) return ts.Client().Transport.RoundTrip(r)
}), }),
} }
conn *websocket.Conn
err error
)
if o.WsPath != "" {
u := url.URL{Scheme: "ws", Host: ts.Listener.Addr().String(), Path: o.WsPath}
conn, _, err = websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
t.Fatalf("dial: %v. url %v", err, u.String())
}
}
return httpClient, conn, ts.Listener.Addr().String()
} }
func TestParseName(t *testing.T) { func TestParseName(t *testing.T) {
...@@ -116,7 +140,7 @@ func TestParseName(t *testing.T) { ...@@ -116,7 +140,7 @@ func TestParseName(t *testing.T) {
})) }))
} }
s := api.New(nil, nil, tC.res, tC.log, nil, api.Options{}).(*api.Server) s := api.New(nil, nil, tC.res, nil, tC.log, nil, api.Options{}).(*api.Server)
t.Run(tC.desc, func(t *testing.T) { t.Run(tC.desc, func(t *testing.T) {
got, err := s.ResolveNameOrAddress(tC.name) got, err := s.ResolveNameOrAddress(tC.name)
......
...@@ -6,11 +6,12 @@ package api_test ...@@ -6,11 +6,12 @@ package api_test
import ( import (
"bytes" "bytes"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"testing" "testing"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
...@@ -31,7 +32,7 @@ func TestBytes(t *testing.T) { ...@@ -31,7 +32,7 @@ func TestBytes(t *testing.T) {
mockStorer = mock.NewStorer() mockStorer = mock.NewStorer()
mockStatestore = statestore.NewStateStore() mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{ client, _, _ = newTestServer(t, testServerOptions{
Storer: mockStorer, Storer: mockStorer,
Tags: tags.NewTags(mockStatestore, logger), Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5), Logger: logging.New(ioutil.Discard, 5),
......
...@@ -9,7 +9,6 @@ import ( ...@@ -9,7 +9,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io" "io"
"io/ioutil" "io/ioutil"
"mime" "mime"
...@@ -17,6 +16,8 @@ import ( ...@@ -17,6 +16,8 @@ import (
"strings" "strings"
"testing" "testing"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/collection/entry" "github.com/ethersphere/bee/pkg/collection/entry"
"github.com/ethersphere/bee/pkg/file/pipeline" "github.com/ethersphere/bee/pkg/file/pipeline"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
...@@ -36,7 +37,7 @@ func TestBzz(t *testing.T) { ...@@ -36,7 +37,7 @@ func TestBzz(t *testing.T) {
ctx = context.Background() ctx = context.Background()
mockStatestore = statestore.NewStateStore() mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{ client, _, _ = newTestServer(t, testServerOptions{
Storer: storer, Storer: storer,
Tags: tags.NewTags(mockStatestore, logger), Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5), Logger: logging.New(ioutil.Discard, 5),
......
...@@ -6,13 +6,14 @@ package api_test ...@@ -6,13 +6,14 @@ package api_test
import ( import (
"bytes" "bytes"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"testing" "testing"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/api"
...@@ -41,7 +42,7 @@ func TestChunkUploadDownload(t *testing.T) { ...@@ -41,7 +42,7 @@ func TestChunkUploadDownload(t *testing.T) {
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger) tag = tags.NewTags(mockStatestore, logger)
mockValidatingStorer = mock.NewStorer(mock.WithValidator(mockValidator)) mockValidatingStorer = mock.NewStorer(mock.WithValidator(mockValidator))
client = newTestServer(t, testServerOptions{ client, _, _ = newTestServer(t, testServerOptions{
Storer: mockValidatingStorer, Storer: mockValidatingStorer,
Tags: tag, Tags: tag,
}) })
......
...@@ -35,7 +35,7 @@ func TestDirs(t *testing.T) { ...@@ -35,7 +35,7 @@ func TestDirs(t *testing.T) {
storer = mock.NewStorer() storer = mock.NewStorer()
mockStatestore = statestore.NewStateStore() mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{ client, _, _ = newTestServer(t, testServerOptions{
Storer: storer, Storer: storer,
Tags: tags.NewTags(mockStatestore, logger), Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5), Logger: logging.New(ioutil.Discard, 5),
......
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io" "io"
"io/ioutil" "io/ioutil"
"mime" "mime"
...@@ -18,6 +17,8 @@ import ( ...@@ -18,6 +17,8 @@ import (
"strings" "strings"
"testing" "testing"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
...@@ -35,7 +36,7 @@ func TestFiles(t *testing.T) { ...@@ -35,7 +36,7 @@ func TestFiles(t *testing.T) {
simpleData = []byte("this is a simple text") simpleData = []byte("this is a simple text")
mockStatestore = statestore.NewStateStore() mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{ client, _, _ = newTestServer(t, testServerOptions{
Storer: mock.NewStorer(), Storer: mock.NewStorer(),
Tags: tags.NewTags(mockStatestore, logger), Tags: tags.NewTags(mockStatestore, logger),
}) })
...@@ -338,7 +339,7 @@ func TestRangeRequests(t *testing.T) { ...@@ -338,7 +339,7 @@ func TestRangeRequests(t *testing.T) {
t.Run(upload.name, func(t *testing.T) { t.Run(upload.name, func(t *testing.T) {
mockStatestore := statestore.NewStateStore() mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
client := newTestServer(t, testServerOptions{ client, _, _ := newTestServer(t, testServerOptions{
Storer: mock.NewStorer(), Storer: mock.NewStorer(),
Tags: tags.NewTags(mockStatestore, logger), Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5), Logger: logging.New(ioutil.Discard, 5),
......
...@@ -20,7 +20,7 @@ import ( ...@@ -20,7 +20,7 @@ import (
func TestGatewayMode(t *testing.T) { func TestGatewayMode(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
client := newTestServer(t, testServerOptions{ client, _, _ := newTestServer(t, testServerOptions{
Storer: mock.NewStorer(), Storer: mock.NewStorer(),
Tags: tags.NewTags(statestore.NewStateStore(), logger), Tags: tags.NewTags(statestore.NewStateStore(), logger),
Logger: logger, Logger: logger,
......
...@@ -37,7 +37,7 @@ func TestPinChunkHandler(t *testing.T) { ...@@ -37,7 +37,7 @@ func TestPinChunkHandler(t *testing.T) {
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger) tag = tags.NewTags(mockStatestore, logger)
client = newTestServer(t, testServerOptions{ client, _, _ = newTestServer(t, testServerOptions{
Storer: mockValidatingStorer, Storer: mockValidatingStorer,
Tags: tag, Tags: tag,
Logger: logger, Logger: logger,
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package api
import (
"context"
"encoding/hex"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
)
var (
upgrader = websocket.Upgrader{
ReadBufferSize: swarm.ChunkSize,
WriteBufferSize: swarm.ChunkSize,
}
writeDeadline = 4 * time.Second // write deadline. should be smaller than the shutdown timeout on api close
targetMaxLength = 2 // max target length in bytes, in order to prevent grieving by excess computation
)
type PssMessage struct {
Topic string
Message string
}
func (s *server) pssPostHandler(w http.ResponseWriter, r *http.Request) {
t := mux.Vars(r)["topic"]
topic := trojan.NewTopic(t)
tg := mux.Vars(r)["targets"]
var targets trojan.Targets
tgts := strings.Split(tg, ",")
for _, v := range tgts {
target, err := hex.DecodeString(v)
if err != nil || len(target) > targetMaxLength {
s.Logger.Debugf("pss send: bad targets: %v", err)
s.Logger.Error("pss send: bad targets")
jsonhttp.BadRequest(w, nil)
return
}
targets = append(targets, target)
}
payload, err := ioutil.ReadAll(r.Body)
if err != nil {
s.Logger.Debugf("pss read payload: %v", err)
s.Logger.Error("pss read payload")
jsonhttp.InternalServerError(w, nil)
return
}
err = s.Pss.Send(r.Context(), targets, topic, payload)
if err != nil {
s.Logger.Debugf("pss send payload: %v. topic: %s", err, t)
s.Logger.Error("pss send payload")
jsonhttp.InternalServerError(w, nil)
return
}
jsonhttp.OK(w, nil)
}
func (s *server) pssWsHandler(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
s.Logger.Debugf("pss ws: upgrade: %v", err)
s.Logger.Error("pss ws: cannot upgrade")
jsonhttp.InternalServerError(w, nil)
return
}
t := mux.Vars(r)["topic"]
s.wsWg.Add(1)
go s.pumpWs(conn, t)
}
func (s *server) pumpWs(conn *websocket.Conn, t string) {
defer s.wsWg.Done()
var (
dataC = make(chan []byte)
gone = make(chan struct{})
topic = trojan.NewTopic(t)
ticker = time.NewTicker(s.WsPingPeriod)
err error
)
defer func() {
ticker.Stop()
_ = conn.Close()
}()
cleanup := s.Pss.Register(topic, func(_ context.Context, m *trojan.Message) {
dataC <- m.Payload
})
defer cleanup()
conn.SetCloseHandler(func(code int, text string) error {
s.Logger.Debugf("pss handler: client gone. code %d message %s", code, text)
close(gone)
return nil
})
for {
select {
case b := <-dataC:
err = conn.SetWriteDeadline(time.Now().Add(writeDeadline))
if err != nil {
s.Logger.Debugf("pss set write deadline: %v", err)
return
}
err = conn.WriteMessage(websocket.BinaryMessage, b)
if err != nil {
s.Logger.Debugf("pss write to websocket: %v", err)
return
}
case <-s.quit:
// shutdown
err = conn.SetWriteDeadline(time.Now().Add(writeDeadline))
if err != nil {
s.Logger.Debugf("pss set write deadline: %v", err)
return
}
err = conn.WriteMessage(websocket.CloseMessage, []byte{})
if err != nil {
s.Logger.Debugf("pss write close message: %v", err)
}
return
case <-gone:
// client gone
return
case <-ticker.C:
err = conn.SetWriteDeadline(time.Now().Add(writeDeadline))
if err != nil {
s.Logger.Debugf("pss set write deadline: %v", err)
return
}
if err = conn.WriteMessage(websocket.PingMessage, nil); err != nil {
// error encountered while pinging client. client probably gone
return
}
}
}
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package api_test
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"sync"
"testing"
"time"
"github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
"github.com/gorilla/websocket"
)
var (
target = trojan.Target([]byte{1})
targets = trojan.Targets([]trojan.Target{target})
payload = []byte("testdata")
topic = trojan.NewTopic("testtopic")
timeout = 10 * time.Second
)
// creates a single websocket handler for an arbitrary topic, and receives a message
func TestPssWebsocketSingleHandler(t *testing.T) {
var (
pss, cl, _ = newPssTest(t, opts{})
msgContent = make([]byte, len(payload))
tc swarm.Chunk
mtx sync.Mutex
done = make(chan struct{})
)
err := cl.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
t.Fatal(err)
}
cl.SetReadLimit(swarm.ChunkSize)
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
m, err := trojan.NewMessage(topic, payload)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, payload, &mtx)
}
func TestPssWebsocketSingleHandlerDeregister(t *testing.T) {
// create a new pss instance, register a handle through ws, call
// pss.TryUnwrap with a chunk designated for this handler and expect
// the handler to be notified
var (
pss, cl, _ = newPssTest(t, opts{})
msgContent = make([]byte, len(payload))
tc swarm.Chunk
mtx sync.Mutex
done = make(chan struct{})
)
err := cl.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
t.Fatal(err)
}
cl.SetReadLimit(swarm.ChunkSize)
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
m, err := trojan.NewMessage(topic, payload)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
// close the websocket before calling pss with the message
err = cl.WriteMessage(websocket.CloseMessage, []byte{})
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, nil, &mtx)
}
func TestPssWebsocketMultiHandler(t *testing.T) {
var (
pss, cl, listener = newPssTest(t, opts{})
u = url.URL{Scheme: "ws", Host: listener, Path: "/pss/subscribe/testtopic"}
cl2, _, err = websocket.DefaultDialer.Dial(u.String(), nil)
msgContent = make([]byte, len(payload))
msgContent2 = make([]byte, len(payload))
tc swarm.Chunk
mtx sync.Mutex
done = make(chan struct{})
)
if err != nil {
t.Fatalf("dial: %v. url %v", err, u.String())
}
err = cl.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
t.Fatal(err)
}
cl.SetReadLimit(swarm.ChunkSize)
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
go waitReadMessage(t, &mtx, cl2, msgContent2, done)
m, err := trojan.NewMessage(topic, payload)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
// close the websocket before calling pss with the message
err = cl.WriteMessage(websocket.CloseMessage, []byte{})
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, nil, &mtx)
waitMessage(t, msgContent2, nil, &mtx)
}
// TestPssSend tests that the pss message sending over http works correctly.
func TestPssSend(t *testing.T) {
var (
logger = logging.New(ioutil.Discard, 0)
mtx sync.Mutex
recievedTargets trojan.Targets
recievedTopic trojan.Topic
recievedBytes []byte
done bool
sendFn = func(_ context.Context, targets trojan.Targets, topic trojan.Topic, bytes []byte) error {
mtx.Lock()
recievedTargets = targets
recievedTopic = topic
recievedBytes = bytes
done = true
mtx.Unlock()
return nil
}
pss = newMockPss(sendFn)
client, _, _ = newTestServer(t, testServerOptions{
Pss: pss,
Storer: mock.NewStorer(),
Logger: logger,
})
targets = fmt.Sprintf("[[%d]]", 0x12)
topic = "testtopic"
hasher = swarm.NewHasher()
_, err = hasher.Write([]byte(topic))
topicHash = hasher.Sum(nil)
)
if err != nil {
t.Fatal(err)
}
t.Run("err - bad targets", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/to/badtarget", http.StatusBadRequest,
jsonhttptest.WithRequestBody(bytes.NewReader(payload)),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: "Bad Request",
Code: http.StatusBadRequest,
}),
)
})
t.Run("ok", func(t *testing.T) {
jsonhttptest.Request(t, client, http.MethodPost, "/pss/send/testtopic/12", http.StatusOK,
jsonhttptest.WithRequestBody(bytes.NewReader(payload)),
jsonhttptest.WithExpectedJSONResponse(jsonhttp.StatusResponse{
Message: "OK",
Code: http.StatusOK,
}),
)
waitDone(t, &mtx, &done)
if !bytes.Equal(recievedBytes, payload) {
t.Fatalf("payload mismatch. want %v got %v", payload, recievedBytes)
}
if targets != fmt.Sprint(recievedTargets) {
t.Fatalf("targets mismatch. want %v got %v", targets, recievedTargets)
}
if string(topicHash) != string(recievedTopic[:]) {
t.Fatalf("topic mismatch. want %v got %v", topic, string(recievedTopic[:]))
}
})
}
// TestPssPingPong tests that the websocket api adheres to the websocket standard
// and sends ping-pong messages to keep the connection alive.
// The test opens a websocket, keeps it alive for 500ms, then receives a pss message.
func TestPssPingPong(t *testing.T) {
var (
pss, cl, _ = newPssTest(t, opts{pingPeriod: 90 * time.Millisecond})
msgContent = make([]byte, len(payload))
tc swarm.Chunk
mtx sync.Mutex
pongWait = 1 * time.Millisecond
done = make(chan struct{})
)
cl.SetReadLimit(swarm.ChunkSize)
err := cl.SetReadDeadline(time.Now().Add(pongWait))
if err != nil {
t.Fatal(err)
}
defer close(done)
go waitReadMessage(t, &mtx, cl, msgContent, done)
m, err := trojan.NewMessage(topic, payload)
if err != nil {
t.Fatal(err)
}
tc, err = m.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
time.Sleep(500 * time.Millisecond) // wait to see that the websocket is kept alive
err = pss.TryUnwrap(context.Background(), tc)
if err != nil {
t.Fatal(err)
}
waitMessage(t, msgContent, nil, &mtx)
}
func waitReadMessage(t *testing.T, mtx *sync.Mutex, cl *websocket.Conn, targetContent []byte, done <-chan struct{}) {
t.Helper()
timeout := time.After(timeout)
for {
select {
case <-done:
return
case <-timeout:
t.Errorf("timed out waiting for message")
return
default:
}
msgType, message, err := cl.ReadMessage()
if err != nil {
return
}
if msgType == websocket.PongMessage {
// ignore pings
continue
}
if message != nil {
mtx.Lock()
copy(targetContent, message)
mtx.Unlock()
}
time.Sleep(50 * time.Millisecond)
}
}
func waitDone(t *testing.T, mtx *sync.Mutex, done *bool) {
for i := 0; i < 10; i++ {
mtx.Lock()
if *done {
mtx.Unlock()
return
}
mtx.Unlock()
time.Sleep(50 * time.Millisecond)
}
t.Fatal("timed out waiting for send")
}
func waitMessage(t *testing.T, data, expData []byte, mtx *sync.Mutex) {
ttl := time.After(timeout)
for {
select {
case <-ttl:
if expData == nil {
return
}
t.Fatal("timed out waiting for pss message")
default:
}
mtx.Lock()
if bytes.Equal(data, expData) {
mtx.Unlock()
return
}
mtx.Unlock()
time.Sleep(100 * time.Millisecond)
}
}
type opts struct {
pingPeriod time.Duration
}
func newPssTest(t *testing.T, o opts) (pss.Interface, *websocket.Conn, string) {
var (
logger = logging.New(ioutil.Discard, 0)
pss = pss.New(logger)
)
if o.pingPeriod == 0 {
o.pingPeriod = 10 * time.Second
}
_, cl, listener := newTestServer(t, testServerOptions{
Pss: pss,
WsPath: "/pss/subscribe/testtopic",
Storer: mock.NewStorer(),
Logger: logger,
WsPingPeriod: o.pingPeriod,
})
return pss, cl, listener
}
type pssSendFn func(context.Context, trojan.Targets, trojan.Topic, []byte) error
type mpss struct {
f pssSendFn
}
func newMockPss(f pssSendFn) *mpss {
return &mpss{f}
}
// Send arbitrary byte slice with the given topic to Targets.
func (m *mpss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, bytes []byte) error {
return m.f(ctx, targets, topic, bytes)
}
// Register a Handler for a given Topic.
func (m *mpss) Register(_ trojan.Topic, _ pss.Handler) func() {
panic("not implemented") // TODO: Implement
}
// TryUnwrap tries to unwrap a wrapped trojan message.
func (m *mpss) TryUnwrap(_ context.Context, _ swarm.Chunk) error {
panic("not implemented") // TODO: Implement
}
func (m *mpss) SetPushSyncer(pushSyncer pushsync.PushSyncer) {
panic("not implemented") // TODO: Implement
}
func (m *mpss) Close() error {
panic("not implemented") // TODO: Implement
}
...@@ -90,6 +90,15 @@ func (s *server) setupRouting() { ...@@ -90,6 +90,15 @@ func (s *server) setupRouting() {
), ),
}) })
handle(router, "/pss/send/{topic}/{targets}", jsonhttp.MethodHandler{
"POST": web.ChainHandlers(
jsonhttp.NewMaxBodyBytesHandler(swarm.ChunkSize),
web.FinalHandlerFunc(s.pssPostHandler),
),
})
handle(router, "/pss/subscribe/{topic}", http.HandlerFunc(s.pssWsHandler))
handle(router, "/tags", web.ChainHandlers( handle(router, "/tags", web.ChainHandlers(
s.gatewayModeForbidEndpointHandler, s.gatewayModeForbidEndpointHandler,
web.FinalHandler(jsonhttp.MethodHandler{ web.FinalHandler(jsonhttp.MethodHandler{
......
...@@ -8,14 +8,15 @@ import ( ...@@ -8,14 +8,15 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest" "github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
...@@ -46,7 +47,7 @@ func TestTags(t *testing.T) { ...@@ -46,7 +47,7 @@ func TestTags(t *testing.T) {
logger = logging.New(ioutil.Discard, 0) logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger) tag = tags.NewTags(mockStatestore, logger)
mockPusher = mp.NewMockPusher(tag) mockPusher = mp.NewMockPusher(tag)
client = newTestServer(t, testServerOptions{ client, _, _ = newTestServer(t, testServerOptions{
Storer: mock.NewStorer(), Storer: mock.NewStorer(),
Tags: tag, Tags: tag,
}) })
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
package logging package logging
import ( import (
"bufio"
"net" "net"
"net/http" "net/http"
"time" "time"
...@@ -89,6 +90,10 @@ func (l *responseLogger) Flush() { ...@@ -89,6 +90,10 @@ func (l *responseLogger) Flush() {
l.w.(http.Flusher).Flush() l.w.(http.Flusher).Flush()
} }
func (l *responseLogger) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return l.w.(http.Hijacker).Hijack()
}
func (l *responseLogger) CloseNotify() <-chan bool { func (l *responseLogger) CloseNotify() <-chan bool {
// staticcheck SA1019 CloseNotifier interface is required by gorilla compress handler // staticcheck SA1019 CloseNotifier interface is required by gorilla compress handler
// nolint:staticcheck // nolint:staticcheck
......
...@@ -54,6 +54,7 @@ import ( ...@@ -54,6 +54,7 @@ import (
type Bee struct { type Bee struct {
p2pService io.Closer p2pService io.Closer
p2pCancel context.CancelFunc p2pCancel context.CancelFunc
apiCloser io.Closer
apiServer *http.Server apiServer *http.Server
debugAPIServer *http.Server debugAPIServer *http.Server
resolverCloser io.Closer resolverCloser io.Closer
...@@ -66,6 +67,8 @@ type Bee struct { ...@@ -66,6 +67,8 @@ type Bee struct {
pusherCloser io.Closer pusherCloser io.Closer
pullerCloser io.Closer pullerCloser io.Closer
pullSyncCloser io.Closer pullSyncCloser io.Closer
pssCloser io.Closer
recoveryHandleCleanup func()
} }
type Options struct { type Options struct {
...@@ -247,7 +250,8 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -247,7 +250,8 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
} }
// instantiate the pss object // instantiate the pss object
psss := pss.New(logger, nil) psss := pss.New(logger)
b.pssCloser = psss
var ns storage.Storer var ns storage.Storer
if o.GlobalPinningEnabled { if o.GlobalPinningEnabled {
...@@ -262,7 +266,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -262,7 +266,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagg, psss.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10)) pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagg, psss.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10))
// set the pushSyncer in the PSS // set the pushSyncer in the PSS
psss.WithPushSyncer(pushSyncProtocol) psss.SetPushSyncer(pushSyncProtocol)
if err = p2ps.AddProtocol(pushSyncProtocol.Protocol()); err != nil { if err = p2ps.AddProtocol(pushSyncProtocol.Protocol()); err != nil {
return nil, fmt.Errorf("pushsync service: %w", err) return nil, fmt.Errorf("pushsync service: %w", err)
...@@ -271,7 +275,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -271,7 +275,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
if o.GlobalPinningEnabled { if o.GlobalPinningEnabled {
// register function for chunk repair upon receiving a trojan message // register function for chunk repair upon receiving a trojan message
chunkRepairHandler := recovery.NewRepairHandler(ns, logger, pushSyncProtocol) chunkRepairHandler := recovery.NewRepairHandler(ns, logger, pushSyncProtocol)
psss.Register(recovery.RecoveryTopic, chunkRepairHandler) b.recoveryHandleCleanup = psss.Register(recovery.RecoveryTopic, chunkRepairHandler)
} }
pushSyncPusher := pusher.New(storer, kad, pushSyncProtocol, tagg, logger) pushSyncPusher := pusher.New(storer, kad, pushSyncProtocol, tagg, logger)
...@@ -299,9 +303,10 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -299,9 +303,10 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
var apiService api.Service var apiService api.Service
if o.APIAddr != "" { if o.APIAddr != "" {
// API server // API server
apiService = api.New(tagg, ns, multiResolver, logger, tracer, api.Options{ apiService = api.New(tagg, ns, multiResolver, psss, logger, tracer, api.Options{
CORSAllowedOrigins: o.CORSAllowedOrigins, CORSAllowedOrigins: o.CORSAllowedOrigins,
GatewayMode: o.GatewayMode, GatewayMode: o.GatewayMode,
WsPingPeriod: 60 * time.Second,
}) })
apiListener, err := net.Listen("tcp", o.APIAddr) apiListener, err := net.Listen("tcp", o.APIAddr)
if err != nil { if err != nil {
...@@ -323,6 +328,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -323,6 +328,7 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
}() }()
b.apiServer = apiServer b.apiServer = apiServer
b.apiCloser = apiService
} }
if o.DebugAPIAddr != "" { if o.DebugAPIAddr != "" {
...@@ -373,6 +379,12 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -373,6 +379,12 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
func (b *Bee) Shutdown(ctx context.Context) error { func (b *Bee) Shutdown(ctx context.Context) error {
errs := new(multiError) errs := new(multiError)
if b.apiCloser != nil {
if err := b.apiCloser.Close(); err != nil {
errs.add(fmt.Errorf("api: %w", err))
}
}
var eg errgroup.Group var eg errgroup.Group
if b.apiServer != nil { if b.apiServer != nil {
eg.Go(func() error { eg.Go(func() error {
...@@ -395,6 +407,10 @@ func (b *Bee) Shutdown(ctx context.Context) error { ...@@ -395,6 +407,10 @@ func (b *Bee) Shutdown(ctx context.Context) error {
errs.add(err) errs.add(err)
} }
if b.recoveryHandleCleanup != nil {
b.recoveryHandleCleanup()
}
if err := b.pusherCloser.Close(); err != nil { if err := b.pusherCloser.Close(); err != nil {
errs.add(fmt.Errorf("pusher: %w", err)) errs.add(fmt.Errorf("pusher: %w", err))
} }
...@@ -407,6 +423,10 @@ func (b *Bee) Shutdown(ctx context.Context) error { ...@@ -407,6 +423,10 @@ func (b *Bee) Shutdown(ctx context.Context) error {
errs.add(fmt.Errorf("pull sync: %w", err)) errs.add(fmt.Errorf("pull sync: %w", err))
} }
if err := b.pssCloser.Close(); err != nil {
errs.add(fmt.Errorf("pss: %w", err))
}
b.p2pCancel() b.p2pCancel()
if err := b.p2pService.Close(); err != nil { if err := b.p2pService.Close(); err != nil {
errs.add(fmt.Errorf("p2p server: %w", err)) errs.add(fmt.Errorf("p2p server: %w", err))
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"io"
"sync" "sync"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
...@@ -22,54 +23,67 @@ var ( ...@@ -22,54 +23,67 @@ var (
) )
type Interface interface { type Interface interface {
Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, payload []byte) error // Send arbitrary byte slice with the given topic to Targets.
Register(topic trojan.Topic, hndlr Handler) Send(context.Context, trojan.Targets, trojan.Topic, []byte) error
GetHandler(topic trojan.Topic) Handler // Register a Handler for a given Topic.
TryUnwrap(ctx context.Context, c swarm.Chunk) error Register(trojan.Topic, Handler) func()
WithPushSyncer(pushSyncer pushsync.PushSyncer) // TryUnwrap tries to unwrap a wrapped trojan message.
TryUnwrap(context.Context, swarm.Chunk) error
SetPushSyncer(pushSyncer pushsync.PushSyncer)
io.Closer
} }
// pss is the top-level struct, which takes care of message sending
type pss struct { type pss struct {
pusher pushsync.PushSyncer pusher pushsync.PushSyncer
handlers map[trojan.Topic]Handler handlers map[trojan.Topic][]*Handler
handlersMu sync.RWMutex handlersMu sync.Mutex
metrics metrics metrics metrics
logger logging.Logger logger logging.Logger
quit chan struct{}
} }
// New inits the pss struct with the storer // New returns a new pss service.
func New(logger logging.Logger, pusher pushsync.PushSyncer) Interface { func New(logger logging.Logger) Interface {
return &pss{ return &pss{
logger: logger, logger: logger,
pusher: pusher, handlers: make(map[trojan.Topic][]*Handler),
handlers: make(map[trojan.Topic]Handler),
metrics: newMetrics(), metrics: newMetrics(),
quit: make(chan struct{}),
} }
} }
func (ps *pss) WithPushSyncer(pushSyncer pushsync.PushSyncer) { func (ps *pss) Close() error {
close(ps.quit)
ps.handlersMu.Lock()
defer ps.handlersMu.Unlock()
ps.handlers = make(map[trojan.Topic][]*Handler) //unset handlers on shutdown
return nil
}
func (ps *pss) SetPushSyncer(pushSyncer pushsync.PushSyncer) {
ps.pusher = pushSyncer ps.pusher = pushSyncer
} }
// Handler defines code to be executed upon reception of a trojan message // Handler defines code to be executed upon reception of a trojan message.
type Handler func(context.Context, *trojan.Message) error type Handler func(context.Context, *trojan.Message)
// Send constructs a padded message with topic and payload, // Send constructs a padded message with topic and payload,
// wraps it in a trojan chunk such that one of the targets is a prefix of the chunk address // wraps it in a trojan chunk such that one of the targets is a prefix of the chunk address.
// uses push-sync to deliver message // Uses push-sync to deliver message.
func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, payload []byte) error { func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, payload []byte) error {
p.metrics.TotalMessagesSentCounter.Inc() p.metrics.TotalMessagesSentCounter.Inc()
//construct Trojan Chunk
m, err := trojan.NewMessage(topic, payload) m, err := trojan.NewMessage(topic, payload)
if err != nil { if err != nil {
return err return err
} }
var tc swarm.Chunk var tc swarm.Chunk
tc, err = m.Wrap(ctx, targets) tc, err = m.Wrap(ctx, targets)
if err != nil { if err != nil {
return err return err
} }
...@@ -81,33 +95,69 @@ func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Top ...@@ -81,33 +95,69 @@ func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Top
return nil return nil
} }
// Register allows the definition of a Handler func for a specific topic on the pss struct // Register allows the definition of a Handler func for a specific topic on the pss struct.
func (p *pss) Register(topic trojan.Topic, hndlr Handler) { func (p *pss) Register(topic trojan.Topic, handler Handler) (cleanup func()) {
p.handlersMu.Lock()
defer p.handlersMu.Unlock()
p.handlers[topic] = append(p.handlers[topic], &handler)
return func() {
p.handlersMu.Lock() p.handlersMu.Lock()
defer p.handlersMu.Unlock() defer p.handlersMu.Unlock()
p.handlers[topic] = hndlr
h := p.handlers[topic]
for i := 0; i < len(h); i++ {
if h[i] == &handler {
p.handlers[topic] = append(h[:i], h[i+1:]...)
return
}
}
}
} }
// TryUnwrap allows unwrapping a chunk as a trojan message and calling its handler func based on its topic // TryUnwrap allows unwrapping a chunk as a trojan message and calling its handlers based on the topic.
func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error { func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error {
if !trojan.IsPotential(c) { if !trojan.IsPotential(c) {
return nil return nil
} }
m, err := trojan.Unwrap(c) // if err occurs unwrapping, there will be no handler m, err := trojan.Unwrap(c)
if err != nil { if err != nil {
return err return err
} }
h := p.GetHandler(m.Topic) h := p.getHandlers(m.Topic)
if h == nil { if h == nil {
return fmt.Errorf("topic %v, %w", m.Topic, ErrNoHandler) return fmt.Errorf("topic %v, %w", m.Topic, ErrNoHandler)
} }
return h(ctx, m)
ctx, cancel := context.WithCancel(ctx)
done := make(chan struct{})
var wg sync.WaitGroup
go func() {
defer cancel()
select {
case <-p.quit:
case <-done:
}
}()
for _, hh := range h {
wg.Add(1)
go func(hh Handler) {
defer wg.Done()
hh(ctx, m)
}(*hh)
}
go func() {
wg.Wait()
close(done)
}()
return nil
} }
// GetHandler returns the Handler func registered in pss for the given topic func (p *pss) getHandlers(topic trojan.Topic) []*Handler {
func (p *pss) GetHandler(topic trojan.Topic) Handler { p.handlersMu.Lock()
p.handlersMu.RLock() defer p.handlersMu.Unlock()
defer p.handlersMu.RUnlock()
return p.handlers[topic] return p.handlers[topic]
} }
...@@ -8,7 +8,10 @@ import ( ...@@ -8,7 +8,10 @@ import (
"bytes" "bytes"
"context" "context"
"io/ioutil" "io/ioutil"
"runtime"
"sync"
"testing" "testing"
"time"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss" "github.com/ethersphere/bee/pkg/pss"
...@@ -21,7 +24,7 @@ import ( ...@@ -21,7 +24,7 @@ import (
// TestSend creates a trojan chunk and sends it using push sync // TestSend creates a trojan chunk and sends it using push sync
func TestSend(t *testing.T) { func TestSend(t *testing.T) {
var err error var err error
ctx := context.TODO() ctx := context.Background()
// create a mock pushsync service to push the chunk to its destination // create a mock pushsync service to push the chunk to its destination
var receipt *pushsync.Receipt var receipt *pushsync.Receipt
...@@ -35,7 +38,8 @@ func TestSend(t *testing.T) { ...@@ -35,7 +38,8 @@ func TestSend(t *testing.T) {
return rcpt, nil return rcpt, nil
}) })
pss := pss.New(logging.New(ioutil.Discard, 0), pushSyncService) pss := pss.New(logging.New(ioutil.Discard, 0))
pss.SetPushSyncer(pushSyncService)
target := trojan.Target([]byte{1}) // arbitrary test target target := trojan.Target([]byte{1}) // arbitrary test target
targets := trojan.Targets([]trojan.Target{target}) targets := trojan.Targets([]trojan.Target{target})
...@@ -64,54 +68,12 @@ func TestSend(t *testing.T) { ...@@ -64,54 +68,12 @@ func TestSend(t *testing.T) {
} }
} }
// TestRegister verifies that handler funcs are able to be registered correctly in pss
func TestRegister(t *testing.T) {
pss := pss.New(logging.New(ioutil.Discard, 0), nil)
handlerVerifier := 0 // test variable to check handler funcs are correctly retrieved
// register first handler
testHandler := func(ctx context.Context, m *trojan.Message) error {
handlerVerifier = 1
return nil
}
testTopic := trojan.NewTopic("FIRST_HANDLER")
pss.Register(testTopic, testHandler)
registeredHandler := pss.GetHandler(testTopic)
err := registeredHandler(context.Background(), &trojan.Message{}) // call handler to verify the retrieved func is correct
if err != nil {
t.Fatal(err)
}
if handlerVerifier != 1 {
t.Fatalf("unexpected handler retrieved, verifier variable should be 1 but is %d instead", handlerVerifier)
}
// register second handler
testHandler = func(ctx context.Context, m *trojan.Message) error {
handlerVerifier = 2
return nil
}
testTopic = trojan.NewTopic("SECOND_HANDLER")
pss.Register(testTopic, testHandler)
registeredHandler = pss.GetHandler(testTopic)
err = registeredHandler(context.Background(), &trojan.Message{}) // call handler to verify the retrieved func is correct
if err != nil {
t.Fatal(err)
}
if handlerVerifier != 2 {
t.Fatalf("unexpected handler retrieved, verifier variable should be 2 but is %d instead", handlerVerifier)
}
}
// TestDeliver verifies that registering a handler on pss for a given topic and then submitting a trojan chunk with said topic to it // TestDeliver verifies that registering a handler on pss for a given topic and then submitting a trojan chunk with said topic to it
// results in the execution of the expected handler func // results in the execution of the expected handler func
func TestDeliver(t *testing.T) { func TestDeliver(t *testing.T) {
pss := pss.New(logging.New(ioutil.Discard, 0), nil) pss := pss.New(logging.New(ioutil.Discard, 0))
ctx := context.TODO() ctx := context.Background()
var mtx sync.Mutex
// test message // test message
topic := trojan.NewTopic("footopic") topic := trojan.NewTopic("footopic")
...@@ -130,9 +92,10 @@ func TestDeliver(t *testing.T) { ...@@ -130,9 +92,10 @@ func TestDeliver(t *testing.T) {
// create and register handler // create and register handler
var tt trojan.Topic // test variable to check handler func was correctly called var tt trojan.Topic // test variable to check handler func was correctly called
hndlr := func(ctx context.Context, m *trojan.Message) error { hndlr := func(ctx context.Context, m *trojan.Message) {
tt = m.Topic // copy the message topic to the test variable mtx.Lock()
return nil copy(tt[:], m.Topic[:]) // copy the message topic to the test variable
mtx.Unlock()
} }
pss.Register(topic, hndlr) pss.Register(topic, hndlr)
...@@ -141,28 +104,125 @@ func TestDeliver(t *testing.T) { ...@@ -141,28 +104,125 @@ func TestDeliver(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if tt != msg.Topic { runtime.Gosched() // schedule the handler goroutine
t.Fatalf("unexpected result for pss Deliver func, expected test variable to have a value of %v but is %v instead", msg.Topic, tt) for i := 0; i < 10; i++ {
mtx.Lock()
eq := bytes.Equal(tt[:], msg.Topic[:])
mtx.Unlock()
if eq {
return
}
<-time.After(50 * time.Millisecond)
} }
t.Fatalf("unexpected result for pss Deliver func, expected test variable to have a value of %v but is %v instead", msg.Topic, tt)
} }
func TestHandler(t *testing.T) { // TestRegister verifies that handler funcs are able to be registered correctly in pss
pss := pss.New(logging.New(ioutil.Discard, 0), nil) func TestRegister(t *testing.T) {
testTopic := trojan.NewTopic("TEST_TOPIC") var (
pss = pss.New(logging.New(ioutil.Discard, 0))
h1Calls = 0
h2Calls = 0
h3Calls = 0
mtx sync.Mutex
topic1 = trojan.NewTopic("one")
topic2 = trojan.NewTopic("two")
payload = []byte("payload")
target = trojan.Target([]byte{1})
targets = trojan.Targets([]trojan.Target{target})
h1 = func(_ context.Context, m *trojan.Message) {
mtx.Lock()
defer mtx.Unlock()
h1Calls++
}
h2 = func(_ context.Context, m *trojan.Message) {
mtx.Lock()
defer mtx.Unlock()
h2Calls++
}
h3 = func(_ context.Context, m *trojan.Message) {
mtx.Lock()
defer mtx.Unlock()
h3Calls++
}
)
_ = pss.Register(topic1, h1)
_ = pss.Register(topic2, h2)
// send a message on topic1, check that only h1 is called
msg, err := trojan.NewMessage(topic1, payload)
if err != nil {
t.Fatal(err)
}
c, err := msg.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), c)
if err != nil {
t.Fatal(err)
}
ensureCalls(t, &mtx, &h1Calls, 1)
ensureCalls(t, &mtx, &h2Calls, 0)
// verify handler is null // register another topic handler on the same topic
if pss.GetHandler(testTopic) != nil { cleanup := pss.Register(topic1, h3)
t.Errorf("handler should be null") err = pss.TryUnwrap(context.Background(), c)
if err != nil {
t.Fatal(err)
} }
// register first handler ensureCalls(t, &mtx, &h1Calls, 2)
testHandler := func(ctx context.Context, m *trojan.Message) error { return nil } ensureCalls(t, &mtx, &h2Calls, 0)
ensureCalls(t, &mtx, &h3Calls, 1)
// set handler for test topic cleanup() // remove the last handler
pss.Register(testTopic, testHandler)
if pss.GetHandler(testTopic) == nil { err = pss.TryUnwrap(context.Background(), c)
t.Errorf("handler should be registered") if err != nil {
t.Fatal(err)
} }
ensureCalls(t, &mtx, &h1Calls, 3)
ensureCalls(t, &mtx, &h2Calls, 0)
ensureCalls(t, &mtx, &h3Calls, 1)
msg, err = trojan.NewMessage(topic2, payload)
if err != nil {
t.Fatal(err)
}
c, err = msg.Wrap(context.Background(), targets)
if err != nil {
t.Fatal(err)
}
err = pss.TryUnwrap(context.Background(), c)
if err != nil {
t.Fatal(err)
}
ensureCalls(t, &mtx, &h1Calls, 3)
ensureCalls(t, &mtx, &h2Calls, 1)
ensureCalls(t, &mtx, &h3Calls, 1)
}
func ensureCalls(t *testing.T, mtx *sync.Mutex, calls *int, exp int) {
t.Helper()
for i := 0; i < 10; i++ {
mtx.Lock()
if *calls == exp {
mtx.Unlock()
return
}
mtx.Unlock()
<-time.After(100 * time.Millisecond)
}
t.Fatal("timed out waiting for value")
} }
...@@ -11,14 +11,18 @@ import ( ...@@ -11,14 +11,18 @@ import (
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
type PushSync struct { type mock struct {
sendChunk func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error) sendChunk func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error)
} }
func New(sendChunk func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error)) *PushSync { func New(sendChunk func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error)) pushsync.PushSyncer {
return &PushSync{sendChunk: sendChunk} return &mock{sendChunk: sendChunk}
} }
func (s *PushSync) PushChunkToClosest(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error) { func (s *mock) PushChunkToClosest(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error) {
return s.sendChunk(ctx, chunk) return s.sendChunk(ctx, chunk)
} }
func (s *mock) Close() error {
return nil
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package recovery
var (
ErrChunkNotPresent = errChunkNotPresent
)
...@@ -6,7 +6,6 @@ package recovery ...@@ -6,7 +6,6 @@ package recovery
import ( import (
"context" "context"
"errors"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss" "github.com/ethersphere/bee/pkg/pss"
...@@ -26,10 +25,6 @@ var ( ...@@ -26,10 +25,6 @@ var (
RecoveryTopic = trojan.NewTopic(RecoveryTopicText) RecoveryTopic = trojan.NewTopic(RecoveryTopicText)
) )
var (
errChunkNotPresent = errors.New("chunk repair: chunk not present in local store for repairing")
)
// RecoveryHook defines code to be executed upon failing to retrieve chunks. // RecoveryHook defines code to be executed upon failing to retrieve chunks.
type RecoveryHook func(chunkAddress swarm.Address, targets trojan.Targets) error type RecoveryHook func(chunkAddress swarm.Address, targets trojan.Targets) error
...@@ -50,32 +45,31 @@ func NewRecoveryHook(pss PssSender) RecoveryHook { ...@@ -50,32 +45,31 @@ func NewRecoveryHook(pss PssSender) RecoveryHook {
// NewRepairHandler creates a repair function to re-upload globally pinned chunks to the network with the given store. // NewRepairHandler creates a repair function to re-upload globally pinned chunks to the network with the given store.
func NewRepairHandler(s storage.Storer, logger logging.Logger, pushSyncer pushsync.PushSyncer) pss.Handler { func NewRepairHandler(s storage.Storer, logger logging.Logger, pushSyncer pushsync.PushSyncer) pss.Handler {
return func(ctx context.Context, m *trojan.Message) error { return func(ctx context.Context, m *trojan.Message) {
chAddr := m.Payload chAddr := m.Payload
// check if the chunk exists in the local store and proceed. // check if the chunk exists in the local store and proceed.
// otherwise the Get will trigger a unnecessary network retrieve // otherwise the Get will trigger a unnecessary network retrieve
exists, err := s.Has(ctx, swarm.NewAddress(chAddr)) exists, err := s.Has(ctx, swarm.NewAddress(chAddr))
if err != nil { if err != nil {
return err return
} }
if !exists { if !exists {
return errChunkNotPresent return
} }
// retrieve the chunk from the local store // retrieve the chunk from the local store
ch, err := s.Get(ctx, storage.ModeGetRequest, swarm.NewAddress(chAddr)) ch, err := s.Get(ctx, storage.ModeGetRequest, swarm.NewAddress(chAddr))
if err != nil { if err != nil {
logger.Tracef("chunk repair: error while getting chunk for repairing: %v", err) logger.Tracef("chunk repair: error while getting chunk for repairing: %v", err)
return err return
} }
// push the chunk using push sync so that it reaches it destination in network // push the chunk using push sync so that it reaches it destination in network
_, err = pushSyncer.PushChunkToClosest(ctx, ch) _, err = pushSyncer.PushChunkToClosest(ctx, ch)
if err != nil { if err != nil {
logger.Tracef("chunk repair: error while sending chunk or receiving receipt: %v", err) logger.Tracef("chunk repair: error while sending chunk or receiving receipt: %v", err)
return err return
} }
return nil
} }
} }
...@@ -155,10 +155,7 @@ func TestNewRepairHandler(t *testing.T) { ...@@ -155,10 +155,7 @@ func TestNewRepairHandler(t *testing.T) {
} }
// invoke the chunk repair handler // invoke the chunk repair handler
err = repairHandler(context.Background(), &msg) repairHandler(context.Background(), &msg)
if err != nil {
t.Fatal(err)
}
// check if receipt is received // check if receipt is received
if receipt == nil { if receipt == nil {
...@@ -200,10 +197,7 @@ func TestNewRepairHandler(t *testing.T) { ...@@ -200,10 +197,7 @@ func TestNewRepairHandler(t *testing.T) {
} }
// invoke the chunk repair handler // invoke the chunk repair handler
err = repairHandler(context.Background(), &msg) repairHandler(context.Background(), &msg)
if !errors.Is(err, recovery.ErrChunkNotPresent) {
t.Fatal(err)
}
if pushServiceCalled { if pushServiceCalled {
t.Fatal("push service called even if the chunk is not present") t.Fatal("push service called even if the chunk is not present")
...@@ -243,10 +237,7 @@ func TestNewRepairHandler(t *testing.T) { ...@@ -243,10 +237,7 @@ func TestNewRepairHandler(t *testing.T) {
} }
// invoke the chunk repair handler // invoke the chunk repair handler
err = repairHandler(context.Background(), &msg) repairHandler(context.Background(), &msg)
if err != nil && err != receiptError {
t.Fatal(err)
}
if receiptError == nil { if receiptError == nil {
t.Fatal("pushsync did not generate a receipt error") t.Fatal("pushsync did not generate a receipt error")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment