Commit 0a40f19f authored by metacertain's avatar metacertain Committed by GitHub

Check for validity before concluding retrieving chunk (#427)

* Check for validity before concluding retrieving chunk

* Removed leftover code

* removed draft comment

* Fix Retrieve test with Fake Validator function

* GoFmt

* Remove second validation

* Retrieve and netstore validation and test fixes

* Content mock validator

* Validator mock configurable return

* Fix retrieval validity check before crediting

* Removed empty comment

* Change name for ChunkValidator & ChunkValidators -> Validator & ChunkValidator

* Remove whitespace
parent eb7ea591
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mock
import (
"github.com/ethersphere/bee/pkg/swarm"
)
var _ swarm.Validator = (*Validator)(nil)
type Validator struct {
rv bool
}
// NewValidator constructs a new Validator
func NewValidator(rv bool) swarm.Validator {
return &Validator{rv: rv}
}
// Validate returns rv from mock struct
func (v *Validator) Validate(ch swarm.Chunk) (valid bool) {
return v.rv
}
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
var _ swarm.ChunkValidator = (*Validator)(nil) var _ swarm.Validator = (*Validator)(nil)
// ContentAddressValidator validates that the address of a given chunk // ContentAddressValidator validates that the address of a given chunk
// is the content address of its contents. // is the content address of its contents.
...@@ -16,7 +16,7 @@ type Validator struct { ...@@ -16,7 +16,7 @@ type Validator struct {
} }
// NewContentAddressValidator constructs a new ContentAddressValidator // NewContentAddressValidator constructs a new ContentAddressValidator
func NewValidator() swarm.ChunkValidator { func NewValidator() swarm.Validator {
return &Validator{} return &Validator{}
} }
......
...@@ -17,15 +17,14 @@ import ( ...@@ -17,15 +17,14 @@ import (
type store struct { type store struct {
storage.Storer storage.Storer
retrieval retrieval.Interface
retrieval retrieval.Interface logger logging.Logger
validators []swarm.ChunkValidator validator swarm.Validator
logger logging.Logger
} }
// New returns a new NetStore that wraps a given Storer. // New returns a new NetStore that wraps a given Storer.
func New(s storage.Storer, r retrieval.Interface, logger logging.Logger, validators ...swarm.ChunkValidator) storage.Storer { func New(s storage.Storer, r retrieval.Interface, logger logging.Logger, validator swarm.Validator) storage.Storer {
return &store{Storer: s, retrieval: r, logger: logger, validators: validators} return &store{Storer: s, retrieval: r, logger: logger, validator: validator}
} }
// Get retrieves a given chunk address. // Get retrieves a given chunk address.
...@@ -35,16 +34,11 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres ...@@ -35,16 +34,11 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres
if err != nil { if err != nil {
if errors.Is(err, storage.ErrNotFound) { if errors.Is(err, storage.ErrNotFound) {
// request from network // request from network
data, err := s.retrieval.RetrieveChunk(ctx, addr) ch, err := s.retrieval.RetrieveChunk(ctx, addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("netstore retrieve chunk: %w", err) return nil, fmt.Errorf("netstore retrieve chunk: %w", err)
} }
ch = swarm.NewChunk(addr, data)
if !s.valid(ch) {
return nil, storage.ErrInvalidChunk
}
_, err = s.Storer.Put(ctx, storage.ModePutRequest, ch) _, err = s.Storer.Put(ctx, storage.ModePutRequest, ch)
if err != nil { if err != nil {
return nil, fmt.Errorf("netstore retrieve put: %w", err) return nil, fmt.Errorf("netstore retrieve put: %w", err)
...@@ -61,19 +55,9 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres ...@@ -61,19 +55,9 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres
// encountering an invalid chunk. // encountering an invalid chunk.
func (s *store) Put(ctx context.Context, mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err error) { func (s *store) Put(ctx context.Context, mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err error) {
for _, ch := range chs { for _, ch := range chs {
if !s.valid(ch) { if !s.validator.Validate(ch) {
return nil, storage.ErrInvalidChunk return nil, storage.ErrInvalidChunk
} }
} }
return s.Storer.Put(ctx, mode, chs...) return s.Storer.Put(ctx, mode, chs...)
} }
// checks if a particular chunk is valid using the built in validators
func (s *store) valid(ch swarm.Chunk) (ok bool) {
for _, v := range s.validators {
if ok = v.Validate(ch); ok {
return true
}
}
return false
}
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"sync/atomic" "sync/atomic"
"testing" "testing"
validatormock "github.com/ethersphere/bee/pkg/content/mock"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/netstore" "github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
...@@ -100,7 +101,8 @@ func newRetrievingNetstore() (ret *retrievalMock, mockStore, ns storage.Storer) ...@@ -100,7 +101,8 @@ func newRetrievingNetstore() (ret *retrievalMock, mockStore, ns storage.Storer)
retrieve := &retrievalMock{} retrieve := &retrievalMock{}
store := mock.NewStorer() store := mock.NewStorer()
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
nstore := netstore.New(store, retrieve, logger, mockValidator{}) validator := swarm.NewChunkValidator(validatormock.NewValidator(true))
nstore := netstore.New(store, retrieve, logger, validator)
return retrieve, store, nstore return retrieve, store, nstore
} }
...@@ -111,9 +113,9 @@ type retrievalMock struct { ...@@ -111,9 +113,9 @@ type retrievalMock struct {
addr swarm.Address addr swarm.Address
} }
func (r *retrievalMock) RetrieveChunk(ctx context.Context, addr swarm.Address) (data []byte, err error) { func (r *retrievalMock) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) {
r.called = true r.called = true
atomic.AddInt32(&r.callCount, 1) atomic.AddInt32(&r.callCount, 1)
r.addr = addr r.addr = addr
return chunkData, nil return swarm.NewChunk(addr, chunkData), nil
} }
...@@ -44,6 +44,7 @@ import ( ...@@ -44,6 +44,7 @@ import (
"github.com/ethersphere/bee/pkg/statestore/leveldb" "github.com/ethersphere/bee/pkg/statestore/leveldb"
mockinmem "github.com/ethersphere/bee/pkg/statestore/mock" mockinmem "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/tracing" "github.com/ethersphere/bee/pkg/tracing"
ma "github.com/multiformats/go-multiaddr" ma "github.com/multiformats/go-multiaddr"
...@@ -241,12 +242,15 @@ func NewBee(o Options) (*Bee, error) { ...@@ -241,12 +242,15 @@ func NewBee(o Options) (*Bee, error) {
DisconnectThreshold: o.DisconnectThreshold, DisconnectThreshold: o.DisconnectThreshold,
}) })
chunkvalidators := swarm.NewChunkValidator(soc.NewValidator(), content.NewValidator())
retrieve := retrieval.New(retrieval.Options{ retrieve := retrieval.New(retrieval.Options{
Streamer: p2ps, Streamer: p2ps,
ChunkPeerer: topologyDriver, ChunkPeerer: topologyDriver,
Logger: logger, Logger: logger,
Accounting: acc, Accounting: acc,
Pricer: accounting.NewFixedPricer(address, 10), Pricer: accounting.NewFixedPricer(address, 10),
Validator: chunkvalidators,
}) })
tagg := tags.NewTags() tagg := tags.NewTags()
...@@ -254,7 +258,7 @@ func NewBee(o Options) (*Bee, error) { ...@@ -254,7 +258,7 @@ func NewBee(o Options) (*Bee, error) {
return nil, fmt.Errorf("retrieval service: %w", err) return nil, fmt.Errorf("retrieval service: %w", err)
} }
ns := netstore.New(storer, retrieve, logger, content.NewValidator(), soc.NewValidator()) ns := netstore.New(storer, retrieve, logger, chunkvalidators)
retrieve.SetStorer(ns) retrieve.SetStorer(ns)
......
...@@ -31,7 +31,7 @@ const ( ...@@ -31,7 +31,7 @@ const (
var _ Interface = (*Service)(nil) var _ Interface = (*Service)(nil)
type Interface interface { type Interface interface {
RetrieveChunk(ctx context.Context, addr swarm.Address) (data []byte, err error) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error)
} }
type Service struct { type Service struct {
...@@ -42,6 +42,7 @@ type Service struct { ...@@ -42,6 +42,7 @@ type Service struct {
logger logging.Logger logger logging.Logger
accounting accounting.Interface accounting accounting.Interface
pricer accounting.Pricer pricer accounting.Pricer
validator swarm.Validator
} }
type Options struct { type Options struct {
...@@ -51,6 +52,7 @@ type Options struct { ...@@ -51,6 +52,7 @@ type Options struct {
Logger logging.Logger Logger logging.Logger
Accounting accounting.Interface Accounting accounting.Interface
Pricer accounting.Pricer Pricer accounting.Pricer
Validator swarm.Validator
} }
func New(o Options) *Service { func New(o Options) *Service {
...@@ -61,6 +63,7 @@ func New(o Options) *Service { ...@@ -61,6 +63,7 @@ func New(o Options) *Service {
logger: o.Logger, logger: o.Logger,
accounting: o.Accounting, accounting: o.Accounting,
pricer: o.Pricer, pricer: o.Pricer,
validator: o.Validator,
} }
} }
...@@ -82,7 +85,7 @@ const ( ...@@ -82,7 +85,7 @@ const (
retrieveChunkTimeout = 10 * time.Second retrieveChunkTimeout = 10 * time.Second
) )
func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (data []byte, err error) { func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) {
ctx, cancel := context.WithTimeout(ctx, maxPeers*retrieveChunkTimeout) ctx, cancel := context.WithTimeout(ctx, maxPeers*retrieveChunkTimeout)
defer cancel() defer cancel()
...@@ -90,7 +93,7 @@ func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (data [ ...@@ -90,7 +93,7 @@ func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (data [
var skipPeers []swarm.Address var skipPeers []swarm.Address
for i := 0; i < maxPeers; i++ { for i := 0; i < maxPeers; i++ {
var peer swarm.Address var peer swarm.Address
data, peer, err = s.retrieveChunk(ctx, addr, skipPeers) chunk, peer, err := s.retrieveChunk(ctx, addr, skipPeers)
if err != nil { if err != nil {
if peer.IsZero() { if peer.IsZero() {
return nil, err return nil, err
...@@ -100,17 +103,18 @@ func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (data [ ...@@ -100,17 +103,18 @@ func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (data [
continue continue
} }
s.logger.Tracef("retrieval: got chunk %s from peer %s", addr, peer) s.logger.Tracef("retrieval: got chunk %s from peer %s", addr, peer)
return data, nil return chunk, nil
} }
return nil, err return nil, err
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
return v.([]byte), nil
return v.(swarm.Chunk), nil
} }
func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPeers []swarm.Address) (data []byte, peer swarm.Address, err error) { func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPeers []swarm.Address) (chunk swarm.Chunk, peer swarm.Address, err error) {
v := ctx.Value(requestSourceContextKey{}) v := ctx.Value(requestSourceContextKey{})
if src, ok := v.(string); ok { if src, ok := v.(string); ok {
skipAddr, err := swarm.ParseHexAddress(src) skipAddr, err := swarm.ParseHexAddress(src)
...@@ -161,12 +165,17 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee ...@@ -161,12 +165,17 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee
} }
// credit the peer after successful delivery // credit the peer after successful delivery
chunk = swarm.NewChunk(addr, d.Data)
if !s.validator.Validate(chunk) {
return nil, peer, err
}
err = s.accounting.Credit(peer, chunkPrice) err = s.accounting.Credit(peer, chunkPrice)
if err != nil { if err != nil {
return nil, peer, err return nil, peer, err
} }
return d.Data, peer, nil return chunk, peer, err
} }
func (s *Service) closestPeer(addr swarm.Address, skipPeers []swarm.Address) (swarm.Address, error) { func (s *Service) closestPeer(addr swarm.Address, skipPeers []swarm.Address) (swarm.Address, error) {
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"time" "time"
accountingmock "github.com/ethersphere/bee/pkg/accounting/mock" accountingmock "github.com/ethersphere/bee/pkg/accounting/mock"
"github.com/ethersphere/bee/pkg/content/mock"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p/protobuf" "github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest" "github.com/ethersphere/bee/pkg/p2p/streamtest"
...@@ -30,7 +31,7 @@ var testTimeout = 5 * time.Second ...@@ -30,7 +31,7 @@ var testTimeout = 5 * time.Second
// TestDelivery tests that a naive request -> delivery flow works. // TestDelivery tests that a naive request -> delivery flow works.
func TestDelivery(t *testing.T) { func TestDelivery(t *testing.T) {
logger := logging.New(ioutil.Discard, 0) logger := logging.New(ioutil.Discard, 0)
mockValidator := swarm.NewChunkValidator(mock.NewValidator(true))
mockStorer := storemock.NewStorer() mockStorer := storemock.NewStorer()
reqAddr, err := swarm.ParseHexAddress("00112233") reqAddr, err := swarm.ParseHexAddress("00112233")
if err != nil { if err != nil {
...@@ -55,6 +56,7 @@ func TestDelivery(t *testing.T) { ...@@ -55,6 +56,7 @@ func TestDelivery(t *testing.T) {
Logger: logger, Logger: logger,
Accounting: serverMockAccounting, Accounting: serverMockAccounting,
Pricer: pricerMock, Pricer: pricerMock,
Validator: mockValidator,
}) })
recorder := streamtest.New( recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()), streamtest.WithProtocols(server.Protocol()),
...@@ -80,6 +82,7 @@ func TestDelivery(t *testing.T) { ...@@ -80,6 +82,7 @@ func TestDelivery(t *testing.T) {
Logger: logger, Logger: logger,
Accounting: clientMockAccounting, Accounting: clientMockAccounting,
Pricer: pricerMock, Pricer: pricerMock,
Validator: mockValidator,
}) })
ctx, cancel := context.WithTimeout(context.Background(), testTimeout) ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel() defer cancel()
...@@ -87,7 +90,7 @@ func TestDelivery(t *testing.T) { ...@@ -87,7 +90,7 @@ func TestDelivery(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if !bytes.Equal(v, reqData) { if !bytes.Equal(v.Data(), reqData) {
t.Fatalf("request and response data not equal. got %s want %s", v, reqData) t.Fatalf("request and response data not equal. got %s want %s", v, reqData)
} }
records, err := recorder.Records(peerID, "retrieval", "1.0.0", "retrieval") records, err := recorder.Records(peerID, "retrieval", "1.0.0", "retrieval")
......
...@@ -7,7 +7,7 @@ import ( ...@@ -7,7 +7,7 @@ import (
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
var _ swarm.ChunkValidator = (*Validator)(nil) var _ swarm.Validator = (*Validator)(nil)
// SocVaildator validates that the address of a given chunk // SocVaildator validates that the address of a given chunk
// is a single-owner chunk. // is a single-owner chunk.
...@@ -15,7 +15,7 @@ type Validator struct { ...@@ -15,7 +15,7 @@ type Validator struct {
} }
// NewSocValidator creates a new SocValidator. // NewSocValidator creates a new SocValidator.
func NewValidator() swarm.ChunkValidator { func NewValidator() swarm.Validator {
return &Validator{} return &Validator{}
} }
......
...@@ -25,7 +25,7 @@ type MockStorer struct { ...@@ -25,7 +25,7 @@ type MockStorer struct {
pinSetMu sync.Mutex pinSetMu sync.Mutex
subpull []storage.Descriptor subpull []storage.Descriptor
partialInterval bool partialInterval bool
validator swarm.ChunkValidator validator swarm.Validator
tags *tags.Tags tags *tags.Tags
morePull chan struct{} morePull chan struct{}
mtx sync.Mutex mtx sync.Mutex
...@@ -63,7 +63,7 @@ func NewStorer(opts ...Option) *MockStorer { ...@@ -63,7 +63,7 @@ func NewStorer(opts ...Option) *MockStorer {
return s return s
} }
func NewValidatingStorer(v swarm.ChunkValidator, tags *tags.Tags) *MockStorer { func NewValidatingStorer(v swarm.Validator, tags *tags.Tags) *MockStorer {
return &MockStorer{ return &MockStorer{
store: make(map[string][]byte), store: make(map[string][]byte),
modeSet: make(map[string]storage.ModeSet), modeSet: make(map[string]storage.ModeSet),
......
...@@ -182,6 +182,27 @@ func (c *chunk) WithType(t Type) Chunk { ...@@ -182,6 +182,27 @@ func (c *chunk) WithType(t Type) Chunk {
return c return c
} }
type ChunkValidator interface { type Validator interface {
Validate(ch Chunk) (valid bool) Validate(ch Chunk) (valid bool)
} }
type chunkValidator struct {
set []Validator
Validator
}
func NewChunkValidator(v ...Validator) Validator {
return &chunkValidator{
set: v,
}
}
func (c *chunkValidator) Validate(ch Chunk) bool {
for _, v := range c.set {
if v.Validate(ch) {
return true
}
}
return false
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment