Commit f8054e91 authored by acud's avatar acud Committed by GitHub

node: add tag persistence across sessions (#573)

* Add individual tags persistence
parent b54afca6
...@@ -18,7 +18,7 @@ import ( ...@@ -18,7 +18,7 @@ import (
cmdfile "github.com/ethersphere/bee/cmd/internal/file" cmdfile "github.com/ethersphere/bee/cmd/internal/file"
"github.com/ethersphere/bee/pkg/api" "github.com/ethersphere/bee/pkg/api"
"github.com/ethersphere/bee/pkg/logging" "github.com/ethersphere/bee/pkg/logging"
resolverMock "github.com/ethersphere/bee/pkg/resolver/mock" statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/storage/mock" "github.com/ethersphere/bee/pkg/storage/mock"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
...@@ -154,7 +154,9 @@ func TestLimitWriter(t *testing.T) { ...@@ -154,7 +154,9 @@ func TestLimitWriter(t *testing.T) {
// newTestServer creates an http server to serve the bee http api endpoints. // newTestServer creates an http server to serve the bee http api endpoints.
func newTestServer(t *testing.T, storer storage.Storer) *url.URL { func newTestServer(t *testing.T, storer storage.Storer) *url.URL {
t.Helper() t.Helper()
s := api.New(tags.NewTags(), storer, resolverMock.NewResolver(), nil, logging.New(ioutil.Discard, 0), nil) logger := logging.New(ioutil.Discard, 0)
store := statestore.NewStateStore()
s := api.New(tags.NewTags(store, logger), storer, nil, nil, logger, nil)
ts := httptest.NewServer(s) ts := httptest.NewServer(s)
srvUrl, err := url.Parse(ts.URL) srvUrl, err := url.Parse(ts.URL)
if err != nil { if err != nil {
......
...@@ -41,7 +41,13 @@ func (s *server) bytesUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -41,7 +41,13 @@ func (s *server) bytesUploadHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
if created { if created {
tag.DoneSplit(address) _, err = tag.DoneSplit(address)
if err != nil {
s.Logger.Debugf("bytes upload: done split: %v", err)
s.Logger.Error("bytes upload: done split failed")
jsonhttp.InternalServerError(w, nil)
return
}
} }
w.Header().Set(SwarmTagUidHeader, fmt.Sprint(tag.Uid)) w.Header().Set(SwarmTagUidHeader, fmt.Sprint(tag.Uid))
w.Header().Set("Access-Control-Expose-Headers", SwarmTagUidHeader) w.Header().Set("Access-Control-Expose-Headers", SwarmTagUidHeader)
......
...@@ -6,6 +6,7 @@ package api_test ...@@ -6,6 +6,7 @@ package api_test
import ( import (
"bytes" "bytes"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"testing" "testing"
...@@ -28,9 +29,11 @@ func TestBytes(t *testing.T) { ...@@ -28,9 +29,11 @@ func TestBytes(t *testing.T) {
targets = "0x222" targets = "0x222"
expHash = "29a5fb121ce96194ba8b7b823a1f9c6af87e1791f824940a53b5a7efe3f790d9" expHash = "29a5fb121ce96194ba8b7b823a1f9c6af87e1791f824940a53b5a7efe3f790d9"
mockStorer = mock.NewStorer() mockStorer = mock.NewStorer()
mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{ client = newTestServer(t, testServerOptions{
Storer: mockStorer, Storer: mockStorer,
Tags: tags.NewTags(), Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5), Logger: logging.New(ioutil.Discard, 5),
}) })
) )
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io" "io"
"io/ioutil" "io/ioutil"
"mime" "mime"
...@@ -33,9 +34,11 @@ func TestBzz(t *testing.T) { ...@@ -33,9 +34,11 @@ func TestBzz(t *testing.T) {
bzzDownloadResource = func(addr, path string) string { return "/bzz/" + addr + "/" + path } bzzDownloadResource = func(addr, path string) string { return "/bzz/" + addr + "/" + path }
storer = smock.NewStorer() storer = smock.NewStorer()
ctx = context.Background() ctx = context.Background()
mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{ client = newTestServer(t, testServerOptions{
Storer: storer, Storer: storer,
Tags: tags.NewTags(), Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5), Logger: logging.New(ioutil.Discard, 5),
}) })
pipeWriteAll = func(r io.Reader, l int64) (swarm.Address, error) { pipeWriteAll = func(r io.Reader, l int64) (swarm.Address, error) {
......
...@@ -45,7 +45,13 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -45,7 +45,13 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) {
ctx := sctx.SetTag(r.Context(), tag) ctx := sctx.SetTag(r.Context(), tag)
// Increment the StateSplit here since we dont have a splitter for the file upload // Increment the StateSplit here since we dont have a splitter for the file upload
tag.Inc(tags.StateSplit) err = tag.Inc(tags.StateSplit)
if err != nil {
s.Logger.Debugf("chunk upload: increment tag: %v", err)
s.Logger.Error("chunk upload: increment tag")
jsonhttp.InternalServerError(w, "increment tag")
return
}
data, err := ioutil.ReadAll(r.Body) data, err := ioutil.ReadAll(r.Body)
if err != nil { if err != nil {
...@@ -65,11 +71,23 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -65,11 +71,23 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) {
jsonhttp.BadRequest(w, "chunk write error") jsonhttp.BadRequest(w, "chunk write error")
return return
} else if len(seen) > 0 && seen[0] { } else if len(seen) > 0 && seen[0] {
tag.Inc(tags.StateSeen) err := tag.Inc(tags.StateSeen)
if err != nil {
s.Logger.Debugf("chunk upload: increment tag", err)
s.Logger.Error("chunk upload: increment tag")
jsonhttp.BadRequest(w, "increment tag")
return
}
} }
// Indicate that the chunk is stored // Indicate that the chunk is stored
tag.Inc(tags.StateStored) err = tag.Inc(tags.StateStored)
if err != nil {
s.Logger.Debugf("chunk upload: increment tag", err)
s.Logger.Error("chunk upload: increment tag")
jsonhttp.BadRequest(w, "increment tag")
return
}
w.Header().Set(SwarmTagUidHeader, fmt.Sprint(tag.Uid)) w.Header().Set(SwarmTagUidHeader, fmt.Sprint(tag.Uid))
w.Header().Set("Access-Control-Expose-Headers", SwarmTagUidHeader) w.Header().Set("Access-Control-Expose-Headers", SwarmTagUidHeader)
......
...@@ -6,6 +6,8 @@ package api_test ...@@ -6,6 +6,8 @@ package api_test
import ( import (
"bytes" "bytes"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
...@@ -35,7 +37,9 @@ func TestChunkUploadDownload(t *testing.T) { ...@@ -35,7 +37,9 @@ func TestChunkUploadDownload(t *testing.T) {
validContent = []byte("bbaatt") validContent = []byte("bbaatt")
invalidContent = []byte("bbaattss") invalidContent = []byte("bbaattss")
mockValidator = validator.NewMockValidator(validHash, validContent) mockValidator = validator.NewMockValidator(validHash, validContent)
tag = tags.NewTags() mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger)
mockValidatingStorer = mock.NewStorer(mock.WithValidator(mockValidator)) mockValidatingStorer = mock.NewStorer(mock.WithValidator(mockValidator))
client = newTestServer(t, testServerOptions{ client = newTestServer(t, testServerOptions{
Storer: mockValidatingStorer, Storer: mockValidatingStorer,
......
...@@ -54,13 +54,19 @@ func (s *server) dirUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -54,13 +54,19 @@ func (s *server) dirUploadHandler(w http.ResponseWriter, r *http.Request) {
reference, err := storeDir(ctx, r.Body, s.Storer, requestModePut(r), s.Logger, requestEncrypt(r)) reference, err := storeDir(ctx, r.Body, s.Storer, requestModePut(r), s.Logger, requestEncrypt(r))
if err != nil { if err != nil {
s.Logger.Debugf("dir upload, store dir err: %v", err) s.Logger.Debugf("dir upload: store dir err: %v", err)
s.Logger.Errorf("dir upload, store dir") s.Logger.Errorf("dir upload: store dir")
jsonhttp.InternalServerError(w, "could not store dir") jsonhttp.InternalServerError(w, "could not store dir")
return return
} }
if created { if created {
tag.DoneSplit(reference) _, err = tag.DoneSplit(reference)
if err != nil {
s.Logger.Debugf("dir upload: done split: %v", err)
s.Logger.Error("dir upload: done split failed")
jsonhttp.InternalServerError(w, nil)
return
}
} }
w.Header().Set(SwarmTagUidHeader, fmt.Sprint(tag.Uid)) w.Header().Set(SwarmTagUidHeader, fmt.Sprint(tag.Uid))
jsonhttp.OK(w, fileUploadResponse{ jsonhttp.OK(w, fileUploadResponse{
......
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"path" "path"
...@@ -32,9 +33,11 @@ func TestDirs(t *testing.T) { ...@@ -32,9 +33,11 @@ func TestDirs(t *testing.T) {
dirUploadResource = "/dirs" dirUploadResource = "/dirs"
fileDownloadResource = func(addr string) string { return "/files/" + addr } fileDownloadResource = func(addr string) string { return "/files/" + addr }
storer = mock.NewStorer() storer = mock.NewStorer()
mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{ client = newTestServer(t, testServerOptions{
Storer: storer, Storer: storer,
Tags: tags.NewTags(), Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5), Logger: logging.New(ioutil.Discard, 5),
}) })
) )
......
...@@ -201,7 +201,13 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -201,7 +201,13 @@ func (s *server) fileUploadHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
if created { if created {
tag.DoneSplit(reference) _, err = tag.DoneSplit(reference)
if err != nil {
s.Logger.Debugf("file upload: done split: %v", err)
s.Logger.Error("file upload: done split failed")
jsonhttp.InternalServerError(w, nil)
return
}
} }
w.Header().Set("ETag", fmt.Sprintf("%q", reference.String())) w.Header().Set("ETag", fmt.Sprintf("%q", reference.String()))
w.Header().Set(SwarmTagUidHeader, fmt.Sprint(tag.Uid)) w.Header().Set(SwarmTagUidHeader, fmt.Sprint(tag.Uid))
......
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io" "io"
"io/ioutil" "io/ioutil"
"mime" "mime"
...@@ -32,9 +33,11 @@ func TestFiles(t *testing.T) { ...@@ -32,9 +33,11 @@ func TestFiles(t *testing.T) {
targets = "0x222" targets = "0x222"
fileDownloadResource = func(addr string) string { return "/files/" + addr } fileDownloadResource = func(addr string) string { return "/files/" + addr }
simpleData = []byte("this is a simple text") simpleData = []byte("this is a simple text")
mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
client = newTestServer(t, testServerOptions{ client = newTestServer(t, testServerOptions{
Storer: mock.NewStorer(), Storer: mock.NewStorer(),
Tags: tags.NewTags(), Tags: tags.NewTags(mockStatestore, logger),
}) })
) )
...@@ -333,9 +336,11 @@ func TestRangeRequests(t *testing.T) { ...@@ -333,9 +336,11 @@ func TestRangeRequests(t *testing.T) {
for _, upload := range uploads { for _, upload := range uploads {
t.Run(upload.name, func(t *testing.T) { t.Run(upload.name, func(t *testing.T) {
mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
client := newTestServer(t, testServerOptions{ client := newTestServer(t, testServerOptions{
Storer: mock.NewStorer(), Storer: mock.NewStorer(),
Tags: tags.NewTags(), Tags: tags.NewTags(mockStatestore, logger),
Logger: logging.New(ioutil.Discard, 5), Logger: logging.New(ioutil.Discard, 5),
}) })
......
...@@ -16,7 +16,6 @@ import ( ...@@ -16,7 +16,6 @@ import (
"github.com/ethersphere/bee/pkg/jsonhttp" "github.com/ethersphere/bee/pkg/jsonhttp"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags" "github.com/ethersphere/bee/pkg/tags"
"github.com/gorilla/mux" "github.com/gorilla/mux"
) )
...@@ -197,6 +196,12 @@ func (s *server) doneSplit(w http.ResponseWriter, r *http.Request) { ...@@ -197,6 +196,12 @@ func (s *server) doneSplit(w http.ResponseWriter, r *http.Request) {
return return
} }
tag.DoneSplit(tagr.Address) _, err = tag.DoneSplit(tagr.Address)
if err != nil {
s.Logger.Debugf("done split: failed for address %v", tagr.Address)
s.Logger.Error("done split: failed for address %v", tagr.Address)
jsonhttp.InternalServerError(w, nil)
return
}
jsonhttp.OK(w, "ok") jsonhttp.OK(w, "ok")
} }
...@@ -8,6 +8,9 @@ import ( ...@@ -8,6 +8,9 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
...@@ -39,7 +42,9 @@ func TestTags(t *testing.T) { ...@@ -39,7 +42,9 @@ func TestTags(t *testing.T) {
someHash = swarm.MustParseHexAddress("aabbcc") someHash = swarm.MustParseHexAddress("aabbcc")
someContent = []byte("bbaatt") someContent = []byte("bbaatt")
someTagName = "file.jpg" someTagName = "file.jpg"
tag = tags.NewTags() mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger)
mockPusher = mp.NewMockPusher(tag) mockPusher = mp.NewMockPusher(tag)
client = newTestServer(t, testServerOptions{ client = newTestServer(t, testServerOptions{
Storer: mock.NewStorer(), Storer: mock.NewStorer(),
......
...@@ -6,6 +6,9 @@ package debugapi_test ...@@ -6,6 +6,9 @@ package debugapi_test
import ( import (
"bytes" "bytes"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil"
"net/http" "net/http"
"testing" "testing"
...@@ -29,7 +32,9 @@ func TestPinChunkHandler(t *testing.T) { ...@@ -29,7 +32,9 @@ func TestPinChunkHandler(t *testing.T) {
data = []byte("bbaatt") data = []byte("bbaatt")
mockValidator = validator.NewMockValidator(hash, data) mockValidator = validator.NewMockValidator(hash, data)
mockValidatingStorer = mock.NewStorer(mock.WithValidator(mockValidator)) mockValidatingStorer = mock.NewStorer(mock.WithValidator(mockValidator))
tag = tags.NewTags() mockStatestore = statestore.NewStateStore()
logger = logging.New(ioutil.Discard, 0)
tag = tags.NewTags(mockStatestore, logger)
debugTestServer = newTestServer(t, testServerOptions{ debugTestServer = newTestServer(t, testServerOptions{
Storer: mockValidatingStorer, Storer: mockValidatingStorer,
......
...@@ -6,7 +6,6 @@ package pipeline ...@@ -6,7 +6,6 @@ package pipeline
import ( import (
"context" "context"
"github.com/ethersphere/bee/pkg/sctx" "github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/storage" "github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
...@@ -28,18 +27,31 @@ func newStoreWriter(ctx context.Context, l storage.Putter, mode storage.ModePut, ...@@ -28,18 +27,31 @@ func newStoreWriter(ctx context.Context, l storage.Putter, mode storage.ModePut,
func (w *storeWriter) chainWrite(p *pipeWriteArgs) error { func (w *storeWriter) chainWrite(p *pipeWriteArgs) error {
tag := sctx.GetTag(w.ctx) tag := sctx.GetTag(w.ctx)
var c swarm.Chunk
if tag != nil { if tag != nil {
tag.Inc(tags.StateSplit) err := tag.Inc(tags.StateSplit)
if err != nil {
return err
} }
c := swarm.NewChunk(swarm.NewAddress(p.ref), p.data) c = swarm.NewChunk(swarm.NewAddress(p.ref), p.data).WithTagID(tag.Uid)
} else {
c = swarm.NewChunk(swarm.NewAddress(p.ref), p.data)
}
seen, err := w.l.Put(w.ctx, w.mode, c) seen, err := w.l.Put(w.ctx, w.mode, c)
if err != nil { if err != nil {
return err return err
} }
if tag != nil { if tag != nil {
tag.Inc(tags.StateStored) err := tag.Inc(tags.StateStored)
if err != nil {
return err
}
if seen[0] { if seen[0] {
tag.Inc(tags.StateSeen) err := tag.Inc(tags.StateSeen)
if err != nil {
return err
}
} }
} }
if w.next == nil { if w.next == nil {
......
...@@ -152,7 +152,10 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) { ...@@ -152,7 +152,10 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) {
binary.LittleEndian.PutUint64(head, uint64(span)) binary.LittleEndian.PutUint64(head, uint64(span))
tail := s.buffer[s.cursors[lvl+1]:s.cursors[lvl]] tail := s.buffer[s.cursors[lvl+1]:s.cursors[lvl]]
chunkData = append(head, tail...) chunkData = append(head, tail...)
s.incrTag(tags.StateSplit) err := s.incrTag(tags.StateSplit)
if err != nil {
return nil, err
}
c := chunkData c := chunkData
var encryptionKey encryption.Key var encryptionKey encryption.Key
...@@ -165,7 +168,7 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) { ...@@ -165,7 +168,7 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) {
} }
s.hasher.Reset() s.hasher.Reset()
err := s.hasher.SetSpanBytes(c[:8]) err = s.hasher.SetSpanBytes(c[:8])
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -188,10 +191,16 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) { ...@@ -188,10 +191,16 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} else if len(seen) > 0 && seen[0] { } else if len(seen) > 0 && seen[0] {
s.incrTag(tags.StateSeen) err = s.incrTag(tags.StateSeen)
if err != nil {
return nil, err
}
} }
s.incrTag(tags.StateStored) err = s.incrTag(tags.StateStored)
if err != nil {
return nil, err
}
return append(ch.Address().Bytes(), encryptionKey...), nil return append(ch.Address().Bytes(), encryptionKey...), nil
} }
...@@ -310,8 +319,9 @@ func (s *SimpleSplitterJob) newDataEncryption(key encryption.Key) encryption.Int ...@@ -310,8 +319,9 @@ func (s *SimpleSplitterJob) newDataEncryption(key encryption.Key) encryption.Int
return encryption.New(key, int(swarm.ChunkSize), 0, sha3.NewLegacyKeccak256) return encryption.New(key, int(swarm.ChunkSize), 0, sha3.NewLegacyKeccak256)
} }
func (s *SimpleSplitterJob) incrTag(state tags.State) { func (s *SimpleSplitterJob) incrTag(state tags.State) error {
if s.tag != nil { if s.tag != nil {
s.tag.Inc(state) return s.tag.Inc(state)
} }
return nil
} }
...@@ -247,7 +247,10 @@ func (db *DB) setSync(batch *leveldb.Batch, addr swarm.Address, mode storage.Mod ...@@ -247,7 +247,10 @@ func (db *DB) setSync(batch *leveldb.Batch, addr swarm.Address, mode storage.Mod
// run to end from db.pushIndex.DeleteInBatch // run to end from db.pushIndex.DeleteInBatch
db.logger.Errorf("localstore: get tags on push sync set uid %d: %v", i.Tag, err) db.logger.Errorf("localstore: get tags on push sync set uid %d: %v", i.Tag, err)
} else { } else {
t.Inc(tags.StateSynced) err = t.Inc(tags.StateSynced)
if err != nil {
return 0, err
}
} }
} }
......
...@@ -19,6 +19,9 @@ package localstore ...@@ -19,6 +19,9 @@ package localstore
import ( import (
"context" "context"
"errors" "errors"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil"
"testing" "testing"
"time" "time"
...@@ -70,7 +73,9 @@ func TestModeSetAccess(t *testing.T) { ...@@ -70,7 +73,9 @@ func TestModeSetAccess(t *testing.T) {
// as a result we should expect the tag value to remain in the pull index // as a result we should expect the tag value to remain in the pull index
// and we expect that the tag should not be incremented by pull sync set // and we expect that the tag should not be incremented by pull sync set
func TestModeSetSyncPullNormalTag(t *testing.T) { func TestModeSetSyncPullNormalTag(t *testing.T) {
db := newTestDB(t, &Options{Tags: tags.NewTags()}) mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
db := newTestDB(t, &Options{Tags: tags.NewTags(mockStatestore, logger)})
tag, err := db.tags.Create("test", 1) tag, err := db.tags.Create("test", 1)
if err != nil { if err != nil {
...@@ -83,7 +88,10 @@ func TestModeSetSyncPullNormalTag(t *testing.T) { ...@@ -83,7 +88,10 @@ func TestModeSetSyncPullNormalTag(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
tag.Inc(tags.StateStored) // so we don't get an error on tag.Status later on err = tag.Inc(tags.StateStored) // so we don't get an error on tag.Status later on
if err != nil {
t.Fatal(err)
}
item, err := db.pullIndex.Get(shed.Item{ item, err := db.pullIndex.Get(shed.Item{
Address: ch.Address().Bytes(), Address: ch.Address().Bytes(),
...@@ -124,7 +132,9 @@ func TestModeSetSyncPullNormalTag(t *testing.T) { ...@@ -124,7 +132,9 @@ func TestModeSetSyncPullNormalTag(t *testing.T) {
// correctly on a normal tag (that is, a tag that is expected to show progress bars // correctly on a normal tag (that is, a tag that is expected to show progress bars
// according to push sync progress) // according to push sync progress)
func TestModeSetSyncPushNormalTag(t *testing.T) { func TestModeSetSyncPushNormalTag(t *testing.T) {
db := newTestDB(t, &Options{Tags: tags.NewTags()}) mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
db := newTestDB(t, &Options{Tags: tags.NewTags(mockStatestore, logger)})
tag, err := db.tags.Create("test", 1) tag, err := db.tags.Create("test", 1)
if err != nil { if err != nil {
...@@ -137,7 +147,11 @@ func TestModeSetSyncPushNormalTag(t *testing.T) { ...@@ -137,7 +147,11 @@ func TestModeSetSyncPushNormalTag(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
tag.Inc(tags.StateStored) // so we don't get an error on tag.Status later on err = tag.Inc(tags.StateStored) // so we don't get an error on tag.Status later on
if err != nil {
t.Fatal(err)
}
item, err := db.pullIndex.Get(shed.Item{ item, err := db.pullIndex.Get(shed.Item{
Address: ch.Address().Bytes(), Address: ch.Address().Bytes(),
BinID: 1, BinID: 1,
......
...@@ -61,6 +61,7 @@ type Bee struct { ...@@ -61,6 +61,7 @@ type Bee struct {
resolverCloser io.Closer resolverCloser io.Closer
errorLogWriter *io.PipeWriter errorLogWriter *io.PipeWriter
tracerCloser io.Closer tracerCloser io.Closer
tagsCloser io.Closer
stateStoreCloser io.Closer stateStoreCloser io.Closer
localstoreCloser io.Closer localstoreCloser io.Closer
topologyCloser io.Closer topologyCloser io.Closer
...@@ -243,7 +244,8 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service, ...@@ -243,7 +244,8 @@ func NewBee(addr string, swarmAddress swarm.Address, keystore keystore.Service,
chunkvalidator := swarm.NewChunkValidator(soc.NewValidator(), content.NewValidator()) chunkvalidator := swarm.NewChunkValidator(soc.NewValidator(), content.NewValidator())
retrieve := retrieval.New(p2ps, kad, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), chunkvalidator) retrieve := retrieval.New(p2ps, kad, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), chunkvalidator)
tagg := tags.NewTags() tagg := tags.NewTags(stateStore, logger)
b.tagsCloser = tagg
if err = p2ps.AddProtocol(retrieve.Protocol()); err != nil { if err = p2ps.AddProtocol(retrieve.Protocol()); err != nil {
return nil, fmt.Errorf("retrieval service: %w", err) return nil, fmt.Errorf("retrieval service: %w", err)
...@@ -413,6 +415,10 @@ func (b *Bee) Shutdown(ctx context.Context) error { ...@@ -413,6 +415,10 @@ func (b *Bee) Shutdown(ctx context.Context) error {
errs.add(fmt.Errorf("tracer: %w", err)) errs.add(fmt.Errorf("tracer: %w", err))
} }
if err := b.tagsCloser.Close(); err != nil {
errs.add(fmt.Errorf("tag persistence: %w", err))
}
if err := b.stateStoreCloser.Close(); err != nil { if err := b.stateStoreCloser.Close(); err != nil {
errs.add(fmt.Errorf("statestore: %w", err)) errs.add(fmt.Errorf("statestore: %w", err))
} }
......
...@@ -23,9 +23,7 @@ func (m *MockPusher) SendChunk(uid uint32) error { ...@@ -23,9 +23,7 @@ func (m *MockPusher) SendChunk(uid uint32) error {
if err != nil { if err != nil {
return err return err
} }
ta.Inc(tags.StateSent) return ta.Inc(tags.StateSent)
return nil
} }
func (m *MockPusher) RcvdReceipt(uid uint32) error { func (m *MockPusher) RcvdReceipt(uid uint32) error {
...@@ -33,7 +31,5 @@ func (m *MockPusher) RcvdReceipt(uid uint32) error { ...@@ -33,7 +31,5 @@ func (m *MockPusher) RcvdReceipt(uid uint32) error {
if err != nil { if err != nil {
return err return err
} }
ta.Inc(tags.StateSynced) return ta.Inc(tags.StateSynced)
return nil
} }
...@@ -123,7 +123,12 @@ LOOP: ...@@ -123,7 +123,12 @@ LOOP:
} }
return return
} }
s.setChunkAsSynced(ctx, ch) err = s.setChunkAsSynced(ctx, ch)
if err != nil {
s.logger.Debugf("pusher: error setting chunk as synced: %v", err)
return
}
}(ctx, ch) }(ctx, ch)
case <-timer.C: case <-timer.C:
// initially timer is set to go off as well as every time we hit the end of push index // initially timer is set to go off as well as every time we hit the end of push index
...@@ -165,15 +170,19 @@ LOOP: ...@@ -165,15 +170,19 @@ LOOP:
} }
} }
func (s *Service) setChunkAsSynced(ctx context.Context, ch swarm.Chunk) { func (s *Service) setChunkAsSynced(ctx context.Context, ch swarm.Chunk) error {
if err := s.storer.Set(ctx, storage.ModeSetSyncPush, ch.Address()); err != nil { if err := s.storer.Set(ctx, storage.ModeSetSyncPush, ch.Address()); err != nil {
s.logger.Errorf("pusher: error setting chunk as synced: %v", err) s.logger.Errorf("pusher: error setting chunk as synced: %v", err)
s.metrics.ErrorSettingChunkToSynced.Inc() s.metrics.ErrorSettingChunkToSynced.Inc()
} }
t, err := s.tagg.Get(ch.TagID()) t, err := s.tagg.Get(ch.TagID())
if err == nil && t != nil { if err == nil && t != nil {
t.Inc(tags.StateSynced) err = t.Inc(tags.StateSynced)
if err != nil {
return err
}
} }
return nil
} }
func (s *Service) Close() error { func (s *Service) Close() error {
......
...@@ -7,6 +7,7 @@ package pusher_test ...@@ -7,6 +7,7 @@ package pusher_test
import ( import (
"context" "context"
"errors" "errors"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil" "io/ioutil"
"sync" "sync"
"testing" "testing"
...@@ -227,7 +228,8 @@ func createPusher(t *testing.T, addr swarm.Address, pushSyncService pushsync.Pus ...@@ -227,7 +228,8 @@ func createPusher(t *testing.T, addr swarm.Address, pushSyncService pushsync.Pus
t.Fatal(err) t.Fatal(err)
} }
mtags := tags.NewTags() mockStatestore := statestore.NewStateStore()
mtags := tags.NewTags(mockStatestore, logger)
pusherStorer := &Store{ pusherStorer := &Store{
Storer: storer, Storer: storer,
modeSet: make(map[string]storage.ModeSet), modeSet: make(map[string]storage.ModeSet),
......
...@@ -221,7 +221,10 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re ...@@ -221,7 +221,10 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re
// this is to make sure that the sent number does not diverge from the synced counter // this is to make sure that the sent number does not diverge from the synced counter
t, err := ps.tagg.Get(ch.TagID()) t, err := ps.tagg.Get(ch.TagID())
if err == nil && t != nil { if err == nil && t != nil {
t.Inc(tags.StateSent) err = t.Inc(tags.StateSent)
if err != nil {
return nil, err
}
} }
// if you are the closest node return a receipt immediately // if you are the closest node return a receipt immediately
...@@ -255,7 +258,10 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re ...@@ -255,7 +258,10 @@ func (ps *PushSync) PushChunkToClosest(ctx context.Context, ch swarm.Chunk) (*Re
// if you manage to get a tag, just increment the respective counter // if you manage to get a tag, just increment the respective counter
t, err := ps.tagg.Get(ch.TagID()) t, err := ps.tagg.Get(ch.TagID())
if err == nil && t != nil { if err == nil && t != nil {
t.Inc(tags.StateSent) err = t.Inc(tags.StateSent)
if err != nil {
return nil, err
}
} }
receiptRTTTimer := time.Now() receiptRTTTimer := time.Now()
......
...@@ -7,6 +7,7 @@ package pushsync_test ...@@ -7,6 +7,7 @@ package pushsync_test
import ( import (
"bytes" "bytes"
"context" "context"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"io/ioutil" "io/ioutil"
"testing" "testing"
"time" "time"
...@@ -286,7 +287,8 @@ func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.R ...@@ -286,7 +287,8 @@ func createPushSyncNode(t *testing.T, addr swarm.Address, recorder *streamtest.R
} }
mockTopology := mock.NewTopologyDriver(mockOpts...) mockTopology := mock.NewTopologyDriver(mockOpts...)
mtag := tags.NewTags() mockStatestore := statestore.NewStateStore()
mtag := tags.NewTags(mockStatestore, logger)
mockAccounting := accountingmock.NewAccounting() mockAccounting := accountingmock.NewAccounting()
mockPricer := accountingmock.NewPricer(fixedPrice, fixedPrice) mockPricer := accountingmock.NewPricer(fixedPrice, fixedPrice)
......
...@@ -20,10 +20,13 @@ import ( ...@@ -20,10 +20,13 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tracing" "github.com/ethersphere/bee/pkg/tracing"
"github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
...@@ -65,15 +68,19 @@ type Tag struct { ...@@ -65,15 +68,19 @@ type Tag struct {
ctx context.Context // tracing context ctx context.Context // tracing context
span opentracing.Span // tracing root span span opentracing.Span // tracing root span
spanOnce sync.Once // make sure we close root span only once spanOnce sync.Once // make sure we close root span only once
stateStore storage.StateStorer // to persist the tag
logger logging.Logger // logger instance for logging
} }
// NewTag creates a new tag, and returns it // NewTag creates a new tag, and returns it
func NewTag(ctx context.Context, uid uint32, s string, total int64, tracer *tracing.Tracer) *Tag { func NewTag(ctx context.Context, uid uint32, s string, total int64, tracer *tracing.Tracer, stateStore storage.StateStorer, logger logging.Logger) *Tag {
t := &Tag{ t := &Tag{
Uid: uid, Uid: uid,
Name: s, Name: s,
StartedAt: time.Now(), StartedAt: time.Now(),
Total: total, Total: total,
stateStore: stateStore,
logger: logger,
} }
// context here is used only to store the root span `new.upload.tag` within Tag, // context here is used only to store the root span `new.upload.tag` within Tag,
...@@ -95,7 +102,7 @@ func (t *Tag) FinishRootSpan() { ...@@ -95,7 +102,7 @@ func (t *Tag) FinishRootSpan() {
} }
// IncN increments the count for a state // IncN increments the count for a state
func (t *Tag) IncN(state State, n int) { func (t *Tag) IncN(state State, n int) error {
var v *int64 var v *int64
switch state { switch state {
case TotalChunks: case TotalChunks:
...@@ -112,11 +119,23 @@ func (t *Tag) IncN(state State, n int) { ...@@ -112,11 +119,23 @@ func (t *Tag) IncN(state State, n int) {
v = &t.Synced v = &t.Synced
} }
atomic.AddInt64(v, int64(n)) atomic.AddInt64(v, int64(n))
// check if syncing is over and persist the tag
if state == StateSynced {
total := atomic.LoadInt64(&t.Total)
seen := atomic.LoadInt64(&t.Seen)
synced := atomic.LoadInt64(&t.Synced)
totalUnique := total - seen
if synced >= totalUnique {
return t.saveTag()
}
}
return nil
} }
// Inc increments the count for a state // Inc increments the count for a state
func (t *Tag) Inc(state State) { func (t *Tag) Inc(state State) error {
t.IncN(state, 1) return t.IncN(state, 1)
} }
// Get returns the count for a state on a tag // Get returns the count for a state on a tag
...@@ -172,7 +191,7 @@ func (t *Tag) Done(s State) bool { ...@@ -172,7 +191,7 @@ func (t *Tag) Done(s State) bool {
// DoneSplit sets total count to SPLIT count and sets the associated swarm hash for this tag // DoneSplit sets total count to SPLIT count and sets the associated swarm hash for this tag
// is meant to be called when splitter finishes for input streams of unknown size // is meant to be called when splitter finishes for input streams of unknown size
func (t *Tag) DoneSplit(address swarm.Address) int64 { func (t *Tag) DoneSplit(address swarm.Address) (int64, error) {
total := atomic.LoadInt64(&t.Split) total := atomic.LoadInt64(&t.Split)
atomic.StoreInt64(&t.Total, total) atomic.StoreInt64(&t.Total, total)
...@@ -180,7 +199,12 @@ func (t *Tag) DoneSplit(address swarm.Address) int64 { ...@@ -180,7 +199,12 @@ func (t *Tag) DoneSplit(address swarm.Address) int64 {
t.Address = address t.Address = address
} }
return total // persist the tag
err := t.saveTag()
if err != nil {
return 0, err
}
return total, nil
} }
// Status returns the value of state and the total count // Status returns the value of state and the total count
...@@ -220,12 +244,12 @@ func (t *Tag) ETA(state State) (time.Time, error) { ...@@ -220,12 +244,12 @@ func (t *Tag) ETA(state State) (time.Time, error) {
func (tag *Tag) MarshalBinary() (data []byte, err error) { func (tag *Tag) MarshalBinary() (data []byte, err error) {
buffer := make([]byte, 4) buffer := make([]byte, 4)
binary.BigEndian.PutUint32(buffer, tag.Uid) binary.BigEndian.PutUint32(buffer, tag.Uid)
encodeInt64Append(&buffer, tag.Total) encodeInt64Append(&buffer, atomic.LoadInt64(&tag.Total))
encodeInt64Append(&buffer, tag.Split) encodeInt64Append(&buffer, atomic.LoadInt64(&tag.Split))
encodeInt64Append(&buffer, tag.Seen) encodeInt64Append(&buffer, atomic.LoadInt64(&tag.Seen))
encodeInt64Append(&buffer, tag.Stored) encodeInt64Append(&buffer, atomic.LoadInt64(&tag.Stored))
encodeInt64Append(&buffer, tag.Sent) encodeInt64Append(&buffer, atomic.LoadInt64(&tag.Sent))
encodeInt64Append(&buffer, tag.Synced) encodeInt64Append(&buffer, atomic.LoadInt64(&tag.Synced))
intBuffer := make([]byte, 8) intBuffer := make([]byte, 8)
...@@ -248,12 +272,12 @@ func (tag *Tag) UnmarshalBinary(buffer []byte) error { ...@@ -248,12 +272,12 @@ func (tag *Tag) UnmarshalBinary(buffer []byte) error {
tag.Uid = binary.BigEndian.Uint32(buffer) tag.Uid = binary.BigEndian.Uint32(buffer)
buffer = buffer[4:] buffer = buffer[4:]
tag.Total = decodeInt64Splice(&buffer) atomic.AddInt64(&tag.Total, decodeInt64Splice(&buffer))
tag.Split = decodeInt64Splice(&buffer) atomic.AddInt64(&tag.Split, decodeInt64Splice(&buffer))
tag.Seen = decodeInt64Splice(&buffer) atomic.AddInt64(&tag.Seen, decodeInt64Splice(&buffer))
tag.Stored = decodeInt64Splice(&buffer) atomic.AddInt64(&tag.Stored, decodeInt64Splice(&buffer))
tag.Sent = decodeInt64Splice(&buffer) atomic.AddInt64(&tag.Sent, decodeInt64Splice(&buffer))
tag.Synced = decodeInt64Splice(&buffer) atomic.AddInt64(&tag.Synced, decodeInt64Splice(&buffer))
t, n := binary.Varint(buffer) t, n := binary.Varint(buffer)
tag.StartedAt = time.Unix(t, 0) tag.StartedAt = time.Unix(t, 0)
...@@ -280,3 +304,24 @@ func decodeInt64Splice(buffer *[]byte) int64 { ...@@ -280,3 +304,24 @@ func decodeInt64Splice(buffer *[]byte) int64 {
*buffer = (*buffer)[n:] *buffer = (*buffer)[n:]
return val return val
} }
// saveTag update the tag in the state store
func (tag *Tag) saveTag() error {
key := getKey(tag.Uid)
value, err := tag.MarshalBinary()
if err != nil {
return err
}
if tag.stateStore != nil {
err = tag.stateStore.Put(key, value)
if err != nil {
return err
}
}
return nil
}
func getKey(uid uint32) string {
return fmt.Sprintf("tags_%d", uid)
}
...@@ -18,10 +18,13 @@ package tags ...@@ -18,10 +18,13 @@ package tags
import ( import (
"context" "context"
"io/ioutil"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -31,7 +34,9 @@ var ( ...@@ -31,7 +34,9 @@ var (
// TestTagSingleIncrements tests if Inc increments the tag state value // TestTagSingleIncrements tests if Inc increments the tag state value
func TestTagSingleIncrements(t *testing.T) { func TestTagSingleIncrements(t *testing.T) {
tg := &Tag{Total: 10} mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
tg := &Tag{Total: 10, stateStore: mockStatestore, logger: logger}
tc := []struct { tc := []struct {
state uint32 state uint32
...@@ -48,7 +53,10 @@ func TestTagSingleIncrements(t *testing.T) { ...@@ -48,7 +53,10 @@ func TestTagSingleIncrements(t *testing.T) {
for _, tc := range tc { for _, tc := range tc {
for i := 0; i < tc.inc; i++ { for i := 0; i < tc.inc; i++ {
tg.Inc(tc.state) err := tg.Inc(tc.state)
if err != nil {
t.Fatal(err)
}
} }
} }
...@@ -62,13 +70,28 @@ func TestTagSingleIncrements(t *testing.T) { ...@@ -62,13 +70,28 @@ func TestTagSingleIncrements(t *testing.T) {
// TestTagStatus is a unit test to cover Tag.Status method functionality // TestTagStatus is a unit test to cover Tag.Status method functionality
func TestTagStatus(t *testing.T) { func TestTagStatus(t *testing.T) {
tg := &Tag{Total: 10} tg := &Tag{Total: 10}
tg.Inc(StateSeen) err := tg.Inc(StateSeen)
tg.Inc(StateSent) if err != nil {
tg.Inc(StateSynced) t.Fatal(err)
}
err = tg.Inc(StateSent)
if err != nil {
t.Fatal(err)
}
err = tg.Inc(StateSynced)
if err != nil {
t.Fatal(err)
}
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
tg.Inc(StateSplit) err = tg.Inc(StateSplit)
tg.Inc(StateStored) if err != nil {
t.Fatal(err)
}
err = tg.Inc(StateStored)
if err != nil {
t.Fatal(err)
}
} }
for _, v := range []struct { for _, v := range []struct {
state State state State
...@@ -100,7 +123,10 @@ func TestTagETA(t *testing.T) { ...@@ -100,7 +123,10 @@ func TestTagETA(t *testing.T) {
maxDiff := 100000 // 100 microsecond maxDiff := 100000 // 100 microsecond
tg := &Tag{Total: 10, StartedAt: now} tg := &Tag{Total: 10, StartedAt: now}
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
tg.Inc(StateSplit) err := tg.Inc(StateSplit)
if err != nil {
t.Fatal(err)
}
eta, err := tg.ETA(StateSplit) eta, err := tg.ETA(StateSplit)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
...@@ -113,15 +139,20 @@ func TestTagETA(t *testing.T) { ...@@ -113,15 +139,20 @@ func TestTagETA(t *testing.T) {
// TestTagConcurrentIncrements tests Inc calls concurrently // TestTagConcurrentIncrements tests Inc calls concurrently
func TestTagConcurrentIncrements(t *testing.T) { func TestTagConcurrentIncrements(t *testing.T) {
tg := &Tag{} mockStatestore := statestore.NewStateStore()
n := 1000 logger := logging.New(ioutil.Discard, 0)
tg := &Tag{stateStore: mockStatestore, logger: logger}
n := 10
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(5 * n) wg.Add(5 * n)
for _, f := range allStates { for _, f := range allStates {
go func(f State) { go func(f State) {
for j := 0; j < n; j++ { for j := 0; j < n; j++ {
go func() { go func() {
tg.Inc(f) err := tg.Inc(f)
if err != nil {
t.Errorf("error incrementing tag counters: %v", err)
}
wg.Done() wg.Done()
}() }()
} }
...@@ -138,7 +169,9 @@ func TestTagConcurrentIncrements(t *testing.T) { ...@@ -138,7 +169,9 @@ func TestTagConcurrentIncrements(t *testing.T) {
// TestTagsMultipleConcurrentIncrements tests Inc calls concurrently // TestTagsMultipleConcurrentIncrements tests Inc calls concurrently
func TestTagsMultipleConcurrentIncrementsSyncMap(t *testing.T) { func TestTagsMultipleConcurrentIncrementsSyncMap(t *testing.T) {
ts := NewTags() mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
ts := NewTags(mockStatestore, logger)
n := 100 n := 100
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(10 * 5 * n) wg.Add(10 * 5 * n)
...@@ -152,7 +185,10 @@ func TestTagsMultipleConcurrentIncrementsSyncMap(t *testing.T) { ...@@ -152,7 +185,10 @@ func TestTagsMultipleConcurrentIncrementsSyncMap(t *testing.T) {
go func(tag *Tag, f State) { go func(tag *Tag, f State) {
for j := 0; j < n; j++ { for j := 0; j < n; j++ {
go func() { go func() {
tag.Inc(f) err := tag.Inc(f)
if err != nil {
t.Errorf("error incrementing tag counters: %v", err)
}
wg.Done() wg.Done()
}() }()
} }
...@@ -185,11 +221,16 @@ func TestTagsMultipleConcurrentIncrementsSyncMap(t *testing.T) { ...@@ -185,11 +221,16 @@ func TestTagsMultipleConcurrentIncrementsSyncMap(t *testing.T) {
// TestMarshallingWithAddr tests that marshalling and unmarshalling is done correctly when the // TestMarshallingWithAddr tests that marshalling and unmarshalling is done correctly when the
// tag Address (byte slice) contains some arbitrary value // tag Address (byte slice) contains some arbitrary value
func TestMarshallingWithAddr(t *testing.T) { func TestMarshallingWithAddr(t *testing.T) {
tg := NewTag(context.Background(), 111, "test/tag", 10, nil) mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
tg := NewTag(context.Background(), 111, "test/tag", 10, nil, mockStatestore, logger)
tg.Address = swarm.NewAddress([]byte{0, 1, 2, 3, 4, 5, 6}) tg.Address = swarm.NewAddress([]byte{0, 1, 2, 3, 4, 5, 6})
for _, f := range allStates { for _, f := range allStates {
tg.Inc(f) err := tg.Inc(f)
if err != nil {
t.Fatal(err)
}
} }
b, err := tg.MarshalBinary() b, err := tg.MarshalBinary()
...@@ -233,9 +274,14 @@ func TestMarshallingWithAddr(t *testing.T) { ...@@ -233,9 +274,14 @@ func TestMarshallingWithAddr(t *testing.T) {
// TestMarshallingNoAddress tests that marshalling and unmarshalling is done correctly // TestMarshallingNoAddress tests that marshalling and unmarshalling is done correctly
func TestMarshallingNoAddr(t *testing.T) { func TestMarshallingNoAddr(t *testing.T) {
tg := NewTag(context.Background(), 111, "test/tag", 10, nil) mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
tg := NewTag(context.Background(), 111, "test/tag", 10, nil, mockStatestore, logger)
for _, f := range allStates { for _, f := range allStates {
tg.Inc(f) err := tg.Inc(f)
if err != nil {
t.Fatal(err)
}
} }
b, err := tg.MarshalBinary() b, err := tg.MarshalBinary()
......
...@@ -26,6 +26,8 @@ import ( ...@@ -26,6 +26,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/ethersphere/bee/pkg/logging"
"github.com/ethersphere/bee/pkg/storage"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
) )
...@@ -37,19 +39,23 @@ var ( ...@@ -37,19 +39,23 @@ var (
// Tags hold tag information indexed by a unique random uint32 // Tags hold tag information indexed by a unique random uint32
type Tags struct { type Tags struct {
tags *sync.Map tags *sync.Map
stateStore storage.StateStorer
logger logging.Logger
} }
// NewTags creates a tags object // NewTags creates a tags object
func NewTags() *Tags { func NewTags(stateStore storage.StateStorer, logger logging.Logger) *Tags {
return &Tags{ return &Tags{
tags: &sync.Map{}, tags: &sync.Map{},
stateStore: stateStore,
logger: logger,
} }
} }
// Create creates a new tag, stores it by the name and returns it // Create creates a new tag, stores it by the name and returns it
// it returns an error if the tag with this name already exists // it returns an error if the tag with this name already exists
func (ts *Tags) Create(s string, total int64) (*Tag, error) { func (ts *Tags) Create(s string, total int64) (*Tag, error) {
t := NewTag(context.Background(), TagUidFunc(), s, total, nil) t := NewTag(context.Background(), TagUidFunc(), s, total, nil, ts.stateStore, ts.logger)
if _, loaded := ts.tags.LoadOrStore(t.Uid, t); loaded { if _, loaded := ts.tags.LoadOrStore(t.Uid, t); loaded {
return nil, errExists return nil, errExists
...@@ -74,8 +80,15 @@ func (ts *Tags) All() (t []*Tag) { ...@@ -74,8 +80,15 @@ func (ts *Tags) All() (t []*Tag) {
func (ts *Tags) Get(uid uint32) (*Tag, error) { func (ts *Tags) Get(uid uint32) (*Tag, error) {
t, ok := ts.tags.Load(uid) t, ok := ts.tags.Load(uid)
if !ok { if !ok {
// see if the tag is present in the store
// if yes, load it in to the memory
ta, err := ts.getTagFromStore(uid)
if err != nil {
return nil, ErrNotFound return nil, ErrNotFound
} }
ts.tags.LoadOrStore(ta.Uid, ta)
return ta, nil
}
return t.(*Tag), nil return t.(*Tag), nil
} }
...@@ -143,3 +156,33 @@ func (ts *Tags) UnmarshalJSON(value []byte) error { ...@@ -143,3 +156,33 @@ func (ts *Tags) UnmarshalJSON(value []byte) error {
return err return err
} }
// getTagFromStore get a given tag from the state store.
func (ts *Tags) getTagFromStore(uid uint32) (*Tag, error) {
key := "tags_" + strconv.Itoa(int(uid))
var data []byte
err := ts.stateStore.Get(key, &data)
if err != nil {
return nil, err
}
var ta Tag
err = ta.UnmarshalBinary(data)
if err != nil {
return nil, err
}
return &ta, nil
}
// Close is called when the node goes down. This is when all the tags in memory is persisted.
func (ts *Tags) Close() (err error) {
// store all the tags in memory
tags := ts.All()
for _, t := range tags {
ts.logger.Trace("updating tag: ", t.Uid)
err := t.saveTag()
if err != nil {
return err
}
}
return nil
}
...@@ -17,11 +17,18 @@ ...@@ -17,11 +17,18 @@
package tags package tags
import ( import (
"io/ioutil"
"testing" "testing"
"github.com/ethersphere/bee/pkg/logging"
statestore "github.com/ethersphere/bee/pkg/statestore/mock"
"github.com/ethersphere/bee/pkg/swarm"
) )
func TestAll(t *testing.T) { func TestAll(t *testing.T) {
ts := NewTags() mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
ts := NewTags(mockStatestore, logger)
if _, err := ts.Create("1", 1); err != nil { if _, err := ts.Create("1", 1); err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -52,3 +59,81 @@ func TestAll(t *testing.T) { ...@@ -52,3 +59,81 @@ func TestAll(t *testing.T) {
t.Fatalf("expected length to be 3 got %d", len(all)) t.Fatalf("expected length to be 3 got %d", len(all))
} }
} }
func TestPersistence(t *testing.T) {
mockStatestore := statestore.NewStateStore()
logger := logging.New(ioutil.Discard, 0)
ts := NewTags(mockStatestore, logger)
ta, err := ts.Create("one", 1)
if err != nil {
t.Fatal(err)
}
ta.Total = 10
ta.Seen = 2
ta.Split = 10
ta.Stored = 8
_, err = ta.DoneSplit(swarm.ZeroAddress)
if err != nil {
t.Fatal(err)
}
// simulate node closing down and booting up
err = ts.Close()
if err != nil {
t.Fatal(err)
}
ts = NewTags(mockStatestore, logger)
// Get the tag after the node bootup
rcvd1, err := ts.Get(ta.Uid)
if err != nil {
t.Fatal(err)
}
// check if the values ae intact after the bootup
if ta.Uid != rcvd1.Uid {
t.Fatalf("invalid uid: expected %d got %d", ta.Uid, rcvd1.Uid)
}
if ta.Total != rcvd1.Total {
t.Fatalf("invalid total: expected %d got %d", ta.Total, rcvd1.Total)
}
// See if tag is saved after syncing is over
for i := 0; i < 8; i++ {
err := ta.Inc(StateSent)
if err != nil {
t.Fatal(err)
}
err = ta.Inc(StateSynced)
if err != nil {
t.Fatal(err)
}
}
// simulate node closing down and booting up
err = ts.Close()
if err != nil {
t.Fatal(err)
}
ts = NewTags(mockStatestore, logger)
// get the tag after the node boot up
rcvd2, err := ts.Get(ta.Uid)
if err != nil {
t.Fatal(err)
}
// check if the values ae intact after the bootup
if ta.Uid != rcvd2.Uid {
t.Fatalf("invalid uid: expected %d got %d", ta.Uid, rcvd2.Uid)
}
if ta.Total != rcvd2.Total {
t.Fatalf("invalid total: expected %d got %d", ta.Total, rcvd2.Total)
}
if ta.Sent != rcvd2.Sent {
t.Fatalf("invalid sent: expected %d got %d", ta.Sent, rcvd2.Sent)
}
if ta.Synced != rcvd2.Synced {
t.Fatalf("invalid synced: expected %d got %d", ta.Synced, rcvd2.Synced)
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment