Commit e3a08844 authored by Nemanja Zbiljić's avatar Nemanja Zbiljić Committed by GitHub

Return error from address iteration function (#1066)

parent 422e54c0
......@@ -123,21 +123,33 @@ func TestPinBytesHandler(t *testing.T) {
)
hashes := []string{rootHash, data1Hash, data2Hash}
sort.Strings(hashes)
expectedResponse := api.ListPinnedChunksResponse{
Chunks: []api.PinnedChunk{},
}
// NOTE: all this because we cannot rely on sort from response
for _, h := range hashes {
expectedResponse.Chunks = append(expectedResponse.Chunks, api.PinnedChunk{
Address: swarm.MustParseHexAddress(h),
PinCounter: 1,
})
}
var resp api.ListPinnedChunksResponse
jsonhttptest.Request(t, client, http.MethodGet, pinChunksResource, http.StatusOK,
jsonhttptest.WithExpectedJSONResponse(expectedResponse),
jsonhttptest.WithUnmarshalJSONResponse(&resp),
)
if len(hashes) != len(resp.Chunks) {
t.Fatalf("expected to find %d pinned chunks, got %d", len(hashes), len(resp.Chunks))
}
respChunksHashes := make([]string, 0)
for _, rc := range resp.Chunks {
respChunksHashes = append(respChunksHashes, rc.Address.String())
}
sort.Strings(respChunksHashes)
for i, h := range hashes {
if h != respChunksHashes[i] {
t.Fatalf("expected to find %s address, found %s", h, respChunksHashes[i])
}
}
})
}
......@@ -319,8 +319,8 @@ func (s *server) updatePinCount(ctx context.Context, reference swarm.Address, de
return nil
}
func (s *server) pinChunkAddressFn(ctx context.Context, reference swarm.Address) func(address swarm.Address) (stop bool) {
return func(address swarm.Address) (stop bool) {
func (s *server) pinChunkAddressFn(ctx context.Context, reference swarm.Address) func(address swarm.Address) error {
return func(address swarm.Address) error {
// NOTE: stop pinning on first error
err := s.Storer.Set(ctx, storage.ModeSetPin, address)
......@@ -330,31 +330,31 @@ func (s *server) pinChunkAddressFn(ctx context.Context, reference swarm.Address)
ch, err := s.Storer.Get(ctx, storage.ModeGetRequest, address)
if err != nil {
s.Logger.Debugf("pin traversal: storer get: for reference %s, address %s: %w", reference, address, err)
return true
return err
}
_, err = s.Storer.Put(ctx, storage.ModePutRequestPin, ch)
if err != nil {
s.Logger.Debugf("pin traversal: storer put pin: for reference %s, address %s: %w", reference, address, err)
return true
return err
}
return false
} else {
s.Logger.Debugf("pin traversal: storer set pin: for reference %s, address %s: %w", reference, address, err)
return true
return nil
}
s.Logger.Debugf("pin traversal: storer set pin: for reference %s, address %s: %w", reference, address, err)
return err
}
return false
return nil
}
}
func (s *server) unpinChunkAddressFn(ctx context.Context, reference swarm.Address) func(address swarm.Address) (stop bool) {
return func(address swarm.Address) (stop bool) {
func (s *server) unpinChunkAddressFn(ctx context.Context, reference swarm.Address) func(address swarm.Address) error {
return func(address swarm.Address) error {
_, err := s.Storer.PinCounter(address)
if err != nil {
return false
return err
}
err = s.Storer.Set(ctx, storage.ModeSetUnpin, address)
......@@ -363,6 +363,6 @@ func (s *server) unpinChunkAddressFn(ctx context.Context, reference swarm.Addres
// continue un-pinning all chunks
}
return false
return nil
}
}
......@@ -6,18 +6,11 @@ package addresses
import (
"context"
"errors"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
)
var (
// ErrStopIterator is returned iterator function marks that iteration should
// be stopped.
ErrStopIterator = errors.New("stop iterator")
)
type addressesGetterStore struct {
getter storage.Getter
fn swarm.AddressIterFunc
......@@ -29,16 +22,11 @@ func NewGetter(getter storage.Getter, fn swarm.AddressIterFunc) storage.Getter {
return &addressesGetterStore{getter, fn}
}
func (s *addressesGetterStore) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Address) (ch swarm.Chunk, err error) {
ch, err = s.getter.Get(ctx, mode, addr)
func (s *addressesGetterStore) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Address) (swarm.Chunk, error) {
ch, err := s.getter.Get(ctx, mode, addr)
if err != nil {
return
}
stop := s.fn(ch.Address())
if stop {
return ch, ErrStopIterator
return nil, err
}
return
return ch, s.fn(ch.Address())
}
......@@ -52,11 +52,12 @@ func TestAddressesGetterIterateChunkAddresses(t *testing.T) {
foundAddresses := make(map[string]struct{})
var foundAddressesMu sync.Mutex
addressIterFunc := func(addr swarm.Address) (stop bool) {
addressIterFunc := func(addr swarm.Address) error {
foundAddressesMu.Lock()
defer foundAddressesMu.Unlock()
foundAddresses[addr.String()] = struct{}{}
foundAddressesMu.Unlock()
return false
return nil
}
addressesGetter := addresses.NewGetter(store, addressIterFunc)
......
......@@ -213,37 +213,36 @@ func (j *joiner) Seek(offset int64, whence int) (int64, error) {
func (j *joiner) IterateChunkAddresses(fn swarm.AddressIterFunc) error {
// report root address
stop := fn(j.addr)
if stop {
return nil
err := fn(j.addr)
if err != nil {
return err
}
var eg errgroup.Group
j.processChunkAddresses(fn, j.rootData, j.span, &eg)
return eg.Wait()
return j.processChunkAddresses(j.ctx, fn, j.rootData, j.span)
}
func (j *joiner) processChunkAddresses(fn swarm.AddressIterFunc, data []byte, subTrieSize int64, eg *errgroup.Group) {
func (j *joiner) processChunkAddresses(ctx context.Context, fn swarm.AddressIterFunc, data []byte, subTrieSize int64) error {
// we are at a leaf data chunk
if subTrieSize <= int64(len(data)) {
return
return nil
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}
eg, ectx := errgroup.WithContext(ctx)
var wg sync.WaitGroup
for cursor := 0; cursor < len(data); cursor += j.refLength {
select {
case <-j.ctx.Done():
return
default:
}
address := swarm.NewAddress(data[cursor : cursor+j.refLength])
stop := fn(address)
if stop {
break
if err := fn(address); err != nil {
return err
}
sec := subtrieSection(data, cursor, j.refLength, subTrieSize)
......@@ -257,20 +256,22 @@ func (j *joiner) processChunkAddresses(fn swarm.AddressIterFunc, data []byte, su
eg.Go(func() error {
defer wg.Done()
ch, err := j.getter.Get(j.ctx, storage.ModeGetRequest, address)
ch, err := j.getter.Get(ectx, storage.ModeGetRequest, address)
if err != nil {
return err
}
chunkData := ch.Data()[8:]
subtrieSpan := int64(chunkToSpan(ch.Data()))
j.processChunkAddresses(fn, chunkData, subtrieSpan, eg)
return nil
return j.processChunkAddresses(ectx, fn, chunkData, subtrieSpan)
})
}(address, eg)
wg.Wait()
}
return eg.Wait()
}
func (j *joiner) Size() int64 {
......
......@@ -12,6 +12,7 @@ import (
"io"
"io/ioutil"
mrand "math/rand"
"sync"
"testing"
"time"
......@@ -800,10 +801,14 @@ func TestJoinerIterateChunkAddresses(t *testing.T) {
}
foundAddresses := make(map[string]struct{})
var foundAddressesMu sync.Mutex
err = j.IterateChunkAddresses(func(addr swarm.Address) error {
foundAddressesMu.Lock()
defer foundAddressesMu.Unlock()
err = j.IterateChunkAddresses(func(addr swarm.Address) (stop bool) {
foundAddresses[addr.String()] = struct{}{}
return false
return nil
})
if err != nil {
t.Fatal(err)
......
......@@ -27,10 +27,6 @@ var (
ErrMissingReference = errors.New("manifest: missing reference")
)
var (
errStopIterator = errors.New("manifest: stop iterator")
)
// Interface for operations with manifest.
type Interface interface {
// Type returns manifest implementation type information
......
......@@ -131,22 +131,20 @@ func (m *mantarayManifest) IterateAddresses(ctx context.Context, fn swarm.Addres
}
if node != nil {
var stop bool
if node.Reference() != nil {
ref := swarm.NewAddress(node.Reference())
stop = fn(ref)
if stop {
return errStopIterator
err = fn(ref)
if err != nil {
return err
}
}
if node.IsValueType() && node.Entry() != nil {
entry := swarm.NewAddress(node.Entry())
stop = fn(entry)
if stop {
return errStopIterator
err = fn(entry)
if err != nil {
return err
}
}
}
......@@ -156,10 +154,7 @@ func (m *mantarayManifest) IterateAddresses(ctx context.Context, fn swarm.Addres
err := m.trie.WalkNode(ctx, []byte{}, m.ls, walker)
if err != nil {
if !errors.Is(err, errStopIterator) {
return fmt.Errorf("manifest iterate addresses: %w", err)
}
// ignore error if interation stopped by caller
return fmt.Errorf("manifest iterate addresses: %w", err)
}
return nil
......
......@@ -108,9 +108,9 @@ func (m *simpleManifest) IterateAddresses(ctx context.Context, fn swarm.AddressI
}
// NOTE: making it behave same for all manifest implementation
stop := fn(m.reference)
if stop {
return nil
err := fn(m.reference)
if err != nil {
return fmt.Errorf("manifest iterate addresses: %w", err)
}
walker := func(path string, entry simple.Entry, err error) error {
......@@ -123,20 +123,12 @@ func (m *simpleManifest) IterateAddresses(ctx context.Context, fn swarm.AddressI
return err
}
stop := fn(ref)
if stop {
return errStopIterator
}
return nil
return fn(ref)
}
err := m.manifest.WalkEntry("", walker)
err = m.manifest.WalkEntry("", walker)
if err != nil {
if !errors.Is(err, errStopIterator) {
return fmt.Errorf("manifest iterate addresses: %w", err)
}
// ignore error if interation stopped by caller
return fmt.Errorf("manifest iterate addresses: %w", err)
}
return nil
......
......@@ -109,8 +109,7 @@ func (a Address) MarshalJSON() ([]byte, error) {
var ZeroAddress = NewAddress(nil)
// AddressIterFunc is a callback on every address that is found by the iterator.
// By returning a true for stop variable, iteration should stop.
type AddressIterFunc func(address Address) (stop bool)
type AddressIterFunc func(address Address) error
type Chunk interface {
Address() Address
......
......@@ -82,12 +82,8 @@ func (s *traversalService) TraverseAddresses(
if isManifest {
// process as manifest
err = m.IterateAddresses(ctx, func(manifestNodeAddr swarm.Address) (stop bool) {
err := s.traverseChunkAddressesFromManifest(ctx, manifestNodeAddr, chunkAddressFunc)
if err != nil {
stop = true
}
return
err = m.IterateAddresses(ctx, func(manifestNodeAddr swarm.Address) error {
return s.traverseChunkAddressesFromManifest(ctx, manifestNodeAddr, chunkAddressFunc)
})
if err != nil {
return fmt.Errorf("traversal: iterate chunks: %s: %w", reference, err)
......@@ -174,12 +170,8 @@ func (s *traversalService) TraverseManifestAddresses(
return ErrInvalidType
}
err = m.IterateAddresses(ctx, func(manifestNodeAddr swarm.Address) (stop bool) {
err := s.traverseChunkAddressesFromManifest(ctx, manifestNodeAddr, chunkAddressFunc)
if err != nil {
stop = true
}
return
err = m.IterateAddresses(ctx, func(manifestNodeAddr swarm.Address) error {
return s.traverseChunkAddressesFromManifest(ctx, manifestNodeAddr, chunkAddressFunc)
})
if err != nil {
return fmt.Errorf("traversal: iterate chunks: %s: %w", reference, err)
......
......@@ -14,6 +14,7 @@ import (
"path"
"sort"
"strings"
"sync"
"testing"
"time"
......@@ -577,11 +578,15 @@ func traversalCheck(t *testing.T,
foundAddressesCount := 0
foundAddresses := make(map[string]struct{})
var foundAddressesMu sync.Mutex
err := traverseFn(traversalService)(
ctx,
reference,
func(addr swarm.Address) (stop bool) {
func(addr swarm.Address) error {
foundAddressesMu.Lock()
defer foundAddressesMu.Unlock()
foundAddressesCount++
if !ignoreDuplicateHash {
if _, ok := foundAddresses[addr.String()]; ok {
......@@ -589,7 +594,7 @@ func traversalCheck(t *testing.T,
}
}
foundAddresses[addr.String()] = struct{}{}
return false
return nil
})
if err != nil {
t.Fatal(err)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment