Commit 0722fbce authored by acud's avatar acud Committed by GitHub

api, netstore: dont set empty targets, return correct errors from netstore (#747)

* dont set empty targets, return correct errors from netstore
parent dfa735c5
......@@ -28,7 +28,9 @@ import (
func (s *server) bzzDownloadHandler(w http.ResponseWriter, r *http.Request) {
logger := tracing.NewLoggerWithTraceID(r.Context(), s.Logger)
targets := r.URL.Query().Get("targets")
if targets != "" {
r = r.WithContext(sctx.SetTargets(r.Context(), targets))
}
ctx := r.Context()
nameOrHex := mux.Vars(r)["address"]
......
......@@ -96,7 +96,9 @@ func (s *server) chunkUploadHandler(w http.ResponseWriter, r *http.Request) {
func (s *server) chunkGetHandler(w http.ResponseWriter, r *http.Request) {
targets := r.URL.Query().Get("targets")
if targets != "" {
r = r.WithContext(sctx.SetTargets(r.Context(), targets))
}
nameOrHex := mux.Vars(r)["addr"]
ctx := r.Context()
......
......@@ -242,7 +242,9 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) {
}
targets := r.URL.Query().Get("targets")
if targets != "" {
r = r.WithContext(sctx.SetTargets(r.Context(), targets))
}
// read entry.
j := seekjoiner.NewSimpleJoiner(s.Storer)
......@@ -303,7 +305,9 @@ func (s *server) fileDownloadHandler(w http.ResponseWriter, r *http.Request) {
func (s *server) downloadHandler(w http.ResponseWriter, r *http.Request, reference swarm.Address, additionalHeaders http.Header) {
logger := tracing.NewLoggerWithTraceID(r.Context(), s.Logger)
targets := r.URL.Query().Get("targets")
if targets != "" {
r = r.WithContext(sctx.SetTargets(r.Context(), targets))
}
rs := seekjoiner.NewSimpleJoiner(s.Storer)
reader, l, err := rs.Join(r.Context(), reference)
......
......@@ -47,10 +47,8 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres
if s.recoveryCallback == nil {
return nil, err
}
targets, err := sctx.GetTargets(ctx)
if err != nil {
return nil, err
}
targets := sctx.GetTargets(ctx)
if targets != nil {
go func() {
err := s.recoveryCallback(addr, targets)
if err != nil {
......@@ -59,6 +57,9 @@ func (s *store) Get(ctx context.Context, mode storage.ModeGet, addr swarm.Addres
}()
return nil, ErrRecoveryAttempt
}
return nil, fmt.Errorf("netstore retrieve chunk: %w", err)
}
_, err = s.Storer.Put(ctx, storage.ModePutRequest, ch)
if err != nil {
return nil, fmt.Errorf("netstore retrieve put: %w", err)
......
......@@ -135,24 +135,6 @@ func TestInvalidRecoveryFunction(t *testing.T) {
}
}
func TestInvalidTargets(t *testing.T) {
hookWasCalled := make(chan bool, 1)
rec := &mockRecovery{
hookC: hookWasCalled,
}
retrieve, _, nstore := newRetrievingNetstore(rec)
addr := swarm.MustParseHexAddress("deadbeef")
retrieve.failure = true
ctx := context.Background()
ctx = sctx.SetTargets(ctx, "gh")
_, err := nstore.Get(ctx, storage.ModeGetRequest, addr)
if err != nil && !errors.Is(err, sctx.ErrTargetPrefix) {
t.Fatal(err)
}
}
// returns a mock retrieval protocol, a mock local storage and a netstore
func newRetrievingNetstore(rec *mockRecovery) (ret *retrievalMock, mockStore, ns storage.Storer) {
retrieve := &retrievalMock{}
......
......@@ -69,15 +69,9 @@ func TestRecoveryHookCalls(t *testing.T) {
target := "BE"
// test cases variables
dummyContext := context.Background() // has no publisher
targetContext := sctx.SetTargets(context.Background(), target)
for _, tc := range []recoveryHookTestCase{
{
name: "no targets in context",
ctx: dummyContext,
expectsFailure: true,
},
{
name: "targets set in context",
ctx: targetContext,
......
......@@ -63,10 +63,10 @@ func SetTargets(ctx context.Context, targets string) context.Context {
// GetTargets returns the specific target pinners for a corresponding chunk by
// reading the prefix targets sent in the download API.
func GetTargets(ctx context.Context) (trojan.Targets, error) {
func GetTargets(ctx context.Context) trojan.Targets {
targetString, ok := ctx.Value(targetsContextKey{}).(string)
if !ok {
return nil, ErrTargetPrefix
return nil
}
prefixes := strings.Split(targetString, ",")
......@@ -80,7 +80,7 @@ func GetTargets(ctx context.Context) (trojan.Targets, error) {
targets = append(targets, target)
}
if len(targets) <= 0 {
return nil, ErrTargetPrefix
return nil
}
return targets, nil
return targets
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment