Commit f8054e91 authored by acud's avatar acud Committed by GitHub

node: add tag persistence across sessions (#573)

* Add individual tags persistence
parent b54afca6
......@@ -18,7 +18,7 @@ import (
cmdfile "github.com/ethersphere/bee/cmd/internal/file"
"github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/logging"
resolverMock "github.com/ethersphere/bee/pkg/resolver/mock"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm"
......@@ -154,7 +154,9 @@ func TestLimitWriter(t *testing.T) {
// newTestServer creates an http server to serve the bee http api endpoints.
func newTestServer(t *testing.T, storer storage.Storer) *url.URL {
t.Helper()
s := api.New(tags.NewTags(), storer, resolverMock.NewResolver(), nil, logging.New(ioutil.Discard, 0), nil)
logger := logging.New(ioutil.Discard, 0)
store := statestore.NewStateStore()
s := api.New(tags.NewTags(store, logger), storer, nil, nil, logger, nil)
ts := httptest.NewServer(s)
srvUrl, err := url.Parse(ts.URL)
if err != nil {
......
......@@ -41,7 +41,13 @@ func (s *server) bytesUploadHandler(w http.ResponseWriter, r *http.Request) {
return
}
if created {
tag.DoneSplit(address)
_, err = tag.DoneSplit(address)
if err != nil {
s.Logger.Debugf("bytes upload: done split: %v", err)
s.Logger.Error("bytes upload: done split failed")
jsonhttp.InternalServerError(w, nil)
return
}
}
w.Header().Set(SwarmTagUidHeader, fmt.Sprint(tag.Uid))
w.Header().Set("Access-Control-Expose-Headers", SwarmTagUidHeader)
......
......@@ -6,6 +6,7 @@ package api_test
import (
"bytes"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil"
"net/http"
"testing"
......@@ -24,13 +25,15 @@ import (
// downloading and requesting a resource that cannot be found.
func TestBytes(t *testing.T) {
var (
resource = "/bytes"
targets = "0x222"
expHash = "29a5fb121ce96194ba8b7b823a1f9c6af87e1791f824940a53b5a7efe3f790d9"
mockStorer = mock.NewStorer()
client = newTestServer(t, testServerOptions{
resource = "/bytes"
targets = "0x222"
expHash = "29a5fb121ce96194ba8b7b823a1f9c6af87e1791f824940a53b5a7efe3f790d9"
mockStorer = mock.NewStorer()
mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{
Storer: mockStorer,
Tags: tags.NewTags(),
Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5),
})
)
......
......@@ -9,6 +9,7 @@ import (
"context"
"encoding/json"
"fmt"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io"
"io/ioutil"
"mime"
......@@ -33,9 +34,11 @@ func TestBzz(t *testing.T) {
bzzDownloadResource = func(addr, path string) string { return "/bzz/" + addr + "/" + path }
storer = smock.NewStorer()
ctx = context.Background()
mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{
Storer: storer,
Tags: tags.NewTags(),
Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5),
})
pipeWriteAll = func(r io.Reader, l int64) (swarm.Address, error) {
......
......@@ -45,7 +45,13 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) {
ctx := sctx.SetTag(r.Context(), tag)
// Increment the StateSplit here since we dont have a splitter for the file upload
tag.Inc(tags.StateSplit)
err = tag.Inc(tags.StateSplit)
if err != nil {
s.Logger.Debugf("chunk upload: increment tag: %v", err)
s.Logger.Error("chunk upload: increment tag")
jsonhttp.InternalServerError(w, "increment tag")
return
}
data, err := ioutil.ReadAll(r.Body)
if err != nil {
......@@ -65,11 +71,23 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) {
jsonhttp.BadRequest(w, "chunk write error")
return
} else if len(seen) > 0 && seen[0] {
tag.Inc(tags.StateSeen)
err := tag.Inc(tags.StateSeen)
if err != nil {
s.Logger.Debugf("chunk upload: increment tag", err)
s.Logger.Error("chunk upload: increment tag")
jsonhttp.BadRequest(w, "increment tag")
return
}
}
// Indicate that the chunk is stored
tag.Inc(tags.StateStored)
err = tag.Inc(tags.StateStored)
if err != nil {
s.Logger.Debugf("chunk upload: increment tag", err)
s.Logger.Error("chunk upload: increment tag")
jsonhttp.BadRequest(w, "increment tag")
return
}
w.Header().Set(SwarmTagUidHeader, fmt.Sprint(tag.Uid))
w.Header().Set("Access-Control-Expose-Headers", SwarmTagUidHeader)
......
......@@ -6,6 +6,8 @@ package api_test
import (
"bytes"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io"
"io/ioutil"
"net/http"
......@@ -35,7 +37,9 @@ func TestChunkUploadDownload(t *testing.T) {
validContent = []byte("bbaatt")
invalidContent = []byte("bbaattss")
mockValidator = validator.NewMockValidator(validHash, validContent)
tag = tags.NewTags()
mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger)
mockValidatingStorer = mock.NewStorer(mock.WithValidator(mockValidator))
client = newTestServer(t, testServerOptions{
Storer: mockValidatingStorer,
......
......@@ -54,13 +54,19 @@ func (s *server) dirUploadHandler(w http.ResponseWriter, r *http.Request) {
reference, err := storeDir(ctx, r.Body, s.Storer, requestModePut(r), s.Logger, requestEncrypt(r))
if err != nil {
s.Logger.Debugf("dir upload, store dir err: %v", err)
s.Logger.Errorf("dir upload, store dir")
s.Logger.Debugf("dir upload: store dir err: %v", err)
s.Logger.Errorf("dir upload: store dir")
jsonhttp.InternalServerError(w, "could not store dir")
return
}
if created {
tag.DoneSplit(reference)
_, err = tag.DoneSplit(reference)
if err != nil {
s.Logger.Debugf("dir upload: done split: %v", err)
s.Logger.Error("dir upload: done split failed")
jsonhttp.InternalServerError(w, nil)
return
}
}
w.Header().Set(SwarmTagUidHeader, fmt.Sprint(tag.Uid))
jsonhttp.OK(w, fileUploadResponse{
......
......@@ -9,6 +9,7 @@ import (
"bytes"
"context"
"encoding/json"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil"
"net/http"
"path"
......@@ -32,9 +33,11 @@ func TestDirs(t *testing.T) {
dirUploadResource = "/dirs"
fileDownloadResource = func(addr string) string { return "/files/" + addr }
storer = mock.NewStorer()
mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{
Storer: storer,
Tags: tags.NewTags(),
Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5),
})
)
......
......@@ -201,7 +201,13 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
return
}
if created {
tag.DoneSplit(reference)
_, err = tag.DoneSplit(reference)
if err != nil {
s.Logger.Debugf("file upload: done split: %v", err)
s.Logger.Error("file upload: done split failed")
jsonhttp.InternalServerError(w, nil)
return
}
}
w.Header().Set("ETag", fmt.Sprintf("%q", reference.String()))
w.Header().Set(SwarmTagUidHeader, fmt.Sprint(tag.Uid))
......
......@@ -8,6 +8,7 @@ import (
"bytes"
"encoding/json"
"fmt"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io"
"io/ioutil"
"mime"
......@@ -32,9 +33,11 @@ func TestFiles(t *testing.T) {
targets = "0x222"
fileDownloadResource = func(addr string) string { return "/files/" + addr }
simpleData = []byte("this is a simple text")
mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{
Storer: mock.NewStorer(),
Tags: tags.NewTags(),
Tags: tags.NewTags(mockStatestore, logger),
})
)
......@@ -333,9 +336,11 @@ func TestRangeRequests(t *testing.T) {
for _, upload := range uploads {
t.Run(upload.name, func(t *testing.T) {
mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
client := newTestServer(t, testServerOptions{
Storer: mock.NewStorer(),
Tags: tags.NewTags(),
Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5),
})
......
......@@ -16,7 +16,6 @@ import (
"github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags"
"github.com/gorilla/mux"
)
......@@ -197,6 +196,12 @@ func (s *server) doneSplit(w http.ResponseWriter, r *http.Request) {
return
}
tag.DoneSplit(tagr.Address)
_, err = tag.DoneSplit(tagr.Address)
if err != nil {
s.Logger.Debugf("done split: failed for address %v", tagr.Address)
s.Logger.Error("done split: failed for address %v", tagr.Address)
jsonhttp.InternalServerError(w, nil)
return
}
jsonhttp.OK(w, "ok")
}
......@@ -8,6 +8,9 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil"
"net/http"
"strconv"
"strings"
......@@ -39,7 +42,9 @@ func TestTags(t *testing.T) {
someHash = swarm.MustParseHexAddress("aabbcc")
someContent = []byte("bbaatt")
someTagName = "file.jpg"
tag = tags.NewTags()
mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger)
mockPusher = mp.NewMockPusher(tag)
client = newTestServer(t, testServerOptions{
Storer: mock.NewStorer(),
......
......@@ -6,6 +6,9 @@ package debugapi_test
import (
"bytes"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil"
"net/http"
"testing"
......@@ -29,7 +32,9 @@ func TestPinChunkHandler(t *testing.T) {
data = []byte("bbaatt")
mockValidator = validator.NewMockValidator(hash, data)
mockValidatingStorer = mock.NewStorer(mock.WithValidator(mockValidator))
tag = tags.NewTags()
mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger)
debugTestServer = newTestServer(t, testServerOptions{
Storer: mockValidatingStorer,
......
......@@ -6,7 +6,6 @@ package pipeline
import (
"context"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
......@@ -28,18 +27,31 @@ func newStoreWriter(ctx context.Context, l storage.Putter, mode storage.ModePut,
func (w *storeWriter) chainWrite(p *pipeWriteArgs) error {
tag := sctx.GetTag(w.ctx)
var c swarm.Chunk
if tag != nil {
tag.Inc(tags.StateSplit)
err := tag.Inc(tags.StateSplit)
if err != nil {
return err
}
c = swarm.NewChunk(swarm.NewAddress(p.ref), p.data).WithTagID(tag.Uid)
} else {
c = swarm.NewChunk(swarm.NewAddress(p.ref), p.data)
}
c := swarm.NewChunk(swarm.NewAddress(p.ref), p.data)
seen, err := w.l.Put(w.ctx, w.mode, c)
if err != nil {
return err
}
if tag != nil {
tag.Inc(tags.StateStored)
err := tag.Inc(tags.StateStored)
if err != nil {
return err
}
if seen[0] {
tag.Inc(tags.StateSeen)
err := tag.Inc(tags.StateSeen)
if err != nil {
return err
}
}
}
if w.next == nil {
......
......@@ -152,7 +152,10 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) {
binary.LittleEndian.PutUint64(head, uint64(span))
tail := s.buffer[s.cursors[lvl+1]:s.cursors[lvl]]
chunkData = append(head, tail...)
s.incrTag(tags.StateSplit)
err := s.incrTag(tags.StateSplit)
if err != nil {
return nil, err
}
c := chunkData
var encryptionKey encryption.Key
......@@ -165,7 +168,7 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) {
}
s.hasher.Reset()
err := s.hasher.SetSpanBytes(c[:8])
err = s.hasher.SetSpanBytes(c[:8])
if err != nil {
return nil, err
}
......@@ -188,10 +191,16 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) {
if err != nil {
return nil, err
} else if len(seen) > 0 && seen[0] {
s.incrTag(tags.StateSeen)
err = s.incrTag(tags.StateSeen)
if err != nil {
return nil, err
}
}
s.incrTag(tags.StateStored)
err = s.incrTag(tags.StateStored)
if err != nil {
return nil, err
}
return append(ch.Address().Bytes(), encryptionKey...), nil
}
......@@ -310,8 +319,9 @@ func (s *SimpleSplitterJob) newDataEncryption(key encryption.Key) encryption.Int
return encryption.New(key, int(swarm.ChunkSize), 0, sha3.NewLegacyKeccak256)
}
func (s *SimpleSplitterJob) incrTag(state tags.State) {
func (s *SimpleSplitterJob) incrTag(state tags.State) error {
if s.tag != nil {
s.tag.Inc(state)
return s.tag.Inc(state)
}
return nil
}
......@@ -247,7 +247,10 @@ func (db *DB) setSync(batch *leveldb.Batch, addr swarm.Address, mode storage.Mod
// run to end from db.pushIndex.DeleteInBatch
db.logger.Errorf("localstore: get tags on push sync set uid %d: %v", i.Tag, err)
} else {
t.Inc(tags.StateSynced)
err = t.Inc(tags.StateSynced)
if err != nil {
return 0, err
}
}
}
......
......@@ -19,6 +19,9 @@ package localstore
import (
"context"
"errors"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil"
"testing"
"time"
......@@ -70,7 +73,9 @@ func TestModeSetAccess(t *testing.T) {
// as a result we should expect the tag value to remain in the pull index
// and we expect that the tag should not be incremented by pull sync set
func TestModeSetSyncPullNormalTag(t *testing.T) {
db := newTestDB(t, &Options{Tags: tags.NewTags()})
mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
db := newTestDB(t, &Options{Tags: tags.NewTags(mockStatestore, logger)})
tag, err := db.tags.Create("test", 1)
if err != nil {
......@@ -83,7 +88,10 @@ func TestModeSetSyncPullNormalTag(t *testing.T) {
t.Fatal(err)
}
tag.Inc(tags.StateStored) // so we don't get an error on tag.Status later on
err = tag.Inc(tags.StateStored) // so we don't get an error on tag.Status later on
if err != nil {
t.Fatal(err)
}
item, err := db.pullIndex.Get(shed.Item{
Address: ch.Address().Bytes(),
......@@ -124,7 +132,9 @@ func TestModeSetSyncPullNormalTag(t *testing.T) {
// correctly on a normal tag (that is, a tag that is expected to show progress bars
// according to push sync progress)
func TestModeSetSyncPushNormalTag(t *testing.T) {
db := newTestDB(t, &Options{Tags: tags.NewTags()})
mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
db := newTestDB(t, &Options{Tags: tags.NewTags(mockStatestore, logger)})
tag, err := db.tags.Create("test", 1)
if err != nil {
......@@ -137,7 +147,11 @@ func TestModeSetSyncPushNormalTag(t *testing.T) {
t.Fatal(err)
}
tag.Inc(tags.StateStored) // so we don't get an error on tag.Status later on
err = tag.Inc(tags.StateStored) // so we don't get an error on tag.Status later on
if err != nil {
t.Fatal(err)
}
item, err := db.pullIndex.Get(shed.Item{
Address: ch.Address().Bytes(),
BinID: 1,
......
......@@ -61,6 +61,7 @@ type Bee struct {
resolverCloser io.Closer
errorLogWriter *io.PipeWriter
tracerCloser io.Closer
tagsCloser io.Closer
stateStoreCloser io.Closer
localstoreCloser io.Closer
topologyCloser io.Closer
......@@ -243,7 +244,8 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
chunkvalidator := swarm.NewChunkValidator(soc.NewValidator(), content.NewValidator())
retrieve := retrieval.New(p2ps, kad, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), chunkvalidator)
tagg := tags.NewTags()
tagg := tags.NewTags(stateStore, logger)
b.tagsCloser = tagg
if err = p2ps.AddProtocol(retrieve.Protocol()); err != nil {
return nil, fmt.Errorf("retrieval service: %w", err)
......@@ -413,6 +415,10 @@ func (b *Bee) Shutdown(ctx context.Context) error {
errs.add(fmt.Errorf("tracer: %w", err))
}
if err := b.tagsCloser.Close(); err != nil {
errs.add(fmt.Errorf("tag persistence: %w", err))
}
if err := b.stateStoreCloser.Close(); err != nil {
errs.add(fmt.Errorf("statestore: %w", err))
}
......
......@@ -23,9 +23,7 @@ func (m *MockPusher) SendChunk(uid uint32) error {
if err != nil {
return err
}
ta.Inc(tags.StateSent)
return nil
return ta.Inc(tags.StateSent)
}
func (m *MockPusher) RcvdReceipt(uid uint32) error {
......@@ -33,7 +31,5 @@ func (m *MockPusher) RcvdReceipt(uid uint32) error {
if err != nil {
return err
}
ta.Inc(tags.StateSynced)
return nil
return ta.Inc(tags.StateSynced)
}
......@@ -123,7 +123,12 @@ LOOP:
}
return
}
s.setChunkAsSynced(ctx, ch)
err = s.setChunkAsSynced(ctx, ch)
if err != nil {
s.logger.Debugf("pusher: error setting chunk as synced: %v", err)
return
}
}(ctx, ch)
case <-timer.C:
// initially timer is set to go off as well as every time we hit the end of push index
......@@ -165,15 +170,19 @@ LOOP:
}
}
func (s *Service) setChunkAsSynced(ctx context.Context, ch swarm.Chunk) {
func (s *Service) setChunkAsSynced(ctx context.Context, ch swarm.Chunk) error {
if err := s.storer.Set(ctx, storage.ModeSetSyncPush, ch.Address()); err != nil {
s.logger.Errorf("pusher: error setting chunk as synced: %v", err)
s.metrics.ErrorSettingChunkToSynced.Inc()
}
t, err := s.tagg.Get(ch.TagID())
if err == nil && t != nil {
t.Inc(tags.StateSynced)
err = t.Inc(tags.StateSynced)
if err != nil {
return err
}
}
return nil
}
func (s *Service) Close() error {
......
......@@ -7,6 +7,7 @@ package pusher_test
import (
"context"
"errors"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil"
"sync"
"testing"
......@@ -227,7 +228,8 @@ func createPusher(t *testing.T, addr swarm.Address, pushSyncService pushsync.Pus
t.Fatal(err)
}
mtags := tags.NewTags()
mockStatestore := statestore.NewStateStore()
mtags := tags.NewTags(mockStatestore, logger)
pusherStorer := &Store{
Storer: storer,
modeSet: make(map[string]storage.ModeSet),
......
......@@ -221,7 +221,10 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re
// this is to make sure that the sent number does not diverge from the synced counter
t, err := ps.tagg.Get(ch.TagID())
if err == nil && t != nil {
t.Inc(tags.StateSent)
err = t.Inc(tags.StateSent)
if err != nil {
return nil, err
}
}
// if you are the closest node return a receipt immediately
......@@ -255,7 +258,10 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re
// if you manage to get a tag, just increment the respective counter
t, err := ps.tagg.Get(ch.TagID())
if err == nil && t != nil {
t.Inc(tags.StateSent)
err = t.Inc(tags.StateSent)
if err != nil {
return nil, err
}
}
receiptRTTTimer := time.Now()
......
......@@ -7,6 +7,7 @@ package pushsync_test
import (
"bytes"
"context"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil"
"testing"
"time"
......@@ -286,7 +287,8 @@ func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.R
}
mockTopology := mock.NewTopologyDriver(mockOpts...)
mtag := tags.NewTags()
mockStatestore := statestore.NewStateStore()
mtag := tags.NewTags(mockStatestore, logger)
mockAccounting := accountingmock.NewAccounting()
mockPricer := accountingmock.NewPricer(fixedPrice, fixedPrice)
......
......@@ -20,10 +20,13 @@ import (
"context"
"encoding/binary"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tracing"
"github.com/opentracing/opentracing-go"
......@@ -62,18 +65,22 @@ type Tag struct {
StartedAt time.Time // tag started to calculate ETA
// end-to-end tag tracing
ctx context.Context // tracing context
span opentracing.Span // tracing root span
spanOnce sync.Once // make sure we close root span only once
ctx context.Context // tracing context
span opentracing.Span // tracing root span
spanOnce sync.Once // make sure we close root span only once
stateStore storage.StateStorer // to persist the tag
logger logging.Logger // logger instance for logging
}
// NewTag creates a new tag, and returns it
func NewTag(ctx context.Context, uid uint32, s string, total int64, tracer *tracing.Tracer) *Tag {
func NewTag(ctx context.Context, uid uint32, s string, total int64, tracer *tracing.Tracer, stateStore storage.StateStorer, logger logging.Logger) *Tag {
t := &Tag{
Uid: uid,
Name: s,
StartedAt: time.Now(),
Total: total,
Uid: uid,
Name: s,
StartedAt: time.Now(),
Total: total,
stateStore: stateStore,
logger: logger,
}
// context here is used only to store the root span `new.upload.tag` within Tag,
......@@ -95,7 +102,7 @@ func (t *Tag) FinishRootSpan() {
}
// IncN increments the count for a state
func (t *Tag) IncN(state State, n int) {
func (t *Tag) IncN(state State, n int) error {
var v *int64
switch state {
case TotalChunks:
......@@ -112,11 +119,23 @@ func (t *Tag) IncN(state State, n int) {
v = &t.Synced
}
atomic.AddInt64(v, int64(n))
// check if syncing is over and persist the tag
if state == StateSynced {
total := atomic.LoadInt64(&t.Total)
seen := atomic.LoadInt64(&t.Seen)
synced := atomic.LoadInt64(&t.Synced)
totalUnique := total - seen
if synced >= totalUnique {
return t.saveTag()
}
}
return nil
}
// Inc increments the count for a state
func (t *Tag) Inc(state State) {
t.IncN(state, 1)
func (t *Tag) Inc(state State) error {
return t.IncN(state, 1)
}
// Get returns the count for a state on a tag
......@@ -172,7 +191,7 @@ func (t *Tag) Done(s State) bool {
// DoneSplit sets total count to SPLIT count and sets the associated swarm hash for this tag
// is meant to be called when splitter finishes for input streams of unknown size
func (t *Tag) DoneSplit(address swarm.Address) int64 {
func (t *Tag) DoneSplit(address swarm.Address) (int64, error) {
total := atomic.LoadInt64(&t.Split)
atomic.StoreInt64(&t.Total, total)
......@@ -180,7 +199,12 @@ func (t *Tag) DoneSplit(address swarm.Address) int64 {
t.Address = address
}
return total
// persist the tag
err := t.saveTag()
if err != nil {
return 0, err
}
return total, nil
}
// Status returns the value of state and the total count
......@@ -220,12 +244,12 @@ func (t *Tag) ETA(state State) (time.Time, error) {
func (tag *Tag) MarshalBinary() (data []byte, err error) {
buffer := make([]byte, 4)
binary.BigEndian.PutUint32(buffer, tag.Uid)
encodeInt64Append(&buffer, tag.Total)
encodeInt64Append(&buffer, tag.Split)
encodeInt64Append(&buffer, tag.Seen)
encodeInt64Append(&buffer, tag.Stored)
encodeInt64Append(&buffer, tag.Sent)
encodeInt64Append(&buffer, tag.Synced)
encodeInt64Append(&buffer, atomic.LoadInt64(&tag.Total))
encodeInt64Append(&buffer, atomic.LoadInt64(&tag.Split))
encodeInt64Append(&buffer, atomic.LoadInt64(&tag.Seen))
encodeInt64Append(&buffer, atomic.LoadInt64(&tag.Stored))
encodeInt64Append(&buffer, atomic.LoadInt64(&tag.Sent))
encodeInt64Append(&buffer, atomic.LoadInt64(&tag.Synced))
intBuffer := make([]byte, 8)
......@@ -248,12 +272,12 @@ func (tag *Tag) UnmarshalBinary(buffer []byte) error {
tag.Uid = binary.BigEndian.Uint32(buffer)
buffer = buffer[4:]
tag.Total = decodeInt64Splice(&buffer)
tag.Split = decodeInt64Splice(&buffer)
tag.Seen = decodeInt64Splice(&buffer)
tag.Stored = decodeInt64Splice(&buffer)
tag.Sent = decodeInt64Splice(&buffer)
tag.Synced = decodeInt64Splice(&buffer)
atomic.AddInt64(&tag.Total, decodeInt64Splice(&buffer))
atomic.AddInt64(&tag.Split, decodeInt64Splice(&buffer))
atomic.AddInt64(&tag.Seen, decodeInt64Splice(&buffer))
atomic.AddInt64(&tag.Stored, decodeInt64Splice(&buffer))
atomic.AddInt64(&tag.Sent, decodeInt64Splice(&buffer))
atomic.AddInt64(&tag.Synced, decodeInt64Splice(&buffer))
t, n := binary.Varint(buffer)
tag.StartedAt = time.Unix(t, 0)
......@@ -280,3 +304,24 @@ func decodeInt64Splice(buffer *[]byte) int64 {
*buffer = (*buffer)[n:]
return val
}
// saveTag update the tag in the state store
func (tag *Tag) saveTag() error {
key := getKey(tag.Uid)
value, err := tag.MarshalBinary()
if err != nil {
return err
}
if tag.stateStore != nil {
err = tag.stateStore.Put(key, value)
if err != nil {
return err
}
}
return nil
}
func getKey(uid uint32) string {
return fmt.Sprintf("tags_%d", uid)
}
......@@ -18,10 +18,13 @@ package tags
import (
"context"
"io/ioutil"
"sync"
"testing"
"time"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/swarm"
)
......@@ -31,7 +34,9 @@ var (
// TestTagSingleIncrements tests if Inc increments the tag state value
func TestTagSingleIncrements(t *testing.T) {
tg := &Tag{Total: 10}
mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
tg := &Tag{Total: 10, stateStore: mockStatestore, logger: logger}
tc := []struct {
state uint32
......@@ -48,7 +53,10 @@ func TestTagSingleIncrements(t *testing.T) {
for _, tc := range tc {
for i := 0; i < tc.inc; i++ {
tg.Inc(tc.state)
err := tg.Inc(tc.state)
if err != nil {
t.Fatal(err)
}
}
}
......@@ -62,13 +70,28 @@ func TestTagSingleIncrements(t *testing.T) {
// TestTagStatus is a unit test to cover Tag.Status method functionality
func TestTagStatus(t *testing.T) {
tg := &Tag{Total: 10}
tg.Inc(StateSeen)
tg.Inc(StateSent)
tg.Inc(StateSynced)
err := tg.Inc(StateSeen)
if err != nil {
t.Fatal(err)
}
err = tg.Inc(StateSent)
if err != nil {
t.Fatal(err)
}
err = tg.Inc(StateSynced)
if err != nil {
t.Fatal(err)
}
for i := 0; i < 10; i++ {
tg.Inc(StateSplit)
tg.Inc(StateStored)
err = tg.Inc(StateSplit)
if err != nil {
t.Fatal(err)
}
err = tg.Inc(StateStored)
if err != nil {
t.Fatal(err)
}
}
for _, v := range []struct {
state State
......@@ -100,7 +123,10 @@ func TestTagETA(t *testing.T) {
maxDiff := 100000 // 100 microsecond
tg := &Tag{Total: 10, StartedAt: now}
time.Sleep(100 * time.Millisecond)
tg.Inc(StateSplit)
err := tg.Inc(StateSplit)
if err != nil {
t.Fatal(err)
}
eta, err := tg.ETA(StateSplit)
if err != nil {
t.Fatal(err)
......@@ -113,15 +139,20 @@ func TestTagETA(t *testing.T) {
// TestTagConcurrentIncrements tests Inc calls concurrently
func TestTagConcurrentIncrements(t *testing.T) {
tg := &Tag{}
n := 1000
mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
tg := &Tag{stateStore: mockStatestore, logger: logger}
n := 10
wg := sync.WaitGroup{}
wg.Add(5 * n)
for _, f := range allStates {
go func(f State) {
for j := 0; j < n; j++ {
go func() {
tg.Inc(f)
err := tg.Inc(f)
if err != nil {
t.Errorf("error incrementing tag counters: %v", err)
}
wg.Done()
}()
}
......@@ -138,7 +169,9 @@ func TestTagConcurrentIncrements(t *testing.T) {
// TestTagsMultipleConcurrentIncrements tests Inc calls concurrently
func TestTagsMultipleConcurrentIncrementsSyncMap(t *testing.T) {
ts := NewTags()
mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
ts := NewTags(mockStatestore, logger)
n := 100
wg := sync.WaitGroup{}
wg.Add(10 * 5 * n)
......@@ -152,7 +185,10 @@ func TestTagsMultipleConcurrentIncrementsSyncMap(t *testing.T) {
go func(tag *Tag, f State) {
for j := 0; j < n; j++ {
go func() {
tag.Inc(f)
err := tag.Inc(f)
if err != nil {
t.Errorf("error incrementing tag counters: %v", err)
}
wg.Done()
}()
}
......@@ -185,11 +221,16 @@ func TestTagsMultipleConcurrentIncrementsSyncMap(t *testing.T) {
// TestMarshallingWithAddr tests that marshalling and unmarshalling is done correctly when the
// tag Address (byte slice) contains some arbitrary value
func TestMarshallingWithAddr(t *testing.T) {
tg := NewTag(context.Background(), 111, "test/tag", 10, nil)
mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
tg := NewTag(context.Background(), 111, "test/tag", 10, nil, mockStatestore, logger)
tg.Address = swarm.NewAddress([]byte{0, 1, 2, 3, 4, 5, 6})
for _, f := range allStates {
tg.Inc(f)
err := tg.Inc(f)
if err != nil {
t.Fatal(err)
}
}
b, err := tg.MarshalBinary()
......@@ -233,9 +274,14 @@ func TestMarshallingWithAddr(t *testing.T) {
// TestMarshallingNoAddress tests that marshalling and unmarshalling is done correctly
func TestMarshallingNoAddr(t *testing.T) {
tg := NewTag(context.Background(), 111, "test/tag", 10, nil)
mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
tg := NewTag(context.Background(), 111, "test/tag", 10, nil, mockStatestore, logger)
for _, f := range allStates {
tg.Inc(f)
err := tg.Inc(f)
if err != nil {
t.Fatal(err)
}
}
b, err := tg.MarshalBinary()
......
......@@ -26,6 +26,8 @@ import (
"sync"
"time"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm"
)
......@@ -36,20 +38,24 @@ var (
// Tags hold tag information indexed by a unique random uint32
type Tags struct {
tags *sync.Map
tags *sync.Map
stateStore storage.StateStorer
logger logging.Logger
}
// NewTags creates a tags object
func NewTags() *Tags {
func NewTags(stateStore storage.StateStorer, logger logging.Logger) *Tags {
return &Tags{
tags: &sync.Map{},
tags: &sync.Map{},
stateStore: stateStore,
logger: logger,
}
}
// Create creates a new tag, stores it by the name and returns it
// it returns an error if the tag with this name already exists
func (ts *Tags) Create(s string, total int64) (*Tag, error) {
t := NewTag(context.Background(), TagUidFunc(), s, total, nil)
t := NewTag(context.Background(), TagUidFunc(), s, total, nil, ts.stateStore, ts.logger)
if _, loaded := ts.tags.LoadOrStore(t.Uid, t); loaded {
return nil, errExists
......@@ -74,7 +80,14 @@ func (ts *Tags) All() (t []*Tag) {
func (ts *Tags) Get(uid uint32) (*Tag, error) {
t, ok := ts.tags.Load(uid)
if !ok {
return nil, ErrNotFound
// see if the tag is present in the store
// if yes, load it in to the memory
ta, err := ts.getTagFromStore(uid)
if err != nil {
return nil, ErrNotFound
}
ts.tags.LoadOrStore(ta.Uid, ta)
return ta, nil
}
return t.(*Tag), nil
}
......@@ -143,3 +156,33 @@ func (ts *Tags) UnmarshalJSON(value []byte) error {
return err
}
// getTagFromStore get a given tag from the state store.
func (ts *Tags) getTagFromStore(uid uint32) (*Tag, error) {
key := "tags_" + strconv.Itoa(int(uid))
var data []byte
err := ts.stateStore.Get(key, &data)
if err != nil {
return nil, err
}
var ta Tag
err = ta.UnmarshalBinary(data)
if err != nil {
return nil, err
}
return &ta, nil
}
// Close is called when the node goes down. This is when all the tags in memory is persisted.
func (ts *Tags) Close() (err error) {
// store all the tags in memory
tags := ts.All()
for _, t := range tags {
ts.logger.Trace("updating tag: ", t.Uid)
err := t.saveTag()
if err != nil {
return err
}
}
return nil
}
......@@ -17,11 +17,18 @@
package tags
import (
"io/ioutil"
"testing"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/swarm"
)
func TestAll(t *testing.T) {
ts := NewTags()
mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
ts := NewTags(mockStatestore, logger)
if _, err := ts.Create("1", 1); err != nil {
t.Fatal(err)
}
......@@ -52,3 +59,81 @@ func TestAll(t *testing.T) {
t.Fatalf("expected length to be 3 got %d", len(all))
}
}
func TestPersistence(t *testing.T) {
mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
ts := NewTags(mockStatestore, logger)
ta, err := ts.Create("one", 1)
if err != nil {
t.Fatal(err)
}
ta.Total = 10
ta.Seen = 2
ta.Split = 10
ta.Stored = 8
_, err = ta.DoneSplit(swarm.ZeroAddress)
if err != nil {
t.Fatal(err)
}
// simulate node closing down and booting up
err = ts.Close()
if err != nil {
t.Fatal(err)
}
ts = NewTags(mockStatestore, logger)
// Get the tag after the node bootup
rcvd1, err := ts.Get(ta.Uid)
if err != nil {
t.Fatal(err)
}
// check if the values ae intact after the bootup
if ta.Uid != rcvd1.Uid {
t.Fatalf("invalid uid: expected %d got %d", ta.Uid, rcvd1.Uid)
}
if ta.Total != rcvd1.Total {
t.Fatalf("invalid total: expected %d got %d", ta.Total, rcvd1.Total)
}
// See if tag is saved after syncing is over
for i := 0; i < 8; i++ {
err := ta.Inc(StateSent)
if err != nil {
t.Fatal(err)
}
err = ta.Inc(StateSynced)
if err != nil {
t.Fatal(err)
}
}
// simulate node closing down and booting up
err = ts.Close()
if err != nil {
t.Fatal(err)
}
ts = NewTags(mockStatestore, logger)
// get the tag after the node boot up
rcvd2, err := ts.Get(ta.Uid)
if err != nil {
t.Fatal(err)
}
// check if the values ae intact after the bootup
if ta.Uid != rcvd2.Uid {
t.Fatalf("invalid uid: expected %d got %d", ta.Uid, rcvd2.Uid)
}
if ta.Total != rcvd2.Total {
t.Fatalf("invalid total: expected %d got %d", ta.Total, rcvd2.Total)
}
if ta.Sent != rcvd2.Sent {
t.Fatalf("invalid sent: expected %d got %d", ta.Sent, rcvd2.Sent)
}
if ta.Synced != rcvd2.Synced {
t.Fatalf("invalid synced: expected %d got %d", ta.Synced, rcvd2.Synced)
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment