Commit 0a40f19f authored by metacertain's avatar metacertain Committed by GitHub

Check for validity before concluding retrieving chunk (#427)

* Check for validity before concluding retrieving chunk

* Removed leftover code

* removed draft comment

* Fix Retrieve test with Fake Validator function

* GoFmt

* Remove second validation

* Retrieve and netstore validation and test fixes

* Content mock validator

* Validator mock configurable return

* Fix retrieval validity check before crediting

* Removed empty comment

* Change name for ChunkValidator & ChunkValidators -> Validator & ChunkValidator

* Remove whitespace
parent eb7ea591
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mock
import (
"github.com/ethersphere/bee/pkg/swarm"
)
var _ swarm.Validator = (*Validator)(nil)
type Validator struct {
rv bool
}
// NewValidator constructs a new Validator
func NewValidator(rv bool) swarm.Validator {
return &Validator{rv: rv}
}
// Validate returns rv from mock struct
func (v *Validator) Validate(ch swarm.Chunk) (valid bool) {
return v.rv
}
......@@ -8,7 +8,7 @@ import (
"github.com/ethersphere/bee/pkg/swarm"
)
var _ swarm.ChunkValidator = (*Validator)(nil)
var _ swarm.Validator = (*Validator)(nil)
// ContentAddressValidator validates that the address of a given chunk
// is the content address of its contents.
......@@ -16,7 +16,7 @@ type Validator struct {
}
// NewContentAddressValidator constructs a new ContentAddressValidator
func NewValidator() swarm.ChunkValidator {
func NewValidator() swarm.Validator {
return &Validator{}
}
......
......@@ -17,15 +17,14 @@ import (
type store struct {
storage.Storer
retrieval retrieval.Interface
validators []swarm.ChunkValidator
logger logging.Logger
retrieval retrieval.Interface
logger logging.Logger
validator swarm.Validator
}
// New returns a new NetStore that wraps a given Storer.
func New(s storage.Storer, r retrieval.Interface, logger logging.Logger, validators ...swarm.ChunkValidator) storage.Storer {
return &store{Storer: s, retrieval: r, logger: logger, validators: validators}
func New(s storage.Storer, r retrieval.Interface, logger logging.Logger, validator swarm.Validator) storage.Storer {
return &store{Storer: s, retrieval: r, logger: logger, validator: validator}
}
// Get retrieves a given chunk address.
......@@ -35,16 +34,11 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres
if err != nil {
if errors.Is(err, storage.ErrNotFound) {
// request from network
data, err := s.retrieval.RetrieveChunk(ctx, addr)
ch, err := s.retrieval.RetrieveChunk(ctx, addr)
if err != nil {
return nil, fmt.Errorf("netstore retrieve chunk: %w", err)
}
ch = swarm.NewChunk(addr, data)
if !s.valid(ch) {
return nil, storage.ErrInvalidChunk
}
_, err = s.Storer.Put(ctx, storage.ModePutRequest, ch)
if err != nil {
return nil, fmt.Errorf("netstore retrieve put: %w", err)
......@@ -61,19 +55,9 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres
// encountering an invalid chunk.
func (s *store) Put(ctx context.Context, mode storage.ModePut, chs ...swarm.Chunk) (exist []bool, err error) {
for _, ch := range chs {
if !s.valid(ch) {
if !s.validator.Validate(ch) {
return nil, storage.ErrInvalidChunk
}
}
return s.Storer.Put(ctx, mode, chs...)
}
// checks if a particular chunk is valid using the built in validators
func (s *store) valid(ch swarm.Chunk) (ok bool) {
for _, v := range s.validators {
if ok = v.Validate(ch); ok {
return true
}
}
return false
}
......@@ -11,6 +11,7 @@ import (
"sync/atomic"
"testing"
validatormock "github.com/ethersphere/bee/pkg/content/mock"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/netstore"
"github.com/ethersphere/bee/pkg/storage"
......@@ -100,7 +101,8 @@ func newRetrievingNetstore() (ret *retrievalMock, mockStore, ns storage.Storer)
retrieve := &retrievalMock{}
store := mock.NewStorer()
logger := logging.New(ioutil.Discard, 0)
nstore := netstore.New(store, retrieve, logger, mockValidator{})
validator := swarm.NewChunkValidator(validatormock.NewValidator(true))
nstore := netstore.New(store, retrieve, logger, validator)
return retrieve, store, nstore
}
......@@ -111,9 +113,9 @@ type retrievalMock struct {
addr swarm.Address
}
func (r *retrievalMock) RetrieveChunk(ctx context.Context, addr swarm.Address) (data []byte, err error) {
func (r *retrievalMock) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) {
r.called = true
atomic.AddInt32(&r.callCount, 1)
r.addr = addr
return chunkData, nil
return swarm.NewChunk(addr, chunkData), nil
}
......@@ -44,6 +44,7 @@ import (
"github.com/ethersphere/bee/pkg/statestore/leveldb"
mockinmem "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bee/pkg/tracing"
ma "github.com/multiformats/go-multiaddr"
......@@ -241,12 +242,15 @@ func NewBee(o Options) (*Bee, error) {
DisconnectThreshold: o.DisconnectThreshold,
})
chunkvalidators := swarm.NewChunkValidator(soc.NewValidator(), content.NewValidator())
retrieve := retrieval.New(retrieval.Options{
Streamer: p2ps,
ChunkPeerer: topologyDriver,
Logger: logger,
Accounting: acc,
Pricer: accounting.NewFixedPricer(address, 10),
Validator: chunkvalidators,
})
tagg := tags.NewTags()
......@@ -254,7 +258,7 @@ func NewBee(o Options) (*Bee, error) {
return nil, fmt.Errorf("retrieval service: %w", err)
}
ns := netstore.New(storer, retrieve, logger, content.NewValidator(), soc.NewValidator())
ns := netstore.New(storer, retrieve, logger, chunkvalidators)
retrieve.SetStorer(ns)
......
......@@ -31,7 +31,7 @@ const (
var _ Interface = (*Service)(nil)
type Interface interface {
RetrieveChunk(ctx context.Context, addr swarm.Address) (data []byte, err error)
RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error)
}
type Service struct {
......@@ -42,6 +42,7 @@ type Service struct {
logger logging.Logger
accounting accounting.Interface
pricer accounting.Pricer
validator swarm.Validator
}
type Options struct {
......@@ -51,6 +52,7 @@ type Options struct {
Logger logging.Logger
Accounting accounting.Interface
Pricer accounting.Pricer
Validator swarm.Validator
}
func New(o Options) *Service {
......@@ -61,6 +63,7 @@ func New(o Options) *Service {
logger: o.Logger,
accounting: o.Accounting,
pricer: o.Pricer,
validator: o.Validator,
}
}
......@@ -82,7 +85,7 @@ const (
retrieveChunkTimeout = 10 * time.Second
)
func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (data []byte, err error) {
func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (chunk swarm.Chunk, err error) {
ctx, cancel := context.WithTimeout(ctx, maxPeers*retrieveChunkTimeout)
defer cancel()
......@@ -90,7 +93,7 @@ func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (data [
var skipPeers []swarm.Address
for i := 0; i < maxPeers; i++ {
var peer swarm.Address
data, peer, err = s.retrieveChunk(ctx, addr, skipPeers)
chunk, peer, err := s.retrieveChunk(ctx, addr, skipPeers)
if err != nil {
if peer.IsZero() {
return nil, err
......@@ -100,17 +103,18 @@ func (s *Service) RetrieveChunk(ctx context.Context, addr swarm.Address) (data [
continue
}
s.logger.Tracef("retrieval: got chunk %s from peer %s", addr, peer)
return data, nil
return chunk, nil
}
return nil, err
})
if err != nil {
return nil, err
}
return v.([]byte), nil
return v.(swarm.Chunk), nil
}
func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPeers []swarm.Address) (data []byte, peer swarm.Address, err error) {
func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPeers []swarm.Address) (chunk swarm.Chunk, peer swarm.Address, err error) {
v := ctx.Value(requestSourceContextKey{})
if src, ok := v.(string); ok {
skipAddr, err := swarm.ParseHexAddress(src)
......@@ -161,12 +165,17 @@ func (s *Service) retrieveChunk(ctx context.Context, addr swarm.Address, skipPee
}
// credit the peer after successful delivery
chunk = swarm.NewChunk(addr, d.Data)
if !s.validator.Validate(chunk) {
return nil, peer, err
}
err = s.accounting.Credit(peer, chunkPrice)
if err != nil {
return nil, peer, err
}
return d.Data, peer, nil
return chunk, peer, err
}
func (s *Service) closestPeer(addr swarm.Address, skipPeers []swarm.Address) (swarm.Address, error) {
......
......@@ -14,6 +14,7 @@ import (
"time"
accountingmock "github.com/ethersphere/bee/pkg/accounting/mock"
"github.com/ethersphere/bee/pkg/content/mock"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/p2p/protobuf"
"github.com/ethersphere/bee/pkg/p2p/streamtest"
......@@ -30,7 +31,7 @@ var testTimeout = 5 * time.Second
// TestDelivery tests that a naive request -> delivery flow works.
func TestDelivery(t *testing.T) {
logger := logging.New(ioutil.Discard, 0)
mockValidator := swarm.NewChunkValidator(mock.NewValidator(true))
mockStorer := storemock.NewStorer()
reqAddr, err := swarm.ParseHexAddress("00112233")
if err != nil {
......@@ -55,6 +56,7 @@ func TestDelivery(t *testing.T) {
Logger: logger,
Accounting: serverMockAccounting,
Pricer: pricerMock,
Validator: mockValidator,
})
recorder := streamtest.New(
streamtest.WithProtocols(server.Protocol()),
......@@ -80,6 +82,7 @@ func TestDelivery(t *testing.T) {
Logger: logger,
Accounting: clientMockAccounting,
Pricer: pricerMock,
Validator: mockValidator,
})
ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
defer cancel()
......@@ -87,7 +90,7 @@ func TestDelivery(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(v, reqData) {
if !bytes.Equal(v.Data(), reqData) {
t.Fatalf("request and response data not equal. got %s want %s", v, reqData)
}
records, err := recorder.Records(peerID, "retrieval", "1.0.0", "retrieval")
......
......@@ -7,7 +7,7 @@ import (
"github.com/ethersphere/bee/pkg/swarm"
)
var _ swarm.ChunkValidator = (*Validator)(nil)
var _ swarm.Validator = (*Validator)(nil)
// SocVaildator validates that the address of a given chunk
// is a single-owner chunk.
......@@ -15,7 +15,7 @@ type Validator struct {
}
// NewSocValidator creates a new SocValidator.
func NewValidator() swarm.ChunkValidator {
func NewValidator() swarm.Validator {
return &Validator{}
}
......
......@@ -25,7 +25,7 @@ type MockStorer struct {
pinSetMu sync.Mutex
subpull []storage.Descriptor
partialInterval bool
validator swarm.ChunkValidator
validator swarm.Validator
tags *tags.Tags
morePull chan struct{}
mtx sync.Mutex
......@@ -63,7 +63,7 @@ func NewStorer(opts ...Option) *MockStorer {
return s
}
func NewValidatingStorer(v swarm.ChunkValidator, tags *tags.Tags) *MockStorer {
func NewValidatingStorer(v swarm.Validator, tags *tags.Tags) *MockStorer {
return &MockStorer{
store: make(map[string][]byte),
modeSet: make(map[string]storage.ModeSet),
......
......@@ -182,6 +182,27 @@ func (c *chunk) WithType(t Type) Chunk {
return c
}
type ChunkValidator interface {
type Validator interface {
Validate(ch Chunk) (valid bool)
}
type chunkValidator struct {
set []Validator
Validator
}
func NewChunkValidator(v ...Validator) Validator {
return &chunkValidator{
set: v,
}
}
func (c *chunkValidator) Validate(ch Chunk) bool {
for _, v := range c.set {
if v.Validate(ch) {
return true
}
}
return false
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment