Commit 2ab40671 authored by Zahoor Mohamed's avatar Zahoor Mohamed Committed by GitHub

Chunk repair (#479)

Chunk repair request and uploading the lost chunk to netwrk
parent f2b4bdc5
...@@ -45,6 +45,8 @@ func (c *command) initStartCmd() (err error) { ...@@ -45,6 +45,8 @@ func (c *command) initStartCmd() (err error) {
optionNameTracingEndpoint = "tracing-endpoint" optionNameTracingEndpoint = "tracing-endpoint"
optionNameTracingServiceName = "tracing-service-name" optionNameTracingServiceName = "tracing-service-name"
optionNameVerbosity = "verbosity" optionNameVerbosity = "verbosity"
optionNameDisconnectThreshold = "disconnect-threshold"
optionNameGlobalPinningEnabled = "global-pinning-enable"
optionNamePaymentThreshold = "payment-threshold" optionNamePaymentThreshold = "payment-threshold"
optionNamePaymentTolerance = "payment-tolerance" optionNamePaymentTolerance = "payment-tolerance"
) )
...@@ -119,6 +121,7 @@ Welcome to the Swarm.... Bzzz Bzzzz Bzzzz ...@@ -119,6 +121,7 @@ Welcome to the Swarm.... Bzzz Bzzzz Bzzzz
Password: password, Password: password,
APIAddr: c.config.GetString(optionNameAPIAddr), APIAddr: c.config.GetString(optionNameAPIAddr),
DebugAPIAddr: debugAPIAddr, DebugAPIAddr: debugAPIAddr,
Addr: c.config.GetString(optionNameP2PAddr),
NATAddr: c.config.GetString(optionNameNATAddr), NATAddr: c.config.GetString(optionNameNATAddr),
EnableWS: c.config.GetBool(optionNameP2PWSEnable), EnableWS: c.config.GetBool(optionNameP2PWSEnable),
EnableQUIC: c.config.GetBool(optionNameP2PQUICEnable), EnableQUIC: c.config.GetBool(optionNameP2PQUICEnable),
...@@ -129,6 +132,9 @@ Welcome to the Swarm.... Bzzz Bzzzz Bzzzz ...@@ -129,6 +132,9 @@ Welcome to the Swarm.... Bzzz Bzzzz Bzzzz
TracingEnabled: c.config.GetBool(optionNameTracingEnabled), TracingEnabled: c.config.GetBool(optionNameTracingEnabled),
TracingEndpoint: c.config.GetString(optionNameTracingEndpoint), TracingEndpoint: c.config.GetString(optionNameTracingEndpoint),
TracingServiceName: c.config.GetString(optionNameTracingServiceName), TracingServiceName: c.config.GetString(optionNameTracingServiceName),
Logger: logger,
DisconnectThreshold: c.config.GetUint64(optionNameDisconnectThreshold),
GlobalPinningEnabled: c.config.GetBool(optionNameGlobalPinningEnabled),
PaymentThreshold: c.config.GetUint64(optionNamePaymentThreshold), PaymentThreshold: c.config.GetUint64(optionNamePaymentThreshold),
PaymentTolerance: c.config.GetUint64(optionNamePaymentTolerance), PaymentTolerance: c.config.GetUint64(optionNamePaymentTolerance),
}) })
......
...@@ -6,19 +6,20 @@ package api ...@@ -6,19 +6,20 @@ package api
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"github.com/gorilla/mux"
"github.com/ethersphere/bee/pkg/collection/entry" "github.com/ethersphere/bee/pkg/collection/entry"
"github.com/ethersphere/bee/pkg/encryption" "github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/file" "github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/manifest/jsonmanifest" "github.com/ethersphere/bee/pkg/manifest/jsonmanifest"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/gorilla/mux"
) )
const ( const (
...@@ -29,7 +30,7 @@ const ( ...@@ -29,7 +30,7 @@ const (
func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) { func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) {
targets := r.URL.Query().Get("targets") targets := r.URL.Query().Get("targets")
r = r.WithContext(context.WithValue(r.Context(), targetsContextKey{}, targets)) r = r.WithContext(sctx.SetTargets(r.Context(), targets))
ctx := r.Context() ctx := r.Context()
addressHex := mux.Vars(r)["address"] addressHex := mux.Vars(r)["address"]
......
...@@ -9,12 +9,14 @@ import ( ...@@ -9,12 +9,14 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/ethersphere/bee/pkg/netstore"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strings" "strings"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
...@@ -94,7 +96,7 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -94,7 +96,7 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) {
func (s *server) chunkGetHandler(w http.ResponseWriter, r *http.Request) { func (s *server) chunkGetHandler(w http.ResponseWriter, r *http.Request) {
targets := r.URL.Query().Get("targets") targets := r.URL.Query().Get("targets")
r = r.WithContext(context.WithValue(r.Context(), targetsContextKey{}, targets)) r = r.WithContext(sctx.SetTargets(r.Context(), targets))
addr := mux.Vars(r)["addr"] addr := mux.Vars(r)["addr"]
ctx := r.Context() ctx := r.Context()
...@@ -115,6 +117,11 @@ func (s *server) chunkGetHandler(w http.ResponseWriter, r *http.Request) { ...@@ -115,6 +117,11 @@ func (s *server) chunkGetHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
if errors.Is(err, netstore.ErrRecoveryAttempt) {
s.Logger.Trace("chunk: chunk recovery initiated. addr %s", address)
jsonhttp.Accepted(w, "chunk recovery initiated. retry after sometime.")
return
}
s.Logger.Debugf("chunk: chunk read error: %v ,addr %s", err, address) s.Logger.Debugf("chunk: chunk read error: %v ,addr %s", err, address)
s.Logger.Error("chunk: chunk read error") s.Logger.Error("chunk: chunk read error")
jsonhttp.InternalServerError(w, "chunk read error") jsonhttp.InternalServerError(w, "chunk read error")
......
...@@ -26,6 +26,7 @@ import ( ...@@ -26,6 +26,7 @@ import (
"github.com/ethersphere/bee/pkg/file/joiner" "github.com/ethersphere/bee/pkg/file/joiner"
"github.com/ethersphere/bee/pkg/file/splitter" "github.com/ethersphere/bee/pkg/file/splitter"
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
...@@ -41,8 +42,6 @@ const ( ...@@ -41,8 +42,6 @@ const (
EncryptHeader = "swarm-encrypt" EncryptHeader = "swarm-encrypt"
) )
type targetsContextKey struct{}
// fileUploadResponse is returned when an HTTP request to upload a file is successful // fileUploadResponse is returned when an HTTP request to upload a file is successful
type fileUploadResponse struct { type fileUploadResponse struct {
Reference swarm.Address `json:"reference"` Reference swarm.Address `json:"reference"`
...@@ -283,9 +282,9 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -283,9 +282,9 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) {
} }
toDecrypt := len(address.Bytes()) == (swarm.HashSize + encryption.KeyLength) toDecrypt := len(address.Bytes()) == (swarm.HashSize + encryption.KeyLength)
targets := r.URL.Query().Get("targets")
r = r.WithContext(context.WithValue(r.Context(), targetsContextKey{}, targets)) targets := r.URL.Query().Get("targets")
sctx.SetTargets(r.Context(), targets)
// read entry. // read entry.
j := joiner.NewSimpleJoiner(s.Storer) j := joiner.NewSimpleJoiner(s.Storer)
...@@ -351,7 +350,7 @@ func (s *server) downloadHandler( ...@@ -351,7 +350,7 @@ func (s *server) downloadHandler(
) { ) {
targets := r.URL.Query().Get("targets") targets := r.URL.Query().Get("targets")
r = r.WithContext(context.WithValue(r.Context(), targetsContextKey{}, targets)) sctx.SetTargets(r.Context(), targets)
ctx := r.Context() ctx := r.Context()
toDecrypt := len(reference.Bytes()) == (swarm.HashSize + encryption.KeyLength) toDecrypt := len(reference.Bytes()) == (swarm.HashSize + encryption.KeyLength)
......
...@@ -109,7 +109,6 @@ func (db *DB) set(mode storage.ModeSet, addrs ...swarm.Address) (err error) { ...@@ -109,7 +109,6 @@ func (db *DB) set(mode storage.ModeSet, addrs ...swarm.Address) (err error) {
return err return err
} }
} }
default: default:
return ErrInvalidMode return ErrInvalidMode
} }
......
...@@ -10,7 +10,9 @@ import ( ...@@ -10,7 +10,9 @@ import (
"fmt" "fmt"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/recovery"
"github.com/ethersphere/bee/pkg/retrieval" "github.com/ethersphere/bee/pkg/retrieval"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -18,13 +20,19 @@ import ( ...@@ -18,13 +20,19 @@ import (
type store struct { type store struct {
storage.Storer storage.Storer
retrieval retrieval.Interface retrieval retrieval.Interface
logger logging.Logger
validator swarm.Validator validator swarm.Validator
logger logging.Logger
recoveryCallback recovery.RecoveryHook // this is the callback to be executed when a chunk fails to be retrieved
} }
var (
ErrRecoveryAttempt = errors.New("failed to retrieve chunk, recovery initiated")
)
// New returns a new NetStore that wraps a given Storer. // New returns a new NetStore that wraps a given Storer.
func New(s storage.Storer, r retrieval.Interface, logger logging.Logger, validator swarm.Validator) storage.Storer { func New(s storage.Storer, rcb recovery.RecoveryHook, r retrieval.Interface, logger logging.Logger,
return &store{Storer: s, retrieval: r, logger: logger, validator: validator} validator swarm.Validator) storage.Storer {
return &store{Storer: s, recoveryCallback: rcb, retrieval: r, logger: logger, validator: validator}
} }
// Get retrieves a given chunk address. // Get retrieves a given chunk address.
...@@ -34,11 +42,23 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres ...@@ -34,11 +42,23 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres
if err != nil { if err != nil {
if errors.Is(err, storage.ErrNotFound) { if errors.Is(err, storage.ErrNotFound) {
// request from network // request from network
ch, err := s.retrieval.RetrieveChunk(ctx, addr) ch, err = s.retrieval.RetrieveChunk(ctx, addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("netstore retrieve chunk: %w", err) if s.recoveryCallback == nil {
return nil, err
}
targets, err := sctx.GetTargets(ctx)
if err != nil {
return nil, err
}
go func() {
err := s.recoveryCallback(addr, targets)
if err != nil {
s.logger.Debugf("netstore: error while recovering chunk: %v", err)
}
}()
return nil, ErrRecoveryAttempt
} }
_, err = s.Storer.Put(ctx, storage.ModePutRequest, ch) _, err = s.Storer.Put(ctx, storage.ModePutRequest, ch)
if err != nil { if err != nil {
return nil, fmt.Errorf("netstore retrieve put: %w", err) return nil, fmt.Errorf("netstore retrieve put: %w", err)
......
...@@ -7,16 +7,21 @@ package netstore_test ...@@ -7,16 +7,21 @@ package netstore_test
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt"
"io/ioutil" "io/ioutil"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time"
validatormock "github.com/ethersphere/bee/pkg/content/mock" validatormock "github.com/ethersphere/bee/pkg/content/mock"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/netstore" "github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
) )
var chunkData = []byte("mockdata") var chunkData = []byte("mockdata")
...@@ -24,7 +29,7 @@ var chunkData = []byte("mockdata") ...@@ -24,7 +29,7 @@ var chunkData = []byte("mockdata")
// TestNetstoreRetrieval verifies that a chunk is asked from the network whenever // TestNetstoreRetrieval verifies that a chunk is asked from the network whenever
// it is not found locally // it is not found locally
func TestNetstoreRetrieval(t *testing.T) { func TestNetstoreRetrieval(t *testing.T) {
retrieve, store, nstore := newRetrievingNetstore() retrieve, store, nstore := newRetrievingNetstore(nil)
addr := swarm.MustParseHexAddress("000001") addr := swarm.MustParseHexAddress("000001")
_, err := nstore.Get(context.Background(), storage.ModeGetRequest, addr) _, err := nstore.Get(context.Background(), storage.ModeGetRequest, addr)
if err != nil { if err != nil {
...@@ -68,7 +73,7 @@ func TestNetstoreRetrieval(t *testing.T) { ...@@ -68,7 +73,7 @@ func TestNetstoreRetrieval(t *testing.T) {
// TestNetstoreNoRetrieval verifies that a chunk is not requested from the network // TestNetstoreNoRetrieval verifies that a chunk is not requested from the network
// whenever it is found locally. // whenever it is found locally.
func TestNetstoreNoRetrieval(t *testing.T) { func TestNetstoreNoRetrieval(t *testing.T) {
retrieve, store, nstore := newRetrievingNetstore() retrieve, store, nstore := newRetrievingNetstore(nil)
addr := swarm.MustParseHexAddress("000001") addr := swarm.MustParseHexAddress("000001")
// store should have the chunk in advance // store should have the chunk in advance
...@@ -92,26 +97,105 @@ func TestNetstoreNoRetrieval(t *testing.T) { ...@@ -92,26 +97,105 @@ func TestNetstoreNoRetrieval(t *testing.T) {
} }
} }
func TestRecovery(t *testing.T) {
hookWasCalled := make(chan bool, 1)
rec := &mockRecovery{
hookC: hookWasCalled,
}
retrieve, _, nstore := newRetrievingNetstore(rec)
addr := swarm.MustParseHexAddress("deadbeef")
retrieve.failure = true
ctx := context.Background()
ctx = sctx.SetTargets(ctx, "be, cd")
_, err := nstore.Get(ctx, storage.ModeGetRequest, addr)
if err != nil && !errors.Is(err, netstore.ErrRecoveryAttempt) {
t.Fatal(err)
}
select {
case <-hookWasCalled:
break
case <-time.After(100 * time.Millisecond):
t.Fatal("recovery hook was not called")
}
}
func TestInvalidRecoveryFunction(t *testing.T) {
retrieve, _, nstore := newRetrievingNetstore(nil)
addr := swarm.MustParseHexAddress("deadbeef")
retrieve.failure = true
ctx := context.Background()
ctx = sctx.SetTargets(ctx, "be, cd")
_, err := nstore.Get(ctx, storage.ModeGetRequest, addr)
if err != nil && err.Error() != "chunk not found" {
t.Fatal(err)
}
}
func TestInvalidTargets(t *testing.T) {
hookWasCalled := make(chan bool, 1)
rec := &mockRecovery{
hookC: hookWasCalled,
}
retrieve, _, nstore := newRetrievingNetstore(rec)
addr := swarm.MustParseHexAddress("deadbeef")
retrieve.failure = true
ctx := context.Background()
ctx = sctx.SetTargets(ctx, "gh")
_, err := nstore.Get(ctx, storage.ModeGetRequest, addr)
if err != nil && !errors.Is(err, sctx.ErrTargetPrefix) {
t.Fatal(err)
}
}
// returns a mock retrieval protocol, a mock local storage and a netstore // returns a mock retrieval protocol, a mock local storage and a netstore
func newRetrievingNetstore() (ret *retrievalMock, mockStore, ns storage.Storer) { func newRetrievingNetstore(rec *mockRecovery) (ret *retrievalMock, mockStore, ns storage.Storer) {
retrieve := &retrievalMock{} retrieve := &retrievalMock{}
store := mock.NewStorer() store := mock.NewStorer()
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
validator := swarm.NewChunkValidator(validatormock.NewValidator(true)) validator := swarm.NewChunkValidator(validatormock.NewValidator(true))
nstore := netstore.New(store, retrieve, logger, validator)
var nstore storage.Storer
if rec != nil {
nstore = netstore.New(store, rec.recovery, retrieve, logger, validator)
} else {
nstore = netstore.New(store, nil, retrieve, logger, validator)
}
return retrieve, store, nstore return retrieve, store, nstore
} }
type retrievalMock struct { type retrievalMock struct {
called bool called bool
callCount int32 callCount int32
failure bool
addr swarm.Address addr swarm.Address
} }
func (r *retrievalMock) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) { func (r *retrievalMock) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) {
if r.failure {
return nil, fmt.Errorf("chunk not found")
}
r.called = true r.called = true
atomic.AddInt32(&r.callCount, 1) atomic.AddInt32(&r.callCount, 1)
r.addr = addr r.addr = addr
return swarm.NewChunk(addr, chunkData), nil return swarm.NewChunk(addr, chunkData), nil
} }
type mockRecovery struct {
hookC chan bool
}
// Send mocks the pss Send function
func (mr *mockRecovery) recovery(chunkAddress swarm.Address, targets trojan.Targets) error {
mr.hookC <- true
return nil
}
func (r *mockRecovery) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) {
return nil, fmt.Errorf("chunk not found")
}
...@@ -31,11 +31,13 @@ import ( ...@@ -31,11 +31,13 @@ import (
"github.com/ethersphere/bee/pkg/netstore" "github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/p2p/libp2p" "github.com/ethersphere/bee/pkg/p2p/libp2p"
"github.com/ethersphere/bee/pkg/pingpong" "github.com/ethersphere/bee/pkg/pingpong"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/puller" "github.com/ethersphere/bee/pkg/puller"
"github.com/ethersphere/bee/pkg/pullsync" "github.com/ethersphere/bee/pkg/pullsync"
"github.com/ethersphere/bee/pkg/pullsync/pullstorage" "github.com/ethersphere/bee/pkg/pullsync/pullstorage"
"github.com/ethersphere/bee/pkg/pusher" "github.com/ethersphere/bee/pkg/pusher"
"github.com/ethersphere/bee/pkg/pushsync" "github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/recovery"
"github.com/ethersphere/bee/pkg/retrieval" "github.com/ethersphere/bee/pkg/retrieval"
"github.com/ethersphere/bee/pkg/settlement/pseudosettle" "github.com/ethersphere/bee/pkg/settlement/pseudosettle"
"github.com/ethersphere/bee/pkg/soc" "github.com/ethersphere/bee/pkg/soc"
...@@ -71,6 +73,7 @@ type Options struct { ...@@ -71,6 +73,7 @@ type Options struct {
Password string Password string
APIAddr string APIAddr string
DebugAPIAddr string DebugAPIAddr string
Addr string
NATAddr string NATAddr string
EnableWS bool EnableWS bool
EnableQUIC bool EnableQUIC bool
...@@ -78,10 +81,12 @@ type Options struct { ...@@ -78,10 +81,12 @@ type Options struct {
WelcomeMessage string WelcomeMessage string
Bootnodes []string Bootnodes []string
CORSAllowedOrigins []string CORSAllowedOrigins []string
Logger logging.Logger
TracingEnabled bool TracingEnabled bool
TracingEndpoint string TracingEndpoint string
TracingServiceName string TracingServiceName string
DisconnectThreshold uint64 DisconnectThreshold uint64
GlobalPinningEnabled bool
PaymentThreshold uint64 PaymentThreshold uint64
PaymentTolerance uint64 PaymentTolerance uint64
} }
...@@ -276,22 +281,41 @@ func NewBee(addr string, logger logging.Logger, o Options) (*Bee, error) { ...@@ -276,22 +281,41 @@ func NewBee(addr string, logger logging.Logger, o Options) (*Bee, error) {
return nil, fmt.Errorf("retrieval service: %w", err) return nil, fmt.Errorf("retrieval service: %w", err)
} }
ns := netstore.New(storer, retrieve, logger, chunkvalidator) // instantiate the pss object
psss := pss.New(logger, nil)
var ns storage.Storer
if o.GlobalPinningEnabled {
// create recovery callback for content repair
recoverFunc := recovery.NewRecoveryHook(psss)
ns = netstore.New(storer, recoverFunc, retrieve, logger, chunkvalidator)
} else {
ns = netstore.New(storer, nil, retrieve, logger, chunkvalidator)
}
retrieve.SetStorer(ns) retrieve.SetStorer(ns)
pushSyncProtocol := pushsync.New(pushsync.Options{ pushSyncProtocol := pushsync.New(pushsync.Options{
Streamer: p2ps, Streamer: p2ps,
Storer: storer, Storer: storer,
ClosestPeerer: kad, ClosestPeerer: kad,
DeliveryCallback: psss.TryUnwrap,
Tagger: tagg, Tagger: tagg,
Logger: logger, Logger: logger,
}) })
// set the pushSyncer in the PSS
psss.WithPushSyncer(pushSyncProtocol)
if err = p2ps.AddProtocol(pushSyncProtocol.Protocol()); err != nil { if err = p2ps.AddProtocol(pushSyncProtocol.Protocol()); err != nil {
return nil, fmt.Errorf("pushsync service: %w", err) return nil, fmt.Errorf("pushsync service: %w", err)
} }
if o.GlobalPinningEnabled {
// register function for chunk repair upon receiving a trojan message
chunkRepairHandler := recovery.NewRepairHandler(ns, logger, pushSyncProtocol)
psss.Register(recovery.RecoveryTopic, chunkRepairHandler)
}
pushSyncPusher := pusher.New(pusher.Options{ pushSyncPusher := pusher.New(pusher.Options{
Storer: storer, Storer: storer,
PeerSuggester: kad, PeerSuggester: kad,
......
...@@ -26,6 +26,7 @@ type Interface interface { ...@@ -26,6 +26,7 @@ type Interface interface {
Register(topic trojan.Topic, hndlr Handler) Register(topic trojan.Topic, hndlr Handler)
GetHandler(topic trojan.Topic) Handler GetHandler(topic trojan.Topic) Handler
TryUnwrap(ctx context.Context, c swarm.Chunk) error TryUnwrap(ctx context.Context, c swarm.Chunk) error
WithPushSyncer(pushSyncer pushsync.PushSyncer)
} }
// pss is the top-level struct, which takes care of message sending // pss is the top-level struct, which takes care of message sending
...@@ -47,8 +48,12 @@ func New(logger logging.Logger, pusher pushsync.PushSyncer) Interface { ...@@ -47,8 +48,12 @@ func New(logger logging.Logger, pusher pushsync.PushSyncer) Interface {
} }
} }
func (ps *pss) WithPushSyncer(pushSyncer pushsync.PushSyncer) {
ps.pusher = pushSyncer
}
// Handler defines code to be executed upon reception of a trojan message // Handler defines code to be executed upon reception of a trojan message
type Handler func(*trojan.Message) type Handler func(context.Context, *trojan.Message) error
// Send constructs a padded message with topic and payload, // Send constructs a padded message with topic and payload,
// wraps it in a trojan chunk such that one of the targets is a prefix of the chunk address // wraps it in a trojan chunk such that one of the targets is a prefix of the chunk address
...@@ -64,6 +69,7 @@ func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Top ...@@ -64,6 +69,7 @@ func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Top
var tc swarm.Chunk var tc swarm.Chunk
tc, err = m.Wrap(targets) tc, err = m.Wrap(targets)
if err != nil { if err != nil {
return err return err
} }
...@@ -84,6 +90,7 @@ func (p *pss) Register(topic trojan.Topic, hndlr Handler) { ...@@ -84,6 +90,7 @@ func (p *pss) Register(topic trojan.Topic, hndlr Handler) {
// TryUnwrap allows unwrapping a chunk as a trojan message and calling its handler func based on its topic // TryUnwrap allows unwrapping a chunk as a trojan message and calling its handler func based on its topic
func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error { func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error {
if !trojan.IsPotential(c) { if !trojan.IsPotential(c) {
return nil return nil
} }
...@@ -95,8 +102,7 @@ func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error { ...@@ -95,8 +102,7 @@ func (p *pss) TryUnwrap(ctx context.Context, c swarm.Chunk) error {
if h == nil { if h == nil {
return fmt.Errorf("topic %v, %w", m.Topic, ErrNoHandler) return fmt.Errorf("topic %v, %w", m.Topic, ErrNoHandler)
} }
h(m) return h(ctx, m)
return nil
} }
// GetHandler returns the Handler func registered in pss for the given topic // GetHandler returns the Handler func registered in pss for the given topic
......
...@@ -71,28 +71,36 @@ func TestRegister(t *testing.T) { ...@@ -71,28 +71,36 @@ func TestRegister(t *testing.T) {
handlerVerifier := 0 // test variable to check handler funcs are correctly retrieved handlerVerifier := 0 // test variable to check handler funcs are correctly retrieved
// register first handler // register first handler
testHandler := func(m *trojan.Message) { testHandler := func(ctx context.Context, m *trojan.Message) error {
handlerVerifier = 1 handlerVerifier = 1
return nil
} }
testTopic := trojan.NewTopic("FIRST_HANDLER") testTopic := trojan.NewTopic("FIRST_HANDLER")
pss.Register(testTopic, testHandler) pss.Register(testTopic, testHandler)
registeredHandler := pss.GetHandler(testTopic) registeredHandler := pss.GetHandler(testTopic)
registeredHandler(&trojan.Message{}) // call handler to verify the retrieved func is correct err := registeredHandler(context.Background(), &trojan.Message{}) // call handler to verify the retrieved func is correct
if err != nil {
t.Fatal(err)
}
if handlerVerifier != 1 { if handlerVerifier != 1 {
t.Fatalf("unexpected handler retrieved, verifier variable should be 1 but is %d instead", handlerVerifier) t.Fatalf("unexpected handler retrieved, verifier variable should be 1 but is %d instead", handlerVerifier)
} }
// register second handler // register second handler
testHandler = func(m *trojan.Message) { testHandler = func(ctx context.Context, m *trojan.Message) error {
handlerVerifier = 2 handlerVerifier = 2
return nil
} }
testTopic = trojan.NewTopic("SECOND_HANDLER") testTopic = trojan.NewTopic("SECOND_HANDLER")
pss.Register(testTopic, testHandler) pss.Register(testTopic, testHandler)
registeredHandler = pss.GetHandler(testTopic) registeredHandler = pss.GetHandler(testTopic)
registeredHandler(&trojan.Message{}) // call handler to verify the retrieved func is correct err = registeredHandler(context.Background(), &trojan.Message{}) // call handler to verify the retrieved func is correct
if err != nil {
t.Fatal(err)
}
if handlerVerifier != 2 { if handlerVerifier != 2 {
t.Fatalf("unexpected handler retrieved, verifier variable should be 2 but is %d instead", handlerVerifier) t.Fatalf("unexpected handler retrieved, verifier variable should be 2 but is %d instead", handlerVerifier)
...@@ -124,8 +132,9 @@ func TestDeliver(t *testing.T) { ...@@ -124,8 +132,9 @@ func TestDeliver(t *testing.T) {
// create and register handler // create and register handler
var tt trojan.Topic // test variable to check handler func was correctly called var tt trojan.Topic // test variable to check handler func was correctly called
hndlr := func(m *trojan.Message) { hndlr := func(ctx context.Context, m *trojan.Message) error {
tt = m.Topic // copy the message topic to the test variable tt = m.Topic // copy the message topic to the test variable
return nil
} }
pss.Register(topic, hndlr) pss.Register(topic, hndlr)
...@@ -149,7 +158,7 @@ func TestHandler(t *testing.T) { ...@@ -149,7 +158,7 @@ func TestHandler(t *testing.T) {
} }
// register first handler // register first handler
testHandler := func(m *trojan.Message) {} testHandler := func(ctx context.Context, m *trojan.Message) error { return nil }
// set handler for test topic // set handler for test topic
pss.Register(testTopic, testHandler) pss.Register(testTopic, testHandler)
......
...@@ -39,6 +39,7 @@ type PushSync struct { ...@@ -39,6 +39,7 @@ type PushSync struct {
storer storage.Putter storer storage.Putter
peerSuggester topology.ClosestPeerer peerSuggester topology.ClosestPeerer
tagg *tags.Tags tagg *tags.Tags
deliveryCallback func(context.Context, swarm.Chunk) error // callback func to be invoked to deliver chunks to PSS
logger logging.Logger logger logging.Logger
metrics metrics metrics metrics
} }
...@@ -48,6 +49,7 @@ type Options struct { ...@@ -48,6 +49,7 @@ type Options struct {
Storer storage.Putter Storer storage.Putter
ClosestPeerer topology.ClosestPeerer ClosestPeerer topology.ClosestPeerer
Tagger *tags.Tags Tagger *tags.Tags
DeliveryCallback func(context.Context, swarm.Chunk) error
Logger logging.Logger Logger logging.Logger
} }
...@@ -59,6 +61,7 @@ func New(o Options) *PushSync { ...@@ -59,6 +61,7 @@ func New(o Options) *PushSync {
storer: o.Storer, storer: o.Storer,
peerSuggester: o.ClosestPeerer, peerSuggester: o.ClosestPeerer,
tagg: o.Tagger, tagg: o.Tagger,
deliveryCallback: o.DeliveryCallback,
logger: o.Logger, logger: o.Logger,
metrics: newMetrics(), metrics: newMetrics(),
} }
...@@ -101,21 +104,7 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -101,21 +104,7 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
if err != nil { if err != nil {
// If i am the closest peer then store the chunk and send receipt // If i am the closest peer then store the chunk and send receipt
if errors.Is(err, topology.ErrWantSelf) { if errors.Is(err, topology.ErrWantSelf) {
return ps.handleDeliveryResponse(ctx, w, p, chunk)
// Store the chunk in the local store
_, err := ps.storer.Put(ctx, storage.ModePutSync, chunk)
if err != nil {
return fmt.Errorf("chunk store: %w", err)
}
ps.metrics.TotalChunksStoredInDB.Inc()
// Send a receipt immediately once the storage of the chunk is successfully
receipt := &pb.Receipt{Address: chunk.Address().Bytes()}
err = ps.sendReceipt(w, receipt)
if err != nil {
return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err)
}
return nil
} }
return err return err
} }
...@@ -123,17 +112,7 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream) ...@@ -123,17 +112,7 @@ func (ps *PushSync) handler(ctx context.Context, p p2p.Peer, stream p2p.Stream)
// This is a special situation in that the other peer thinks thats we are the closest node // This is a special situation in that the other peer thinks thats we are the closest node
// and we think that the sending peer // and we think that the sending peer
if p.Address.Equal(peer) { if p.Address.Equal(peer) {
return ps.handleDeliveryResponse(ctx, w, p, chunk)
// Store the chunk in the local store
_, err := ps.storer.Put(ctx, storage.ModePutSync, chunk)
if err != nil {
return fmt.Errorf("chunk store: %w", err)
}
ps.metrics.TotalChunksStoredInDB.Inc()
// Send a receipt immediately once the storage of the chunk is successfully
receipt := &pb.Receipt{Address: chunk.Address().Bytes()}
return ps.sendReceipt(w, receipt)
} }
// Forward chunk to closest peer // Forward chunk to closest peer
...@@ -276,3 +255,29 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re ...@@ -276,3 +255,29 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re
return rec, nil return rec, nil
} }
func (ps *PushSync) deliverToPSS(ctx context.Context, ch swarm.Chunk) error {
// if callback is defined, call it for every new, valid chunk
if ps.deliveryCallback != nil {
return ps.deliveryCallback(ctx, ch)
}
return nil
}
func (ps *PushSync) handleDeliveryResponse(ctx context.Context, w protobuf.Writer, p p2p.Peer, chunk swarm.Chunk) error {
// Store the chunk in the local store
_, err := ps.storer.Put(ctx, storage.ModePutSync, chunk)
if err != nil {
return fmt.Errorf("chunk store: %w", err)
}
ps.metrics.TotalChunksStoredInDB.Inc()
// Send a receipt immediately once the storage of the chunk is successfully
receipt := &pb.Receipt{Address: chunk.Address().Bytes()}
err = ps.sendReceipt(w, receipt)
if err != nil {
return fmt.Errorf("send receipt to peer %s: %w", p.Address.String(), err)
}
// since all PSS messages comes through push sync, deliver them here if this node is the destination
return ps.deliverToPSS(ctx, chunk)
}
...@@ -7,9 +7,6 @@ package pushsync_test ...@@ -7,9 +7,6 @@ package pushsync_test
import ( import (
"bytes" "bytes"
"context" "context"
"io/ioutil"
"testing"
"github.com/ethersphere/bee/pkg/localstore" "github.com/ethersphere/bee/pkg/localstore"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
...@@ -20,6 +17,9 @@ import ( ...@@ -20,6 +17,9 @@ import (
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/topology" "github.com/ethersphere/bee/pkg/topology"
"github.com/ethersphere/bee/pkg/topology/mock" "github.com/ethersphere/bee/pkg/topology/mock"
"io/ioutil"
"testing"
"time"
) )
// TestSendChunkAndGetReceipt inserts a chunk as uploaded chunk in db. This triggers sending a chunk to the closest node // TestSendChunkAndGetReceipt inserts a chunk as uploaded chunk in db. This triggers sending a chunk to the closest node
...@@ -36,14 +36,14 @@ func TestSendChunkAndReceiveReceipt(t *testing.T) { ...@@ -36,14 +36,14 @@ func TestSendChunkAndReceiveReceipt(t *testing.T) {
// peer is the node responding to the chunk receipt message // peer is the node responding to the chunk receipt message
// mock should return ErrWantSelf since there's no one to forward to // mock should return ErrWantSelf since there's no one to forward to
psPeer, storerPeer, _ := createPushSyncNode(t, closestPeer, nil, mock.WithClosestPeerErr(topology.ErrWantSelf)) psPeer, storerPeer, _ := createPushSyncNode(t, closestPeer, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer storerPeer.Close() defer storerPeer.Close()
recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol())) recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol()))
// pivot node needs the streamer since the chunk is intercepted by // pivot node needs the streamer since the chunk is intercepted by
// the chunk worker, then gets sent by opening a new stream // the chunk worker, then gets sent by opening a new stream
psPivot, storerPivot, _ := createPushSyncNode(t, pivotNode, recorder, mock.WithClosestPeer(closestPeer)) psPivot, storerPivot, _ := createPushSyncNode(t, pivotNode, recorder, nil, mock.WithClosestPeer(closestPeer))
defer storerPivot.Close() defer storerPivot.Close()
// Trigger the sending of chunk to the closest node // Trigger the sending of chunk to the closest node
...@@ -77,14 +77,14 @@ func TestPushChunkToClosest(t *testing.T) { ...@@ -77,14 +77,14 @@ func TestPushChunkToClosest(t *testing.T) {
// peer is the node responding to the chunk receipt message // peer is the node responding to the chunk receipt message
// mock should return ErrWantSelf since there's no one to forward to // mock should return ErrWantSelf since there's no one to forward to
psPeer, storerPeer, _ := createPushSyncNode(t, closestPeer, nil, mock.WithClosestPeerErr(topology.ErrWantSelf)) psPeer, storerPeer, _ := createPushSyncNode(t, closestPeer, nil, nil, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer storerPeer.Close() defer storerPeer.Close()
recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol())) recorder := streamtest.New(streamtest.WithProtocols(psPeer.Protocol()))
// pivot node needs the streamer since the chunk is intercepted by // pivot node needs the streamer since the chunk is intercepted by
// the chunk worker, then gets sent by opening a new stream // the chunk worker, then gets sent by opening a new stream
psPivot, storerPivot, pivotTags := createPushSyncNode(t, pivotNode, recorder, mock.WithClosestPeer(closestPeer)) psPivot, storerPivot, pivotTags := createPushSyncNode(t, pivotNode, recorder, nil, mock.WithClosestPeer(closestPeer))
defer storerPivot.Close() defer storerPivot.Close()
ta, err := pivotTags.Create("test", 1, false) ta, err := pivotTags.Create("test", 1, false)
...@@ -130,7 +130,7 @@ func TestPushChunkToClosest(t *testing.T) { ...@@ -130,7 +130,7 @@ func TestPushChunkToClosest(t *testing.T) {
// TestHandler expect a chunk from a node on a stream. It then stores the chunk in the local store and // TestHandler expect a chunk from a node on a stream. It then stores the chunk in the local store and
// sends back a receipt. This is tested by intercepting the incoming stream for proper messages. // sends back a receipt. This is tested by intercepting the incoming stream for proper messages.
// It also sends the chunk to the closest peerand receives a receipt. // It also sends the chunk to the closest peer and receives a receipt.
// //
// Chunk moves from TriggerPeer -> PivotPeer -> ClosestPeer // Chunk moves from TriggerPeer -> PivotPeer -> ClosestPeer
// //
...@@ -145,20 +145,27 @@ func TestHandler(t *testing.T) { ...@@ -145,20 +145,27 @@ func TestHandler(t *testing.T) {
triggerPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000") triggerPeer := swarm.MustParseHexAddress("6000000000000000000000000000000000000000000000000000000000000000")
closestPeer := swarm.MustParseHexAddress("f000000000000000000000000000000000000000000000000000000000000000") closestPeer := swarm.MustParseHexAddress("f000000000000000000000000000000000000000000000000000000000000000")
// mock call back function to see if pss message is delivered when it is received in the destination (closestPeer in this testcase)
hookWasCalled := make(chan bool, 1) // channel to check if hook is called
pssDeliver := func(ctx context.Context, ch swarm.Chunk) error {
hookWasCalled <- true
return nil
}
// Create the closest peer // Create the closest peer
psClosestPeer, closestStorerPeerDB, _ := createPushSyncNode(t, closestPeer, nil, mock.WithClosestPeerErr(topology.ErrWantSelf)) psClosestPeer, closestStorerPeerDB, _ := createPushSyncNode(t, closestPeer, nil, pssDeliver, mock.WithClosestPeerErr(topology.ErrWantSelf))
defer closestStorerPeerDB.Close() defer closestStorerPeerDB.Close()
closestRecorder := streamtest.New(streamtest.WithProtocols(psClosestPeer.Protocol())) closestRecorder := streamtest.New(streamtest.WithProtocols(psClosestPeer.Protocol()))
// creating the pivot peer // creating the pivot peer
psPivot, storerPivotDB, _ := createPushSyncNode(t, pivotPeer, closestRecorder, mock.WithClosestPeer(closestPeer)) psPivot, storerPivotDB, _ := createPushSyncNode(t, pivotPeer, closestRecorder, nil, mock.WithClosestPeer(closestPeer))
defer storerPivotDB.Close() defer storerPivotDB.Close()
pivotRecorder := streamtest.New(streamtest.WithProtocols(psPivot.Protocol())) pivotRecorder := streamtest.New(streamtest.WithProtocols(psPivot.Protocol()))
// Creating the trigger peer // Creating the trigger peer
psTriggerPeer, triggerStorerDB, _ := createPushSyncNode(t, triggerPeer, pivotRecorder, mock.WithClosestPeer(pivotPeer)) psTriggerPeer, triggerStorerDB, _ := createPushSyncNode(t, triggerPeer, pivotRecorder, nil, mock.WithClosestPeer(pivotPeer))
defer triggerStorerDB.Close() defer triggerStorerDB.Close()
receipt, err := psTriggerPeer.PushChunkToClosest(context.Background(), chunk) receipt, err := psTriggerPeer.PushChunkToClosest(context.Background(), chunk)
...@@ -182,9 +189,17 @@ func TestHandler(t *testing.T) { ...@@ -182,9 +189,17 @@ func TestHandler(t *testing.T) {
// In the received stream, check if a receipt is sent from pivot peer and check for its correctness. // In the received stream, check if a receipt is sent from pivot peer and check for its correctness.
waitOnRecordAndTest(t, pivotPeer, pivotRecorder, chunkAddress, nil) waitOnRecordAndTest(t, pivotPeer, pivotRecorder, chunkAddress, nil)
// check if the pss delivery hook is called
select {
case <-hookWasCalled:
break
case <-time.After(100 * time.Millisecond):
t.Fatal("recovery hook was not called")
}
} }
func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.Recorder, mockOpts ...mock.Option) (*pushsync.PushSync, *localstore.DB, *tags.Tags) { func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.Recorder, pssDeliver func(context.Context, swarm.Chunk) error, mockOpts ...mock.Option) (*pushsync.PushSync, *localstore.DB, *tags.Tags) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
storer, err := localstore.New("", addr.Bytes(), nil, logger) storer, err := localstore.New("", addr.Bytes(), nil, logger)
...@@ -199,6 +214,7 @@ func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.R ...@@ -199,6 +214,7 @@ func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.R
Streamer: recorder, Streamer: recorder,
Storer: storer, Storer: storer,
Tagger: mtag, Tagger: mtag,
DeliveryCallback: pssDeliver,
ClosestPeerer: mockTopology, ClosestPeerer: mockTopology,
Logger: logger, Logger: logger,
}) })
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package recovery
import (
"context"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/pss"
"github.com/ethersphere/bee/pkg/pushsync"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/trojan"
)
const (
// RecoveryTopicText is the string used to construct the recovery topic.
RecoveryTopicText = "RECOVERY"
)
var (
// RecoveryTopic is the topic used for repairing globally pinned chunks.
RecoveryTopic = trojan.NewTopic(RecoveryTopicText)
)
// RecoveryHook defines code to be executed upon failing to retrieve chunks.
type RecoveryHook func(chunkAddress swarm.Address, targets trojan.Targets) error
// sender is the function call for sending trojan chunks.
type PssSender interface {
Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, payload []byte) error
}
// NewRecoveryHook returns a new RecoveryHook with the sender function defined.
func NewRecoveryHook(pss PssSender) RecoveryHook {
return func(chunkAddress swarm.Address, targets trojan.Targets) error {
payload := chunkAddress
ctx := context.Background()
err := pss.Send(ctx, targets, RecoveryTopic, payload.Bytes())
return err
}
}
// NewRepairHandler creates a repair function to re-upload globally pinned chunks to the network with the given store.
func NewRepairHandler(s storage.Storer, logger logging.Logger, pushSyncer pushsync.PushSyncer) pss.Handler {
return func(ctx context.Context, m *trojan.Message) error {
chAddr := m.Payload
ch, err := s.Get(ctx, storage.ModeGetRequest, swarm.NewAddress(chAddr))
if err != nil {
logger.Tracef("chunk repair: error while getting chunk for repairing: %v", err)
return err
}
// push the chunk using push sync so that it reaches it destination in network
_, err = pushSyncer.PushChunkToClosest(ctx, ch)
if err != nil {
logger.Tracef("chunk repair: error while sending chunk or receiving receipt: %v", err)
return err
}
return nil
}
}
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package recovery_test
import (
"context"
"errors"
"io/ioutil"
"testing"
"time"
accountingmock "github.com/ethersphere/bee/pkg/accounting/mock"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/p2p/streamtest"
"github.com/ethersphere/bee/pkg/pushsync"
pushsyncmock "github.com/ethersphere/bee/pkg/pushsync/mock"
"github.com/ethersphere/bee/pkg/recovery"
"github.com/ethersphere/bee/pkg/retrieval"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock"
storemock "github.com/ethersphere/bee/pkg/storage/mock"
chunktesting "github.com/ethersphere/bee/pkg/storage/testing"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/topology"
"github.com/ethersphere/bee/pkg/trojan"
)
// TestRecoveryHook tests that a recovery hook can be created and called.
func TestRecoveryHook(t *testing.T) {
// test variables needed to be correctly set for any recovery hook to reach the sender func
chunkAddr := chunktesting.GenerateTestRandomChunk().Address()
targets := trojan.Targets{[]byte{0xED}}
//setup the sender
hookWasCalled := make(chan bool, 1) // channel to check if hook is called
pssSender := &mockPssSender{
hookC: hookWasCalled,
}
// create recovery hook and call it
recoveryHook := recovery.NewRecoveryHook(pssSender)
if err := recoveryHook(chunkAddr, targets); err != nil {
t.Fatal(err)
}
select {
case <-hookWasCalled:
break
case <-time.After(100 * time.Millisecond):
t.Fatal("recovery hook was not called")
}
}
// RecoveryHookTestCase is a struct used as test cases for the TestRecoveryHookCalls func.
type recoveryHookTestCase struct {
name string
ctx context.Context
expectsFailure bool
}
// TestRecoveryHookCalls verifies that recovery hooks are being called as expected when net store attempts to get a chunk.
func TestRecoveryHookCalls(t *testing.T) {
// generate test chunk, store and publisher
c := chunktesting.GenerateTestRandomChunk()
ref := c.Address()
target := "BE"
// test cases variables
dummyContext := context.Background() // has no publisher
targetContext := sctx.SetTargets(context.Background(), target)
for _, tc := range []recoveryHookTestCase{
{
name: "no targets in context",
ctx: dummyContext,
expectsFailure: true,
},
{
name: "targets set in context",
ctx: targetContext,
expectsFailure: false,
},
} {
t.Run(tc.name, func(t *testing.T) {
hookWasCalled := make(chan bool, 1) // channel to check if hook is called
// setup the sender
pssSender := &mockPssSender{
hookC: hookWasCalled,
}
recoverFunc := recovery.NewRecoveryHook(pssSender)
ns := newTestNetStore(t, recoverFunc)
// fetch test chunk
_, err := ns.Get(tc.ctx, storage.ModeGetRequest, ref)
if err != nil && !errors.Is(err, netstore.ErrRecoveryAttempt) && err.Error() != "error decoding prefix string" {
t.Fatal(err)
}
// checks whether the callback is invoked or the test case times out
select {
case <-hookWasCalled:
if !tc.expectsFailure {
return
}
t.Fatal("recovery hook was unexpectedly called")
case <-time.After(1000 * time.Millisecond):
if tc.expectsFailure {
return
}
t.Fatal("recovery hook was not called when expected")
}
})
}
}
// TestNewRepairHandler tests the function of repairing a chunk when a request for chunk repair is received.
func TestNewRepairHandler(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
t.Run("repair-chunk", func(t *testing.T) {
// generate test chunk, store and publisher
c1 := chunktesting.GenerateTestRandomChunk()
// create a mock storer and put a chunk that will be repaired
mockStorer := storemock.NewStorer()
defer mockStorer.Close()
_, err := mockStorer.Put(context.Background(), storage.ModePutRequest, c1)
if err != nil {
t.Fatal(err)
}
// create a mock pushsync service to push the chunk to its destination
var receipt *pushsync.Receipt
pushSyncService := pushsyncmock.New(func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error) {
receipt = &pushsync.Receipt{
Address: swarm.NewAddress(chunk.Address().Bytes()),
}
return receipt, nil
})
// create the chunk repair handler
repairHandler := recovery.NewRepairHandler(mockStorer, logger, pushSyncService)
//create a trojan message to trigger the repair of the chunk
testTopic := trojan.NewTopic("foo")
maxPayload := make([]byte, swarm.SectionSize)
var msg trojan.Message
copy(maxPayload, c1.Address().Bytes())
if msg, err = trojan.NewMessage(testTopic, maxPayload); err != nil {
t.Fatal(err)
}
// invoke the chunk repair handler
err = repairHandler(context.Background(), &msg)
if err != nil {
t.Fatal(err)
}
// check if receipt is received
if receipt == nil {
t.Fatal("receipt not received")
}
if !receipt.Address.Equal(c1.Address()) {
t.Fatalf("invalid address in receipt: expected %s received %s", c1.Address(), receipt.Address)
}
})
t.Run("repair-chunk-not-present", func(t *testing.T) {
// generate test chunk, store and publisher
c2 := chunktesting.GenerateTestRandomChunk()
// create a mock storer
mockStorer := storemock.NewStorer()
defer mockStorer.Close()
// create a mock pushsync service
pushServiceCalled := false
pushSyncService := pushsyncmock.New(func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error) {
pushServiceCalled = true
return nil, nil
})
// create the chunk repair handler
repairHandler := recovery.NewRepairHandler(mockStorer, logger, pushSyncService)
//create a trojan message to trigger the repair of the chunk
testTopic := trojan.NewTopic("foo")
maxPayload := make([]byte, swarm.SectionSize)
var msg trojan.Message
copy(maxPayload, c2.Address().Bytes())
msg, err := trojan.NewMessage(testTopic, maxPayload)
if err != nil {
t.Fatal(err)
}
// invoke the chunk repair handler
err = repairHandler(context.Background(), &msg)
if err != nil && err.Error() != "storage: not found" {
t.Fatal(err)
}
if pushServiceCalled {
t.Fatal("push service called even if the chunk is not present")
}
})
t.Run("repair-chunk-closest-peer-not-present", func(t *testing.T) {
// generate test chunk, store and publisher
c3 := chunktesting.GenerateTestRandomChunk()
// create a mock storer
mockStorer := storemock.NewStorer()
defer mockStorer.Close()
_, err := mockStorer.Put(context.Background(), storage.ModePutRequest, c3)
if err != nil {
t.Fatal(err)
}
// create a mock pushsync service
var receiptError error
pushSyncService := pushsyncmock.New(func(ctx context.Context, chunk swarm.Chunk) (*pushsync.Receipt, error) {
receiptError = errors.New("invalid receipt")
return nil, receiptError
})
// create the chunk repair handler
repairHandler := recovery.NewRepairHandler(mockStorer, logger, pushSyncService)
//create a trojan message to trigger the repair of the chunk
testTopic := trojan.NewTopic("foo")
maxPayload := make([]byte, swarm.SectionSize)
var msg trojan.Message
copy(maxPayload, c3.Address().Bytes())
msg, err = trojan.NewMessage(testTopic, maxPayload)
if err != nil {
t.Fatal(err)
}
// invoke the chunk repair handler
err = repairHandler(context.Background(), &msg)
if err != nil && err != receiptError {
t.Fatal(err)
}
if receiptError == nil {
t.Fatal("pushsync did not generate a receipt error")
}
})
}
// newTestNetStore creates a test store with a set RemoteGet func.
func newTestNetStore(t *testing.T, recoveryFunc recovery.RecoveryHook) storage.Storer {
t.Helper()
storer := mock.NewStorer()
logger := logging.New(ioutil.Discard, 5)
mockStorer := storemock.NewStorer()
serverMockAccounting := accountingmock.NewAccounting()
price := uint64(12345)
pricerMock := accountingmock.NewPricer(price, price)
peerID := swarm.MustParseHexAddress("deadbeef")
ps := mockPeerSuggester{eachPeerRevFunc: func(f topology.EachPeerFunc) error {
_, _, _ = f(peerID, 0)
return nil
}}
server := retrieval.New(retrieval.Options{
Storer: mockStorer,
Logger: logger,
Accounting: serverMockAccounting,
})
recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()),
)
retrieve := retrieval.New(retrieval.Options{
Streamer: recorder,
ChunkPeerer: ps,
Storer: mockStorer,
Logger: logger,
Accounting: serverMockAccounting,
Pricer: pricerMock,
})
ns := netstore.New(storer, recoveryFunc, retrieve, logger, nil)
return ns
}
type mockPeerSuggester struct {
eachPeerRevFunc func(f topology.EachPeerFunc) error
}
func (s mockPeerSuggester) EachPeer(topology.EachPeerFunc) error {
return errors.New("not implemented")
}
func (s mockPeerSuggester) EachPeerRev(f topology.EachPeerFunc) error {
return s.eachPeerRevFunc(f)
}
type mockPssSender struct {
hookC chan bool
}
// Send mocks the pss Send function
func (mp *mockPssSender) Send(ctx context.Context, targets trojan.Targets, topic trojan.Topic, payload []byte) error {
mp.hookC <- true
return nil
}
...@@ -4,12 +4,25 @@ ...@@ -4,12 +4,25 @@
package sctx package sctx
import "context" import (
"context"
"encoding/hex"
"errors"
"strings"
"github.com/ethersphere/bee/pkg/trojan"
)
var (
// ErrTargetPrefix is returned when target prefix decoding fails.
ErrTargetPrefix = errors.New("error decoding prefix string")
)
type ( type (
HTTPRequestIDKey struct{} HTTPRequestIDKey struct{}
requestHostKey struct{} requestHostKey struct{}
tagKey struct{} tagKey struct{}
targetsContextKey struct{}
) )
// SetHost sets the http request host in the context // SetHost sets the http request host in the context
...@@ -39,3 +52,32 @@ func GetTag(ctx context.Context) uint32 { ...@@ -39,3 +52,32 @@ func GetTag(ctx context.Context) uint32 {
} }
return 0 return 0
} }
// SetTargets set the target string in the context to be used downstream in netstore
func SetTargets(ctx context.Context, targets string) context.Context {
return context.WithValue(ctx, targetsContextKey{}, targets)
}
// GetTargets returns the specific target pinners for a corresponding chunk by
// reading the prefix targets sent in the download API.
func GetTargets(ctx context.Context) (trojan.Targets, error) {
targetString, ok := ctx.Value(targetsContextKey{}).(string)
if !ok {
return nil, ErrTargetPrefix
}
prefixes := strings.Split(targetString, ",")
var targets trojan.Targets
for _, prefix := range prefixes {
var target trojan.Target
target, err := hex.DecodeString(prefix)
if err != nil {
continue
}
targets = append(targets, target)
}
if len(targets) <= 0 {
return nil, ErrTargetPrefix
}
return targets, nil
}
...@@ -30,6 +30,8 @@ type MockStorer struct { ...@@ -30,6 +30,8 @@ type MockStorer struct {
morePull chan struct{} morePull chan struct{}
mtx sync.Mutex mtx sync.Mutex
quit chan struct{} quit chan struct{}
baseAddress []byte
bins []uint64
} }
func WithSubscribePullChunks(chs ...storage.Descriptor) Option { func WithSubscribePullChunks(chs ...storage.Descriptor) Option {
...@@ -41,6 +43,18 @@ func WithSubscribePullChunks(chs ...storage.Descriptor) Option { ...@@ -41,6 +43,18 @@ func WithSubscribePullChunks(chs ...storage.Descriptor) Option {
}) })
} }
func WithBaseAddress(a swarm.Address) Option {
return optionFunc(func(m *MockStorer) {
m.baseAddress = a.Bytes()
})
}
func WithTags(t *tags.Tags) Option {
return optionFunc(func(m *MockStorer) {
m.tags = t
})
}
func WithPartialInterval(v bool) Option { func WithPartialInterval(v bool) Option {
return optionFunc(func(m *MockStorer) { return optionFunc(func(m *MockStorer) {
m.partialInterval = v m.partialInterval = v
...@@ -54,6 +68,7 @@ func NewStorer(opts ...Option) *MockStorer { ...@@ -54,6 +68,7 @@ func NewStorer(opts ...Option) *MockStorer {
modeSetMu: sync.Mutex{}, modeSetMu: sync.Mutex{},
morePull: make(chan struct{}), morePull: make(chan struct{}),
quit: make(chan struct{}), quit: make(chan struct{}),
bins: make([]uint64, swarm.MaxBins),
} }
for _, v := range opts { for _, v := range opts {
...@@ -74,6 +89,16 @@ func NewValidatingStorer(v swarm.Validator, tags *tags.Tags) *MockStorer { ...@@ -74,6 +89,16 @@ func NewValidatingStorer(v swarm.Validator, tags *tags.Tags) *MockStorer {
} }
} }
func NewTagsStorer(tags *tags.Tags) *MockStorer {
return &MockStorer{
store: make(map[string][]byte),
modeSet: make(map[string]storage.ModeSet),
modeSetMu: sync.Mutex{},
pinSetMu: sync.Mutex{},
tags: tags,
}
}
func (m *MockStorer) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Address) (ch swarm.Chunk, err error) { func (m *MockStorer) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Address) (ch swarm.Chunk, err error) {
m.mtx.Lock() m.mtx.Lock()
defer m.mtx.Unlock() defer m.mtx.Unlock()
...@@ -104,6 +129,8 @@ func (m *MockStorer) Put(ctx context.Context, mode storage.ModePut, chs ...swarm ...@@ -104,6 +129,8 @@ func (m *MockStorer) Put(ctx context.Context, mode storage.ModePut, chs ...swarm
if yes { if yes {
exist = append(exist, true) exist = append(exist, true)
} else { } else {
po := swarm.Proximity(ch.Address().Bytes(), m.baseAddress)
m.bins[po]++
exist = append(exist, false) exist = append(exist, false)
} }
...@@ -186,7 +213,7 @@ func (m *MockStorer) GetModeSet(addr swarm.Address) (mode storage.ModeSet) { ...@@ -186,7 +213,7 @@ func (m *MockStorer) GetModeSet(addr swarm.Address) (mode storage.ModeSet) {
} }
func (m *MockStorer) LastPullSubscriptionBinID(bin uint8) (id uint64, err error) { func (m *MockStorer) LastPullSubscriptionBinID(bin uint8) (id uint64, err error) {
panic("not implemented") // TODO: Implement return m.bins[bin], nil
} }
func (m *MockStorer) SubscribePull(ctx context.Context, bin uint8, since, until uint64) (<-chan storage.Descriptor, <-chan struct{}, func()) { func (m *MockStorer) SubscribePull(ctx context.Context, bin uint8, since, until uint64) (<-chan storage.Descriptor, <-chan struct{}, func()) {
......
...@@ -11,8 +11,9 @@ import ( ...@@ -11,8 +11,9 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/ethersphere/bee/pkg/encryption"
"golang.org/x/crypto/sha3" "golang.org/x/crypto/sha3"
"github.com/ethersphere/bee/pkg/encryption"
) )
const ( const (
......
...@@ -42,7 +42,7 @@ const ( ...@@ -42,7 +42,7 @@ const (
NonceSize = 32 NonceSize = 32
LengthSize = 2 LengthSize = 2
TopicSize = 32 TopicSize = 32
MinerTimeout = 5 // seconds after which the mining will fail MinerTimeout = 20 // seconds after which the mining will fail
) )
// NewTopic creates a new Topic variable with the given input string // NewTopic creates a new Topic variable with the given input string
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment