Commit 0722fbce authored by acud's avatar acud Committed by GitHub

api, netstore: dont set empty targets, return correct errors from netstore (#747)

* dont set empty targets, return correct errors from netstore
parent dfa735c5
...@@ -28,7 +28,9 @@ import ( ...@@ -28,7 +28,9 @@ import (
func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) { func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) {
logger := tracing.NewLoggerWithTraceID(r.Context(), s.Logger) logger := tracing.NewLoggerWithTraceID(r.Context(), s.Logger)
targets := r.URL.Query().Get("targets") targets := r.URL.Query().Get("targets")
r = r.WithContext(sctx.SetTargets(r.Context(), targets)) if targets != "" {
r = r.WithContext(sctx.SetTargets(r.Context(), targets))
}
ctx := r.Context() ctx := r.Context()
nameOrHex := mux.Vars(r)["address"] nameOrHex := mux.Vars(r)["address"]
......
...@@ -96,7 +96,9 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -96,7 +96,9 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) {
func (s *server) chunkGetHandler(w http.ResponseWriter, r *http.Request) { func (s *server) chunkGetHandler(w http.ResponseWriter, r *http.Request) {
targets := r.URL.Query().Get("targets") targets := r.URL.Query().Get("targets")
r = r.WithContext(sctx.SetTargets(r.Context(), targets)) if targets != "" {
r = r.WithContext(sctx.SetTargets(r.Context(), targets))
}
nameOrHex := mux.Vars(r)["addr"] nameOrHex := mux.Vars(r)["addr"]
ctx := r.Context() ctx := r.Context()
......
...@@ -242,7 +242,9 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -242,7 +242,9 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) {
} }
targets := r.URL.Query().Get("targets") targets := r.URL.Query().Get("targets")
r = r.WithContext(sctx.SetTargets(r.Context(), targets)) if targets != "" {
r = r.WithContext(sctx.SetTargets(r.Context(), targets))
}
// read entry. // read entry.
j := seekjoiner.NewSimpleJoiner(s.Storer) j := seekjoiner.NewSimpleJoiner(s.Storer)
...@@ -303,7 +305,9 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) { ...@@ -303,7 +305,9 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) {
func (s *server) downloadHandler(w http.ResponseWriter, r *http.Request, reference swarm.Address, additionalHeaders http.Header) { func (s *server) downloadHandler(w http.ResponseWriter, r *http.Request, reference swarm.Address, additionalHeaders http.Header) {
logger := tracing.NewLoggerWithTraceID(r.Context(), s.Logger) logger := tracing.NewLoggerWithTraceID(r.Context(), s.Logger)
targets := r.URL.Query().Get("targets") targets := r.URL.Query().Get("targets")
r = r.WithContext(sctx.SetTargets(r.Context(), targets)) if targets != "" {
r = r.WithContext(sctx.SetTargets(r.Context(), targets))
}
rs := seekjoiner.NewSimpleJoiner(s.Storer) rs := seekjoiner.NewSimpleJoiner(s.Storer)
reader, l, err := rs.Join(r.Context(), reference) reader, l, err := rs.Join(r.Context(), reference)
......
...@@ -47,18 +47,19 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres ...@@ -47,18 +47,19 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres
if s.recoveryCallback == nil { if s.recoveryCallback == nil {
return nil, err return nil, err
} }
targets, err := sctx.GetTargets(ctx) targets := sctx.GetTargets(ctx)
if err != nil { if targets != nil {
return nil, err go func() {
err := s.recoveryCallback(addr, targets)
if err != nil {
s.logger.Debugf("netstore: error while recovering chunk: %v", err)
}
}()
return nil, ErrRecoveryAttempt
} }
go func() { return nil, fmt.Errorf("netstore retrieve chunk: %w", err)
err := s.recoveryCallback(addr, targets)
if err != nil {
s.logger.Debugf("netstore: error while recovering chunk: %v", err)
}
}()
return nil, ErrRecoveryAttempt
} }
_, err = s.Storer.Put(ctx, storage.ModePutRequest, ch) _, err = s.Storer.Put(ctx, storage.ModePutRequest, ch)
if err != nil { if err != nil {
return nil, fmt.Errorf("netstore retrieve put: %w", err) return nil, fmt.Errorf("netstore retrieve put: %w", err)
......
...@@ -135,24 +135,6 @@ func TestInvalidRecoveryFunction(t *testing.T) { ...@@ -135,24 +135,6 @@ func TestInvalidRecoveryFunction(t *testing.T) {
} }
} }
func TestInvalidTargets(t *testing.T) {
hookWasCalled := make(chan bool, 1)
rec := &mockRecovery{
hookC: hookWasCalled,
}
retrieve, _, nstore := newRetrievingNetstore(rec)
addr := swarm.MustParseHexAddress("deadbeef")
retrieve.failure = true
ctx := context.Background()
ctx = sctx.SetTargets(ctx, "gh")
_, err := nstore.Get(ctx, storage.ModeGetRequest, addr)
if err != nil && !errors.Is(err, sctx.ErrTargetPrefix) {
t.Fatal(err)
}
}
// returns a mock retrieval protocol, a mock local storage and a netstore // returns a mock retrieval protocol, a mock local storage and a netstore
func newRetrievingNetstore(rec *mockRecovery) (ret *retrievalMock, mockStore, ns storage.Storer) { func newRetrievingNetstore(rec *mockRecovery) (ret *retrievalMock, mockStore, ns storage.Storer) {
retrieve := &retrievalMock{} retrieve := &retrievalMock{}
......
...@@ -69,15 +69,9 @@ func TestRecoveryHookCalls(t *testing.T) { ...@@ -69,15 +69,9 @@ func TestRecoveryHookCalls(t *testing.T) {
target := "BE" target := "BE"
// test cases variables // test cases variables
dummyContext := context.Background() // has no publisher
targetContext := sctx.SetTargets(context.Background(), target) targetContext := sctx.SetTargets(context.Background(), target)
for _, tc := range []recoveryHookTestCase{ for _, tc := range []recoveryHookTestCase{
{
name: "no targets in context",
ctx: dummyContext,
expectsFailure: true,
},
{ {
name: "targets set in context", name: "targets set in context",
ctx: targetContext, ctx: targetContext,
......
...@@ -63,10 +63,10 @@ func SetTargets(ctx context.Context, targets string) context.Context { ...@@ -63,10 +63,10 @@ func SetTargets(ctx context.Context, targets string) context.Context {
// GetTargets returns the specific target pinners for a corresponding chunk by // GetTargets returns the specific target pinners for a corresponding chunk by
// reading the prefix targets sent in the download API. // reading the prefix targets sent in the download API.
func GetTargets(ctx context.Context) (trojan.Targets, error) { func GetTargets(ctx context.Context) trojan.Targets {
targetString, ok := ctx.Value(targetsContextKey{}).(string) targetString, ok := ctx.Value(targetsContextKey{}).(string)
if !ok { if !ok {
return nil, ErrTargetPrefix return nil
} }
prefixes := strings.Split(targetString, ",") prefixes := strings.Split(targetString, ",")
...@@ -80,7 +80,7 @@ func GetTargets(ctx context.Context) (trojan.Targets, error) { ...@@ -80,7 +80,7 @@ func GetTargets(ctx context.Context) (trojan.Targets, error) {
targets = append(targets, target) targets = append(targets, target)
} }
if len(targets) <= 0 { if len(targets) <= 0 {
return nil, ErrTargetPrefix return nil
} }
return targets, nil return targets
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment