Commit 7297a0f1 authored by Petar Radovic's avatar Petar Radovic Committed by Janoš Guljaš

add hanshake protocol

parent 1bf3600f
//go:generate sh -c "protoc -I . -I \"$(go list -f '{{ .Dir }}' -m github.com/gogo/protobuf)/protobuf\" --gogofaster_out=. handshake.proto"
package handshake
import (
"fmt"
"io"
"log"
"github.com/janos/bee/pkg/p2p"
"github.com/janos/bee/pkg/p2p/protobuf"
)
const (
ProtocolName = "handshake"
StreamName = "handshake"
StreamVersion = "1.0.0"
)
type Service struct {
overlay string
}
func New(overlay string) *Service {
return &Service{overlay: overlay}
}
func (s *Service) Handshake(stream p2p.Stream) (overlay string, err error) {
w, r := protobuf.NewWriterAndReader(stream)
var resp ShakeHand
if err := w.WriteMsg(&ShakeHand{Address: s.overlay}); err != nil {
return "", fmt.Errorf("handshake handler: write message: %v\n", err)
}
log.Printf("sent handshake req %s\n", s.overlay)
if err := r.ReadMsg(&resp); err != nil {
if err == io.EOF {
return "", nil
}
return "", fmt.Errorf("handshake handler: read message: %v\n", err)
}
log.Printf("read handshake resp: %s\n", resp.Address)
return resp.Address, nil
}
func (s *Service) Handler(stream p2p.Stream) string {
w, r := protobuf.NewWriterAndReader(stream)
defer stream.Close()
var req ShakeHand
if err := r.ReadMsg(&req); err != nil {
if err == io.EOF {
return ""
}
log.Printf("handshake handler: read message: %v\n", err)
return ""
}
log.Printf("received handshake req %s\n", req.Address)
if err := w.WriteMsg(&ShakeHand{
Address: s.overlay,
}); err != nil {
log.Printf("handshake handler: write message: %v\n", err)
}
log.Printf("sent handshake resp: %s\n", s.overlay)
return req.Address
}
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: handshake.proto
package handshake
import (
fmt "fmt"
proto "github.com/gogo/protobuf/proto"
io "io"
math "math"
math_bits "math/bits"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
type ShakeHand struct {
Address string `protobuf:"bytes,1,opt,name=PeerID,proto3" json:"PeerID,omitempty"`
}
func (m *ShakeHand) Reset() { *m = ShakeHand{} }
func (m *ShakeHand) String() string { return proto.CompactTextString(m) }
func (*ShakeHand) ProtoMessage() {}
func (*ShakeHand) Descriptor() ([]byte, []int) {
return fileDescriptor_a77305914d5d202f, []int{0}
}
func (m *ShakeHand) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *ShakeHand) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_ShakeHand.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *ShakeHand) XXX_Merge(src proto.Message) {
xxx_messageInfo_ShakeHand.Merge(m, src)
}
func (m *ShakeHand) XXX_Size() int {
return m.Size()
}
func (m *ShakeHand) XXX_DiscardUnknown() {
xxx_messageInfo_ShakeHand.DiscardUnknown(m)
}
var xxx_messageInfo_ShakeHand proto.InternalMessageInfo
func (m *ShakeHand) GetAddress() string {
if m != nil {
return m.Address
}
return ""
}
func init() {
proto.RegisterType((*ShakeHand)(nil), "handshake.ShakeHand")
}
func init() { proto.RegisterFile("handshake.proto", fileDescriptor_a77305914d5d202f) }
var fileDescriptor_a77305914d5d202f = []byte{
// 108 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xcf, 0x48, 0xcc, 0x4b,
0x29, 0xce, 0x48, 0xcc, 0x4e, 0xd5, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x84, 0x0b, 0x28,
0xa9, 0x72, 0x71, 0x06, 0x83, 0x18, 0x1e, 0x89, 0x79, 0x29, 0x42, 0x12, 0x5c, 0xec, 0x8e, 0x29,
0x29, 0x45, 0xa9, 0xc5, 0xc5, 0x12, 0x8c, 0x0a, 0x8c, 0x1a, 0x9c, 0x41, 0x30, 0xae, 0x93, 0xc4,
0x89, 0x47, 0x72, 0x8c, 0x17, 0x1e, 0xc9, 0x31, 0x3e, 0x78, 0x24, 0xc7, 0x38, 0xe1, 0xb1, 0x1c,
0xc3, 0x85, 0xc7, 0x72, 0x0c, 0x37, 0x1e, 0xcb, 0x31, 0x24, 0xb1, 0x81, 0x8d, 0x34, 0x06, 0x04,
0x00, 0x00, 0xff, 0xff, 0x5d, 0x34, 0x69, 0xba, 0x65, 0x00, 0x00, 0x00,
}
func (m *ShakeHand) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *ShakeHand) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *ShakeHand) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if len(m.Address) > 0 {
i -= len(m.Address)
copy(dAtA[i:], m.Address)
i = encodeVarintHandshake(dAtA, i, uint64(len(m.Address)))
i--
dAtA[i] = 0xa
}
return len(dAtA) - i, nil
}
func encodeVarintHandshake(dAtA []byte, offset int, v uint64) int {
offset -= sovHandshake(v)
base := offset
for v >= 1<<7 {
dAtA[offset] = uint8(v&0x7f | 0x80)
v >>= 7
offset++
}
dAtA[offset] = uint8(v)
return base
}
func (m *ShakeHand) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
l = len(m.Address)
if l > 0 {
n += 1 + l + sovHandshake(uint64(l))
}
return n
}
func sovHandshake(x uint64) (n int) {
return (math_bits.Len64(x|1) + 6) / 7
}
func sozHandshake(x uint64) (n int) {
return sovHandshake(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
func (m *ShakeHand) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: ShakeHand: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: ShakeHand: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field PeerID", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthHandshake
}
postIndex := iNdEx + intStringLen
if postIndex < 0 {
return ErrInvalidLengthHandshake
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Address = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipHandshake(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthHandshake
}
if (iNdEx + skippy) < 0 {
return ErrInvalidLengthHandshake
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipHandshake(dAtA []byte) (n int, err error) {
l := len(dAtA)
iNdEx := 0
depth := 0
for iNdEx < l {
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowHandshake
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
wireType := int(wire & 0x7)
switch wireType {
case 0:
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowHandshake
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
iNdEx++
if dAtA[iNdEx-1] < 0x80 {
break
}
}
case 1:
iNdEx += 8
case 2:
var length int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowHandshake
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
length |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
if length < 0 {
return 0, ErrInvalidLengthHandshake
}
iNdEx += length
case 3:
depth++
case 4:
if depth == 0 {
return 0, ErrUnexpectedEndOfGroupHandshake
}
depth--
case 5:
iNdEx += 4
default:
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
}
if iNdEx < 0 {
return 0, ErrInvalidLengthHandshake
}
if depth == 0 {
return iNdEx, nil
}
}
return 0, io.ErrUnexpectedEOF
}
var (
ErrInvalidLengthHandshake = fmt.Errorf("proto: negative length found during unmarshaling")
ErrIntOverflowHandshake = fmt.Errorf("proto: integer overflow")
ErrUnexpectedEndOfGroupHandshake = fmt.Errorf("proto: unexpected end of group")
)
syntax = "proto3";
package handshake;
message ShakeHand {
string Address = 1;
}
...@@ -10,12 +10,15 @@ import ( ...@@ -10,12 +10,15 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"math/rand"
"net" "net"
"os" "os"
"strconv"
"time" "time"
"github.com/janos/bee/pkg/p2p" "github.com/janos/bee/pkg/p2p"
handshake "github.com/janos/bee/pkg/p2p/libp2p/internal/handshake"
"github.com/libp2p/go-libp2p" "github.com/libp2p/go-libp2p"
autonat "github.com/libp2p/go-libp2p-autonat-svc" autonat "github.com/libp2p/go-libp2p-autonat-svc"
connmgr "github.com/libp2p/go-libp2p-connmgr" connmgr "github.com/libp2p/go-libp2p-connmgr"
...@@ -23,7 +26,7 @@ import ( ...@@ -23,7 +26,7 @@ import (
"github.com/libp2p/go-libp2p-core/helpers" "github.com/libp2p/go-libp2p-core/helpers"
"github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/host"
"github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer" libp2ppeer "github.com/libp2p/go-libp2p-core/peer"
protocol "github.com/libp2p/go-libp2p-core/protocol" protocol "github.com/libp2p/go-libp2p-core/protocol"
libp2pquic "github.com/libp2p/go-libp2p-quic-transport" libp2pquic "github.com/libp2p/go-libp2p-quic-transport"
secio "github.com/libp2p/go-libp2p-secio" secio "github.com/libp2p/go-libp2p-secio"
...@@ -35,8 +38,11 @@ import ( ...@@ -35,8 +38,11 @@ import (
var _ p2p.Service = new(Service) var _ p2p.Service = new(Service)
type Service struct { type Service struct {
host host.Host host host.Host
metrics metrics metrics metrics
handshakeService *handshake.Service
overlayToPeerID map[string]libp2ppeer.ID
peerIDToOverlay map[libp2ppeer.ID]string
} }
type Options struct { type Options struct {
...@@ -72,7 +78,6 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -72,7 +78,6 @@ func New(ctx context.Context, o Options) (*Service, error) {
} }
var listenAddrs []string var listenAddrs []string
if ip4Addr != "" { if ip4Addr != "" {
listenAddrs = append(listenAddrs, fmt.Sprintf("/ip4/%s/tcp/%s", ip4Addr, port)) listenAddrs = append(listenAddrs, fmt.Sprintf("/ip4/%s/tcp/%s", ip4Addr, port))
if !o.DisableWS { if !o.DisableWS {
...@@ -174,11 +179,29 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -174,11 +179,29 @@ func New(ctx context.Context, o Options) (*Service, error) {
return nil, fmt.Errorf("autonat: %w", err) return nil, fmt.Errorf("autonat: %w", err)
} }
overlay := strconv.Itoa(rand.Int())
s := &Service{ s := &Service{
host: h, host: h,
metrics: newMetrics(), metrics: newMetrics(),
overlayToPeerID: make(map[string]libp2ppeer.ID),
peerIDToOverlay: make(map[libp2ppeer.ID]string),
handshakeService: handshake.New(overlay),
}
// Construct protocols.
id := protocol.ID(p2p.NewSwarmStreamName(handshake.ProtocolName, handshake.StreamName, handshake.StreamVersion))
matcher, err := helpers.MultistreamSemverMatcher(id)
if err != nil {
return nil, fmt.Errorf("match semver %s: %w", id, err)
} }
s.host.SetStreamHandlerMatch(id, matcher, func(stream network.Stream) {
s.metrics.HandledStreamCount.Inc()
overlay := s.handshakeService.Handler(stream)
s.addAddresses(overlay, stream.Conn().RemotePeer())
})
// TODO: be more resilient on connection errors and connect in parallel // TODO: be more resilient on connection errors and connect in parallel
for _, a := range o.Bootnodes { for _, a := range o.Bootnodes {
addr, err := ma.NewMultiaddr(a) addr, err := ma.NewMultiaddr(a)
...@@ -186,7 +209,8 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -186,7 +209,8 @@ func New(ctx context.Context, o Options) (*Service, error) {
return nil, fmt.Errorf("bootnode %s: %w", a, err) return nil, fmt.Errorf("bootnode %s: %w", a, err)
} }
if _, err := s.Connect(ctx, addr); err != nil { err = s.Connect(ctx, addr)
if err != nil {
return nil, fmt.Errorf("connect to bootnode %s: %w", a, err) return nil, fmt.Errorf("connect to bootnode %s: %w", a, err)
} }
} }
...@@ -205,12 +229,17 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) { ...@@ -205,12 +229,17 @@ func (s *Service) AddProtocol(p p2p.ProtocolSpec) (err error) {
if err != nil { if err != nil {
return fmt.Errorf("match semver %s: %w", id, err) return fmt.Errorf("match semver %s: %w", id, err)
} }
s.host.SetStreamHandlerMatch(id, matcher, func(stream network.Stream) { s.host.SetStreamHandlerMatch(id, matcher, func(stream network.Stream) {
overlay, ok := s.peerIDToOverlay[stream.Conn().RemotePeer()]
if !ok {
// todo: handle better
fmt.Printf("Could not fetch handshake for peerID %s\n", stream)
return
}
s.metrics.HandledStreamCount.Inc() s.metrics.HandledStreamCount.Inc()
ss.Handler(p2p.Peer{ ss.Handler(p2p.Peer{Address: overlay}, stream)
Addr: stream.Conn().RemoteMultiaddr(),
Stream: stream,
})
}) })
} }
return nil return nil
...@@ -231,29 +260,46 @@ func (s *Service) Addresses() (addrs []string, err error) { ...@@ -231,29 +260,46 @@ func (s *Service) Addresses() (addrs []string, err error) {
return addrs, nil return addrs, nil
} }
func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (peerID string, err error) { func (s *Service) Connect(ctx context.Context, addr ma.Multiaddr) (err error) {
// Extract the peer ID from the multiaddr. // Extract the peer ID from the multiaddr.
info, err := peer.AddrInfoFromP2pAddr(addr) info, err := libp2ppeer.AddrInfoFromP2pAddr(addr)
if err != nil { if err != nil {
return "", err return err
} }
if err := s.host.Connect(ctx, *info); err != nil { if err := s.host.Connect(ctx, *info); err != nil {
return "", err return err
}
stream, err := s.newStreamForPeerID(ctx, info.ID, handshake.ProtocolName, handshake.StreamName, handshake.StreamVersion)
if err != nil {
return fmt.Errorf("new stream: %w", err)
}
defer stream.Close()
overlay, err := s.handshakeService.Handshake(stream)
if err != nil {
return err
} }
s.addAddresses(overlay, info.ID)
s.metrics.CreatedConnectionCount.Inc() s.metrics.CreatedConnectionCount.Inc()
fmt.Println("handshake handshake finished")
return nil
}
func (s *Service) NewStream(ctx context.Context, overlay, protocolName, streamName, version string) (p2p.Stream, error) {
peerID, ok := s.overlayToPeerID[overlay]
if !ok {
fmt.Printf("Could not fetch peerID for handshake %s\n", overlay)
return nil, nil
}
return info.ID.String(), nil return s.newStreamForPeerID(ctx, peerID, protocolName, streamName, version)
} }
func (s *Service) NewStream(ctx context.Context, peerID, protocolName, streamName, version string) (p2p.Stream, error) { func (s *Service) newStreamForPeerID(ctx context.Context, peerID libp2ppeer.ID, protocolName, streamName, version string) (p2p.Stream, error) {
id, err := peer.Decode(peerID)
if err != nil {
return nil, fmt.Errorf("decode peer id %q: %w", peerID, err)
}
swarmStreamName := p2p.NewSwarmStreamName(protocolName, streamName, version) swarmStreamName := p2p.NewSwarmStreamName(protocolName, streamName, version)
st, err := s.host.NewStream(ctx, id, protocol.ID(swarmStreamName)) st, err := s.host.NewStream(ctx, peerID, protocol.ID(swarmStreamName))
if err != nil { if err != nil {
if err == multistream.ErrNotSupported || err == multistream.ErrIncorrectVersion { if err == multistream.ErrNotSupported || err == multistream.ErrIncorrectVersion {
return nil, p2p.NewIncompatibleStreamError(err) return nil, p2p.NewIncompatibleStreamError(err)
...@@ -264,6 +310,11 @@ func (s *Service) NewStream(ctx context.Context, peerID, protocolName, streamNam ...@@ -264,6 +310,11 @@ func (s *Service) NewStream(ctx context.Context, peerID, protocolName, streamNam
return st, nil return st, nil
} }
func (s *Service) addAddresses(overlay string, peerID libp2ppeer.ID) {
s.overlayToPeerID[overlay] = peerID
s.peerIDToOverlay[peerID] = overlay
}
func (s *Service) Close() error { func (s *Service) Close() error {
return s.host.Close() return s.host.Close()
} }
...@@ -11,7 +11,6 @@ import ( ...@@ -11,7 +11,6 @@ import (
"sync" "sync"
"github.com/janos/bee/pkg/p2p" "github.com/janos/bee/pkg/p2p"
ma "github.com/multiformats/go-multiaddr"
) )
type Recorder struct { type Recorder struct {
...@@ -27,13 +26,13 @@ func NewRecorder(protocols ...p2p.ProtocolSpec) *Recorder { ...@@ -27,13 +26,13 @@ func NewRecorder(protocols ...p2p.ProtocolSpec) *Recorder {
} }
} }
func (r *Recorder) NewStream(_ context.Context, peerID, protocolName, streamName, version string) (p2p.Stream, error) { func (r *Recorder) NewStream(_ context.Context, overlay, protocolName, streamName, version string) (p2p.Stream, error) {
recordIn := newRecord() recordIn := newRecord()
recordOut := newRecord() recordOut := newRecord()
streamOut := newStream(recordIn, recordOut) streamOut := newStream(recordIn, recordOut)
streamIn := newStream(recordOut, recordIn) streamIn := newStream(recordOut, recordIn)
var handler func(p2p.Peer) var handler func(p2p.Peer, p2p.Stream)
for _, p := range r.protocols { for _, p := range r.protocols {
if p.Name == protocolName { if p.Name == protocolName {
for _, s := range p.StreamSpecs { for _, s := range p.StreamSpecs {
...@@ -46,13 +45,9 @@ func (r *Recorder) NewStream(_ context.Context, peerID, protocolName, streamName ...@@ -46,13 +45,9 @@ func (r *Recorder) NewStream(_ context.Context, peerID, protocolName, streamName
if handler == nil { if handler == nil {
return nil, fmt.Errorf("unsupported protocol stream %q %q %q", protocolName, streamName, version) return nil, fmt.Errorf("unsupported protocol stream %q %q %q", protocolName, streamName, version)
} }
go handler(p2p.Peer{Address: overlay}, streamIn)
go handler(p2p.Peer{ id := overlay + p2p.NewSwarmStreamName(protocolName, streamName, version)
Addr: ma.StringCast(peerID),
Stream: streamIn,
})
id := peerID + p2p.NewSwarmStreamName(protocolName, streamName, version)
r.recordsMu.Lock() r.recordsMu.Lock()
defer r.recordsMu.Unlock() defer r.recordsMu.Unlock()
......
...@@ -14,11 +14,11 @@ import ( ...@@ -14,11 +14,11 @@ import (
type Service interface { type Service interface {
AddProtocol(ProtocolSpec) error AddProtocol(ProtocolSpec) error
Connect(ctx context.Context, addr ma.Multiaddr) (peerID string, err error) Connect(ctx context.Context, addr ma.Multiaddr) (err error)
} }
type Streamer interface { type Streamer interface {
NewStream(ctx context.Context, peerID, protocol, stream, version string) (Stream, error) NewStream(ctx context.Context, address, protocol, stream, version string) (Stream, error)
} }
type Stream interface { type Stream interface {
...@@ -26,11 +26,6 @@ type Stream interface { ...@@ -26,11 +26,6 @@ type Stream interface {
io.Closer io.Closer
} }
type Peer struct {
Addr ma.Multiaddr
Stream Stream
}
type ProtocolSpec struct { type ProtocolSpec struct {
Name string Name string
StreamSpecs []StreamSpec StreamSpecs []StreamSpec
...@@ -39,7 +34,7 @@ type ProtocolSpec struct { ...@@ -39,7 +34,7 @@ type ProtocolSpec struct {
type StreamSpec struct { type StreamSpec struct {
Name string Name string
Version string Version string
Handler func(Peer) Handler func(Peer, Stream)
} }
type IncompatibleStreamError struct { type IncompatibleStreamError struct {
......
package p2p
type Peer struct {
Address string
}
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:generate sh -c "protoc -I . -I \"$(go list -f '{{ .Dir }}' -m github.com/gogo/protobuf)/protobuf\" --gogofaster_out=. pingpong.proto" //go:generate sh -c "protoc -I . -I \"$(go list -f '{{ .Dir }}' -m github.com/gogo/protobuf)/protobuf\" --gogofaster_out=. pingpong.proto"
package pingpong package pingpong
import ( import (
...@@ -59,8 +58,8 @@ func (s *Service) Protocol() p2p.ProtocolSpec { ...@@ -59,8 +58,8 @@ func (s *Service) Protocol() p2p.ProtocolSpec {
} }
} }
func (s *Service) Ping(ctx context.Context, peerID string, msgs ...string) (rtt time.Duration, err error) { func (s *Service) Ping(ctx context.Context, address string, msgs ...string) (rtt time.Duration, err error) {
stream, err := s.streamer.NewStream(ctx, peerID, protocolName, streamName, streamVersion) stream, err := s.streamer.NewStream(ctx, address, protocolName, streamName, streamVersion)
if err != nil { if err != nil {
return 0, fmt.Errorf("new stream: %w", err) return 0, fmt.Errorf("new stream: %w", err)
} }
...@@ -91,10 +90,11 @@ func (s *Service) Ping(ctx context.Context, peerID string, msgs ...string) (rtt ...@@ -91,10 +90,11 @@ func (s *Service) Ping(ctx context.Context, peerID string, msgs ...string) (rtt
return time.Since(start) / time.Duration(len(msgs)), nil return time.Since(start) / time.Duration(len(msgs)), nil
} }
func (s *Service) Handler(p p2p.Peer) { func (s *Service) Handler(peer p2p.Peer, stream p2p.Stream) {
w, r := protobuf.NewWriterAndReader(p.Stream) w, r := protobuf.NewWriterAndReader(stream)
defer p.Stream.Close() defer stream.Close()
fmt.Printf("Initiate pinpong for peer %s", peer)
var ping Ping var ping Ping
for { for {
if err := r.ReadMsg(&ping); err != nil { if err := r.ReadMsg(&ping); err != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment