Commit 65b65f5f authored by Pavle Batuta's avatar Pavle Batuta Committed by GitHub

Dynamic welcome message (#439)

* Add dynamic welcome message to p2p handshake

* Add '/welcome-msg' to debug api

* Add missing license headers

* Rename to WelcomeMessageSynced to Get

* Add atomic.value instead of struct

* Simplify mocks, refactor tests

* Add request struct and length check

* Extract DebugService p2p interface

* Add check

* Remove MaxBytesReader

* Remove unused constant

* Remove out of scope test

* Set max request size

* Refactor test
parent e9d1e561
...@@ -32,7 +32,7 @@ type server struct { ...@@ -32,7 +32,7 @@ type server struct {
type Options struct { type Options struct {
Overlay swarm.Address Overlay swarm.Address
P2P p2p.Service P2P p2p.DebugService
Pingpong pingpong.Interface Pingpong pingpong.Interface
TopologyDriver topology.PeerAdder TopologyDriver topology.PeerAdder
Storer storage.Storer Storer storage.Storer
......
...@@ -13,4 +13,6 @@ type ( ...@@ -13,4 +13,6 @@ type (
PinnedChunk = pinnedChunk PinnedChunk = pinnedChunk
ListPinnedChunksResponse = listPinnedChunksResponse ListPinnedChunksResponse = listPinnedChunksResponse
TagResponse = tagResponse TagResponse = tagResponse
WelcomeMessageRequest = welcomeMessageRequest
WelcomeMessageResponse = welcomeMessageResponse
) )
...@@ -86,6 +86,13 @@ func (s *server) setupRouting() { ...@@ -86,6 +86,13 @@ func (s *server) setupRouting() {
router.Handle("/topology", jsonhttp.MethodHandler{ router.Handle("/topology", jsonhttp.MethodHandler{
"GET": http.HandlerFunc(s.topologyHandler), "GET": http.HandlerFunc(s.topologyHandler),
}) })
router.Handle("/welcome-message", jsonhttp.MethodHandler{
"GET": http.HandlerFunc(s.getWelcomeMessageHandler),
"POST": web.ChainHandlers(
jsonhttp.NewMaxBodyBytesHandler(welcomeMessageMaxRequestSize),
web.FinalHandlerFunc(s.setWelcomeMessageHandler),
),
})
baseRouter.Handle("/", web.ChainHandlers( baseRouter.Handle("/", web.ChainHandlers(
logging.NewHTTPAccessLogHandler(s.Logger, logrus.InfoLevel, "debug api access"), logging.NewHTTPAccessLogHandler(s.Logger, logrus.InfoLevel, "debug api access"),
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package debugapi
import (
"encoding/json"
"net/http"
"github.com/ethersphere/bee/pkg/jsonhttp"
)
const welcomeMessageMaxRequestSize = 512
type welcomeMessageRequest struct {
WelcomeMesssage string `json:"welcome_message"`
}
type welcomeMessageResponse struct {
WelcomeMesssage string `json:"welcome_message"`
}
func (s *server) getWelcomeMessageHandler(w http.ResponseWriter, r *http.Request) {
val := s.P2P.GetWelcomeMessage()
jsonhttp.OK(w, welcomeMessageResponse{
WelcomeMesssage: val,
})
}
func (s *server) setWelcomeMessageHandler(w http.ResponseWriter, r *http.Request) {
var data welcomeMessageRequest
err := json.NewDecoder(r.Body).Decode(&data)
if err != nil {
s.Logger.Debugf("debugapi: welcome message: failed to read request: %v", err)
jsonhttp.BadRequest(w, err)
return
}
if err := s.P2P.SetWelcomeMessage(data.WelcomeMesssage); err != nil {
s.Logger.Debugf("debugapi: welcome message: failed to set: %v", err)
s.Logger.Errorf("Failed to set welcome message")
jsonhttp.InternalServerError(w, err)
return
}
jsonhttp.OK(w, nil)
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package debugapi_test
import (
"bytes"
"encoding/json"
"errors"
"net/http"
"testing"
"github.com/ethersphere/bee/pkg/debugapi"
"github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
"github.com/ethersphere/bee/pkg/p2p/mock"
)
func TestGetWelcomeMessage(t *testing.T) {
const DefaultTestWelcomeMessage = "Hello World!"
srv := newTestServer(t, testServerOptions{
P2P: mock.New(mock.WithGetWelcomeMessageFunc(func() string {
return DefaultTestWelcomeMessage
}))})
jsonhttptest.ResponseDirect(t, srv.Client, http.MethodGet, "/welcome-message", nil, http.StatusOK, debugapi.WelcomeMessageResponse{
WelcomeMesssage: DefaultTestWelcomeMessage,
})
}
func TestSetWelcomeMessage(t *testing.T) {
testCases := []struct {
desc string
message string
wantFail bool
wantStatus int
wantMessage string
}{
{
desc: "OK",
message: "Changed value",
wantStatus: http.StatusOK,
},
{
desc: "OK - welcome message empty",
message: "",
wantStatus: http.StatusOK,
},
{
desc: "fails - request entity too large",
wantFail: true,
message: `zZZbzbzbzbBzBBZbBbZbbbBzzzZBZBbzzBBBbBzBzzZbbBzBBzBBbZz
bZZZBBbbZbbZzBbzBbzbZBZzBZZbZzZzZzbbZZBZzzbBZBzZzzBBzZZzzZbZZZzbbbzz
bBzZZBbBZBzZzBZBzbzBBbzBBzbzzzBbBbZzZBZBZzBZZbbZZBZZBzZzBZbzZBzZbBzZ
bbbBbbZzZbzbZzZzbzzzbzzbzZZzbbzbBZZbBbBZBBZzZzzbBBBBBZbZzBzzBbzBbbbz
BBzbbZBbzbzBZbzzBzbZBzzbzbbbBZBzBZzBZbzBzZzBZZZBzZZBzBZZzbzZbzzZzBBz
ZZzbZzzZZZBZBBbZZbZzBBBzbzZZbbZZBZZBBBbBZzZbZBZBBBzzZBbbbbzBzbbzBBBz
bZBBbZzBbZZBzbBbZZBzBzBzBBbzzzZBbzbZBbzBbZzbbBZBBbbZbBBbbBZbzbZzbBzB
bBbbZZbzZzbbBbzZbZZZZbzzZZbBzZZbZzZzzBzbZZ`, // 513 characters
wantStatus: http.StatusRequestEntityTooLarge,
},
}
testURL := "/welcome-message"
srv := newTestServer(t, testServerOptions{
P2P: mock.New(),
})
for _, tC := range testCases {
t.Run(tC.desc, func(t *testing.T) {
if tC.wantMessage == "" {
tC.wantMessage = http.StatusText(tC.wantStatus)
}
data, _ := json.Marshal(debugapi.WelcomeMessageRequest{
WelcomeMesssage: tC.message,
})
body := bytes.NewReader(data)
wantResponse := jsonhttp.StatusResponse{
Message: tC.wantMessage,
Code: tC.wantStatus,
}
jsonhttptest.ResponseDirect(t, srv.Client, http.MethodPost, testURL, body, tC.wantStatus, wantResponse)
if !tC.wantFail {
got := srv.P2PMock.GetWelcomeMessage()
if got != tC.message {
t.Fatalf("could not set dynamic welcome message: want %s, got %s", tC.message, got)
}
}
})
}
}
func TestSetWelcomeMessageInternalServerError(t *testing.T) {
testMessage := "NO CHANCE BYE"
testError := errors.New("Could not set value")
testURL := "/welcome-message"
srv := newTestServer(t, testServerOptions{
P2P: mock.New(mock.WithSetWelcomeMessageFunc(func(string) error {
return testError
})),
})
data, _ := json.Marshal(debugapi.WelcomeMessageRequest{
WelcomeMesssage: testMessage,
})
body := bytes.NewReader(data)
t.Run("internal server error - failed to store", func(t *testing.T) {
wantCode := http.StatusInternalServerError
wantResp := jsonhttp.StatusResponse{
Message: testError.Error(),
Code: wantCode,
}
jsonhttptest.ResponseDirect(t, srv.Client, http.MethodPost, testURL, body, wantCode, wantResp)
})
}
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/ethersphere/bee/pkg/bzz" "github.com/ethersphere/bee/pkg/bzz"
...@@ -64,7 +65,7 @@ type Service struct { ...@@ -64,7 +65,7 @@ type Service struct {
overlay swarm.Address overlay swarm.Address
lightNode bool lightNode bool
networkID uint64 networkID uint64
welcomeMessage string welcomeMessage atomic.Value
receivedHandshakes map[libp2ppeer.ID]struct{} receivedHandshakes map[libp2ppeer.ID]struct{}
receivedHandshakesMu sync.Mutex receivedHandshakesMu sync.Mutex
logger logging.Logger logger logging.Logger
...@@ -84,17 +85,19 @@ func New(signer crypto.Signer, advertisableAddresser AdvertisableAddressResolver ...@@ -84,17 +85,19 @@ func New(signer crypto.Signer, advertisableAddresser AdvertisableAddressResolver
return nil, ErrWelcomeMessageLength return nil, ErrWelcomeMessageLength
} }
return &Service{ svc := &Service{
signer: signer, signer: signer,
advertisableAddresser: advertisableAddresser, advertisableAddresser: advertisableAddresser,
overlay: overlay, overlay: overlay,
networkID: networkID, networkID: networkID,
lightNode: lighNode, lightNode: lighNode,
welcomeMessage: welcomeMessage,
receivedHandshakes: make(map[libp2ppeer.ID]struct{}), receivedHandshakes: make(map[libp2ppeer.ID]struct{}),
logger: logger, logger: logger,
Notifiee: new(network.NoopNotifiee), Notifiee: new(network.NoopNotifiee),
}, nil }
svc.welcomeMessage.Store(welcomeMessage)
return svc, nil
} }
// Handshake initiates a handshake with a peer. // Handshake initiates a handshake with a peer.
...@@ -146,6 +149,8 @@ func (s *Service) Handshake(stream p2p.Stream, peerMultiaddr ma.Multiaddr, peerI ...@@ -146,6 +149,8 @@ func (s *Service) Handshake(stream p2p.Stream, peerMultiaddr ma.Multiaddr, peerI
return nil, err return nil, err
} }
// Synced read:
welcomeMessage := s.GetWelcomeMessage()
if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Ack{ if err := w.WriteMsgWithTimeout(messageTimeout, &pb.Ack{
Address: &pb.BzzAddress{ Address: &pb.BzzAddress{
Underlay: advertisableUnderlayBytes, Underlay: advertisableUnderlayBytes,
...@@ -154,7 +159,7 @@ func (s *Service) Handshake(stream p2p.Stream, peerMultiaddr ma.Multiaddr, peerI ...@@ -154,7 +159,7 @@ func (s *Service) Handshake(stream p2p.Stream, peerMultiaddr ma.Multiaddr, peerI
}, },
NetworkID: s.networkID, NetworkID: s.networkID,
Light: s.lightNode, Light: s.lightNode,
WelcomeMessage: s.welcomeMessage, WelcomeMessage: welcomeMessage,
}); err != nil { }); err != nil {
return nil, fmt.Errorf("write ack message: %w", err) return nil, fmt.Errorf("write ack message: %w", err)
} }
...@@ -216,6 +221,8 @@ func (s *Service) Handle(stream p2p.Stream, remoteMultiaddr ma.Multiaddr, remote ...@@ -216,6 +221,8 @@ func (s *Service) Handle(stream p2p.Stream, remoteMultiaddr ma.Multiaddr, remote
return nil, err return nil, err
} }
welcomeMessage := s.GetWelcomeMessage()
if err := w.WriteMsgWithTimeout(messageTimeout, &pb.SynAck{ if err := w.WriteMsgWithTimeout(messageTimeout, &pb.SynAck{
Syn: &pb.Syn{ Syn: &pb.Syn{
ObservedUnderlay: fullRemoteMABytes, ObservedUnderlay: fullRemoteMABytes,
...@@ -228,7 +235,7 @@ func (s *Service) Handle(stream p2p.Stream, remoteMultiaddr ma.Multiaddr, remote ...@@ -228,7 +235,7 @@ func (s *Service) Handle(stream p2p.Stream, remoteMultiaddr ma.Multiaddr, remote
}, },
NetworkID: s.networkID, NetworkID: s.networkID,
Light: s.lightNode, Light: s.lightNode,
WelcomeMessage: s.welcomeMessage, WelcomeMessage: welcomeMessage,
}, },
}); err != nil { }); err != nil {
return nil, fmt.Errorf("write synack message: %w", err) return nil, fmt.Errorf("write synack message: %w", err)
...@@ -259,6 +266,20 @@ func (s *Service) Disconnected(_ network.Network, c network.Conn) { ...@@ -259,6 +266,20 @@ func (s *Service) Disconnected(_ network.Network, c network.Conn) {
delete(s.receivedHandshakes, c.RemotePeer()) delete(s.receivedHandshakes, c.RemotePeer())
} }
// SetWelcomeMessage sets the new handshake welcome message.
func (s *Service) SetWelcomeMessage(msg string) (err error) {
if len(msg) > MaxWelcomeMessageLength {
return ErrWelcomeMessageLength
}
s.welcomeMessage.Store(msg)
return nil
}
// GetWelcomeMessage returns the the current handshake welcome message.
func (s *Service) GetWelcomeMessage() string {
return s.welcomeMessage.Load().(string)
}
func buildFullMA(addr ma.Multiaddr, peerID libp2ppeer.ID) (ma.Multiaddr, error) { func buildFullMA(addr ma.Multiaddr, peerID libp2ppeer.ID) (ma.Multiaddr, error) {
return ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", addr.String(), peerID.Pretty())) return ma.NewMultiaddr(fmt.Sprintf("%s/p2p/%s", addr.String(), peerID.Pretty()))
} }
......
...@@ -165,6 +165,29 @@ func TestHandshake(t *testing.T) { ...@@ -165,6 +165,29 @@ func TestHandshake(t *testing.T) {
} }
}) })
t.Run("Handshake - dynamic welcome message too long", func(t *testing.T) {
const LongMessage = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Morbi consectetur urna ut lorem sollicitudin posuere. Donec sagittis laoreet sapien."
expectedErr := handshake.ErrWelcomeMessageLength
err := handshakeService.SetWelcomeMessage(LongMessage)
if err == nil || err.Error() != expectedErr.Error() {
t.Fatal("expected:", expectedErr, "got:", err)
}
})
t.Run("Handshake - set welcome message", func(t *testing.T) {
const TestMessage = "Hi im the new test message"
err := handshakeService.SetWelcomeMessage(TestMessage)
if err != nil {
t.Fatal("Got error:", err)
}
got := handshakeService.GetWelcomeMessage()
if got != TestMessage {
t.Fatal("expected:", TestMessage, ", got:", got)
}
})
t.Run("Handshake - Syn write error", func(t *testing.T) { t.Run("Handshake - Syn write error", func(t *testing.T) {
testErr := errors.New("test error") testErr := errors.New("test error")
expectedErr := fmt.Errorf("write syn message: %w", testErr) expectedErr := fmt.Errorf("write syn message: %w", testErr)
......
...@@ -39,7 +39,8 @@ import ( ...@@ -39,7 +39,8 @@ import (
) )
var ( var (
_ p2p.Service = (*Service)(nil) _ p2p.Service = (*Service)(nil)
_ p2p.DebugService = (*Service)(nil)
) )
type Service struct { type Service struct {
...@@ -503,3 +504,13 @@ func (s *Service) Close() error { ...@@ -503,3 +504,13 @@ func (s *Service) Close() error {
} }
return s.host.Close() return s.host.Close()
} }
// SetWelcomeMessage sets the welcome message for the handshake protocol.
func (s *Service) SetWelcomeMessage(val string) error {
return s.handshakeService.SetWelcomeMessage(val)
}
// GetWelcomeMessage returns the value of the welcome message.
func (s *Service) GetWelcomeMessage() string {
return s.handshakeService.GetWelcomeMessage()
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package libp2p_test
import (
"testing"
"github.com/ethersphere/bee/pkg/p2p/libp2p"
"github.com/ethersphere/bee/pkg/p2p/libp2p/internal/handshake"
)
func TestDynamicWelcomeMessage(t *testing.T) {
const TestWelcomeMessage = "Hello World!"
svc, _ := newService(t, 1, libp2p.Options{WelcomeMessage: TestWelcomeMessage})
t.Run("Get current message - OK", func(t *testing.T) {
got := svc.GetWelcomeMessage()
if got != TestWelcomeMessage {
t.Fatalf("expected %s, got %s", TestWelcomeMessage, got)
}
})
t.Run("Set new message", func(t *testing.T) {
t.Run("OK", func(t *testing.T) {
const testMessage = "I'm the new message!"
err := svc.SetWelcomeMessage(testMessage)
if err != nil {
t.Fatal("got error:", err)
}
got := svc.GetWelcomeMessage()
if got != testMessage {
t.Fatalf("expected: %s. got %s", testMessage, got)
}
})
t.Run("fails - message too long", func(t *testing.T) {
const testMessage = `Lorem ipsum dolor sit amet, consectetur adipiscing elit.
Maecenas eu aliquam enim. Nulla tincidunt arcu nec nulla condimentum nullam sodales` // 141 characters
want := handshake.ErrWelcomeMessageLength
got := svc.SetWelcomeMessage(testMessage)
if got != want {
t.Fatalf("wrong error: want %v, got %v", want, got)
}
})
})
}
...@@ -16,52 +16,77 @@ import ( ...@@ -16,52 +16,77 @@ import (
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
) )
// Service is the mock of a P2P Service
type Service struct { type Service struct {
addProtocolFunc func(p2p.ProtocolSpec) error addProtocolFunc func(p2p.ProtocolSpec) error
connectFunc func(ctx context.Context, addr ma.Multiaddr) (address *bzz.Address, err error) connectFunc func(ctx context.Context, addr ma.Multiaddr) (address *bzz.Address, err error)
disconnectFunc func(overlay swarm.Address) error disconnectFunc func(overlay swarm.Address) error
peersFunc func() []p2p.Peer peersFunc func() []p2p.Peer
setNotifierFunc func(topology.Notifier) setNotifierFunc func(topology.Notifier)
addressesFunc func() ([]ma.Multiaddr, error) addressesFunc func() ([]ma.Multiaddr, error)
notifyCalled int32 setWelcomeMessageFunc func(string) error
} getWelcomeMessageFunc func() string
welcomeMessage string
notifyCalled int32
}
// WithAddProtocolFunc sets the mock implementation of the AddProtocol function
func WithAddProtocolFunc(f func(p2p.ProtocolSpec) error) Option { func WithAddProtocolFunc(f func(p2p.ProtocolSpec) error) Option {
return optionFunc(func(s *Service) { return optionFunc(func(s *Service) {
s.addProtocolFunc = f s.addProtocolFunc = f
}) })
} }
// WithConnectFunc sets the mock implementation of the Connect function
func WithConnectFunc(f func(ctx context.Context, addr ma.Multiaddr) (address *bzz.Address, err error)) Option { func WithConnectFunc(f func(ctx context.Context, addr ma.Multiaddr) (address *bzz.Address, err error)) Option {
return optionFunc(func(s *Service) { return optionFunc(func(s *Service) {
s.connectFunc = f s.connectFunc = f
}) })
} }
// WithDisconnectFunc sets the mock implementation of the Disconnect function
func WithDisconnectFunc(f func(overlay swarm.Address) error) Option { func WithDisconnectFunc(f func(overlay swarm.Address) error) Option {
return optionFunc(func(s *Service) { return optionFunc(func(s *Service) {
s.disconnectFunc = f s.disconnectFunc = f
}) })
} }
// WithPeersFunc sets the mock implementation of the Peers function
func WithPeersFunc(f func() []p2p.Peer) Option { func WithPeersFunc(f func() []p2p.Peer) Option {
return optionFunc(func(s *Service) { return optionFunc(func(s *Service) {
s.peersFunc = f s.peersFunc = f
}) })
} }
// WithSetNotifierFunc sets the mock implementation of the SetNotifier function
func WithSetNotifierFunc(f func(topology.Notifier)) Option { func WithSetNotifierFunc(f func(topology.Notifier)) Option {
return optionFunc(func(s *Service) { return optionFunc(func(s *Service) {
s.setNotifierFunc = f s.setNotifierFunc = f
}) })
} }
// WithAddressesFunc sets the mock implementation of the Adresses function
func WithAddressesFunc(f func() ([]ma.Multiaddr, error)) Option { func WithAddressesFunc(f func() ([]ma.Multiaddr, error)) Option {
return optionFunc(func(s *Service) { return optionFunc(func(s *Service) {
s.addressesFunc = f s.addressesFunc = f
}) })
} }
// WithGetWelcomeMessageFunc sets the mock implementation of the GetWelcomeMessage function
func WithGetWelcomeMessageFunc(f func() string) Option {
return optionFunc(func(s *Service) {
s.getWelcomeMessageFunc = f
})
}
// WithSetWelcomeMessageFunc sets the mock implementation of the SetWelcomeMessage function
func WithSetWelcomeMessageFunc(f func(string) error) Option {
return optionFunc(func(s *Service) {
s.setWelcomeMessageFunc = f
})
}
// New will create a new mock P2P Service with the given options
func New(opts ...Option) *Service { func New(opts ...Option) *Service {
s := new(Service) s := new(Service)
for _, o := range opts { for _, o := range opts {
...@@ -126,6 +151,21 @@ func (s *Service) ConnectNotifyCalls() int32 { ...@@ -126,6 +151,21 @@ func (s *Service) ConnectNotifyCalls() int32 {
return c return c
} }
func (s *Service) SetWelcomeMessage(val string) error {
if s.setWelcomeMessageFunc != nil {
return s.setWelcomeMessageFunc(val)
}
s.welcomeMessage = val
return nil
}
func (s *Service) GetWelcomeMessage() string {
if s.getWelcomeMessageFunc != nil {
return s.getWelcomeMessageFunc()
}
return s.welcomeMessage
}
type Option interface { type Option interface {
apply(*Service) apply(*Service)
} }
......
...@@ -28,6 +28,13 @@ type Service interface { ...@@ -28,6 +28,13 @@ type Service interface {
Addresses() ([]ma.Multiaddr, error) Addresses() ([]ma.Multiaddr, error)
} }
// DebugService extends the Service with method used for debugging.
type DebugService interface {
Service
SetWelcomeMessage(val string) error
GetWelcomeMessage() string
}
// Streamer is able to create a new Stream. // Streamer is able to create a new Stream.
type Streamer interface { type Streamer interface {
NewStream(ctx context.Context, address swarm.Address, h Headers, protocol, version, stream string) (Stream, error) NewStream(ctx context.Context, address swarm.Address, h Headers, protocol, version, stream string) (Stream, error)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment