Commit 2b6cbf1b authored by Petar Radovic's avatar Petar Radovic

Merge branch 'master' of github.com:ethersphere/bee

parents 9b6566e9 a40d342b
...@@ -68,7 +68,7 @@ func (c *command) initStartCmd() (err error) { ...@@ -68,7 +68,7 @@ func (c *command) initStartCmd() (err error) {
return fmt.Errorf("unknown verbosity level %q", v) return fmt.Errorf("unknown verbosity level %q", v)
} }
var libp2pPrivateKey io.ReadWriteCloser var libp2pPrivateKey, swarmPrivateKey io.ReadWriteCloser
if dataDir := c.config.GetString(optionNameDataDir); dataDir != "" { if dataDir := c.config.GetString(optionNameDataDir); dataDir != "" {
if err := os.MkdirAll(dataDir, os.ModePerm); err != nil { if err := os.MkdirAll(dataDir, os.ModePerm); err != nil {
return err return err
...@@ -78,6 +78,11 @@ func (c *command) initStartCmd() (err error) { ...@@ -78,6 +78,11 @@ func (c *command) initStartCmd() (err error) {
return err return err
} }
libp2pPrivateKey = libp2pKey libp2pPrivateKey = libp2pKey
swarmKey, err := os.OpenFile(filepath.Join(dataDir, "swarm.key"), os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
return err
}
swarmPrivateKey = swarmKey
} }
debugAPIAddr := c.config.GetString(optionNameDebugAPIAddr) debugAPIAddr := c.config.GetString(optionNameDebugAPIAddr)
...@@ -86,6 +91,7 @@ func (c *command) initStartCmd() (err error) { ...@@ -86,6 +91,7 @@ func (c *command) initStartCmd() (err error) {
} }
b, err := node.NewBee(node.Options{ b, err := node.NewBee(node.Options{
PrivateKey: swarmPrivateKey,
APIAddr: c.config.GetString(optionNameAPIAddr), APIAddr: c.config.GetString(optionNameAPIAddr),
DebugAPIAddr: debugAPIAddr, DebugAPIAddr: debugAPIAddr,
LibP2POptions: libp2p.Options{ LibP2POptions: libp2p.Options{
......
...@@ -21,6 +21,7 @@ require ( ...@@ -21,6 +21,7 @@ require (
github.com/sirupsen/logrus v1.4.2 github.com/sirupsen/logrus v1.4.2
github.com/spf13/cobra v0.0.5 github.com/spf13/cobra v0.0.5
github.com/spf13/viper v1.6.2 github.com/spf13/viper v1.6.2
golang.org/x/crypto v0.0.0-20191219195013-becbf705a915
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e
resenje.org/web v0.4.0 resenje.org/web v0.4.0
) )
...@@ -7,7 +7,6 @@ package api ...@@ -7,7 +7,6 @@ package api
import ( import (
"errors" "errors"
"net/http" "net/http"
"time"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
...@@ -16,7 +15,7 @@ import ( ...@@ -16,7 +15,7 @@ import (
) )
type pingpongResponse struct { type pingpongResponse struct {
RTT time.Duration `json:"rtt"` RTT string `json:"rtt"`
} }
func (s *server) pingpongHandler(w http.ResponseWriter, r *http.Request) { func (s *server) pingpongHandler(w http.ResponseWriter, r *http.Request) {
...@@ -46,6 +45,6 @@ func (s *server) pingpongHandler(w http.ResponseWriter, r *http.Request) { ...@@ -46,6 +45,6 @@ func (s *server) pingpongHandler(w http.ResponseWriter, r *http.Request) {
s.Logger.Infof("pingpong succeeded to peer %s", peerID) s.Logger.Infof("pingpong succeeded to peer %s", peerID)
jsonhttp.OK(w, pingpongResponse{ jsonhttp.OK(w, pingpongResponse{
RTT: rtt, RTT: rtt.String(),
}) })
} }
...@@ -43,7 +43,7 @@ func TestPingpong(t *testing.T) { ...@@ -43,7 +43,7 @@ func TestPingpong(t *testing.T) {
t.Run("ok", func(t *testing.T) { t.Run("ok", func(t *testing.T) {
jsonhttptest.ResponseDirect(t, client, http.MethodPost, "/pingpong/"+peerID.String(), nil, http.StatusOK, api.PingpongResponse{ jsonhttptest.ResponseDirect(t, client, http.MethodPost, "/pingpong/"+peerID.String(), nil, http.StatusOK, api.PingpongResponse{
RTT: rtt, RTT: rtt.String(),
}) })
}) })
...@@ -64,7 +64,7 @@ func TestPingpong(t *testing.T) { ...@@ -64,7 +64,7 @@ func TestPingpong(t *testing.T) {
t.Run("error", func(t *testing.T) { t.Run("error", func(t *testing.T) {
jsonhttptest.ResponseDirect(t, client, http.MethodPost, "/pingpong/"+errorPeerID.String(), nil, http.StatusInternalServerError, jsonhttp.StatusResponse{ jsonhttptest.ResponseDirect(t, client, http.MethodPost, "/pingpong/"+errorPeerID.String(), nil, http.StatusInternalServerError, jsonhttp.StatusResponse{
Code: http.StatusInternalServerError, Code: http.StatusInternalServerError,
Message: http.StatusText(http.StatusInternalServerError), // do not leek internal error Message: http.StatusText(http.StatusInternalServerError), // do not leak internal error
}) })
}) })
......
...@@ -18,9 +18,7 @@ import ( ...@@ -18,9 +18,7 @@ import (
func (s *server) setupRouting() { func (s *server) setupRouting() {
router := mux.NewRouter() router := mux.NewRouter()
router.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { router.NotFoundHandler = http.HandlerFunc(jsonhttp.NotFoundHandler)
jsonhttp.NotFound(w, nil)
})
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Ethereum Swarm Bee") fmt.Fprintln(w, "Ethereum Swarm Bee")
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package crypto
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/json"
"fmt"
"github.com/btcsuite/btcd/btcec"
"github.com/ethersphere/bee/pkg/swarm"
"golang.org/x/crypto/sha3"
)
var keyTypeSecp256k1 = "secp256k1"
// GenerateSecp256k1Key generates an ECDSA private key using
// secp256k1 elliptic curve.
func GenerateSecp256k1Key() (*ecdsa.PrivateKey, error) {
return ecdsa.GenerateKey(btcec.S256(), rand.Reader)
}
// NewAddress constructs a Swarm Address from ECDSA private key.
func NewAddress(p ecdsa.PublicKey) swarm.Address {
h := sha3.Sum256(elliptic.Marshal(btcec.S256(), p.X, p.Y))
return swarm.NewAddress(h[:])
}
// privateKey holds information about Swarm private key for marshaling.
type privateKey struct {
Type string `json:"type"`
Key []byte `json:"key"`
}
// MarshalSecp256k1PrivateKey marshals secp256k1 ECDSA private key
// that can be unmarshaled by UnmarshalPrivateKey.
func MarshalSecp256k1PrivateKey(k *ecdsa.PrivateKey) ([]byte, error) {
return json.Marshal(privateKey{
Type: keyTypeSecp256k1,
Key: (*btcec.PrivateKey)(k).Serialize(),
})
}
// UnmarshalPrivateKey unmarshals ECDSA private key from encoded data.
func UnmarshalPrivateKey(data []byte) (*ecdsa.PrivateKey, error) {
var pk privateKey
if err := json.Unmarshal(data, &pk); err != nil {
return nil, err
}
switch t := pk.Type; t {
case keyTypeSecp256k1:
return decodeSecp256k1PrivateKey(pk.Key)
default:
return nil, fmt.Errorf("unknown key type %q", t)
}
}
// decodeSecp256k1PrivateKey decodes raw ECDSA private key.
func decodeSecp256k1PrivateKey(data []byte) (*ecdsa.PrivateKey, error) {
if l := len(data); l != btcec.PrivKeyBytesLen {
return nil, fmt.Errorf("secp256k1 data size %d expected %d", l, btcec.PrivKeyBytesLen)
}
privk, _ := btcec.PrivKeyFromBytes(btcec.S256(), data)
return (*ecdsa.PrivateKey)(privk), nil
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package crypto_test
import (
"bytes"
"testing"
"github.com/ethersphere/bee/pkg/crypto"
)
func TestGenerateSecp256k1Key(t *testing.T) {
k1, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
if k1 == nil {
t.Fatal("nil key")
}
k2, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
if k2 == nil {
t.Fatal("nil key")
}
if bytes.Equal(k1.D.Bytes(), k2.D.Bytes()) {
t.Fatal("two generated keys are equal")
}
}
func TestNewAddress(t *testing.T) {
k, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
a := crypto.NewAddress(k.PublicKey)
if l := len(a.Bytes()); l != 32 {
t.Errorf("got address length %v, want %v", l, 32)
}
}
func TestMarshalSecp256k1PrivateKey(t *testing.T) {
k1, err := crypto.GenerateSecp256k1Key()
if err != nil {
t.Fatal(err)
}
d, err := crypto.MarshalSecp256k1PrivateKey(k1)
if err != nil {
t.Fatal(err)
}
k2, err := crypto.UnmarshalPrivateKey(d)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(k1.D.Bytes(), k2.D.Bytes()) {
t.Fatal("marshaled and unmarshaled keys are not equal")
}
}
...@@ -16,7 +16,7 @@ import ( ...@@ -16,7 +16,7 @@ import (
) )
type peerConnectResponse struct { type peerConnectResponse struct {
Address string Address string `json:"address"`
} }
func (s *server) peerConnectHandler(w http.ResponseWriter, r *http.Request) { func (s *server) peerConnectHandler(w http.ResponseWriter, r *http.Request) {
......
...@@ -27,9 +27,7 @@ func (s *server) setupRouting() { ...@@ -27,9 +27,7 @@ func (s *server) setupRouting() {
)) ))
router := mux.NewRouter() router := mux.NewRouter()
router.NotFoundHandler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { router.NotFoundHandler = http.HandlerFunc(jsonhttp.NotFoundHandler)
jsonhttp.NotFound(w, nil)
})
router.Handle("/debug/pprof/", http.HandlerFunc(pprof.Index)) router.Handle("/debug/pprof/", http.HandlerFunc(pprof.Index))
router.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) router.Handle("/debug/pprof/cmdline", http.HandlerFunc(pprof.Cmdline))
......
...@@ -15,3 +15,7 @@ type MethodHandler map[string]http.Handler ...@@ -15,3 +15,7 @@ type MethodHandler map[string]http.Handler
func (h MethodHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h MethodHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
web.HandleMethods(h, `{"message":"Method Not Allowed","code":405}`, DefaultContentTypeHeader, w, r) web.HandleMethods(h, `{"message":"Method Not Allowed","code":405}`, DefaultContentTypeHeader, w, r)
} }
func NotFoundHandler(w http.ResponseWriter, _ *http.Request) {
NotFound(w, nil)
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package jsonhttp_test
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/ethersphere/bee/pkg/jsonhttp"
)
func TestMethodHandler(t *testing.T) {
contentType := "application/swarm"
h := jsonhttp.MethodHandler{
"POST": http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatal(err)
}
w.Header().Set("Content-Type", contentType)
fmt.Fprint(w, "got: ", string(got))
}),
}
t.Run("method allowed", func(t *testing.T) {
body := "test body"
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
statusCode := w.Result().StatusCode
if statusCode != http.StatusOK {
t.Errorf("got status code %d, want %d", statusCode, http.StatusOK)
}
wantBody := "got: " + body
gotBody := w.Body.String()
if gotBody != wantBody {
t.Errorf("got body %q, want %q", gotBody, wantBody)
}
if got := w.Header().Get("Content-Type"); got != contentType {
t.Errorf("got content type %q, want %q", got, contentType)
}
})
t.Run("method not allowed", func(t *testing.T) {
r := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
statusCode := w.Result().StatusCode
wantCode := http.StatusMethodNotAllowed
if statusCode != wantCode {
t.Errorf("got status code %d, want %d", statusCode, wantCode)
}
var m *jsonhttp.StatusResponse
if err := json.Unmarshal(w.Body.Bytes(), &m); err != nil {
t.Errorf("json unmarshal response body: %s", err)
}
if m.Code != wantCode {
t.Errorf("got message code %d, want %d", m.Code, wantCode)
}
wantMessage := http.StatusText(wantCode)
if m.Message != wantMessage {
t.Errorf("got message message %q, want %q", m.Message, wantMessage)
}
testContentType(t, w)
})
}
func TestNotFoundHandler(t *testing.T) {
w := httptest.NewRecorder()
jsonhttp.NotFoundHandler(w, nil)
statusCode := w.Result().StatusCode
wantCode := http.StatusNotFound
if statusCode != wantCode {
t.Errorf("got status code %d, want %d", statusCode, wantCode)
}
var m *jsonhttp.StatusResponse
if err := json.Unmarshal(w.Body.Bytes(), &m); err != nil {
t.Errorf("json unmarshal response body: %s", err)
}
if m.Code != wantCode {
t.Errorf("got message code %d, want %d", m.Code, wantCode)
}
wantMessage := http.StatusText(wantCode)
if m.Message != wantMessage {
t.Errorf("got message message %q, want %q", m.Message, wantMessage)
}
testContentType(t, w)
}
...@@ -36,6 +36,9 @@ type StatusResponse struct { ...@@ -36,6 +36,9 @@ type StatusResponse struct {
// Respond writes a JSON-encoded body to http.ResponseWriter. // Respond writes a JSON-encoded body to http.ResponseWriter.
func Respond(w http.ResponseWriter, statusCode int, response interface{}) { func Respond(w http.ResponseWriter, statusCode int, response interface{}) {
if statusCode == 0 {
statusCode = http.StatusOK
}
if response == nil { if response == nil {
response = &StatusResponse{ response = &StatusResponse{
Message: http.StatusText(statusCode), Message: http.StatusText(statusCode),
......
This diff is collapsed.
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package jsonhttptest_test
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/jsonhttp/jsonhttptest"
)
func TestResponse(t *testing.T) {
type response struct {
Message string `json:"message"`
}
message := "text"
wantMethod, wantPath, wantBody := http.MethodPatch, "/testing", "request body"
var gotMethod, gotPath, gotBody string
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotMethod = r.Method
gotPath = r.URL.Path
b, err := ioutil.ReadAll(r.Body)
if err != nil {
jsonhttp.InternalServerError(w, err)
return
}
gotBody = string(b)
jsonhttp.Created(w, response{
Message: message,
})
}))
defer s.Close()
c := s.Client()
t.Run("direct", func(t *testing.T) {
jsonhttptest.ResponseDirect(t, c, wantMethod, s.URL+wantPath, strings.NewReader(wantBody), http.StatusCreated, response{
Message: message,
})
if gotMethod != wantMethod {
t.Errorf("got method %s, want %s", gotMethod, wantMethod)
}
if gotPath != wantPath {
t.Errorf("got path %s, want %s", gotPath, wantPath)
}
if gotBody != wantBody {
t.Errorf("got body %s, want %s", gotBody, wantBody)
}
})
t.Run("unmarshal", func(t *testing.T) {
var r response
jsonhttptest.ResponseUnmarshal(t, c, wantMethod, s.URL+wantPath, strings.NewReader(wantBody), http.StatusCreated, &r)
if gotMethod != wantMethod {
t.Errorf("got method %s, want %s", gotMethod, wantMethod)
}
if gotPath != wantPath {
t.Errorf("got path %s, want %s", gotPath, wantPath)
}
if gotBody != wantBody {
t.Errorf("got body %s, want %s", gotBody, wantBody)
}
if r.Message != message {
t.Errorf("got messag %s, want %s", r.Message, message)
}
})
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package metrics_test
import (
"strings"
"testing"
"github.com/ethersphere/bee/pkg/metrics"
"github.com/prometheus/client_golang/prometheus"
)
func TestPrometheusCollectorsFromFields(t *testing.T) {
s := newService()
collectors := metrics.PrometheusCollectorsFromFields(s)
if l := len(collectors); l != 2 {
t.Fatalf("got %v collectors %+v, want 2", l, collectors)
}
m1 := collectors[0].(prometheus.Metric).Desc().String()
if !strings.Contains(m1, "api_request_count") {
t.Errorf("unexpected metric %s", m1)
}
m2 := collectors[1].(prometheus.Metric).Desc().String()
if !strings.Contains(m2, "api_response_duration_seconds") {
t.Errorf("unexpected metric %s", m2)
}
}
type service struct {
// valid metrics
RequestCount prometheus.Counter
ResponseDuration prometheus.Histogram
// invalid metrics
unexportedCount prometheus.Counter
UninitializedCount prometheus.Counter
}
func newService() *service {
return &service{
RequestCount: prometheus.NewCounter(prometheus.CounterOpts{
Name: "api_request_count",
Help: "Number of API requests.",
}),
ResponseDuration: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "api_response_duration_seconds",
Help: "Histogram of API response durations.",
Buckets: []float64{0.01, 0.1, 0.25, 0.5, 1, 2.5, 5, 10},
}),
unexportedCount: prometheus.NewCounter(prometheus.CounterOpts{
Name: "api_unexported_count",
Help: "This metrics should not be discoverable by metrics.PrometheusCollectorsFromFields.",
}),
}
}
...@@ -5,17 +5,22 @@ ...@@ -5,17 +5,22 @@
package node package node
import ( import (
"bytes"
"context" "context"
"crypto/ecdsa"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
"os"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/debugapi" "github.com/ethersphere/bee/pkg/debugapi"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/metrics" "github.com/ethersphere/bee/pkg/metrics"
...@@ -32,6 +37,7 @@ type Bee struct { ...@@ -32,6 +37,7 @@ type Bee struct {
} }
type Options struct { type Options struct {
PrivateKey io.ReadWriteCloser
APIAddr string APIAddr string
DebugAPIAddr string DebugAPIAddr string
LibP2POptions libp2p.Options LibP2POptions libp2p.Options
...@@ -48,8 +54,50 @@ func NewBee(o Options) (*Bee, error) { ...@@ -48,8 +54,50 @@ func NewBee(o Options) (*Bee, error) {
errorLogWriter: logger.WriterLevel(logrus.ErrorLevel), errorLogWriter: logger.WriterLevel(logrus.ErrorLevel),
} }
var privateKey *ecdsa.PrivateKey
if o.PrivateKey != nil {
privateKeyData, err := ioutil.ReadAll(o.PrivateKey)
if err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("read private key: %w", err)
}
if len(privateKeyData) == 0 {
var err error
privateKey, err = crypto.GenerateSecp256k1Key()
if err != nil {
return nil, fmt.Errorf("generate secp256k1 key: %w", err)
}
d, err := crypto.MarshalSecp256k1PrivateKey(privateKey)
if err != nil {
return nil, fmt.Errorf("encode private key: %w", err)
}
if _, err := io.Copy(o.PrivateKey, bytes.NewReader(d)); err != nil {
return nil, fmt.Errorf("write private key: %w", err)
}
} else {
var err error
privateKey, err = crypto.UnmarshalPrivateKey(privateKeyData)
if err != nil {
return nil, fmt.Errorf("decode private key: %w", err)
}
}
if err := o.PrivateKey.Close(); err != nil {
return nil, fmt.Errorf("close private key: %w", err)
}
} else {
var err error
privateKey, err = crypto.GenerateSecp256k1Key()
if err != nil {
return nil, fmt.Errorf("generate secp256k1 key: %w", err)
}
}
address := crypto.NewAddress(privateKey.PublicKey)
logger.Infof("address: %s", address)
// Construct P2P service. // Construct P2P service.
p2ps, err := libp2p.New(p2pCtx, o.LibP2POptions) libP2POptions := o.LibP2POptions
libP2POptions.Overlay = address
p2ps, err := libp2p.New(p2pCtx, libP2POptions)
if err != nil { if err != nil {
return nil, fmt.Errorf("p2p service: %w", err) return nil, fmt.Errorf("p2p service: %w", err)
} }
......
...@@ -4,6 +4,15 @@ ...@@ -4,6 +4,15 @@
package p2p package p2p
import (
"errors"
"fmt"
)
// ErrPeerNotFound should be returned by p2p service methods when the requested
// peer is not found.
var ErrPeerNotFound = errors.New("peer not found")
// DisconnectError is an error that is specifically handled inside p2p. If returned by specific protocol // DisconnectError is an error that is specifically handled inside p2p. If returned by specific protocol
// handler it causes peer disconnect. // handler it causes peer disconnect.
type DisconnectError struct { type DisconnectError struct {
...@@ -25,3 +34,24 @@ func (e *DisconnectError) Unwrap() error { return e.err } ...@@ -25,3 +34,24 @@ func (e *DisconnectError) Unwrap() error { return e.err }
func (e *DisconnectError) Error() string { func (e *DisconnectError) Error() string {
return e.err.Error() return e.err.Error()
} }
// IncompatibleStreamError is the error that should be returned by p2p service
// NewStream method when the stream or its version is not supported.
type IncompatibleStreamError struct {
err error
}
// NewIncompatibleStreamError wraps the error that is the cause of stream
// incompatibility with IncompatibleStreamError that it can be detected and
// returns it.
func NewIncompatibleStreamError(err error) *IncompatibleStreamError {
return &IncompatibleStreamError{err: err}
}
// Unwrap returns an underlying error.
func (e *IncompatibleStreamError) Unwrap() error { return e.err }
// Error implements function of the standard go error interface.
func (e *IncompatibleStreamError) Error() string {
return fmt.Sprintf("incompatible stream: %v", e.err)
}
...@@ -51,10 +51,12 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) { ...@@ -51,10 +51,12 @@ func (s *Service) Handshake(stream p2p.Stream) (i *Info, err error) {
return nil, fmt.Errorf("ack: write message: %w", err) return nil, fmt.Errorf("ack: write message: %w", err)
} }
s.logger.Tracef("handshake finished for peer %s", resp.ShakeHand.Address) address := swarm.NewAddress(resp.ShakeHand.Address)
s.logger.Tracef("handshake finished for peer %s", address)
return &Info{ return &Info{
Address: swarm.NewAddress(resp.ShakeHand.Address), Address: address,
NetworkID: resp.ShakeHand.NetworkID, NetworkID: resp.ShakeHand.NetworkID,
Light: resp.ShakeHand.Light, Light: resp.ShakeHand.Light,
}, nil }, nil
...@@ -84,9 +86,11 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) { ...@@ -84,9 +86,11 @@ func (s *Service) Handle(stream p2p.Stream) (i *Info, err error) {
return nil, fmt.Errorf("ack: read message: %w", err) return nil, fmt.Errorf("ack: read message: %w", err)
} }
s.logger.Tracef("handshake finished for peer %s", req.Address) address := swarm.NewAddress(req.Address)
s.logger.Tracef("handshake finished for peer %s", address)
return &Info{ return &Info{
Address: swarm.NewAddress(req.Address), Address: address,
NetworkID: req.NetworkID, NetworkID: req.NetworkID,
Light: req.Light, Light: req.Light,
}, nil }, nil
......
...@@ -11,7 +11,6 @@ import ( ...@@ -11,7 +11,6 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"math/rand"
"net" "net"
"os" "os"
"time" "time"
...@@ -38,11 +37,6 @@ import ( ...@@ -38,11 +37,6 @@ import (
var _ p2p.Service = (*Service)(nil) var _ p2p.Service = (*Service)(nil)
func init() {
// Only temporary for fake overlay address generation.
rand.Seed(time.Now().UnixNano())
}
type Service struct { type Service struct {
host host.Host host host.Host
metrics metrics metrics metrics
...@@ -54,6 +48,7 @@ type Service struct { ...@@ -54,6 +48,7 @@ type Service struct {
type Options struct { type Options struct {
PrivateKey io.ReadWriteCloser PrivateKey io.ReadWriteCloser
Overlay swarm.Address
Addr string Addr string
DisableWS bool DisableWS bool
DisableQUIC bool DisableQUIC bool
...@@ -187,15 +182,11 @@ func New(ctx context.Context, o Options) (*Service, error) { ...@@ -187,15 +182,11 @@ func New(ctx context.Context, o Options) (*Service, error) {
return nil, fmt.Errorf("autonat: %w", err) return nil, fmt.Errorf("autonat: %w", err)
} }
// This is just a temporary way to generate an overlay address.
// TODO: proper key management and overlay address generation
overlay := make([]byte, 32)
rand.Read(overlay)
s := &Service{ s := &Service{
host: h, host: h,
metrics: newMetrics(), metrics: newMetrics(),
networkID: o.NetworkID, networkID: o.NetworkID,
handshakeService: handshake.New(swarm.NewAddress(overlay), o.NetworkID, o.Logger), handshakeService: handshake.New(o.Overlay, o.NetworkID, o.Logger),
peers: newPeerRegistry(), peers: newPeerRegistry(),
logger: o.Logger, logger: o.Logger,
} }
......
...@@ -6,7 +6,6 @@ package p2p ...@@ -6,7 +6,6 @@ package p2p
import ( import (
"context" "context"
"fmt"
"io" "io"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
...@@ -39,23 +38,13 @@ type StreamSpec struct { ...@@ -39,23 +38,13 @@ type StreamSpec struct {
Handler HandlerFunc Handler HandlerFunc
} }
type HandlerFunc func(Peer, Stream) error type Peer struct {
Address swarm.Address
type HandlerMiddleware func(HandlerFunc) HandlerFunc
type IncompatibleStreamError struct {
err error
} }
func NewIncompatibleStreamError(err error) *IncompatibleStreamError { type HandlerFunc func(Peer, Stream) error
return &IncompatibleStreamError{err: err}
}
func (e *IncompatibleStreamError) Unwrap() error { return e.err }
func (e *IncompatibleStreamError) Error() string { type HandlerMiddleware func(HandlerFunc) HandlerFunc
return fmt.Sprintf("incompatible stream: %v", e.err)
}
func NewSwarmStreamName(protocol, stream, version string) string { func NewSwarmStreamName(protocol, stream, version string) string {
return "/swarm/" + protocol + "/" + stream + "/" + version return "/swarm/" + protocol + "/" + stream + "/" + version
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package p2p_test
import (
"testing"
"github.com/ethersphere/bee/pkg/p2p"
)
func TestNewSwarmStreamName(t *testing.T) {
want := "/swarm/hive/peers/1.2.0"
got := p2p.NewSwarmStreamName("hive", "peers", "1.2.0")
if got != want {
t.Errorf("got %s, want %s", got, want)
}
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:generate sh -c "protoc -I . -I \"$(go list -f '{{ .Dir }}' -m github.com/gogo/protobuf)/protobuf\" --gogofaster_out=. test.proto"
// Package pb holds only Protocol Buffer definitions and generated code for
// testing purposes.
package pb
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: test.proto
package pb
import (
fmt "fmt"
proto "github.com/gogo/protobuf/proto"
io "io"
math "math"
math_bits "math/bits"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
type Message struct {
Text string `protobuf:"bytes,1,opt,name=Text,proto3" json:"Text,omitempty"`
}
func (m *Message) Reset() { *m = Message{} }
func (m *Message) String() string { return proto.CompactTextString(m) }
func (*Message) ProtoMessage() {}
func (*Message) Descriptor() ([]byte, []int) {
return fileDescriptor_c161fcfdc0c3ff1e, []int{0}
}
func (m *Message) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *Message) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_Message.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *Message) XXX_Merge(src proto.Message) {
xxx_messageInfo_Message.Merge(m, src)
}
func (m *Message) XXX_Size() int {
return m.Size()
}
func (m *Message) XXX_DiscardUnknown() {
xxx_messageInfo_Message.DiscardUnknown(m)
}
var xxx_messageInfo_Message proto.InternalMessageInfo
func (m *Message) GetText() string {
if m != nil {
return m.Text
}
return ""
}
func init() {
proto.RegisterType((*Message)(nil), "pb.Message")
}
func init() { proto.RegisterFile("test.proto", fileDescriptor_c161fcfdc0c3ff1e) }
var fileDescriptor_c161fcfdc0c3ff1e = []byte{
// 100 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x2a, 0x49, 0x2d, 0x2e,
0xd1, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x2a, 0x48, 0x52, 0x92, 0xe5, 0x62, 0xf7, 0x4d,
0x2d, 0x2e, 0x4e, 0x4c, 0x4f, 0x15, 0x12, 0xe2, 0x62, 0x09, 0x49, 0xad, 0x28, 0x91, 0x60, 0x54,
0x60, 0xd4, 0xe0, 0x0c, 0x02, 0xb3, 0x9d, 0x24, 0x4e, 0x3c, 0x92, 0x63, 0xbc, 0xf0, 0x48, 0x8e,
0xf1, 0xc1, 0x23, 0x39, 0xc6, 0x09, 0x8f, 0xe5, 0x18, 0x2e, 0x3c, 0x96, 0x63, 0xb8, 0xf1, 0x58,
0x8e, 0x21, 0x89, 0x0d, 0x6c, 0x86, 0x31, 0x20, 0x00, 0x00, 0xff, 0xff, 0xaa, 0xbb, 0x60, 0xa9,
0x51, 0x00, 0x00, 0x00,
}
func (m *Message) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *Message) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *Message) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if len(m.Text) > 0 {
i -= len(m.Text)
copy(dAtA[i:], m.Text)
i = encodeVarintTest(dAtA, i, uint64(len(m.Text)))
i--
dAtA[i] = 0xa
}
return len(dAtA) - i, nil
}
func encodeVarintTest(dAtA []byte, offset int, v uint64) int {
offset -= sovTest(v)
base := offset
for v >= 1<<7 {
dAtA[offset] = uint8(v&0x7f | 0x80)
v >>= 7
offset++
}
dAtA[offset] = uint8(v)
return base
}
func (m *Message) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
l = len(m.Text)
if l > 0 {
n += 1 + l + sovTest(uint64(l))
}
return n
}
func sovTest(x uint64) (n int) {
return (math_bits.Len64(x|1) + 6) / 7
}
func sozTest(x uint64) (n int) {
return sovTest(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
func (m *Message) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowTest
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: Message: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: Message: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Text", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowTest
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthTest
}
postIndex := iNdEx + intStringLen
if postIndex < 0 {
return ErrInvalidLengthTest
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Text = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipTest(dAtA[iNdEx:])
if err != nil {
return err
}
if skippy < 0 {
return ErrInvalidLengthTest
}
if (iNdEx + skippy) < 0 {
return ErrInvalidLengthTest
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipTest(dAtA []byte) (n int, err error) {
l := len(dAtA)
iNdEx := 0
depth := 0
for iNdEx < l {
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowTest
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
wireType := int(wire & 0x7)
switch wireType {
case 0:
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowTest
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
iNdEx++
if dAtA[iNdEx-1] < 0x80 {
break
}
}
case 1:
iNdEx += 8
case 2:
var length int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowTest
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
length |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
if length < 0 {
return 0, ErrInvalidLengthTest
}
iNdEx += length
case 3:
depth++
case 4:
if depth == 0 {
return 0, ErrUnexpectedEndOfGroupTest
}
depth--
case 5:
iNdEx += 4
default:
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
}
if iNdEx < 0 {
return 0, ErrInvalidLengthTest
}
if depth == 0 {
return iNdEx, nil
}
}
return 0, io.ErrUnexpectedEOF
}
var (
ErrInvalidLengthTest = fmt.Errorf("proto: negative length found during unmarshaling")
ErrIntOverflowTest = fmt.Errorf("proto: integer overflow")
ErrUnexpectedEndOfGroupTest = fmt.Errorf("proto: unexpected end of group")
)
...@@ -2,16 +2,11 @@ ...@@ -2,16 +2,11 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
package p2p syntax = "proto3";
import ( package pb;
"errors"
"github.com/ethersphere/bee/pkg/swarm" message Message {
) string Text = 1;
type Peer struct {
Address swarm.Address
} }
var ErrPeerNotFound = errors.New("peer not found")
...@@ -5,10 +5,11 @@ ...@@ -5,10 +5,11 @@
package protobuf package protobuf
import ( import (
"io"
"github.com/ethersphere/bee/pkg/p2p" "github.com/ethersphere/bee/pkg/p2p"
ggio "github.com/gogo/protobuf/io" ggio "github.com/gogo/protobuf/io"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"io"
) )
const delimitedReaderMaxSize = 128 * 1024 // max message size const delimitedReaderMaxSize = 128 * 1024 // max message size
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package protobuf_test
import (
"fmt"
"io"
"testing"
"github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/protobuf/internal/pb"
)
func TestReadMessages(t *testing.T) {
r, pipe := io.Pipe()
w := protobuf.NewWriter(pipe)
messages := []string{"first", "second", "third"}
go func() {
for _, m := range messages {
if err := w.WriteMsg(&pb.Message{
Text: m,
}); err != nil {
panic(err)
}
}
if err := pipe.Close(); err != nil {
panic(err)
}
}()
got, err := protobuf.ReadMessages(r, func() protobuf.Message { return new(pb.Message) })
if err != nil {
t.Fatal(err)
}
var gotMessages []string
for _, m := range got {
gotMessages = append(gotMessages, m.(*pb.Message).Text)
}
if fmt.Sprint(gotMessages) != fmt.Sprint(messages) {
t.Errorf("got messages %v, want %v", gotMessages, messages)
}
}
...@@ -15,7 +15,7 @@ import ( ...@@ -15,7 +15,7 @@ import (
) )
type Recorder struct { type Recorder struct {
records map[string][]Record records map[string][]*Record
recordsMu sync.Mutex recordsMu sync.Mutex
protocols []p2p.ProtocolSpec protocols []p2p.ProtocolSpec
middlewares []p2p.HandlerMiddleware middlewares []p2p.HandlerMiddleware
...@@ -35,7 +35,7 @@ func WithMiddlewares(middlewares ...p2p.HandlerMiddleware) Option { ...@@ -35,7 +35,7 @@ func WithMiddlewares(middlewares ...p2p.HandlerMiddleware) Option {
func New(opts ...Option) *Recorder { func New(opts ...Option) *Recorder {
r := &Recorder{ r := &Recorder{
records: make(map[string][]Record), records: make(map[string][]*Record),
} }
for _, o := range opts { for _, o := range opts {
o.apply(r) o.apply(r)
...@@ -43,7 +43,7 @@ func New(opts ...Option) *Recorder { ...@@ -43,7 +43,7 @@ func New(opts ...Option) *Recorder {
return r return r
} }
func (r *Recorder) NewStream(_ context.Context, overlay swarm.Address, protocolName, streamName, version string) (p2p.Stream, error) { func (r *Recorder) NewStream(_ context.Context, addr swarm.Address, protocolName, streamName, version string) (p2p.Stream, error) {
recordIn := newRecord() recordIn := newRecord()
recordOut := newRecord() recordOut := newRecord()
streamOut := newStream(recordIn, recordOut) streamOut := newStream(recordIn, recordOut)
...@@ -65,37 +65,39 @@ func (r *Recorder) NewStream(_ context.Context, overlay swarm.Address, protocolN ...@@ -65,37 +65,39 @@ func (r *Recorder) NewStream(_ context.Context, overlay swarm.Address, protocolN
for _, m := range r.middlewares { for _, m := range r.middlewares {
handler = m(handler) handler = m(handler)
} }
record := &Record{in: recordIn, out: recordOut}
go func() { go func() {
if err := handler(p2p.Peer{Address: overlay}, streamIn); err != nil { err := handler(p2p.Peer{Address: addr}, streamIn)
panic(err) // todo: store error and export error records for inspection record.setErr(err)
}
}() }()
id := overlay.String() + p2p.NewSwarmStreamName(protocolName, streamName, version) id := addr.String() + p2p.NewSwarmStreamName(protocolName, streamName, version)
r.recordsMu.Lock() r.recordsMu.Lock()
defer r.recordsMu.Unlock() defer r.recordsMu.Unlock()
r.records[id] = append(r.records[id], Record{in: recordIn, out: recordOut}) r.records[id] = append(r.records[id], record)
return streamOut, nil return streamOut, nil
} }
func (r *Recorder) Records(peerID, protocolName, streamName, version string) ([]Record, error) { func (r *Recorder) Records(addr swarm.Address, protocolName, streamName, version string) ([]*Record, error) {
id := peerID + p2p.NewSwarmStreamName(protocolName, streamName, version) id := addr.String() + p2p.NewSwarmStreamName(protocolName, streamName, version)
r.recordsMu.Lock() r.recordsMu.Lock()
defer r.recordsMu.Unlock() defer r.recordsMu.Unlock()
records, ok := r.records[id] records, ok := r.records[id]
if !ok { if !ok {
return nil, fmt.Errorf("records not found for %q %q %q %q", peerID, protocolName, streamName, version) return nil, fmt.Errorf("records not found for %q %q %q %q", addr, protocolName, streamName, version)
} }
return records, nil return records, nil
} }
type Record struct { type Record struct {
in *record in *record
out *record out *record
err error
errMu sync.Mutex
} }
func (r *Record) In() []byte { func (r *Record) In() []byte {
...@@ -106,6 +108,20 @@ func (r *Record) Out() []byte { ...@@ -106,6 +108,20 @@ func (r *Record) Out() []byte {
return r.out.bytes() return r.out.bytes()
} }
func (r *Record) Err() error {
r.errMu.Lock()
defer r.errMu.Unlock()
return r.err
}
func (r *Record) setErr(err error) {
r.errMu.Lock()
defer r.errMu.Unlock()
r.err = err
}
type stream struct { type stream struct {
in io.WriteCloser in io.WriteCloser
out io.ReadCloser out io.ReadCloser
......
...@@ -51,10 +51,9 @@ func TestPing(t *testing.T) { ...@@ -51,10 +51,9 @@ func TestPing(t *testing.T) {
}) })
// ping // ping
peerID := "ca1e9f3938cc1425c6061b96ad9eb93e134dfe8734ad490164ef20af9d1cf59c" addr := swarm.MustParseHexAddress("ca1e9f3938cc1425c6061b96ad9eb93e134dfe8734ad490164ef20af9d1cf59c")
peerIDAddress := swarm.MustParseHexAddress(peerID)
greetings := []string{"hey", "there", "fella"} greetings := []string{"hey", "there", "fella"}
rtt, err := client.Ping(context.Background(), peerIDAddress, greetings...) rtt, err := client.Ping(context.Background(), addr, greetings...)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -65,7 +64,7 @@ func TestPing(t *testing.T) { ...@@ -65,7 +64,7 @@ func TestPing(t *testing.T) {
} }
// get a record for this stream // get a record for this stream
records, err := recorder.Records(peerID, "pingpong", "pingpong", "1.0.0") records, err := recorder.Records(addr, "pingpong", "pingpong", "1.0.0")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -110,4 +109,8 @@ func TestPing(t *testing.T) { ...@@ -110,4 +109,8 @@ func TestPing(t *testing.T) {
if fmt.Sprint(gotResponses) != fmt.Sprint(wantResponses) { if fmt.Sprint(gotResponses) != fmt.Sprint(wantResponses) {
t.Errorf("got responses %v, want %v", gotResponses, wantResponses) t.Errorf("got responses %v, want %v", gotResponses, wantResponses)
} }
if err := record.Err(); err != nil {
t.Fatal(err)
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment