Commit e3a08844 authored by Nemanja Zbiljić's avatar Nemanja Zbiljić Committed by GitHub

Return error from address iteration function (#1066)

parent 422e54c0
...@@ -123,21 +123,33 @@ func TestPinBytesHandler(t *testing.T) { ...@@ -123,21 +123,33 @@ func TestPinBytesHandler(t *testing.T) {
) )
hashes := []string{rootHash, data1Hash, data2Hash} hashes := []string{rootHash, data1Hash, data2Hash}
sort.Strings(hashes)
expectedResponse := api.ListPinnedChunksResponse{ // NOTE: all this because we cannot rely on sort from response
Chunks: []api.PinnedChunk{},
}
for _, h := range hashes { var resp api.ListPinnedChunksResponse
expectedResponse.Chunks = append(expectedResponse.Chunks, api.PinnedChunk{
Address: swarm.MustParseHexAddress(h),
PinCounter: 1,
})
}
jsonhttptest.Request(t, client, http.MethodGet, pinChunksResource, http.StatusOK, jsonhttptest.Request(t, client, http.MethodGet, pinChunksResource, http.StatusOK,
jsonhttptest.WithExpectedJSONResponse(expectedResponse), jsonhttptest.WithUnmarshalJSONResponse(&resp),
) )
if len(hashes) != len(resp.Chunks) {
t.Fatalf("expected to find %d pinned chunks, got %d", len(hashes), len(resp.Chunks))
}
respChunksHashes := make([]string, 0)
for _, rc := range resp.Chunks {
respChunksHashes = append(respChunksHashes, rc.Address.String())
}
sort.Strings(respChunksHashes)
for i, h := range hashes {
if h != respChunksHashes[i] {
t.Fatalf("expected to find %s address, found %s", h, respChunksHashes[i])
}
}
}) })
} }
...@@ -319,8 +319,8 @@ func (s *server) updatePinCount(ctx context.Context, reference swarm.Address, de ...@@ -319,8 +319,8 @@ func (s *server) updatePinCount(ctx context.Context, reference swarm.Address, de
return nil return nil
} }
func (s *server) pinChunkAddressFn(ctx context.Context, reference swarm.Address) func(address swarm.Address) (stop bool) { func (s *server) pinChunkAddressFn(ctx context.Context, reference swarm.Address) func(address swarm.Address) error {
return func(address swarm.Address) (stop bool) { return func(address swarm.Address) error {
// NOTE: stop pinning on first error // NOTE: stop pinning on first error
err := s.Storer.Set(ctx, storage.ModeSetPin, address) err := s.Storer.Set(ctx, storage.ModeSetPin, address)
...@@ -330,31 +330,31 @@ func (s *server) pinChunkAddressFn(ctx context.Context, reference swarm.Address) ...@@ -330,31 +330,31 @@ func (s *server) pinChunkAddressFn(ctx context.Context, reference swarm.Address)
ch, err := s.Storer.Get(ctx, storage.ModeGetRequest, address) ch, err := s.Storer.Get(ctx, storage.ModeGetRequest, address)
if err != nil { if err != nil {
s.Logger.Debugf("pin traversal: storer get: for reference %s, address %s: %w", reference, address, err) s.Logger.Debugf("pin traversal: storer get: for reference %s, address %s: %w", reference, address, err)
return true return err
} }
_, err = s.Storer.Put(ctx, storage.ModePutRequestPin, ch) _, err = s.Storer.Put(ctx, storage.ModePutRequestPin, ch)
if err != nil { if err != nil {
s.Logger.Debugf("pin traversal: storer put pin: for reference %s, address %s: %w", reference, address, err) s.Logger.Debugf("pin traversal: storer put pin: for reference %s, address %s: %w", reference, address, err)
return true return err
} }
return false return nil
} else {
s.Logger.Debugf("pin traversal: storer set pin: for reference %s, address %s: %w", reference, address, err)
return true
} }
s.Logger.Debugf("pin traversal: storer set pin: for reference %s, address %s: %w", reference, address, err)
return err
} }
return false return nil
} }
} }
func (s *server) unpinChunkAddressFn(ctx context.Context, reference swarm.Address) func(address swarm.Address) (stop bool) { func (s *server) unpinChunkAddressFn(ctx context.Context, reference swarm.Address) func(address swarm.Address) error {
return func(address swarm.Address) (stop bool) { return func(address swarm.Address) error {
_, err := s.Storer.PinCounter(address) _, err := s.Storer.PinCounter(address)
if err != nil { if err != nil {
return false return err
} }
err = s.Storer.Set(ctx, storage.ModeSetUnpin, address) err = s.Storer.Set(ctx, storage.ModeSetUnpin, address)
...@@ -363,6 +363,6 @@ func (s *server) unpinChunkAddressFn(ctx context.Context, reference swarm.Addres ...@@ -363,6 +363,6 @@ func (s *server) unpinChunkAddressFn(ctx context.Context, reference swarm.Addres
// continue un-pinning all chunks // continue un-pinning all chunks
} }
return false return nil
} }
} }
...@@ -6,18 +6,11 @@ package addresses ...@@ -6,18 +6,11 @@ package addresses
import ( import (
"context" "context"
"errors"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
var (
// ErrStopIterator is returned iterator function marks that iteration should
// be stopped.
ErrStopIterator = errors.New("stop iterator")
)
type addressesGetterStore struct { type addressesGetterStore struct {
getter storage.Getter getter storage.Getter
fn swarm.AddressIterFunc fn swarm.AddressIterFunc
...@@ -29,16 +22,11 @@ func NewGetter(getter storage.Getter, fn swarm.AddressIterFunc) storage.Getter { ...@@ -29,16 +22,11 @@ func NewGetter(getter storage.Getter, fn swarm.AddressIterFunc) storage.Getter {
return &addressesGetterStore{getter, fn} return &addressesGetterStore{getter, fn}
} }
func (s *addressesGetterStore) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Address) (ch swarm.Chunk, err error) { func (s *addressesGetterStore) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Address) (swarm.Chunk, error) {
ch, err = s.getter.Get(ctx, mode, addr) ch, err := s.getter.Get(ctx, mode, addr)
if err != nil { if err != nil {
return return nil, err
}
stop := s.fn(ch.Address())
if stop {
return ch, ErrStopIterator
} }
return return ch, s.fn(ch.Address())
} }
...@@ -52,11 +52,12 @@ func TestAddressesGetterIterateChunkAddresses(t *testing.T) { ...@@ -52,11 +52,12 @@ func TestAddressesGetterIterateChunkAddresses(t *testing.T) {
foundAddresses := make(map[string]struct{}) foundAddresses := make(map[string]struct{})
var foundAddressesMu sync.Mutex var foundAddressesMu sync.Mutex
addressIterFunc := func(addr swarm.Address) (stop bool) { addressIterFunc := func(addr swarm.Address) error {
foundAddressesMu.Lock() foundAddressesMu.Lock()
defer foundAddressesMu.Unlock()
foundAddresses[addr.String()] = struct{}{} foundAddresses[addr.String()] = struct{}{}
foundAddressesMu.Unlock() return nil
return false
} }
addressesGetter := addresses.NewGetter(store, addressIterFunc) addressesGetter := addresses.NewGetter(store, addressIterFunc)
......
...@@ -213,37 +213,36 @@ func (j *joiner) Seek(offset int64, whence int) (int64, error) { ...@@ -213,37 +213,36 @@ func (j *joiner) Seek(offset int64, whence int) (int64, error) {
func (j *joiner) IterateChunkAddresses(fn swarm.AddressIterFunc) error { func (j *joiner) IterateChunkAddresses(fn swarm.AddressIterFunc) error {
// report root address // report root address
stop := fn(j.addr) err := fn(j.addr)
if stop { if err != nil {
return nil return err
} }
var eg errgroup.Group return j.processChunkAddresses(j.ctx, fn, j.rootData, j.span)
j.processChunkAddresses(fn, j.rootData, j.span, &eg)
return eg.Wait()
} }
func (j *joiner) processChunkAddresses(fn swarm.AddressIterFunc, data []byte, subTrieSize int64, eg *errgroup.Group) { func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIterFunc, data []byte, subTrieSize int64) error {
// we are at a leaf data chunk // we are at a leaf data chunk
if subTrieSize <= int64(len(data)) { if subTrieSize <= int64(len(data)) {
return return nil
}
select {
case <-ctx.Done():
return ctx.Err()
default:
} }
eg, ectx := errgroup.WithContext(ctx)
var wg sync.WaitGroup var wg sync.WaitGroup
for cursor := 0; cursor < len(data); cursor += j.refLength { for cursor := 0; cursor < len(data); cursor += j.refLength {
select {
case <-j.ctx.Done():
return
default:
}
address := swarm.NewAddress(data[cursor : cursor+j.refLength]) address := swarm.NewAddress(data[cursor : cursor+j.refLength])
stop := fn(address) if err := fn(address); err != nil {
if stop { return err
break
} }
sec := subtrieSection(data, cursor, j.refLength, subTrieSize) sec := subtrieSection(data, cursor, j.refLength, subTrieSize)
...@@ -257,20 +256,22 @@ func (j *joiner) processChunkAddresses(fn swarm.AddressIterFunc, data []byte, su ...@@ -257,20 +256,22 @@ func (j *joiner) processChunkAddresses(fn swarm.AddressIterFunc, data []byte, su
eg.Go(func() error { eg.Go(func() error {
defer wg.Done() defer wg.Done()
ch, err := j.getter.Get(j.ctx, storage.ModeGetRequest, address) ch, err := j.getter.Get(ectx, storage.ModeGetRequest, address)
if err != nil { if err != nil {
return err return err
} }
chunkData := ch.Data()[8:] chunkData := ch.Data()[8:]
subtrieSpan := int64(chunkToSpan(ch.Data())) subtrieSpan := int64(chunkToSpan(ch.Data()))
j.processChunkAddresses(fn, chunkData, subtrieSpan, eg)
return nil return j.processChunkAddresses(ectx, fn, chunkData, subtrieSpan)
}) })
}(address, eg) }(address, eg)
wg.Wait() wg.Wait()
} }
return eg.Wait()
} }
func (j *joiner) Size() int64 { func (j *joiner) Size() int64 {
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
mrand "math/rand" mrand "math/rand"
"sync"
"testing" "testing"
"time" "time"
...@@ -800,10 +801,14 @@ func TestJoinerIterateChunkAddresses(t *testing.T) { ...@@ -800,10 +801,14 @@ func TestJoinerIterateChunkAddresses(t *testing.T) {
} }
foundAddresses := make(map[string]struct{}) foundAddresses := make(map[string]struct{})
var foundAddressesMu sync.Mutex
err = j.IterateChunkAddresses(func(addr swarm.Address) error {
foundAddressesMu.Lock()
defer foundAddressesMu.Unlock()
err = j.IterateChunkAddresses(func(addr swarm.Address) (stop bool) {
foundAddresses[addr.String()] = struct{}{} foundAddresses[addr.String()] = struct{}{}
return false return nil
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
......
...@@ -27,10 +27,6 @@ var ( ...@@ -27,10 +27,6 @@ var (
ErrMissingReference = errors.New("manifest: missing reference") ErrMissingReference = errors.New("manifest: missing reference")
) )
var (
errStopIterator = errors.New("manifest: stop iterator")
)
// Interface for operations with manifest. // Interface for operations with manifest.
type Interface interface { type Interface interface {
// Type returns manifest implementation type information // Type returns manifest implementation type information
......
...@@ -131,22 +131,20 @@ func (m *mantarayManifest) IterateAddresses(ctx context.Context, fn swarm.Addres ...@@ -131,22 +131,20 @@ func (m *mantarayManifest) IterateAddresses(ctx context.Context, fn swarm.Addres
} }
if node != nil { if node != nil {
var stop bool
if node.Reference() != nil { if node.Reference() != nil {
ref := swarm.NewAddress(node.Reference()) ref := swarm.NewAddress(node.Reference())
stop = fn(ref) err = fn(ref)
if stop { if err != nil {
return errStopIterator return err
} }
} }
if node.IsValueType() && node.Entry() != nil { if node.IsValueType() && node.Entry() != nil {
entry := swarm.NewAddress(node.Entry()) entry := swarm.NewAddress(node.Entry())
stop = fn(entry) err = fn(entry)
if stop { if err != nil {
return errStopIterator return err
} }
} }
} }
...@@ -156,10 +154,7 @@ func (m *mantarayManifest) IterateAddresses(ctx context.Context, fn swarm.Addres ...@@ -156,10 +154,7 @@ func (m *mantarayManifest) IterateAddresses(ctx context.Context, fn swarm.Addres
err := m.trie.WalkNode(ctx, []byte{}, m.ls, walker) err := m.trie.WalkNode(ctx, []byte{}, m.ls, walker)
if err != nil { if err != nil {
if !errors.Is(err, errStopIterator) { return fmt.Errorf("manifest iterate addresses: %w", err)
return fmt.Errorf("manifest iterate addresses: %w", err)
}
// ignore error if interation stopped by caller
} }
return nil return nil
......
...@@ -108,9 +108,9 @@ func (m *simpleManifest) IterateAddresses(ctx context.Context, fn swarm.AddressI ...@@ -108,9 +108,9 @@ func (m *simpleManifest) IterateAddresses(ctx context.Context, fn swarm.AddressI
} }
// NOTE: making it behave same for all manifest implementation // NOTE: making it behave same for all manifest implementation
stop := fn(m.reference) err := fn(m.reference)
if stop { if err != nil {
return nil return fmt.Errorf("manifest iterate addresses: %w", err)
} }
walker := func(path string, entry simple.Entry, err error) error { walker := func(path string, entry simple.Entry, err error) error {
...@@ -123,20 +123,12 @@ func (m *simpleManifest) IterateAddresses(ctx context.Context, fn swarm.AddressI ...@@ -123,20 +123,12 @@ func (m *simpleManifest) IterateAddresses(ctx context.Context, fn swarm.AddressI
return err return err
} }
stop := fn(ref) return fn(ref)
if stop {
return errStopIterator
}
return nil
} }
err := m.manifest.WalkEntry("", walker) err = m.manifest.WalkEntry("", walker)
if err != nil { if err != nil {
if !errors.Is(err, errStopIterator) { return fmt.Errorf("manifest iterate addresses: %w", err)
return fmt.Errorf("manifest iterate addresses: %w", err)
}
// ignore error if interation stopped by caller
} }
return nil return nil
......
...@@ -109,8 +109,7 @@ func (a Address) MarshalJSON() ([]byte, error) { ...@@ -109,8 +109,7 @@ func (a Address) MarshalJSON() ([]byte, error) {
var ZeroAddress = NewAddress(nil) var ZeroAddress = NewAddress(nil)
// AddressIterFunc is a callback on every address that is found by the iterator. // AddressIterFunc is a callback on every address that is found by the iterator.
// By returning a true for stop variable, iteration should stop. type AddressIterFunc func(address Address) error
type AddressIterFunc func(address Address) (stop bool)
type Chunk interface { type Chunk interface {
Address() Address Address() Address
......
...@@ -82,12 +82,8 @@ func (s *traversalService) TraverseAddresses( ...@@ -82,12 +82,8 @@ func (s *traversalService) TraverseAddresses(
if isManifest { if isManifest {
// process as manifest // process as manifest
err = m.IterateAddresses(ctx, func(manifestNodeAddr swarm.Address) (stop bool) { err = m.IterateAddresses(ctx, func(manifestNodeAddr swarm.Address) error {
err := s.traverseChunkAddressesFromManifest(ctx, manifestNodeAddr, chunkAddressFunc) return s.traverseChunkAddressesFromManifest(ctx, manifestNodeAddr, chunkAddressFunc)
if err != nil {
stop = true
}
return
}) })
if err != nil { if err != nil {
return fmt.Errorf("traversal: iterate chunks: %s: %w", reference, err) return fmt.Errorf("traversal: iterate chunks: %s: %w", reference, err)
...@@ -174,12 +170,8 @@ func (s *traversalService) TraverseManifestAddresses( ...@@ -174,12 +170,8 @@ func (s *traversalService) TraverseManifestAddresses(
return ErrInvalidType return ErrInvalidType
} }
err = m.IterateAddresses(ctx, func(manifestNodeAddr swarm.Address) (stop bool) { err = m.IterateAddresses(ctx, func(manifestNodeAddr swarm.Address) error {
err := s.traverseChunkAddressesFromManifest(ctx, manifestNodeAddr, chunkAddressFunc) return s.traverseChunkAddressesFromManifest(ctx, manifestNodeAddr, chunkAddressFunc)
if err != nil {
stop = true
}
return
}) })
if err != nil { if err != nil {
return fmt.Errorf("traversal: iterate chunks: %s: %w", reference, err) return fmt.Errorf("traversal: iterate chunks: %s: %w", reference, err)
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"path" "path"
"sort" "sort"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
...@@ -577,11 +578,15 @@ func traversalCheck(t *testing.T, ...@@ -577,11 +578,15 @@ func traversalCheck(t *testing.T,
foundAddressesCount := 0 foundAddressesCount := 0
foundAddresses := make(map[string]struct{}) foundAddresses := make(map[string]struct{})
var foundAddressesMu sync.Mutex
err := traverseFn(traversalService)( err := traverseFn(traversalService)(
ctx, ctx,
reference, reference,
func(addr swarm.Address) (stop bool) { func(addr swarm.Address) error {
foundAddressesMu.Lock()
defer foundAddressesMu.Unlock()
foundAddressesCount++ foundAddressesCount++
if !ignoreDuplicateHash { if !ignoreDuplicateHash {
if _, ok := foundAddresses[addr.String()]; ok { if _, ok := foundAddresses[addr.String()]; ok {
...@@ -589,7 +594,7 @@ func traversalCheck(t *testing.T, ...@@ -589,7 +594,7 @@ func traversalCheck(t *testing.T,
} }
} }
foundAddresses[addr.String()] = struct{}{} foundAddresses[addr.String()] = struct{}{}
return false return nil
}) })
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment