Commit 2ab40671 authored by Zahoor Mohamed's avatar Zahoor Mohamed Committed by GitHub

Chunk repair (#479)

Chunk repair request and uploading the lost chunk to netwrk
parent f2b4bdc5
...@@ -26,27 +26,29 @@ import ( ...@@ -26,27 +26,29 @@ import (
func (c *command) initStartCmd() (err error) { func (c *command) initStartCmd() (err error) {
const ( const (
optionNameDataDir = "data-dir" optionNameDataDir = "data-dir"
optionNameDBCapacity = "db-capacity" optionNameDBCapacity = "db-capacity"
optionNamePassword = "password" optionNamePassword = "password"
optionNamePasswordFile = "password-file" optionNamePasswordFile = "password-file"
optionNameAPIAddr = "api-addr" optionNameAPIAddr = "api-addr"
optionNameP2PAddr = "p2p-addr" optionNameP2PAddr = "p2p-addr"
optionNameNATAddr = "nat-addr" optionNameNATAddr = "nat-addr"
optionNameP2PWSEnable = "p2p-ws-enable" optionNameP2PWSEnable = "p2p-ws-enable"
optionNameP2PQUICEnable = "p2p-quic-enable" optionNameP2PQUICEnable = "p2p-quic-enable"
optionNameDebugAPIEnable = "debug-api-enable" optionNameDebugAPIEnable = "debug-api-enable"
optionNameDebugAPIAddr = "debug-api-addr" optionNameDebugAPIAddr = "debug-api-addr"
optionNameBootnodes = "bootnode" optionNameBootnodes = "bootnode"
optionNameNetworkID = "network-id" optionNameNetworkID = "network-id"
optionWelcomeMessage = "welcome-message" optionWelcomeMessage = "welcome-message"
optionCORSAllowedOrigins = "cors-allowed-origins" optionCORSAllowedOrigins = "cors-allowed-origins"
optionNameTracingEnabled = "tracing-enable" optionNameTracingEnabled = "tracing-enable"
optionNameTracingEndpoint = "tracing-endpoint" optionNameTracingEndpoint = "tracing-endpoint"
optionNameTracingServiceName = "tracing-service-name" optionNameTracingServiceName = "tracing-service-name"
optionNameVerbosity = "verbosity" optionNameVerbosity = "verbosity"
optionNamePaymentThreshold = "payment-threshold" optionNameDisconnectThreshold = "disconnect-threshold"
optionNamePaymentTolerance = "payment-tolerance" optionNameGlobalPinningEnabled = "global-pinning-enable"
optionNamePaymentThreshold = "payment-threshold"
optionNamePaymentTolerance = "payment-tolerance"
) )
cmd := &cobra.Command{ cmd := &cobra.Command{
...@@ -114,23 +116,27 @@ Welcome to the Swarm.... Bzzz Bzzzz Bzzzz ...@@ -114,23 +116,27 @@ Welcome to the Swarm.... Bzzz Bzzzz Bzzzz
} }
b, err := node.NewBee(c.config.GetString(optionNameP2PAddr), logger, node.Options{ b, err := node.NewBee(c.config.GetString(optionNameP2PAddr), logger, node.Options{
DataDir: c.config.GetString(optionNameDataDir), DataDir: c.config.GetString(optionNameDataDir),
DBCapacity: c.config.GetUint64(optionNameDBCapacity), DBCapacity: c.config.GetUint64(optionNameDBCapacity),
Password: password, Password: password,
APIAddr: c.config.GetString(optionNameAPIAddr), APIAddr: c.config.GetString(optionNameAPIAddr),
DebugAPIAddr: debugAPIAddr, DebugAPIAddr: debugAPIAddr,
NATAddr: c.config.GetString(optionNameNATAddr), Addr: c.config.GetString(optionNameP2PAddr),
EnableWS: c.config.GetBool(optionNameP2PWSEnable), NATAddr: c.config.GetString(optionNameNATAddr),
EnableQUIC: c.config.GetBool(optionNameP2PQUICEnable), EnableWS: c.config.GetBool(optionNameP2PWSEnable),
NetworkID: c.config.GetUint64(optionNameNetworkID), EnableQUIC: c.config.GetBool(optionNameP2PQUICEnable),
WelcomeMessage: c.config.GetString(optionWelcomeMessage), NetworkID: c.config.GetUint64(optionNameNetworkID),
Bootnodes: c.config.GetStringSlice(optionNameBootnodes), WelcomeMessage: c.config.GetString(optionWelcomeMessage),
CORSAllowedOrigins: c.config.GetStringSlice(optionCORSAllowedOrigins), Bootnodes: c.config.GetStringSlice(optionNameBootnodes),
TracingEnabled: c.config.GetBool(optionNameTracingEnabled), CORSAllowedOrigins: c.config.GetStringSlice(optionCORSAllowedOrigins),
TracingEndpoint: c.config.GetString(optionNameTracingEndpoint), TracingEnabled: c.config.GetBool(optionNameTracingEnabled),
TracingServiceName: c.config.GetString(optionNameTracingServiceName), TracingEndpoint: c.config.GetString(optionNameTracingEndpoint),
PaymentThreshold: c.config.GetUint64(optionNamePaymentThreshold), TracingServiceName: c.config.GetString(optionNameTracingServiceName),
PaymentTolerance: c.config.GetUint64(optionNamePaymentTolerance), Logger: logger,
DisconnectThreshold: c.config.GetUint64(optionNameDisconnectThreshold),
GlobalPinningEnabled: c.config.GetBool(optionNameGlobalPinningEnabled),
PaymentThreshold: c.config.GetUint64(optionNamePaymentThreshold),
PaymentTolerance: c.config.GetUint64(optionNamePaymentTolerance),
}) })
if err != nil { if err != nil {
return err return err
......
...@@ -6,19 +6,20 @@ package api ...@@ -6,19 +6,20 @@ package api
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"github.com/gorilla/mux"
"github.com/ethersphere/bee/pkg/collection/entry" "github.com/ethersphere/bee/pkg/collection/entry"
"github.com/ethersphere/bee/pkg/encryption" "github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/manifest/jsonmanifest" "github.com/ethersphere/bee/pkg/manifest/jsonmanifest"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/gorilla/mux"
) )
const ( const (
...@@ -29,7 +30,7 @@ const ( ...@@ -29,7 +30,7 @@ const (
func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) { func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) {
targets := r.URL.Query().Get("targets") targets := r.URL.Query().Get("targets")
r = r.WithContext(context.WithValue(r.Context(), targetsContextKey{}, targets)) r = r.WithContext(sctx.SetTargets(r.Context(), targets))
ctx := r.Context() ctx := r.Context()
addressHex := mux.Vars(r)["address"] addressHex := mux.Vars(r)["address"]
......
...@@ -9,12 +9,14 @@ import ( ...@@ -9,12 +9,14 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/ethersphere/bee/pkg/netstore"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strings" "strings"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
...@@ -94,7 +96,7 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -94,7 +96,7 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) {
func (s *server) chunkGetHandler(w http.ResponseWriter, r *http.Request) { func (s *server) chunkGetHandler(w http.ResponseWriter, r *http.Request) {
targets := r.URL.Query().Get("targets") targets := r.URL.Query().Get("targets")
r = r.WithContext(context.WithValue(r.Context(), targetsContextKey{}, targets)) r = r.WithContext(sctx.SetTargets(r.Context(), targets))
addr := mux.Vars(r)["addr"] addr := mux.Vars(r)["addr"]
ctx := r.Context() ctx := r.Context()
...@@ -115,6 +117,11 @@ func (s *server) chunkGetHandler(w http.ResponseWriter, r *http.Request) { ...@@ -115,6 +117,11 @@ func (s *server) chunkGetHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
if errors.Is(err, netstore.ErrRecoveryAttempt) {
s.Logger.Trace("chunk: chunk recovery initiated. addr %s", address)
jsonhttp.Accepted(w, "chunk recovery initiated. retry after sometime.")
return
}
s.Logger.Debugf("chunk: chunk read error: %v ,addr %s", err, address) s.Logger.Debugf("chunk: chunk read error: %v ,addr %s", err, address)
s.Logger.Error("chunk: chunk read error") s.Logger.Error("chunk: chunk read error")
jsonhttp.InternalServerError(w, "chunk read error") jsonhttp.InternalServerError(w, "chunk read error")
......
...@@ -26,6 +26,7 @@ import ( ...@@ -26,6 +26,7 @@ import (
"github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/splitter" "github.com/ethersphere/bee/pkg/file/splitter"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
...@@ -41,8 +42,6 @@ const ( ...@@ -41,8 +42,6 @@ const (
EncryptHeader = "swarm-encrypt" EncryptHeader = "swarm-encrypt"
) )
type targetsContextKey struct{}
// fileUploadResponse is returned when an HTTP request to upload a file is successful // fileUploadResponse is returned when an HTTP request to upload a file is successful
type fileUploadResponse struct { type fileUploadResponse struct {
Reference swarm.Address `json:"reference"` Reference swarm.Address `json:"reference"`
...@@ -283,9 +282,9 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -283,9 +282,9 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) {
} }
toDecrypt := len(address.Bytes()) == (swarm.HashSize + encryption.KeyLength) toDecrypt := len(address.Bytes()) == (swarm.HashSize + encryption.KeyLength)
targets := r.URL.Query().Get("targets")
r = r.WithContext(context.WithValue(r.Context(), targetsContextKey{}, targets)) targets := r.URL.Query().Get("targets")
sctx.SetTargets(r.Context(), targets)
// read entry. // read entry.
j := joiner.NewSimpleJoiner(s.Storer) j := joiner.NewSimpleJoiner(s.Storer)
...@@ -351,7 +350,7 @@ func (s *server) downloadHandler( ...@@ -351,7 +350,7 @@ func (s *server) downloadHandler(
) { ) {
targets := r.URL.Query().Get("targets") targets := r.URL.Query().Get("targets")
r = r.WithContext(context.WithValue(r.Context(), targetsContextKey{}, targets)) sctx.SetTargets(r.Context(), targets)
ctx := r.Context() ctx := r.Context()
toDecrypt := len(reference.Bytes()) == (swarm.HashSize + encryption.KeyLength) toDecrypt := len(reference.Bytes()) == (swarm.HashSize + encryption.KeyLength)
......
...@@ -109,7 +109,6 @@ func (db *DB) set(mode storage.ModeSet, addrs ...swarm.Address) (err error) { ...@@ -109,7 +109,6 @@ func (db *DB) set(mode storage.ModeSet, addrs ...swarm.Address) (err error) {
return err return err
} }
} }
default: default:
return ErrInvalidMode return ErrInvalidMode
} }
......
...@@ -10,21 +10,29 @@ import ( ...@@ -10,21 +10,29 @@ import (
"fmt" "fmt"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/recovery"
"github.com/ethersphere/bee/pkg/retrieval" "github.com/ethersphere/bee/pkg/retrieval"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
type store struct { type store struct {
storage.Storer storage.Storer
retrieval retrieval.Interface retrieval retrieval.Interface
logger logging.Logger validator swarm.Validator
validator swarm.Validator logger logging.Logger
recoveryCallback recovery.RecoveryHook // this is the callback to be executed when a chunk fails to be retrieved
} }
var (
ErrRecoveryAttempt = errors.New("failed to retrieve chunk, recovery initiated")
)
// New returns a new NetStore that wraps a given Storer. // New returns a new NetStore that wraps a given Storer.
func New(s storage.Storer, r retrieval.Interface, logger logging.Logger, validator swarm.Validator) storage.Storer { func New(s storage.Storer, rcb recovery.RecoveryHook, r retrieval.Interface, logger logging.Logger,
return &store{Storer: s, retrieval: r, logger: logger, validator: validator} validator swarm.Validator) storage.Storer {
return &store{Storer: s, recoveryCallback: rcb, retrieval: r, logger: logger, validator: validator}
} }
// Get retrieves a given chunk address. // Get retrieves a given chunk address.
...@@ -34,11 +42,23 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres ...@@ -34,11 +42,23 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres
if err != nil { if err != nil {
if errors.Is(err, storage.ErrNotFound) { if errors.Is(err, storage.ErrNotFound) {
// request from network // request from network
ch, err := s.retrieval.RetrieveChunk(ctx, addr) ch, err = s.retrieval.RetrieveChunk(ctx, addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("netstore retrieve chunk: %w", err) if s.recoveryCallback == nil {
return nil, err
}
targets, err := sctx.GetTargets(ctx)
if err != nil {
return nil, err
}
go func() {
err := s.recoveryCallback(addr, targets)
if err != nil {
s.logger.Debugf("netstore: error while recovering chunk: %v", err)
}
}()
return nil, ErrRecoveryAttempt
} }
_, err = s.Storer.Put(ctx, storage.ModePutRequest, ch) _, err = s.Storer.Put(ctx, storage.ModePutRequest, ch)
if err != nil { if err != nil {
return nil, fmt.Errorf("netstore retrieve put: %w", err) return nil, fmt.Errorf("netstore retrieve put: %w", err)
......
...@@ -7,16 +7,21 @@ package netstore_test ...@@ -7,16 +7,21 @@ package netstore_test
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt"
"io/ioutil" "io/ioutil"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time"
validatormock "github.com/ethersphere/bee/pkg/content/mock" validatormock "github.com/ethersphere/bee/pkg/content/mock"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/netstore" "github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
) )
var chunkData = []byte("mockdata") var chunkData = []byte("mockdata")
...@@ -24,7 +29,7 @@ var chunkData = []byte("mockdata") ...@@ -24,7 +29,7 @@ var chunkData = []byte("mockdata")
// TestNetstoreRetrieval verifies that a chunk is asked from the network whenever // TestNetstoreRetrieval verifies that a chunk is asked from the network whenever
// it is not found locally // it is not found locally
func TestNetstoreRetrieval(t *testing.T) { func TestNetstoreRetrieval(t *testing.T) {
retrieve, store, nstore := newRetrievingNetstore() retrieve, store, nstore := newRetrievingNetstore(nil)
addr := swarm.MustParseHexAddress("000001") addr := swarm.MustParseHexAddress("000001")
_, err := nstore.Get(context.Background(), storage.ModeGetRequest, addr) _, err := nstore.Get(context.Background(), storage.ModeGetRequest, addr)
if err != nil { if err != nil {
...@@ -68,7 +73,7 @@ func TestNetstoreRetrieval(t *testing.T) { ...@@ -68,7 +73,7 @@ func TestNetstoreRetrieval(t *testing.T) {
// TestNetstoreNoRetrieval verifies that a chunk is not requested from the network // TestNetstoreNoRetrieval verifies that a chunk is not requested from the network
// whenever it is found locally. // whenever it is found locally.
func TestNetstoreNoRetrieval(t *testing.T) { func TestNetstoreNoRetrieval(t *testing.T) {
retrieve, store, nstore := newRetrievingNetstore() retrieve, store, nstore := newRetrievingNetstore(nil)
addr := swarm.MustParseHexAddress("000001") addr := swarm.MustParseHexAddress("000001")
// store should have the chunk in advance // store should have the chunk in advance
...@@ -92,26 +97,105 @@ func TestNetstoreNoRetrieval(t *testing.T) { ...@@ -92,26 +97,105 @@ func TestNetstoreNoRetrieval(t *testing.T) {
} }
} }
func TestRecovery(t *testing.T) {
hookWasCalled := make(chan bool, 1)
rec := &mockRecovery{
hookC: hookWasCalled,
}
retrieve, _, nstore := newRetrievingNetstore(rec)
addr := swarm.MustParseHexAddress("deadbeef")
retrieve.failure = true
ctx := context.Background()
ctx = sctx.SetTargets(ctx, "be, cd")
_, err := nstore.Get(ctx, storage.ModeGetRequest, addr)
if err != nil && !errors.Is(err, netstore.ErrRecoveryAttempt) {
t.Fatal(err)
}
select {
case <-hookWasCalled:
break
case <-time.After(100 * time.Millisecond):
t.Fatal("recovery hook was not called")
}
}
func TestInvalidRecoveryFunction(t *testing.T) {
retrieve, _, nstore := newRetrievingNetstore(nil)
addr := swarm.MustParseHexAddress("deadbeef")
retrieve.failure = true
ctx := context.Background()
ctx = sctx.SetTargets(ctx, "be, cd")
_, err := nstore.Get(ctx, storage.ModeGetRequest, addr)
if err != nil && err.Error() != "chunk not found" {
t.Fatal(err)
}
}
func TestInvalidTargets(t *testing.T) {
hookWasCalled := make(chan bool, 1)
rec := &mockRecovery{
hookC: hookWasCalled,
}
retrieve, _, nstore := newRetrievingNetstore(rec)
addr := swarm.MustParseHexAddress("deadbeef")
retrieve.failure = true
ctx := context.Background()
ctx = sctx.SetTargets(ctx, "gh")
_, err := nstore.Get(ctx, storage.ModeGetRequest, addr)
if err != nil && !errors.Is(err, sctx.ErrTargetPrefix) {
t.Fatal(err)
}
}
// returns a mock retrieval protocol, a mock local storage and a netstore // returns a mock retrieval protocol, a mock local storage and a netstore
func newRetrievingNetstore() (ret *retrievalMock, mockStore, ns storage.Storer) { func newRetrievingNetstore(rec *mockRecovery) (ret *retrievalMock, mockStore, ns storage.Storer) {
retrieve := &retrievalMock{} retrieve := &retrievalMock{}
store := mock.NewStorer() store := mock.NewStorer()
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
validator := swarm.NewChunkValidator(validatormock.NewValidator(true)) validator := swarm.NewChunkValidator(validatormock.NewValidator(true))
nstore := netstore.New(store, retrieve, logger, validator)
var nstore storage.Storer
if rec != nil {
nstore = netstore.New(store, rec.recovery, retrieve, logger, validator)
} else {
nstore = netstore.New(store, nil, retrieve, logger, validator)
}
return retrieve, store, nstore return retrieve, store, nstore
} }
type retrievalMock struct { type retrievalMock struct {
called bool called bool
callCount int32 callCount int32
failure bool
addr swarm.Address addr swarm.Address
} }
func (r *retrievalMock) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) { func (r *retrievalMock) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) {
if r.failure {
return nil, fmt.Errorf("chunk not found")
}
r.called = true r.called = true
atomic.AddInt32(&r.callCount, 1) atomic.AddInt32(&r.callCount, 1)
r.addr = addr r.addr = addr
return swarm.NewChunk(addr, chunkData), nil return swarm.NewChunk(addr, chunkData), nil
} }
type mockRecovery struct {
hookC chan bool
}
// Send mocks the pss Send function
func (mr *mockRecovery) recovery(chunkAddress swarm.Address, targets trojan.Targets) error {
mr.hookC <- true
return nil
}
func (r *mockRecovery) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) {
return nil, fmt.Errorf("chunk not found")
}
...@@ -31,11 +31,13 @@ import ( ...@@ -31,11 +31,13 @@ import (
"github.com/ethersphere/bee/pkg/netstore" "github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/p2p/libp2p" "github.com/ethersphere/bee/pkg/p2p/libp2p"
"github.com/ethersphere/bee/pkg/pingpong" "github.com/ethersphere/bee/pkg/pingpong"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/puller" "github.com/ethersphere/bee/pkg/puller"
"github.com/ethersphere/bee/pkg/pullsync" "github.com/ethersphere/bee/pkg/pullsync"
"github.com/ethersphere/bee/pkg/pullsync/pullstorage" "github.com/ethersphere/bee/pkg/pullsync/pullstorage"
"github.com/ethersphere/bee/pkg/pusher" "github.com/ethersphere/bee/pkg/pusher"
"github.com/ethersphere/bee/pkg/pushsync" "github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/recovery"
"github.com/ethersphere/bee/pkg/retrieval" "github.com/ethersphere/bee/pkg/retrieval"
"github.com/ethersphere/bee/pkg/settlement/pseudosettle" "github.com/ethersphere/bee/pkg/settlement/pseudosettle"
"github.com/ethersphere/bee/pkg/soc" "github.com/ethersphere/bee/pkg/soc"
...@@ -66,24 +68,27 @@ type Bee struct { ...@@ -66,24 +68,27 @@ type Bee struct {
} }
type Options struct { type Options struct {
DataDir string DataDir string
DBCapacity uint64 DBCapacity uint64
Password string Password string
APIAddr string APIAddr string
DebugAPIAddr string DebugAPIAddr string
NATAddr string Addr string
EnableWS bool NATAddr string
EnableQUIC bool EnableWS bool
NetworkID uint64 EnableQUIC bool
WelcomeMessage string NetworkID uint64
Bootnodes []string WelcomeMessage string
CORSAllowedOrigins []string Bootnodes []string
TracingEnabled bool CORSAllowedOrigins []string
TracingEndpoint string Logger logging.Logger
TracingServiceName string TracingEnabled bool
DisconnectThreshold uint64 TracingEndpoint string
PaymentThreshold uint64 TracingServiceName string
PaymentTolerance uint64 DisconnectThreshold uint64
GlobalPinningEnabled bool
PaymentThreshold uint64
PaymentTolerance uint64
} }
func NewBee(addr string, logger logging.Logger, o Options) (*Bee, error) { func NewBee(addr string, logger logging.Logger, o Options) (*Bee, error) {
...@@ -276,22 +281,41 @@ func NewBee(addr string, logger logging.Logger, o Options) (*Bee, error) { ...@@ -276,22 +281,41 @@ func NewBee(addr string, logger logging.Logger, o Options) (*Bee, error) {
return nil, fmt.Errorf("retrieval service: %w", err) return nil, fmt.Errorf("retrieval service: %w", err)
} }
ns := netstore.New(storer, retrieve, logger, chunkvalidator) // instantiate the pss object
psss := pss.New(logger, nil)
var ns storage.Storer
if o.GlobalPinningEnabled {
// create recovery callback for content repair
recoverFunc := recovery.NewRecoveryHook(psss)
ns = netstore.New(storer, recoverFunc, retrieve, logger, chunkvalidator)
} else {
ns = netstore.New(storer, nil, retrieve, logger, chunkvalidator)
}
retrieve.SetStorer(ns) retrieve.SetStorer(ns)
pushSyncProtocol := pushsync.New(pushsync.Options{ pushSyncProtocol := pushsync.New(pushsync.Options{
Streamer: p2ps, Streamer: p2ps,
Storer: storer, Storer: storer,
ClosestPeerer: kad, ClosestPeerer: kad,
Tagger: tagg, DeliveryCallback: psss.TryUnwrap,
Logger: logger, Tagger: tagg,
Logger: logger,
}) })
// set the pushSyncer in the PSS
psss.WithPushSyncer(pushSyncProtocol)
if err = p2ps.AddProtocol(pushSyncProtocol.Protocol()); err != nil { if err = p2ps.AddProtocol(pushSyncProtocol.Protocol()); err != nil {
return nil, fmt.Errorf("pushsync service: %w", err) return nil, fmt.Errorf("pushsync service: %w", err)
} }
if o.GlobalPinningEnabled {
// register function for chunk repair upon receiving a trojan message
chunkRepairHandler := recovery.NewRepairHandler(ns, logger, pushSyncProtocol)
psss.Register(recovery.RecoveryTopic, chunkRepairHandler)
}
pushSyncPusher := pusher.New(pusher.Options{ pushSyncPusher := pusher.New(pusher.Options{
Storer: storer, Storer: storer,
PeerSuggester: kad, PeerSuggester: kad,
......
...@@ -26,6 +26,7 @@ type Interface interface { ...@@ -26,6 +26,7 @@ type Interface interface {
Register(topic trojan.Topic, hndlr Handler) Register(topic trojan.Topic, hndlr Handler)
GetHandler(topic trojan.Topic) Handler GetHandler(topic trojan.Topic) Handler
TryUnwrap(ctx context.Context, c swarm.Chunk) error TryUnwrap(ctx context.Context, c swarm.Chunk) error
WithPushSyncer(pushSyncer pushsync.PushSyncer)
} }
// pss is the top-level struct, which takes care of message sending // pss is the top-level struct, which takes care of message sending
...@@ -47,8 +48,12 @@ func New(logger logging.Logger, pusher pushsync.PushSyncer) Interface { ...@@ -47,8 +48,12 @@ func New(logger logging.Logger, pusher pushsync.PushSyncer) Interface {
} }
} }
func (ps *pss) WithPushSyncer(pushSyncer pushsync.PushSyncer) {
ps.pusher = pushSyncer
}
// Handler defines code to be executed upon reception of a trojan message // Handler defines code to be executed upon reception of a trojan message
type Handler func(*trojan.Message) type Handler func(context.Context, *trojan.Message) error
// Send constructs a padded message with topic and payload, // Send constructs a padded message with topic and payload,
// wraps it in a trojan chunk such that one of the targets is a prefix of the chunk address // wraps it in a trojan chunk such that one of the targets is a prefix of the chunk address
...@@ -64,6 +69,7 @@ func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Top ...@@ -64,6 +69,7 @@ func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Top
var tc swarm.Chunk var tc swarm.Chunk
tc, err = m.Wrap(targets) tc, err = m.Wrap(targets)
if err != nil { if err != nil {
return err return err
} }
...@@ -84,6 +90,7 @@ func (p *pss) Register(topic trojan.Topic, hndlr Handler) { ...@@ -84,6 +90,7 @@ func (p *pss) Register(topic trojan.Topic, hndlr Handler) {
// TryUnwrap allows unwrapping a chunk as a trojan message and calling its handler func based on its topic // TryUnwrap allows unwrapping a chunk as a trojan message and calling its handler func based on its topic
func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error { func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error {
if !trojan.IsPotential(c) { if !trojan.IsPotential(c) {
return nil return nil
} }
...@@ -95,8 +102,7 @@ func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error { ...@@ -95,8 +102,7 @@ func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error {
if h == nil { if h == nil {
return fmt.Errorf("topic %v, %w", m.Topic, ErrNoHandler) return fmt.Errorf("topic %v, %w", m.Topic, ErrNoHandler)
} }
h(m) return h(ctx, m)
return nil
} }
// GetHandler returns the Handler func registered in pss for the given topic // GetHandler returns the Handler func registered in pss for the given topic
......
...@@ -71,28 +71,36 @@ func TestRegister(t *testing.T) { ...@@ -71,28 +71,36 @@ func TestRegister(t *testing.T) {
handlerVerifier := 0 // test variable to check handler funcs are correctly retrieved handlerVerifier := 0 // test variable to check handler funcs are correctly retrieved
// register first handler // register first handler
testHandler := func(m *trojan.Message) { testHandler := func(ctx context.Context, m *trojan.Message) error {
handlerVerifier = 1 handlerVerifier = 1
return nil
} }
testTopic := trojan.NewTopic("FIRST_HANDLER") testTopic := trojan.NewTopic("FIRST_HANDLER")
pss.Register(testTopic, testHandler) pss.Register(testTopic, testHandler)
registeredHandler := pss.GetHandler(testTopic) registeredHandler := pss.GetHandler(testTopic)
registeredHandler(&trojan.Message{}) // call handler to verify the retrieved func is correct err := registeredHandler(context.Background(), &trojan.Message{}) // call handler to verify the retrieved func is correct
if err != nil {
t.Fatal(err)
}
if handlerVerifier != 1 { if handlerVerifier != 1 {
t.Fatalf("unexpected handler retrieved, verifier variable should be 1 but is %d instead", handlerVerifier) t.Fatalf("unexpected handler retrieved, verifier variable should be 1 but is %d instead", handlerVerifier)
} }
// register second handler // register second handler
testHandler = func(m *trojan.Message) { testHandler = func(ctx context.Context, m *trojan.Message) error {
handlerVerifier = 2 handlerVerifier = 2
return nil
} }
testTopic = trojan.NewTopic("SECOND_HANDLER") testTopic = trojan.NewTopic("SECOND_HANDLER")
pss.Register(testTopic, testHandler) pss.Register(testTopic, testHandler)
registeredHandler = pss.GetHandler(testTopic) registeredHandler = pss.GetHandler(testTopic)
registeredHandler(&trojan.Message{}) // call handler to verify the retrieved func is correct err = registeredHandler(context.Background(), &trojan.Message{}) // call handler to verify the retrieved func is correct
if err != nil {
t.Fatal(err)
}
if handlerVerifier != 2 { if handlerVerifier != 2 {
t.Fatalf("unexpected handler retrieved, verifier variable should be 2 but is %d instead", handlerVerifier) t.Fatalf("unexpected handler retrieved, verifier variable should be 2 but is %d instead", handlerVerifier)
...@@ -124,8 +132,9 @@ func TestDeliver(t *testing.T) { ...@@ -124,8 +132,9 @@ func TestDeliver(t *testing.T) {
// create and register handler // create and register handler
var tt trojan.Topic // test variable to check handler func was correctly called var tt trojan.Topic // test variable to check handler func was correctly called
hndlr := func(m *trojan.Message) { hndlr := func(ctx context.Context, m *trojan.Message) error {
tt = m.Topic // copy the message topic to the test variable tt = m.Topic // copy the message topic to the test variable
return nil
} }
pss.Register(topic, hndlr) pss.Register(topic, hndlr)
...@@ -149,7 +158,7 @@ func TestHandler(t *testing.T) { ...@@ -149,7 +158,7 @@ func TestHandler(t *testing.T) {
} }
// register first handler // register first handler
testHandler := func(m *trojan.Message) {} testHandler := func(ctx context.Context, m *trojan.Message) error { return nil }
// set handler for test topic // set handler for test topic
pss.Register(testTopic, testHandler) pss.Register(testTopic, testHandler)
......
...@@ -35,32 +35,35 @@ type Receipt struct { ...@@ -35,32 +35,35 @@ type Receipt struct {
} }
type PushSync struct { type PushSync struct {
streamer p2p.Streamer streamer p2p.Streamer
storer storage.Putter storer storage.Putter
peerSuggester topology.ClosestPeerer peerSuggester topology.ClosestPeerer
tagg *tags.Tags tagg *tags.Tags
logger logging.Logger deliveryCallback func(context.Context, swarm.Chunk) error // callback func to be invoked to deliver chunks to PSS
metrics metrics logger logging.Logger
metrics metrics
} }
type Options struct { type Options struct {
Streamer p2p.Streamer Streamer p2p.Streamer
Storer storage.Putter Storer storage.Putter
ClosestPeerer topology.ClosestPeerer ClosestPeerer topology.ClosestPeerer
Tagger *tags.Tags Tagger *tags.Tags
Logger logging.Logger DeliveryCallback func(context.Context, swarm.Chunk) error
Logger logging.Logger
} }
var timeToWaitForReceipt = 3 * time.Second // time to wait to get a receipt for a chunk var timeToWaitForReceipt = 3 * time.Second // time to wait to get a receipt for a chunk
func New(o Options) *PushSync { func New(o Options) *PushSync {
ps := &PushSync{ ps := &PushSync{
streamer: o.Streamer, streamer: o.Streamer,
storer: o.Storer, storer: o.Storer,
peerSuggester: o.ClosestPeerer, peerSuggester: o.ClosestPeerer,
tagg: o.Tagger, tagg: o.Tagger,
logger: o.Logger, deliveryCallback: o.DeliveryCallback,
metrics: newMetrics(), logger: o.Logger,
metrics: newMetrics(),
} }
return ps return ps
} }
...@@ -101,21 +104,7 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -101,21 +104,7 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
if err != nil { if err != nil {
// If i am the closest peer then store the chunk and send receipt // If i am the closest peer then store the chunk and send receipt
if errors.Is(err, topology.ErrWantSelf) { if errors.Is(err, topology.ErrWantSelf) {
return ps.handleDeliveryResponse(ctx, w, p, chunk)
// Store the chunk in the local store
_, err := ps.storer.Put(ctx, storage.ModePutSync, chunk)
if err != nil {
return fmt.Errorf("chunk store: %w", err)
}
ps.metrics.TotalChunksStoredInDB.Inc()
// Send a receipt immediately once the storage of the chunk is successfully
receipt := &pb.Receipt{Address: chunk.Address().Bytes()}
err = ps.sendReceipt(w, receipt)
if err != nil {
return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err)
}
return nil
} }
return err return err
} }
...@@ -123,17 +112,7 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -123,17 +112,7 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
// This is a special situation in that the other peer thinks thats we are the closest node // This is a special situation in that the other peer thinks thats we are the closest node
// and we think that the sending peer // and we think that the sending peer
if p.Address.Equal(peer) { if p.Address.Equal(peer) {
return ps.handleDeliveryResponse(ctx, w, p, chunk)
// Store the chunk in the local store
_, err := ps.storer.Put(ctx, storage.ModePutSync, chunk)
if err != nil {
return fmt.Errorf("chunk store: %w", err)
}
ps.metrics.TotalChunksStoredInDB.Inc()
// Send a receipt immediately once the storage of the chunk is successfully
receipt := &pb.Receipt{Address: chunk.Address().Bytes()}
return ps.sendReceipt(w, receipt)
} }
// Forward chunk to closest peer // Forward chunk to closest peer
...@@ -276,3 +255,29 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re ...@@ -276,3 +255,29 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re
return rec, nil return rec, nil
} }
func (ps *PushSync) deliverToPSS(ctx context.Context, ch swarm.Chunk) error {
// if callback is defined, call it for every new, valid chunk
if ps.deliveryCallback != nil {
return ps.deliveryCallback(ctx, ch)
}
return nil
}
func (ps *PushSync) handleDeliveryResponse(ctx context.Context, w protobuf.Writer, p p2p.Peer, chunk swarm.Chunk) error {
// Store the chunk in the local store
_, err := ps.storer.Put(ctx, storage.ModePutSync, chunk)
if err != nil {
return fmt.Errorf("chunk store: %w", err)
}
ps.metrics.TotalChunksStoredInDB.Inc()
// Send a receipt immediately once the storage of the chunk is successfully
receipt := &pb.Receipt{Address: chunk.Address().Bytes()}
err = ps.sendReceipt(w, receipt)
if err != nil {
return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err)
}
// since all PSS messages comes through push sync, deliver them here if this node is the destination
return ps.deliverToPSS(ctx, chunk)
}
...@@ -7,9 +7,6 @@ package pushsync_test ...@@ -7,9 +7,6 @@ package pushsync_test
import ( import (
"bytes" "bytes"
"context" "context"
"io/ioutil"
"testing"
"github.com/ethersphere/bee/pkg/localstore" "github.com/ethersphere/bee/pkg/localstore"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
...@@ -20,6 +17,9 @@ import ( ...@@ -20,6 +17,9 @@ import (
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/topology" "github.com/ethersphere/bee/pkg/topology"
"github.com/ethersphere/bee/pkg/topology/mock" "github.com/ethersphere/bee/pkg/topology/mock"
"io/ioutil"
"testing"
"time"
) )
// TestSendChunkAndGetReceipt inserts a chunk as uploaded chunk in db. This triggers sending a chunk to the closest node // TestSendChunkAndGetReceipt inserts a chunk as uploaded chunk in db. This triggers sending a chunk to the closest node
...@@ -36,14 +36,14 @@ func TestSendChunkAndReceiveReceipt(t *testing.T) { ...@@ -36,14 +36,14 @@ func TestSendChunkAndReceiveReceipt(t *testing.T) {
// peer is the node responding to the chunk receipt message // peer is the node responding to the chunk receipt message
// mock should return ErrWantSelf since there's no one to forward to // mock should return ErrWantSelf since there's no one to forward to
psPeer, storerPeer, _ := createPushSyncNode(t, closestPeer, nil, mock.WithClosestPeerErr(topology.ErrWantSelf)) psPeer, storerPeer, _ := createPushSyncNode(t, closestPeer, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer storerPeer.Close() defer storerPeer.Close()
recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol())) recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol()))
// pivot node needs the streamer since the chunk is intercepted by // pivot node needs the streamer since the chunk is intercepted by
// the chunk worker, then gets sent by opening a new stream // the chunk worker, then gets sent by opening a new stream
psPivot, storerPivot, _ := createPushSyncNode(t, pivotNode, recorder, mock.WithClosestPeer(closestPeer)) psPivot, storerPivot, _ := createPushSyncNode(t, pivotNode, recorder, nil, mock.WithClosestPeer(closestPeer))
defer storerPivot.Close() defer storerPivot.Close()
// Trigger the sending of chunk to the closest node // Trigger the sending of chunk to the closest node
...@@ -77,14 +77,14 @@ func TestPushChunkToClosest(t *testing.T) { ...@@ -77,14 +77,14 @@ func TestPushChunkToClosest(t *testing.T) {
// peer is the node responding to the chunk receipt message // peer is the node responding to the chunk receipt message
// mock should return ErrWantSelf since there's no one to forward to // mock should return ErrWantSelf since there's no one to forward to
psPeer, storerPeer, _ := createPushSyncNode(t, closestPeer, nil, mock.WithClosestPeerErr(topology.ErrWantSelf)) psPeer, storerPeer, _ := createPushSyncNode(t, closestPeer, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer storerPeer.Close() defer storerPeer.Close()
recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol())) recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol()))
// pivot node needs the streamer since the chunk is intercepted by // pivot node needs the streamer since the chunk is intercepted by
// the chunk worker, then gets sent by opening a new stream // the chunk worker, then gets sent by opening a new stream
psPivot, storerPivot, pivotTags := createPushSyncNode(t, pivotNode, recorder, mock.WithClosestPeer(closestPeer)) psPivot, storerPivot, pivotTags := createPushSyncNode(t, pivotNode, recorder, nil, mock.WithClosestPeer(closestPeer))
defer storerPivot.Close() defer storerPivot.Close()
ta, err := pivotTags.Create("test", 1, false) ta, err := pivotTags.Create("test", 1, false)
...@@ -130,7 +130,7 @@ func TestPushChunkToClosest(t *testing.T) { ...@@ -130,7 +130,7 @@ func TestPushChunkToClosest(t *testing.T) {
// TestHandler expect a chunk from a node on a stream. It then stores the chunk in the local store and // TestHandler expect a chunk from a node on a stream. It then stores the chunk in the local store and
// sends back a receipt. This is tested by intercepting the incoming stream for proper messages. // sends back a receipt. This is tested by intercepting the incoming stream for proper messages.
// It also sends the chunk to the closest peerand receives a receipt. // It also sends the chunk to the closest peer and receives a receipt.
// //
// Chunk moves from TriggerPeer -> PivotPeer -> ClosestPeer // Chunk moves from TriggerPeer -> PivotPeer -> ClosestPeer
// //
...@@ -145,20 +145,27 @@ func TestHandler(t *testing.T) { ...@@ -145,20 +145,27 @@ func TestHandler(t *testing.T) {
triggerPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") triggerPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000")
closestPeer := swarm.MustParseHexAddress("f000000000000000000000000000000000000000000000000000000000000000") closestPeer := swarm.MustParseHexAddress("f000000000000000000000000000000000000000000000000000000000000000")
// mock call back function to see if pss message is delivered when it is received in the destination (closestPeer in this testcase)
hookWasCalled := make(chan bool, 1) // channel to check if hook is called
pssDeliver := func(ctx context.Context, ch swarm.Chunk) error {
hookWasCalled <- true
return nil
}
// Create the closest peer // Create the closest peer
psClosestPeer, closestStorerPeerDB, _ := createPushSyncNode(t, closestPeer, nil, mock.WithClosestPeerErr(topology.ErrWantSelf)) psClosestPeer, closestStorerPeerDB, _ := createPushSyncNode(t, closestPeer, nil, pssDeliver, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer closestStorerPeerDB.Close() defer closestStorerPeerDB.Close()
closestRecorder := streamtest.New(streamtest.WithProtocols(psClosestPeer.Protocol())) closestRecorder := streamtest.New(streamtest.WithProtocols(psClosestPeer.Protocol()))
// creating the pivot peer // creating the pivot peer
psPivot, storerPivotDB, _ := createPushSyncNode(t, pivotPeer, closestRecorder, mock.WithClosestPeer(closestPeer)) psPivot, storerPivotDB, _ := createPushSyncNode(t, pivotPeer, closestRecorder, nil, mock.WithClosestPeer(closestPeer))
defer storerPivotDB.Close() defer storerPivotDB.Close()
pivotRecorder := streamtest.New(streamtest.WithProtocols(psPivot.Protocol())) pivotRecorder := streamtest.New(streamtest.WithProtocols(psPivot.Protocol()))
// Creating the trigger peer // Creating the trigger peer
psTriggerPeer, triggerStorerDB, _ := createPushSyncNode(t, triggerPeer, pivotRecorder, mock.WithClosestPeer(pivotPeer)) psTriggerPeer, triggerStorerDB, _ := createPushSyncNode(t, triggerPeer, pivotRecorder, nil, mock.WithClosestPeer(pivotPeer))
defer triggerStorerDB.Close() defer triggerStorerDB.Close()
receipt, err := psTriggerPeer.PushChunkToClosest(context.Background(), chunk) receipt, err := psTriggerPeer.PushChunkToClosest(context.Background(), chunk)
...@@ -182,9 +189,17 @@ func TestHandler(t *testing.T) { ...@@ -182,9 +189,17 @@ func TestHandler(t *testing.T) {
// In the received stream, check if a receipt is sent from pivot peer and check for its correctness. // In the received stream, check if a receipt is sent from pivot peer and check for its correctness.
waitOnRecordAndTest(t, pivotPeer, pivotRecorder, chunkAddress, nil) waitOnRecordAndTest(t, pivotPeer, pivotRecorder, chunkAddress, nil)
// check if the pss delivery hook is called
select {
case <-hookWasCalled:
break
case <-time.After(100 * time.Millisecond):
t.Fatal("recovery hook was not called")
}
} }
func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.Recorder, mockOpts ...mock.Option) (*pushsync.PushSync, *localstore.DB, *tags.Tags) { func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.Recorder, pssDeliver func(context.Context, swarm.Chunk) error, mockOpts ...mock.Option) (*pushsync.PushSync, *localstore.DB, *tags.Tags) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
storer, err := localstore.New("", addr.Bytes(), nil, logger) storer, err := localstore.New("", addr.Bytes(), nil, logger)
...@@ -196,11 +211,12 @@ func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.R ...@@ -196,11 +211,12 @@ func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.R
mtag := tags.NewTags() mtag := tags.NewTags()
ps := pushsync.New(pushsync.Options{ ps := pushsync.New(pushsync.Options{
Streamer: recorder, Streamer: recorder,
Storer: storer, Storer: storer,
Tagger: mtag, Tagger: mtag,
ClosestPeerer: mockTopology, DeliveryCallback: pssDeliver,
Logger: logger, ClosestPeerer: mockTopology,
Logger: logger,
}) })
return ps, storer, mtag return ps, storer, mtag
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package recovery
import (
"context"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
)
const (
// RecoveryTopicText is the string used to construct the recovery topic.
RecoveryTopicText = "RECOVERY"
)
var (
// RecoveryTopic is the topic used for repairing globally pinned chunks.
RecoveryTopic = trojan.NewTopic(RecoveryTopicText)
)
// RecoveryHook defines code to be executed upon failing to retrieve chunks.
type RecoveryHook func(chunkAddress swarm.Address, targets trojan.Targets) error
// sender is the function call for sending trojan chunks.
type PssSender interface {
Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, payload []byte) error
}
// NewRecoveryHook returns a new RecoveryHook with the sender function defined.
func NewRecoveryHook(pss PssSender) RecoveryHook {
return func(chunkAddress swarm.Address, targets trojan.Targets) error {
payload := chunkAddress
ctx := context.Background()
err := pss.Send(ctx, targets, RecoveryTopic, payload.Bytes())
return err
}
}
// NewRepairHandler creates a repair function to re-upload globally pinned chunks to the network with the given store.
func NewRepairHandler(s storage.Storer, logger logging.Logger, pushSyncer pushsync.PushSyncer) pss.Handler {
return func(ctx context.Context, m *trojan.Message) error {
chAddr := m.Payload
ch, err := s.Get(ctx, storage.ModeGetRequest, swarm.NewAddress(chAddr))
if err != nil {
logger.Tracef("chunk repair: error while getting chunk for repairing: %v", err)
return err
}
// push the chunk using push sync so that it reaches it destination in network
_, err = pushSyncer.PushChunkToClosest(ctx, ch)
if err != nil {
logger.Tracef("chunk repair: error while sending chunk or receiving receipt: %v", err)
return err
}
return nil
}
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package recovery_test
import (
"context"
"errors"
"io/ioutil"
"testing"
"time"
accountingmock "github.com/ethersphere/bee/pkg/accounting/mock"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/pushsync"
pushsyncmock "github.com/ethersphere/bee/pkg/pushsync/mock"
"github.com/ethersphere/bee/pkg/recovery"
"github.com/ethersphere/bee/pkg/retrieval"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock"
storemock "github.com/ethersphere/bee/pkg/storage/mock"
chunktesting "github.com/ethersphere/bee/pkg/storage/testing"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/topology"
"github.com/ethersphere/bee/pkg/trojan"
)
// TestRecoveryHook tests that a recovery hook can be created and called.
func TestRecoveryHook(t *testing.T) {
// test variables needed to be correctly set for any recovery hook to reach the sender func
chunkAddr := chunktesting.GenerateTestRandomChunk().Address()
targets := trojan.Targets{[]byte{0xED}}
//setup the sender
hookWasCalled := make(chan bool, 1) // channel to check if hook is called
pssSender := &mockPssSender{
hookC: hookWasCalled,
}
// create recovery hook and call it
recoveryHook := recovery.NewRecoveryHook(pssSender)
if err := recoveryHook(chunkAddr, targets); err != nil {
t.Fatal(err)
}
select {
case <-hookWasCalled:
break
case <-time.After(100 * time.Millisecond):
t.Fatal("recovery hook was not called")
}
}
// RecoveryHookTestCase is a struct used as test cases for the TestRecoveryHookCalls func.
type recoveryHookTestCase struct {
name string
ctx context.Context
expectsFailure bool
}
// TestRecoveryHookCalls verifies that recovery hooks are being called as expected when net store attempts to get a chunk.
func TestRecoveryHookCalls(t *testing.T) {
// generate test chunk, store and publisher
c := chunktesting.GenerateTestRandomChunk()
ref := c.Address()
target := "BE"
// test cases variables
dummyContext := context.Background() // has no publisher
targetContext := sctx.SetTargets(context.Background(), target)
for _, tc := range []recoveryHookTestCase{
{
name: "no targets in context",
ctx: dummyContext,
expectsFailure: true,
},
{
name: "targets set in context",
ctx: targetContext,
expectsFailure: false,
},
} {
t.Run(tc.name, func(t *testing.T) {
hookWasCalled := make(chan bool, 1) // channel to check if hook is called
// setup the sender
pssSender := &mockPssSender{
hookC: hookWasCalled,
}
recoverFunc := recovery.NewRecoveryHook(pssSender)
ns := newTestNetStore(t, recoverFunc)
// fetch test chunk
_, err := ns.Get(tc.ctx, storage.ModeGetRequest, ref)
if err != nil && !errors.Is(err, netstore.ErrRecoveryAttempt) && err.Error() != "error decoding prefix string" {
t.Fatal(err)
}
// checks whether the callback is invoked or the test case times out
select {
case <-hookWasCalled:
if !tc.expectsFailure {
return
}
t.Fatal("recovery hook was unexpectedly called")
case <-time.After(1000 * time.Millisecond):
if tc.expectsFailure {
return
}
t.Fatal("recovery hook was not called when expected")
}
})
}
}
// TestNewRepairHandler tests the function of repairing a chunk when a request for chunk repair is received.
func TestNewRepairHandler(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
t.Run("repair-chunk", func(t *testing.T) {
// generate test chunk, store and publisher
c1 := chunktesting.GenerateTestRandomChunk()
// create a mock storer and put a chunk that will be repaired
mockStorer := storemock.NewStorer()
defer mockStorer.Close()
_, err := mockStorer.Put(context.Background(), storage.ModePutRequest, c1)
if err != nil {
t.Fatal(err)
}
// create a mock pushsync service to push the chunk to its destination
var receipt *pushsync.Receipt
pushSyncService := pushsyncmock.New(func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error) {
receipt = &pushsync.Receipt{
Address: swarm.NewAddress(chunk.Address().Bytes()),
}
return receipt, nil
})
// create the chunk repair handler
repairHandler := recovery.NewRepairHandler(mockStorer, logger, pushSyncService)
//create a trojan message to trigger the repair of the chunk
testTopic := trojan.NewTopic("foo")
maxPayload := make([]byte, swarm.SectionSize)
var msg trojan.Message
copy(maxPayload, c1.Address().Bytes())
if msg, err = trojan.NewMessage(testTopic, maxPayload); err != nil {
t.Fatal(err)
}
// invoke the chunk repair handler
err = repairHandler(context.Background(), &msg)
if err != nil {
t.Fatal(err)
}
// check if receipt is received
if receipt == nil {
t.Fatal("receipt not received")
}
if !receipt.Address.Equal(c1.Address()) {
t.Fatalf("invalid address in receipt: expected %s received %s", c1.Address(), receipt.Address)
}
})
t.Run("repair-chunk-not-present", func(t *testing.T) {
// generate test chunk, store and publisher
c2 := chunktesting.GenerateTestRandomChunk()
// create a mock storer
mockStorer := storemock.NewStorer()
defer mockStorer.Close()
// create a mock pushsync service
pushServiceCalled := false
pushSyncService := pushsyncmock.New(func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error) {
pushServiceCalled = true
return nil, nil
})
// create the chunk repair handler
repairHandler := recovery.NewRepairHandler(mockStorer, logger, pushSyncService)
//create a trojan message to trigger the repair of the chunk
testTopic := trojan.NewTopic("foo")
maxPayload := make([]byte, swarm.SectionSize)
var msg trojan.Message
copy(maxPayload, c2.Address().Bytes())
msg, err := trojan.NewMessage(testTopic, maxPayload)
if err != nil {
t.Fatal(err)
}
// invoke the chunk repair handler
err = repairHandler(context.Background(), &msg)
if err != nil && err.Error() != "storage: not found" {
t.Fatal(err)
}
if pushServiceCalled {
t.Fatal("push service called even if the chunk is not present")
}
})
t.Run("repair-chunk-closest-peer-not-present", func(t *testing.T) {
// generate test chunk, store and publisher
c3 := chunktesting.GenerateTestRandomChunk()
// create a mock storer
mockStorer := storemock.NewStorer()
defer mockStorer.Close()
_, err := mockStorer.Put(context.Background(), storage.ModePutRequest, c3)
if err != nil {
t.Fatal(err)
}
// create a mock pushsync service
var receiptError error
pushSyncService := pushsyncmock.New(func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error) {
receiptError = errors.New("invalid receipt")
return nil, receiptError
})
// create the chunk repair handler
repairHandler := recovery.NewRepairHandler(mockStorer, logger, pushSyncService)
//create a trojan message to trigger the repair of the chunk
testTopic := trojan.NewTopic("foo")
maxPayload := make([]byte, swarm.SectionSize)
var msg trojan.Message
copy(maxPayload, c3.Address().Bytes())
msg, err = trojan.NewMessage(testTopic, maxPayload)
if err != nil {
t.Fatal(err)
}
// invoke the chunk repair handler
err = repairHandler(context.Background(), &msg)
if err != nil && err != receiptError {
t.Fatal(err)
}
if receiptError == nil {
t.Fatal("pushsync did not generate a receipt error")
}
})
}
// newTestNetStore creates a test store with a set RemoteGet func.
func newTestNetStore(t *testing.T, recoveryFunc recovery.RecoveryHook) storage.Storer {
t.Helper()
storer := mock.NewStorer()
logger := logging.New(ioutil.Discard, 5)
mockStorer := storemock.NewStorer()
serverMockAccounting := accountingmock.NewAccounting()
price := uint64(12345)
pricerMock := accountingmock.NewPricer(price, price)
peerID := swarm.MustParseHexAddress("deadbeef")
ps := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(peerID, 0)
return nil
}}
server := retrieval.New(retrieval.Options{
Storer: mockStorer,
Logger: logger,
Accounting: serverMockAccounting,
})
recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()),
)
retrieve := retrieval.New(retrieval.Options{
Streamer: recorder,
ChunkPeerer: ps,
Storer: mockStorer,
Logger: logger,
Accounting: serverMockAccounting,
Pricer: pricerMock,
})
ns := netstore.New(storer, recoveryFunc, retrieve, logger, nil)
return ns
}
type mockPeerSuggester struct {
eachPeerRevFunc func(f topology.EachPeerFunc) error
}
func (s mockPeerSuggester) EachPeer(topology.EachPeerFunc) error {
return errors.New("not implemented")
}
func (s mockPeerSuggester) EachPeerRev(f topology.EachPeerFunc) error {
return s.eachPeerRevFunc(f)
}
type mockPssSender struct {
hookC chan bool
}
// Send mocks the pss Send function
func (mp *mockPssSender) Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, payload []byte) error {
mp.hookC <- true
return nil
}
...@@ -4,12 +4,25 @@ ...@@ -4,12 +4,25 @@
package sctx package sctx
import "context" import (
"context"
"encoding/hex"
"errors"
"strings"
"github.com/ethersphere/bee/pkg/trojan"
)
var (
// ErrTargetPrefix is returned when target prefix decoding fails.
ErrTargetPrefix = errors.New("error decoding prefix string")
)
type ( type (
HTTPRequestIDKey struct{} HTTPRequestIDKey struct{}
requestHostKey struct{} requestHostKey struct{}
tagKey struct{} tagKey struct{}
targetsContextKey struct{}
) )
// SetHost sets the http request host in the context // SetHost sets the http request host in the context
...@@ -39,3 +52,32 @@ func GetTag(ctx context.Context) uint32 { ...@@ -39,3 +52,32 @@ func GetTag(ctx context.Context) uint32 {
} }
return 0 return 0
} }
// SetTargets set the target string in the context to be used downstream in netstore
func SetTargets(ctx context.Context, targets string) context.Context {
return context.WithValue(ctx, targetsContextKey{}, targets)
}
// GetTargets returns the specific target pinners for a corresponding chunk by
// reading the prefix targets sent in the download API.
func GetTargets(ctx context.Context) (trojan.Targets, error) {
targetString, ok := ctx.Value(targetsContextKey{}).(string)
if !ok {
return nil, ErrTargetPrefix
}
prefixes := strings.Split(targetString, ",")
var targets trojan.Targets
for _, prefix := range prefixes {
var target trojan.Target
target, err := hex.DecodeString(prefix)
if err != nil {
continue
}
targets = append(targets, target)
}
if len(targets) <= 0 {
return nil, ErrTargetPrefix
}
return targets, nil
}
...@@ -30,6 +30,8 @@ type MockStorer struct { ...@@ -30,6 +30,8 @@ type MockStorer struct {
morePull chan struct{} morePull chan struct{}
mtx sync.Mutex mtx sync.Mutex
quit chan struct{} quit chan struct{}
baseAddress []byte
bins []uint64
} }
func WithSubscribePullChunks(chs ...storage.Descriptor) Option { func WithSubscribePullChunks(chs ...storage.Descriptor) Option {
...@@ -41,6 +43,18 @@ func WithSubscribePullChunks(chs ...storage.Descriptor) Option { ...@@ -41,6 +43,18 @@ func WithSubscribePullChunks(chs ...storage.Descriptor) Option {
}) })
} }
func WithBaseAddress(a swarm.Address) Option {
return optionFunc(func(m *MockStorer) {
m.baseAddress = a.Bytes()
})
}
func WithTags(t *tags.Tags) Option {
return optionFunc(func(m *MockStorer) {
m.tags = t
})
}
func WithPartialInterval(v bool) Option { func WithPartialInterval(v bool) Option {
return optionFunc(func(m *MockStorer) { return optionFunc(func(m *MockStorer) {
m.partialInterval = v m.partialInterval = v
...@@ -54,6 +68,7 @@ func NewStorer(opts ...Option) *MockStorer { ...@@ -54,6 +68,7 @@ func NewStorer(opts ...Option) *MockStorer {
modeSetMu: sync.Mutex{}, modeSetMu: sync.Mutex{},
morePull: make(chan struct{}), morePull: make(chan struct{}),
quit: make(chan struct{}), quit: make(chan struct{}),
bins: make([]uint64, swarm.MaxBins),
} }
for _, v := range opts { for _, v := range opts {
...@@ -74,6 +89,16 @@ func NewValidatingStorer(v swarm.Validator, tags *tags.Tags) *MockStorer { ...@@ -74,6 +89,16 @@ func NewValidatingStorer(v swarm.Validator, tags *tags.Tags) *MockStorer {
} }
} }
func NewTagsStorer(tags *tags.Tags) *MockStorer {
return &MockStorer{
store: make(map[string][]byte),
modeSet: make(map[string]storage.ModeSet),
modeSetMu: sync.Mutex{},
pinSetMu: sync.Mutex{},
tags: tags,
}
}
func (m *MockStorer) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Address) (ch swarm.Chunk, err error) { func (m *MockStorer) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Address) (ch swarm.Chunk, err error) {
m.mtx.Lock() m.mtx.Lock()
defer m.mtx.Unlock() defer m.mtx.Unlock()
...@@ -104,6 +129,8 @@ func (m *MockStorer) Put(ctx context.Context, mode storage.ModePut, chs ...swarm ...@@ -104,6 +129,8 @@ func (m *MockStorer) Put(ctx context.Context, mode storage.ModePut, chs ...swarm
if yes { if yes {
exist = append(exist, true) exist = append(exist, true)
} else { } else {
po := swarm.Proximity(ch.Address().Bytes(), m.baseAddress)
m.bins[po]++
exist = append(exist, false) exist = append(exist, false)
} }
...@@ -186,7 +213,7 @@ func (m *MockStorer) GetModeSet(addr swarm.Address) (mode storage.ModeSet) { ...@@ -186,7 +213,7 @@ func (m *MockStorer) GetModeSet(addr swarm.Address) (mode storage.ModeSet) {
} }
func (m *MockStorer) LastPullSubscriptionBinID(bin uint8) (id uint64, err error) { func (m *MockStorer) LastPullSubscriptionBinID(bin uint8) (id uint64, err error) {
panic("not implemented") // TODO: Implement return m.bins[bin], nil
} }
func (m *MockStorer) SubscribePull(ctx context.Context, bin uint8, since, until uint64) (<-chan storage.Descriptor, <-chan struct{}, func()) { func (m *MockStorer) SubscribePull(ctx context.Context, bin uint8, since, until uint64) (<-chan storage.Descriptor, <-chan struct{}, func()) {
......
...@@ -11,8 +11,9 @@ import ( ...@@ -11,8 +11,9 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/ethersphere/bee/pkg/encryption"
"golang.org/x/crypto/sha3" "golang.org/x/crypto/sha3"
"github.com/ethersphere/bee/pkg/encryption"
) )
const ( const (
......
...@@ -42,7 +42,7 @@ const ( ...@@ -42,7 +42,7 @@ const (
NonceSize = 32 NonceSize = 32
LengthSize = 2 LengthSize = 2
TopicSize = 32 TopicSize = 32
MinerTimeout = 5 // seconds after which the mining will fail MinerTimeout = 20 // seconds after which the mining will fail
) )
// NewTopic creates a new Topic variable with the given input string // NewTopic creates a new Topic variable with the given input string
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment