Commit 47a1a4f2 authored by Janos Guljas's avatar Janos Guljas

support network id in handshake protocol

parent 377f98a1
...@@ -6,7 +6,7 @@ GOLANGCI_LINT ?= golangci-lint ...@@ -6,7 +6,7 @@ GOLANGCI_LINT ?= golangci-lint
LDFLAGS ?= -s -w -X github.com/ethersphere/bee.commit="$(COMMIT)" LDFLAGS ?= -s -w -X github.com/ethersphere/bee.commit="$(COMMIT)"
.PHONY: all .PHONY: all
all: lint vet test binary all: build lint vet test binary
.PHONY: binary .PHONY: binary
binary: export CGO_ENABLED=0 binary: export CGO_ENABLED=0
......
...@@ -89,7 +89,7 @@ func (c *command) initStartCmd() (err error) { ...@@ -89,7 +89,7 @@ func (c *command) initStartCmd() (err error) {
DisableWS: c.config.GetBool(optionNameP2PDisableWS), DisableWS: c.config.GetBool(optionNameP2PDisableWS),
DisableQUIC: c.config.GetBool(optionNameP2PDisableQUIC), DisableQUIC: c.config.GetBool(optionNameP2PDisableQUIC),
Bootnodes: c.config.GetStringSlice(optionNameBootnodes), Bootnodes: c.config.GetStringSlice(optionNameBootnodes),
NetworkID: c.config.GetInt(optionNameNetworkID), NetworkID: c.config.GetInt32(optionNameNetworkID),
ConnectionsLow: c.config.GetInt(optionNameConnectionsLow), ConnectionsLow: c.config.GetInt(optionNameConnectionsLow),
ConnectionsHigh: c.config.GetInt(optionNameConnectionsHigh), ConnectionsHigh: c.config.GetInt(optionNameConnectionsHigh),
ConnectionsGrace: c.config.GetDuration(optionNameConnectionsGrace), ConnectionsGrace: c.config.GetDuration(optionNameConnectionsGrace),
...@@ -143,7 +143,7 @@ func (c *command) initStartCmd() (err error) { ...@@ -143,7 +143,7 @@ func (c *command) initStartCmd() (err error) {
cmd.Flags().Bool(optionNameP2PDisableQUIC, false, "disable P2P QUIC protocol") cmd.Flags().Bool(optionNameP2PDisableQUIC, false, "disable P2P QUIC protocol")
cmd.Flags().StringSlice(optionNameBootnodes, nil, "initial nodes to connect to") cmd.Flags().StringSlice(optionNameBootnodes, nil, "initial nodes to connect to")
cmd.Flags().String(optionNameDebugAPIAddr, "", "debug HTTP API listen address, e.g. 127.0.0.1:6060") cmd.Flags().String(optionNameDebugAPIAddr, "", "debug HTTP API listen address, e.g. 127.0.0.1:6060")
cmd.Flags().Int(optionNameNetworkID, 1, "ID of the Swarm network") cmd.Flags().Int32(optionNameNetworkID, 1, "ID of the Swarm network")
cmd.Flags().Int(optionNameConnectionsLow, 200, "low watermark governing the number of connections that'll be maintained") cmd.Flags().Int(optionNameConnectionsLow, 200, "low watermark governing the number of connections that'll be maintained")
cmd.Flags().Int(optionNameConnectionsHigh, 400, "high watermark governing the number of connections that'll be maintained") cmd.Flags().Int(optionNameConnectionsHigh, 400, "high watermark governing the number of connections that'll be maintained")
cmd.Flags().Duration(optionNameConnectionsGrace, time.Minute, "the amount of time a newly opened connection is given before it becomes subject to pruning") cmd.Flags().Duration(optionNameConnectionsGrace, time.Minute, "the amount of time a newly opened connection is given before it becomes subject to pruning")
......
...@@ -20,14 +20,16 @@ const ( ...@@ -20,14 +20,16 @@ const (
) )
type Service struct { type Service struct {
overlay string overlay string
logger Logger networkID int32
logger Logger
} }
func New(overlay string, logger Logger) *Service { func New(overlay string, networkID int32, logger Logger) *Service {
return &Service{ return &Service{
overlay: overlay, overlay: overlay,
logger: logger, networkID: networkID,
logger: logger,
} }
} }
...@@ -35,38 +37,56 @@ type Logger interface { ...@@ -35,38 +37,56 @@ type Logger interface {
Tracef(format string, args ...interface{}) Tracef(format string, args ...interface{})
} }
func (s *Service) Handshake(stream p2p.Stream) (overlay string, err error) { func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
w, r := protobuf.NewWriterAndReader(stream) w, r := protobuf.NewWriterAndReader(stream)
var resp ShakeHand var resp ShakeHand
if err := w.WriteMsg(&ShakeHand{Address: s.overlay}); err != nil { if err := w.WriteMsg(&ShakeHand{
return "", fmt.Errorf("handshake handler: write message: %w", err) Address: s.overlay,
NetworkID: s.networkID,
}); err != nil {
return nil, fmt.Errorf("handshake handler: write message: %w", err)
} }
s.logger.Tracef("handshake: sent request %s", s.overlay) s.logger.Tracef("handshake: sent request %s", s.overlay)
if err := r.ReadMsg(&resp); err != nil { if err := r.ReadMsg(&resp); err != nil {
return "", fmt.Errorf("handshake handler: read message: %w", err) return nil, fmt.Errorf("handshake handler: read message: %w", err)
} }
s.logger.Tracef("handshake: read response: %s", resp.Address) s.logger.Tracef("handshake: read response: %s", resp.Address)
return resp.Address, nil return &Info{
Address: resp.Address,
NetworkID: resp.NetworkID,
Light: resp.Light,
}, nil
} }
func (s *Service) Handle(stream p2p.Stream) (overlay string, err error) { func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) {
w, r := protobuf.NewWriterAndReader(stream) w, r := protobuf.NewWriterAndReader(stream)
defer stream.Close() defer stream.Close()
var req ShakeHand var req ShakeHand
if err := r.ReadMsg(&req); err != nil { if err := r.ReadMsg(&req); err != nil {
return "", fmt.Errorf("read message: %w", err) return nil, fmt.Errorf("read message: %w", err)
} }
s.logger.Tracef("handshake: received request %s", req.Address) s.logger.Tracef("handshake: received request %s", req.Address)
if err := w.WriteMsg(&ShakeHand{ if err := w.WriteMsg(&ShakeHand{
Address: s.overlay, Address: s.overlay,
NetworkID: s.networkID,
}); err != nil { }); err != nil {
return "", fmt.Errorf("write message: %w", err) return nil, fmt.Errorf("write message: %w", err)
} }
s.logger.Tracef("handshake: handled response: %s", s.overlay) s.logger.Tracef("handshake: handled response: %s", s.overlay)
return req.Address, nil return &Info{
Address: req.Address,
NetworkID: req.NetworkID,
Light: req.Light,
}, nil
}
type Info struct {
Address string
NetworkID int32
Light bool
} }
...@@ -23,7 +23,9 @@ var _ = math.Inf ...@@ -23,7 +23,9 @@ var _ = math.Inf
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
type ShakeHand struct { type ShakeHand struct {
Address string `protobuf:"bytes,1,opt,name=PeerID,proto3" json:"PeerID,omitempty"` Address string `protobuf:"bytes,1,opt,name=Address,proto3" json:"Address,omitempty"`
NetworkID int32 `protobuf:"varint,2,opt,name=NetworkID,proto3" json:"NetworkID,omitempty"`
Light bool `protobuf:"varint,3,opt,name=Light,proto3" json:"Light,omitempty"`
} }
func (m *ShakeHand) Reset() { *m = ShakeHand{} } func (m *ShakeHand) Reset() { *m = ShakeHand{} }
...@@ -66,6 +68,20 @@ func (m *ShakeHand) GetAddress() string { ...@@ -66,6 +68,20 @@ func (m *ShakeHand) GetAddress() string {
return "" return ""
} }
func (m *ShakeHand) GetNetworkID() int32 {
if m != nil {
return m.NetworkID
}
return 0
}
func (m *ShakeHand) GetLight() bool {
if m != nil {
return m.Light
}
return false
}
func init() { func init() {
proto.RegisterType((*ShakeHand)(nil), "handshake.ShakeHand") proto.RegisterType((*ShakeHand)(nil), "handshake.ShakeHand")
} }
...@@ -73,14 +89,17 @@ func init() { ...@@ -73,14 +89,17 @@ func init() {
func init() { proto.RegisterFile("handshake.proto", fileDescriptor_a77305914d5d202f) } func init() { proto.RegisterFile("handshake.proto", fileDescriptor_a77305914d5d202f) }
var fileDescriptor_a77305914d5d202f = []byte{ var fileDescriptor_a77305914d5d202f = []byte{
// 108 bytes of a gzipped FileDescriptorProto // 148 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xcf, 0x48, 0xcc, 0x4b, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xcf, 0x48, 0xcc, 0x4b,
0x29, 0xce, 0x48, 0xcc, 0x4e, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x84, 0x0b, 0x28, 0x29, 0xce, 0x48, 0xcc, 0x4e, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x84, 0x0b, 0x28,
0xa9, 0x72, 0x71, 0x06, 0x83, 0x18, 0x1e, 0x89, 0x79, 0x29, 0x42, 0x12, 0x5c, 0xec, 0x8e, 0x29, 0x45, 0x72, 0x71, 0x06, 0x83, 0x18, 0x1e, 0x89, 0x79, 0x29, 0x42, 0x12, 0x5c, 0xec, 0x8e, 0x29,
0x29, 0x45, 0xa9, 0xc5, 0xc5, 0x12, 0x8c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x30, 0xae, 0x93, 0xc4, 0x29, 0x45, 0xa9, 0xc5, 0xc5, 0x12, 0x8c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x30, 0xae, 0x90, 0x0c,
0x89, 0x47, 0x72, 0x8c, 0x17, 0x1e, 0xc9, 0x31, 0x3e, 0x78, 0x24, 0xc7, 0x38, 0xe1, 0xb1, 0x1c, 0x17, 0xa7, 0x5f, 0x6a, 0x49, 0x79, 0x7e, 0x51, 0xb6, 0xa7, 0x8b, 0x04, 0x93, 0x02, 0xa3, 0x06,
0xc3, 0x85, 0xc7, 0x72, 0x0c, 0x37, 0x1e, 0xcb, 0x31, 0x24, 0xb1, 0x81, 0x8d, 0x34, 0x06, 0x04, 0x6b, 0x10, 0x42, 0x40, 0x48, 0x84, 0x8b, 0xd5, 0x27, 0x33, 0x3d, 0xa3, 0x44, 0x82, 0x59, 0x81,
0x00, 0x00, 0xff, 0xff, 0x5d, 0x34, 0x69, 0xba, 0x65, 0x00, 0x00, 0x00, 0x51, 0x83, 0x23, 0x08, 0xc2, 0x71, 0x92, 0x38, 0xf1, 0x48, 0x8e, 0xf1, 0xc2, 0x23, 0x39, 0xc6,
0x07, 0x8f, 0xe4, 0x18, 0x27, 0x3c, 0x96, 0x63, 0xb8, 0xf0, 0x58, 0x8e, 0xe1, 0xc6, 0x63, 0x39,
0x86, 0x24, 0x36, 0xb0, 0x33, 0x8c, 0x01, 0x01, 0x00, 0x00, 0xff, 0xff, 0x62, 0x1c, 0xa2, 0x06,
0x99, 0x00, 0x00, 0x00,
} }
func (m *ShakeHand) Marshal() (dAtA []byte, err error) { func (m *ShakeHand) Marshal() (dAtA []byte, err error) {
...@@ -103,6 +122,21 @@ func (m *ShakeHand) MarshalToSizedBuffer(dAtA []byte) (int, error) { ...@@ -103,6 +122,21 @@ func (m *ShakeHand) MarshalToSizedBuffer(dAtA []byte) (int, error) {
_ = i _ = i
var l int var l int
_ = l _ = l
if m.Light {
i--
if m.Light {
dAtA[i] = 1
} else {
dAtA[i] = 0
}
i--
dAtA[i] = 0x18
}
if m.NetworkID != 0 {
i = encodeVarintHandshake(dAtA, i, uint64(m.NetworkID))
i--
dAtA[i] = 0x10
}
if len(m.Address) > 0 { if len(m.Address) > 0 {
i -= len(m.Address) i -= len(m.Address)
copy(dAtA[i:], m.Address) copy(dAtA[i:], m.Address)
...@@ -134,6 +168,12 @@ func (m *ShakeHand) Size() (n int) { ...@@ -134,6 +168,12 @@ func (m *ShakeHand) Size() (n int) {
if l > 0 { if l > 0 {
n += 1 + l + sovHandshake(uint64(l)) n += 1 + l + sovHandshake(uint64(l))
} }
if m.NetworkID != 0 {
n += 1 + sovHandshake(uint64(m.NetworkID))
}
if m.Light {
n += 2
}
return n return n
} }
...@@ -174,7 +214,7 @@ func (m *ShakeHand) Unmarshal(dAtA []byte) error { ...@@ -174,7 +214,7 @@ func (m *ShakeHand) Unmarshal(dAtA []byte) error {
switch fieldNum { switch fieldNum {
case 1: case 1:
if wireType != 2 { if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field PeerID", wireType) return fmt.Errorf("proto: wrong wireType = %d for field Address", wireType)
} }
var stringLen uint64 var stringLen uint64
for shift := uint(0); ; shift += 7 { for shift := uint(0); ; shift += 7 {
...@@ -204,6 +244,45 @@ func (m *ShakeHand) Unmarshal(dAtA []byte) error { ...@@ -204,6 +244,45 @@ func (m *ShakeHand) Unmarshal(dAtA []byte) error {
} }
m.Address = string(dAtA[iNdEx:postIndex]) m.Address = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex iNdEx = postIndex
case 2:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field NetworkID", wireType)
}
m.NetworkID = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.NetworkID |= int32(b&0x7F) << shift
if b < 0x80 {
break
}
}
case 3:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Light", wireType)
}
var v int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
v |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
m.Light = bool(v != 0)
default: default:
iNdEx = preIndex iNdEx = preIndex
skippy, err := skipHandshake(dAtA[iNdEx:]) skippy, err := skipHandshake(dAtA[iNdEx:])
......
...@@ -4,5 +4,7 @@ package handshake; ...@@ -4,5 +4,7 @@ package handshake;
message ShakeHand { message ShakeHand {
string Address = 1; string Address = 1;
int32 NetworkID = 2;
bool Light = 3;
} }
...@@ -45,6 +45,7 @@ func init() { ...@@ -45,6 +45,7 @@ func init() {
type Service struct { type Service struct {
host host.Host host host.Host
metrics metrics metrics metrics
networkID int32
handshakeService *handshake.Service handshakeService *handshake.Service
peers *peerRegistry peers *peerRegistry
logger Logger logger Logger
...@@ -56,7 +57,7 @@ type Options struct { ...@@ -56,7 +57,7 @@ type Options struct {
DisableWS bool DisableWS bool
DisableQUIC bool DisableQUIC bool
Bootnodes []string Bootnodes []string
NetworkID int // TODO: to be used in the handshake protocol NetworkID int32
ConnectionsLow int ConnectionsLow int
ConnectionsHigh int ConnectionsHigh int
ConnectionsGrace time.Duration ConnectionsGrace time.Duration
...@@ -197,7 +198,8 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -197,7 +198,8 @@ func New(ctx context.Context, o Options) (*Service, error) {
s := &Service{ s := &Service{
host: h, host: h,
metrics: newMetrics(), metrics: newMetrics(),
handshakeService: handshake.New(overlay, o.Logger), networkID: o.NetworkID,
handshakeService: handshake.New(overlay, o.NetworkID, o.Logger),
peers: newPeerRegistry(), peers: newPeerRegistry(),
logger: o.Logger, logger: o.Logger,
} }
...@@ -212,11 +214,16 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -212,11 +214,16 @@ func New(ctx context.Context, o Options) (*Service, error) {
s.host.SetStreamHandlerMatch(id, matcher, func(stream network.Stream) { s.host.SetStreamHandlerMatch(id, matcher, func(stream network.Stream) {
peerID := stream.Conn().RemotePeer() peerID := stream.Conn().RemotePeer()
overlay, err := s.handshakeService.Handle(stream) i, err := s.handshakeService.Handle(stream)
if err != nil { if err != nil {
s.logger.Errorf("handshake with peer %s: %w", peerID, err) s.logger.Errorf("handshake with peer %s: %w", peerID, err)
return
} }
s.peers.add(peerID, overlay) if i.NetworkID != s.networkID {
s.logger.Errorf("handshake with peer %s: invalid network id %v", peerID, i.NetworkID)
return
}
s.peers.add(peerID, i.Address)
s.metrics.HandledStreamCount.Inc() s.metrics.HandledStreamCount.Inc()
s.logger.Infof("peer %q connected", overlay) s.logger.Infof("peer %q connected", overlay)
}) })
...@@ -297,14 +304,17 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (err error) { ...@@ -297,14 +304,17 @@ func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (err error) {
} }
defer stream.Close() defer stream.Close()
overlay, err := s.handshakeService.Handshake(stream) i, err := s.handshakeService.Handshake(stream)
if err != nil { if err != nil {
return err return err
} }
if i.NetworkID != s.networkID {
return fmt.Errorf("invalid network id %v", i.NetworkID)
}
s.peers.add(info.ID, overlay) s.peers.add(info.ID, i.Address)
s.metrics.CreatedConnectionCount.Inc() s.metrics.CreatedConnectionCount.Inc()
s.logger.Infof("peer %q connected", overlay) s.logger.Infof("peer %q connected", i.Address)
return nil return nil
} }
func (s *Service) NewStream(ctx context.Context, overlay, protocolName, streamName, version string) (p2p.Stream, error) { func (s *Service) NewStream(ctx context.Context, overlay, protocolName, streamName, version string) (p2p.Stream, error) {
......
...@@ -17,6 +17,7 @@ import ( ...@@ -17,6 +17,7 @@ import (
"github.com/ethersphere/bee/pkg/pingpong" "github.com/ethersphere/bee/pkg/pingpong"
) )
func TestPing(t *testing.T) { func TestPing(t *testing.T) {
logger := logging.New(ioutil.Discard) logger := logging.New(ioutil.Discard)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment