Commit 9ee31073 authored by aloknerurkar's avatar aloknerurkar Committed by GitHub

feat(debugapi, postage): dilute batch handling (#2410)

parent da0cb074
...@@ -908,3 +908,38 @@ paths: ...@@ -908,3 +908,38 @@ paths:
$ref: "SwarmCommon.yaml#/components/responses/500" $ref: "SwarmCommon.yaml#/components/responses/500"
default: default:
description: Default response description: Default response
"/stamps/dilute/{id}/{depth}":
patch:
summary: Dilute an existing postage batch.
description: Be aware, this endpoint creates on-chain transactions and transfers BZZ from the node's Ethereum account and hence directly manipulates the wallet balance!
tags:
- Postage Stamps
parameters:
- in: path
name: id
schema:
$ref: "SwarmCommon.yaml#/components/schemas/BatchID"
required: true
description: Batch ID to dilute
- in: path
name: depth
schema:
type: integer
required: true
description: New batch depth. Must be higher than the previous depth.
responses:
"202":
description: Returns the postage batch ID that was diluted.
content:
application/json:
schema:
$ref: "SwarmCommon.yaml#/components/schemas/BatchIDResponse"
"400":
$ref: "SwarmCommon.yaml#/components/responses/400"
"429":
$ref: "SwarmCommon.yaml#/components/responses/429"
"500":
$ref: "SwarmCommon.yaml#/components/responses/500"
default:
description: Default response
...@@ -69,7 +69,7 @@ type Service struct { ...@@ -69,7 +69,7 @@ type Service struct {
// The following are semaphores which exists to limit concurrent access // The following are semaphores which exists to limit concurrent access
// to some parts of the resources in order to avoid undefined behaviour. // to some parts of the resources in order to avoid undefined behaviour.
postageCreateSem *semaphore.Weighted postageSem *semaphore.Weighted
cashOutChequeSem *semaphore.Weighted cashOutChequeSem *semaphore.Weighted
} }
...@@ -88,7 +88,7 @@ func New(publicKey, pssPublicKey ecdsa.PublicKey, ethereumAddress common.Address ...@@ -88,7 +88,7 @@ func New(publicKey, pssPublicKey ecdsa.PublicKey, ethereumAddress common.Address
s.blockTime = blockTime s.blockTime = blockTime
s.metricsRegistry = newMetricsRegistry() s.metricsRegistry = newMetricsRegistry()
s.transaction = transaction s.transaction = transaction
s.postageCreateSem = semaphore.NewWeighted(1) s.postageSem = semaphore.NewWeighted(1)
s.cashOutChequeSem = semaphore.NewWeighted(1) s.cashOutChequeSem = semaphore.NewWeighted(1)
s.setRouter(s.newBasicRouter()) s.setRouter(s.newBasicRouter())
......
...@@ -21,6 +21,20 @@ import ( ...@@ -21,6 +21,20 @@ import (
"github.com/gorilla/mux" "github.com/gorilla/mux"
) )
func (s *Service) postageAccessHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !s.postageSem.TryAcquire(1) {
s.logger.Debug("postage access: simultaneous on-chain operations not supported")
s.logger.Error("postage access: simultaneous on-chain operations not supported")
jsonhttp.TooManyRequests(w, "simultaneous on-chain operations not supported")
return
}
defer s.postageSem.Release(1)
h.ServeHTTP(w, r)
})
}
type batchID []byte type batchID []byte
func (b batchID) MarshalJSON() ([]byte, error) { func (b batchID) MarshalJSON() ([]byte, error) {
...@@ -68,14 +82,6 @@ func (s *Service) postageCreateHandler(w http.ResponseWriter, r *http.Request) { ...@@ -68,14 +82,6 @@ func (s *Service) postageCreateHandler(w http.ResponseWriter, r *http.Request) {
immutable, _ = strconv.ParseBool(val[0]) immutable, _ = strconv.ParseBool(val[0])
} }
if !s.postageCreateSem.TryAcquire(1) {
s.logger.Debug("create batch: simultaneous on-chain operations not supported")
s.logger.Error("create batch: simultaneous on-chain operations not supported")
jsonhttp.TooManyRequests(w, "simultaneous on-chain operations not supported")
return
}
defer s.postageCreateSem.Release(1)
batchID, err := s.postageContract.CreateBatch(ctx, amount, uint8(depth), immutable, label) batchID, err := s.postageContract.CreateBatch(ctx, amount, uint8(depth), immutable, label)
if err != nil { if err != nil {
if errors.Is(err, postagecontract.ErrInsufficientFunds) { if errors.Is(err, postagecontract.ErrInsufficientFunds) {
...@@ -324,7 +330,7 @@ func (s *Service) estimateBatchTTL(id []byte) (int64, error) { ...@@ -324,7 +330,7 @@ func (s *Service) estimateBatchTTL(id []byte) (int64, error) {
func (s *Service) postageTopUpHandler(w http.ResponseWriter, r *http.Request) { func (s *Service) postageTopUpHandler(w http.ResponseWriter, r *http.Request) {
idStr := mux.Vars(r)["id"] idStr := mux.Vars(r)["id"]
if idStr == "" || len(idStr) != 64 { if len(idStr) != 64 {
s.logger.Error("topup batch: invalid batchID") s.logger.Error("topup batch: invalid batchID")
jsonhttp.BadRequest(w, "invalid batchID") jsonhttp.BadRequest(w, "invalid batchID")
return return
...@@ -355,14 +361,6 @@ func (s *Service) postageTopUpHandler(w http.ResponseWriter, r *http.Request) { ...@@ -355,14 +361,6 @@ func (s *Service) postageTopUpHandler(w http.ResponseWriter, r *http.Request) {
ctx = sctx.SetGasPrice(ctx, p) ctx = sctx.SetGasPrice(ctx, p)
} }
if !s.postageCreateSem.TryAcquire(1) {
s.logger.Debug("topup batch: simultaneous on-chain operations not supported")
s.logger.Error("topup batch: simultaneous on-chain operations not supported")
jsonhttp.TooManyRequests(w, "simultaneous on-chain operations not supported")
return
}
defer s.postageCreateSem.Release(1)
err = s.postageContract.TopUpBatch(ctx, id, amount) err = s.postageContract.TopUpBatch(ctx, id, amount)
if err != nil { if err != nil {
if errors.Is(err, postagecontract.ErrInsufficientFunds) { if errors.Is(err, postagecontract.ErrInsufficientFunds) {
...@@ -381,3 +379,57 @@ func (s *Service) postageTopUpHandler(w http.ResponseWriter, r *http.Request) { ...@@ -381,3 +379,57 @@ func (s *Service) postageTopUpHandler(w http.ResponseWriter, r *http.Request) {
BatchID: id, BatchID: id,
}) })
} }
func (s *Service) postageDiluteHandler(w http.ResponseWriter, r *http.Request) {
idStr := mux.Vars(r)["id"]
if len(idStr) != 64 {
s.logger.Error("dilute batch: invalid batchID")
jsonhttp.BadRequest(w, "invalid batchID")
return
}
id, err := hex.DecodeString(idStr)
if err != nil {
s.logger.Debugf("dilute batch: invalid batchID: %v", err)
s.logger.Error("dilute batch: invalid batchID")
jsonhttp.BadRequest(w, "invalid batchID")
return
}
depthStr := mux.Vars(r)["depth"]
depth, err := strconv.ParseUint(depthStr, 10, 8)
if err != nil {
s.logger.Debugf("dilute batch: invalid depth: %v", err)
s.logger.Error("dilute batch: invalid depth")
jsonhttp.BadRequest(w, "invalid depth")
return
}
ctx := r.Context()
if price, ok := r.Header[gasPriceHeader]; ok {
p, ok := big.NewInt(0).SetString(price[0], 10)
if !ok {
s.logger.Error("dilute batch: bad gas price")
jsonhttp.BadRequest(w, errBadGasPrice)
return
}
ctx = sctx.SetGasPrice(ctx, p)
}
err = s.postageContract.DiluteBatch(ctx, id, uint8(depth))
if err != nil {
if errors.Is(err, postagecontract.ErrInvalidDepth) {
s.logger.Debugf("dilute batch: invalid depth: %v", err)
s.logger.Error("dilte batch: invalid depth")
jsonhttp.BadRequest(w, "invalid depth")
return
}
s.logger.Debugf("dilute batch: failed to dilute: %v", err)
s.logger.Error("dilute batch: failed to dilute")
jsonhttp.InternalServerError(w, "cannot dilute batch")
return
}
jsonhttp.Accepted(w, &postageCreateResponse{
BatchID: id,
})
}
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"math/big" "math/big"
"net/http" "net/http"
"testing" "testing"
"time"
"github.com/ethersphere/bee/pkg/bigint" "github.com/ethersphere/bee/pkg/bigint"
"github.com/ethersphere/bee/pkg/debugapi" "github.com/ethersphere/bee/pkg/debugapi"
...@@ -486,3 +487,237 @@ func TestPostageTopUpStamp(t *testing.T) { ...@@ -486,3 +487,237 @@ func TestPostageTopUpStamp(t *testing.T) {
) )
}) })
} }
func TestPostageDiluteStamp(t *testing.T) {
newBatchDepth := uint8(17)
diluteBatch := func(id string, depth uint8) string {
return fmt.Sprintf("/stamps/dilute/%s/%d", id, depth)
}
t.Run("ok", func(t *testing.T) {
contract := contractMock.New(
contractMock.WithDiluteBatchFunc(func(ctx context.Context, id []byte, newDepth uint8) error {
if !bytes.Equal(id, batchOk) {
return errors.New("incorrect batch ID in call")
}
if newDepth != newBatchDepth {
return fmt.Errorf("called with wrong depth. wanted %d, got %d", newBatchDepth, newDepth)
}
return nil
}),
)
ts := newTestServer(t, testServerOptions{
PostageContract: contract,
})
jsonhttptest.Request(t, ts.Client, http.MethodPatch, diluteBatch(batchOkStr, newBatchDepth), http.StatusAccepted,
jsonhttptest.WithExpectedJSONResponse(&debugapi.PostageCreateResponse{
BatchID: batchOk,
}),
)
})
t.Run("with-custom-gas", func(t *testing.T) {
contract := contractMock.New(
contractMock.WithDiluteBatchFunc(func(ctx context.Context, id []byte, newDepth uint8) error {
if !bytes.Equal(id, batchOk) {
return errors.New("incorrect batch ID in call")
}
if newDepth != newBatchDepth {
return fmt.Errorf("called with wrong depth. wanted %d, got %d", newBatchDepth, newDepth)
}
if sctx.GetGasPrice(ctx).Cmp(big.NewInt(10000)) != 0 {
return fmt.Errorf("called with wrong gas price. wanted %d, got %d", 10000, sctx.GetGasPrice(ctx))
}
return nil
}),
)
ts := newTestServer(t, testServerOptions{
PostageContract: contract,
})
jsonhttptest.Request(t, ts.Client, http.MethodPatch, diluteBatch(batchOkStr, newBatchDepth), http.StatusAccepted,
jsonhttptest.WithRequestHeader("Gas-Price", "10000"),
jsonhttptest.WithExpectedJSONResponse(&debugapi.PostageCreateResponse{
BatchID: batchOk,
}),
)
})
t.Run("with-error", func(t *testing.T) {
contract := contractMock.New(
contractMock.WithDiluteBatchFunc(func(ctx context.Context, id []byte, newDepth uint8) error {
return errors.New("err")
}),
)
ts := newTestServer(t, testServerOptions{
PostageContract: contract,
})
jsonhttptest.Request(t, ts.Client, http.MethodPatch, diluteBatch(batchOkStr, newBatchDepth), http.StatusInternalServerError,
jsonhttptest.WithExpectedJSONResponse(&jsonhttp.StatusResponse{
Code: http.StatusInternalServerError,
Message: "cannot dilute batch",
}),
)
})
t.Run("with depth error", func(t *testing.T) {
contract := contractMock.New(
contractMock.WithDiluteBatchFunc(func(ctx context.Context, id []byte, newDepth uint8) error {
return postagecontract.ErrInvalidDepth
}),
)
ts := newTestServer(t, testServerOptions{
PostageContract: contract,
})
jsonhttptest.Request(t, ts.Client, http.MethodPatch, diluteBatch(batchOkStr, newBatchDepth), http.StatusBadRequest,
jsonhttptest.WithExpectedJSONResponse(&jsonhttp.StatusResponse{
Code: http.StatusBadRequest,
Message: "invalid depth",
}),
)
})
t.Run("invalid batch id", func(t *testing.T) {
ts := newTestServer(t, testServerOptions{})
jsonhttptest.Request(t, ts.Client, http.MethodPatch, "/stamps/dilute/abcd/2", http.StatusBadRequest,
jsonhttptest.WithExpectedJSONResponse(&jsonhttp.StatusResponse{
Code: http.StatusBadRequest,
Message: "invalid batchID",
}),
)
})
t.Run("invalid depth", func(t *testing.T) {
ts := newTestServer(t, testServerOptions{})
wrongURL := fmt.Sprintf("/stamps/dilute/%s/depth", batchOkStr)
jsonhttptest.Request(t, ts.Client, http.MethodPatch, wrongURL, http.StatusBadRequest,
jsonhttptest.WithExpectedJSONResponse(&jsonhttp.StatusResponse{
Code: http.StatusBadRequest,
Message: "invalid depth",
}),
)
})
}
// Tests the postageAccessHandler middleware for any set of operations that are guarded
// by the postage semaphore
func TestPostageAccessHandler(t *testing.T) {
type operation struct {
name string
method string
url string
respCode int
resp interface{}
}
success := []operation{
{
name: "create batch ok",
method: http.MethodPost,
url: "/stamps/1000/17?label=test",
respCode: http.StatusCreated,
resp: &debugapi.PostageCreateResponse{
BatchID: batchOk,
},
},
{
name: "topup batch ok",
method: http.MethodPatch,
url: fmt.Sprintf("/stamps/topup/%s/10", batchOkStr),
respCode: http.StatusAccepted,
resp: &debugapi.PostageCreateResponse{
BatchID: batchOk,
},
},
{
name: "dilute batch ok",
method: http.MethodPatch,
url: fmt.Sprintf("/stamps/dilute/%s/18", batchOkStr),
respCode: http.StatusAccepted,
resp: &debugapi.PostageCreateResponse{
BatchID: batchOk,
},
},
}
failure := []operation{
{
name: "create batch not ok",
method: http.MethodPost,
url: "/stamps/1000/17?label=test",
respCode: http.StatusTooManyRequests,
resp: &jsonhttp.StatusResponse{
Code: http.StatusTooManyRequests,
Message: "simultaneous on-chain operations not supported",
},
},
{
name: "topup batch not ok",
method: http.MethodPatch,
url: fmt.Sprintf("/stamps/topup/%s/10", batchOkStr),
respCode: http.StatusTooManyRequests,
resp: &jsonhttp.StatusResponse{
Code: http.StatusTooManyRequests,
Message: "simultaneous on-chain operations not supported",
},
},
{
name: "dilute batch not ok",
method: http.MethodPatch,
url: fmt.Sprintf("/stamps/dilute/%s/18", batchOkStr),
respCode: http.StatusTooManyRequests,
resp: &jsonhttp.StatusResponse{
Code: http.StatusTooManyRequests,
Message: "simultaneous on-chain operations not supported",
},
},
}
for _, op1 := range success {
for _, op2 := range failure {
t.Run(op1.name+"-"+op2.name, func(t *testing.T) {
wait, done := make(chan struct{}), make(chan struct{})
contract := contractMock.New(
contractMock.WithCreateBatchFunc(func(ctx context.Context, ib *big.Int, d uint8, i bool, l string) ([]byte, error) {
<-wait
return batchOk, nil
}),
contractMock.WithTopUpBatchFunc(func(ctx context.Context, id []byte, ib *big.Int) error {
<-wait
return nil
}),
contractMock.WithDiluteBatchFunc(func(ctx context.Context, id []byte, newDepth uint8) error {
<-wait
return nil
}),
)
ts := newTestServer(t, testServerOptions{
PostageContract: contract,
})
go func() {
defer close(done)
jsonhttptest.Request(t, ts.Client, op1.method, op1.url, op1.respCode, jsonhttptest.WithExpectedJSONResponse(op1.resp))
}()
time.Sleep(time.Millisecond * 100)
jsonhttptest.Request(t, ts.Client, op2.method, op2.url, op2.respCode, jsonhttptest.WithExpectedJSONResponse(op2.resp))
close(wait)
<-done
})
}
}
}
...@@ -205,17 +205,26 @@ func (s *Service) newRouter() *mux.Router { ...@@ -205,17 +205,26 @@ func (s *Service) newRouter() *mux.Router {
) )
router.Handle("/stamps/{amount}/{depth}", web.ChainHandlers( router.Handle("/stamps/{amount}/{depth}", web.ChainHandlers(
s.postageAccessHandler,
web.FinalHandler(jsonhttp.MethodHandler{ web.FinalHandler(jsonhttp.MethodHandler{
"POST": http.HandlerFunc(s.postageCreateHandler), "POST": http.HandlerFunc(s.postageCreateHandler),
})), })),
) )
router.Handle("/stamps/topup/{id}/{amount}", web.ChainHandlers( router.Handle("/stamps/topup/{id}/{amount}", web.ChainHandlers(
s.postageAccessHandler,
web.FinalHandler(jsonhttp.MethodHandler{ web.FinalHandler(jsonhttp.MethodHandler{
"PATCH": http.HandlerFunc(s.postageTopUpHandler), "PATCH": http.HandlerFunc(s.postageTopUpHandler),
})), })),
) )
router.Handle("/stamps/dilute/{id}/{depth}", web.ChainHandlers(
s.postageAccessHandler,
web.FinalHandler(jsonhttp.MethodHandler{
"PATCH": http.HandlerFunc(s.postageDiluteHandler),
})),
)
return router return router
} }
......
...@@ -29,6 +29,7 @@ import ( ...@@ -29,6 +29,7 @@ import (
"github.com/ethersphere/bee/pkg/postage" "github.com/ethersphere/bee/pkg/postage"
"github.com/ethersphere/bee/pkg/postage/batchstore" "github.com/ethersphere/bee/pkg/postage/batchstore"
mockPost "github.com/ethersphere/bee/pkg/postage/mock" mockPost "github.com/ethersphere/bee/pkg/postage/mock"
"github.com/ethersphere/bee/pkg/postage/postagecontract"
mockPostContract "github.com/ethersphere/bee/pkg/postage/postagecontract/mock" mockPostContract "github.com/ethersphere/bee/pkg/postage/postagecontract/mock"
postagetesting "github.com/ethersphere/bee/pkg/postage/testing" postagetesting "github.com/ethersphere/bee/pkg/postage/testing"
"github.com/ethersphere/bee/pkg/pss" "github.com/ethersphere/bee/pkg/pss"
...@@ -245,6 +246,28 @@ func NewDevBee(logger logging.Logger, o *DevOptions) (b *DevBee, err error) { ...@@ -245,6 +246,28 @@ func NewDevBee(logger logging.Logger, o *DevOptions) (b *DevBee, err error) {
return nil return nil
}, },
), ),
mockPostContract.WithDiluteBatchFunc(
func(ctx context.Context, batchID []byte, newDepth uint8) error {
batch, err := batchStore.Get(batchID)
if err != nil {
return err
}
if newDepth < batch.Depth {
return postagecontract.ErrInvalidDepth
}
newBalance := big.NewInt(0).Div(batch.Value, big.NewInt(int64(1<<(newDepth-batch.Depth))))
err = batchStore.Put(batch, newBalance, newDepth)
if err != nil {
return err
}
post.HandleDepthIncrease(batch.ID, newDepth, newBalance)
return nil
},
),
) )
feedFactory := factory.New(storer) feedFactory := factory.New(storer)
......
...@@ -159,6 +159,11 @@ func (svc *batchService) UpdateDepth(id []byte, depth uint8, normalisedBalance * ...@@ -159,6 +159,11 @@ func (svc *batchService) UpdateDepth(id []byte, depth uint8, normalisedBalance *
if err != nil { if err != nil {
return fmt.Errorf("put: %w", err) return fmt.Errorf("put: %w", err)
} }
if bytes.Equal(svc.owner, b.Owner) && svc.batchListener != nil {
svc.batchListener.HandleDepthIncrease(id, depth, normalisedBalance)
}
cs, err := svc.updateChecksum(txHash) cs, err := svc.updateChecksum(txHash)
if err != nil { if err != nil {
return fmt.Errorf("update checksum: %w", err) return fmt.Errorf("update checksum: %w", err)
......
...@@ -41,6 +41,7 @@ func newMockListener() *mockListener { ...@@ -41,6 +41,7 @@ func newMockListener() *mockListener {
type mockBatchListener struct { type mockBatchListener struct {
createCount int createCount int
topupCount int topupCount int
diluteCount int
} }
func (m *mockBatchListener) HandleCreate(b *postage.Batch) { func (m *mockBatchListener) HandleCreate(b *postage.Batch) {
...@@ -51,6 +52,10 @@ func (m *mockBatchListener) HandleTopUp(_ []byte, _ *big.Int) { ...@@ -51,6 +52,10 @@ func (m *mockBatchListener) HandleTopUp(_ []byte, _ *big.Int) {
m.topupCount++ m.topupCount++
} }
func (m *mockBatchListener) HandleDepthIncrease(_ []byte, _ uint8, _ *big.Int) {
m.diluteCount++
}
func TestBatchServiceCreate(t *testing.T) { func TestBatchServiceCreate(t *testing.T) {
testChainState := postagetesting.NewChainState() testChainState := postagetesting.NewChainState()
...@@ -279,19 +284,29 @@ func TestBatchServiceUpdateDepth(t *testing.T) { ...@@ -279,19 +284,29 @@ func TestBatchServiceUpdateDepth(t *testing.T) {
testBatch := postagetesting.MustNewBatch() testBatch := postagetesting.MustNewBatch()
t.Run("expect get error", func(t *testing.T) { t.Run("expect get error", func(t *testing.T) {
svc, _, _ := newTestStoreAndService( testBatchListener := &mockBatchListener{}
svc, _, _ := newTestStoreAndServiceWithListener(
t, t,
testBatch.Owner,
testBatchListener,
mock.WithGetErr(errTest, 0), mock.WithGetErr(errTest, 0),
) )
if err := svc.UpdateDepth(testBatch.ID, testNewDepth, testNormalisedBalance, testTxHash); err == nil { if err := svc.UpdateDepth(testBatch.ID, testNewDepth, testNormalisedBalance, testTxHash); err == nil {
t.Fatal("expected get error") t.Fatal("expected get error")
} }
if testBatchListener.diluteCount != 0 {
t.Fatalf("unexpected batch listener count, exp %d found %d", 0, testBatchListener.diluteCount)
}
}) })
t.Run("expect put error", func(t *testing.T) { t.Run("expect put error", func(t *testing.T) {
svc, batchStore, _ := newTestStoreAndService( testBatchListener := &mockBatchListener{}
svc, batchStore, _ := newTestStoreAndServiceWithListener(
t, t,
testBatch.Owner,
testBatchListener,
mock.WithPutErr(errTest, 1), mock.WithPutErr(errTest, 1),
) )
putBatch(t, batchStore, testBatch) putBatch(t, batchStore, testBatch)
...@@ -299,10 +314,19 @@ func TestBatchServiceUpdateDepth(t *testing.T) { ...@@ -299,10 +314,19 @@ func TestBatchServiceUpdateDepth(t *testing.T) {
if err := svc.UpdateDepth(testBatch.ID, testNewDepth, testNormalisedBalance, testTxHash); err == nil { if err := svc.UpdateDepth(testBatch.ID, testNewDepth, testNormalisedBalance, testTxHash); err == nil {
t.Fatal("expected put error") t.Fatal("expected put error")
} }
if testBatchListener.diluteCount != 0 {
t.Fatalf("unexpected batch listener count, exp %d found %d", 0, testBatchListener.diluteCount)
}
}) })
t.Run("passes", func(t *testing.T) { t.Run("passes", func(t *testing.T) {
svc, batchStore, _ := newTestStoreAndService(t) testBatchListener := &mockBatchListener{}
svc, batchStore, _ := newTestStoreAndServiceWithListener(
t,
testBatch.Owner,
testBatchListener,
)
putBatch(t, batchStore, testBatch) putBatch(t, batchStore, testBatch)
if err := svc.UpdateDepth(testBatch.ID, testNewDepth, testNormalisedBalance, testTxHash); err != nil { if err := svc.UpdateDepth(testBatch.ID, testNewDepth, testNormalisedBalance, testTxHash); err != nil {
...@@ -317,6 +341,43 @@ func TestBatchServiceUpdateDepth(t *testing.T) { ...@@ -317,6 +341,43 @@ func TestBatchServiceUpdateDepth(t *testing.T) {
if val.Depth != testNewDepth { if val.Depth != testNewDepth {
t.Fatalf("wrong batch depth set: want %v, got %v", testNewDepth, val.Depth) t.Fatalf("wrong batch depth set: want %v, got %v", testNewDepth, val.Depth)
} }
if testBatchListener.diluteCount != 1 {
t.Fatalf("unexpected batch listener count, exp %d found %d", 1, testBatchListener.diluteCount)
}
})
// if a batch with a different owner is diluted we should not see any event fired in the
// batch service
t.Run("passes without BatchEventListener update", func(t *testing.T) {
testBatchListener := &mockBatchListener{}
// create a owner different from the batch owner
owner := make([]byte, 32)
rand.Read(owner)
svc, batchStore, _ := newTestStoreAndServiceWithListener(
t,
owner,
testBatchListener,
)
putBatch(t, batchStore, testBatch)
if err := svc.UpdateDepth(testBatch.ID, testNewDepth, testNormalisedBalance, testTxHash); err != nil {
t.Fatalf("update depth: %v", err)
}
val, err := batchStore.Get(testBatch.ID)
if err != nil {
t.Fatalf("batch store get: %v", err)
}
if val.Depth != testNewDepth {
t.Fatalf("wrong batch depth set: want %v, got %v", testNewDepth, val.Depth)
}
if testBatchListener.diluteCount != 0 {
t.Fatalf("unexpected batch listener count, exp %d found %d", 0, testBatchListener.diluteCount)
}
}) })
} }
......
...@@ -53,4 +53,5 @@ type Listener interface { ...@@ -53,4 +53,5 @@ type Listener interface {
type BatchEventListener interface { type BatchEventListener interface {
HandleCreate(*Batch) HandleCreate(*Batch)
HandleTopUp(id []byte, newBalance *big.Int) HandleTopUp(id []byte, newBalance *big.Int)
HandleDepthIncrease(id []byte, newDepth uint8, normalisedBalance *big.Int)
} }
...@@ -92,6 +92,8 @@ func (m *mockPostage) HandleCreate(_ *postage.Batch) {} ...@@ -92,6 +92,8 @@ func (m *mockPostage) HandleCreate(_ *postage.Batch) {}
func (m *mockPostage) HandleTopUp(_ []byte, _ *big.Int) {} func (m *mockPostage) HandleTopUp(_ []byte, _ *big.Int) {}
func (m *mockPostage) HandleDepthIncrease(_ []byte, _ uint8, _ *big.Int) {}
func (m *mockPostage) Close() error { func (m *mockPostage) Close() error {
return nil return nil
} }
...@@ -29,16 +29,24 @@ var ( ...@@ -29,16 +29,24 @@ var (
erc20ABI = parseABI(sw3abi.ERC20ABIv0_3_1) erc20ABI = parseABI(sw3abi.ERC20ABIv0_3_1)
batchCreatedTopic = postageStampABI.Events["BatchCreated"].ID batchCreatedTopic = postageStampABI.Events["BatchCreated"].ID
batchTopUpTopic = postageStampABI.Events["BatchTopUp"].ID batchTopUpTopic = postageStampABI.Events["BatchTopUp"].ID
batchDiluteTopic = postageStampABI.Events["BatchDepthIncrease"].ID
ErrBatchCreate = errors.New("batch creation failed") ErrBatchCreate = errors.New("batch creation failed")
ErrInsufficientFunds = errors.New("insufficient token balance") ErrInsufficientFunds = errors.New("insufficient token balance")
ErrInvalidDepth = errors.New("invalid depth") ErrInvalidDepth = errors.New("invalid depth")
ErrBatchTopUp = errors.New("batch topUp failed") ErrBatchTopUp = errors.New("batch topUp failed")
ErrBatchDilute = errors.New("batch dilute failed")
approveDescription = "Approve tokens for postage operations"
createBatchDescription = "Postage batch creation"
topUpBatchDescription = "Postage batch top up"
diluteBatchDescription = "Postage batch dilute"
) )
type Interface interface { type Interface interface {
CreateBatch(ctx context.Context, initialBalance *big.Int, depth uint8, immutable bool, label string) ([]byte, error) CreateBatch(ctx context.Context, initialBalance *big.Int, depth uint8, immutable bool, label string) ([]byte, error)
TopUpBatch(ctx context.Context, batchID []byte, topupBalance *big.Int) error TopUpBatch(ctx context.Context, batchID []byte, topupBalance *big.Int) error
DiluteBatch(ctx context.Context, batchID []byte, newDepth uint8) error
} }
type postageContract struct { type postageContract struct {
...@@ -80,7 +88,7 @@ func (c *postageContract) sendApproveTransaction(ctx context.Context, amount *bi ...@@ -80,7 +88,7 @@ func (c *postageContract) sendApproveTransaction(ctx context.Context, amount *bi
GasPrice: sctx.GetGasPrice(ctx), GasPrice: sctx.GetGasPrice(ctx),
GasLimit: 65000, GasLimit: 65000,
Value: big.NewInt(0), Value: big.NewInt(0),
Description: "Approve tokens for postage operations", Description: approveDescription,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -98,25 +106,19 @@ func (c *postageContract) sendApproveTransaction(ctx context.Context, amount *bi ...@@ -98,25 +106,19 @@ func (c *postageContract) sendApproveTransaction(ctx context.Context, amount *bi
return receipt, nil return receipt, nil
} }
func (c *postageContract) sendCreateBatchTransaction(ctx context.Context, owner common.Address, initialBalance *big.Int, depth uint8, nonce common.Hash, immutable bool) (*types.Receipt, error) { func (c *postageContract) sendTransaction(ctx context.Context, callData []byte, desc string) (*types.Receipt, error) {
callData, err := postageStampABI.Pack("createBatch", owner, initialBalance, depth, BucketDepth, nonce, immutable)
if err != nil {
return nil, err
}
request := &transaction.TxRequest{ request := &transaction.TxRequest{
To: &c.postageContractAddress, To: &c.postageContractAddress,
Data: callData, Data: callData,
GasPrice: sctx.GetGasPrice(ctx), GasPrice: sctx.GetGasPrice(ctx),
GasLimit: 160000, GasLimit: 160000,
Value: big.NewInt(0), Value: big.NewInt(0),
Description: "Postage batch creation", Description: desc,
} }
txHash, err := c.transactionService.Send(ctx, request) txHash, err := c.transactionService.Send(ctx, request)
if err != nil { if err != nil {
return nil, fmt.Errorf("send: depth %d bucketDepth %d immutable %t: %w", depth, BucketDepth, immutable, err) return nil, err
} }
receipt, err := c.transactionService.WaitForReceipt(ctx, txHash) receipt, err := c.transactionService.WaitForReceipt(ctx, txHash)
...@@ -131,34 +133,46 @@ func (c *postageContract) sendCreateBatchTransaction(ctx context.Context, owner ...@@ -131,34 +133,46 @@ func (c *postageContract) sendCreateBatchTransaction(ctx context.Context, owner
return receipt, nil return receipt, nil
} }
func (c *postageContract) sendTopUpBatchTransaction(ctx context.Context, batchID []byte, topUpAmount *big.Int) (*types.Receipt, error) { func (c *postageContract) sendCreateBatchTransaction(ctx context.Context, owner common.Address, initialBalance *big.Int, depth uint8, nonce common.Hash, immutable bool) (*types.Receipt, error) {
callData, err := postageStampABI.Pack("topUp", common.BytesToHash(batchID), topUpAmount) callData, err := postageStampABI.Pack("createBatch", owner, initialBalance, depth, BucketDepth, nonce, immutable)
if err != nil { if err != nil {
return nil, err return nil, err
} }
request := &transaction.TxRequest{ receipt, err := c.sendTransaction(ctx, callData, createBatchDescription)
To: &c.postageContractAddress, if err != nil {
Data: callData, return nil, fmt.Errorf("create batch: depth %d bucketDepth %d immutable %t: %w", depth, BucketDepth, immutable, err)
GasPrice: sctx.GetGasPrice(ctx),
GasLimit: 160000,
Value: big.NewInt(0),
Description: "Postage batch top up",
} }
txHash, err := c.transactionService.Send(ctx, request) return receipt, nil
}
func (c *postageContract) sendTopUpBatchTransaction(ctx context.Context, batchID []byte, topUpAmount *big.Int) (*types.Receipt, error) {
callData, err := postageStampABI.Pack("topUp", common.BytesToHash(batchID), topUpAmount)
if err != nil { if err != nil {
return nil, err return nil, err
} }
receipt, err := c.transactionService.WaitForReceipt(ctx, txHash) receipt, err := c.sendTransaction(ctx, callData, topUpBatchDescription)
if err != nil {
return nil, fmt.Errorf("topup batch: amount %d: %w", topUpAmount.Int64(), err)
}
return receipt, nil
}
func (c *postageContract) sendDiluteTransaction(ctx context.Context, batchID []byte, newDepth uint8) (*types.Receipt, error) {
callData, err := postageStampABI.Pack("increaseDepth", common.BytesToHash(batchID), newDepth)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if receipt.Status == 0 { receipt, err := c.sendTransaction(ctx, callData, diluteBatchDescription)
return nil, transaction.ErrTransactionReverted if err != nil {
return nil, fmt.Errorf("dilute batch: new depth %d: %w", newDepth, err)
} }
return receipt, nil return receipt, nil
...@@ -218,7 +232,7 @@ func (c *postageContract) CreateBatch(ctx context.Context, initialBalance *big.I ...@@ -218,7 +232,7 @@ func (c *postageContract) CreateBatch(ctx context.Context, initialBalance *big.I
} }
for _, ev := range receipt.Logs { for _, ev := range receipt.Logs {
if ev.Address == c.postageContractAddress && ev.Topics[0] == batchCreatedTopic { if ev.Address == c.postageContractAddress && len(ev.Topics) > 0 && ev.Topics[0] == batchCreatedTopic {
var createdEvent batchCreatedEvent var createdEvent batchCreatedEvent
err = transaction.ParseEvent(&postageStampABI, "BatchCreated", &createdEvent, *ev) err = transaction.ParseEvent(&postageStampABI, "BatchCreated", &createdEvent, *ev)
if err != nil { if err != nil {
...@@ -281,6 +295,31 @@ func (c *postageContract) TopUpBatch(ctx context.Context, batchID []byte, topUpA ...@@ -281,6 +295,31 @@ func (c *postageContract) TopUpBatch(ctx context.Context, batchID []byte, topUpA
return ErrBatchTopUp return ErrBatchTopUp
} }
func (c *postageContract) DiluteBatch(ctx context.Context, batchID []byte, newDepth uint8) error {
batch, err := c.postageStorer.Get(batchID)
if err != nil {
return err
}
if batch.Depth > newDepth {
return fmt.Errorf("new depth should be greater: %w", ErrInvalidDepth)
}
receipt, err := c.sendDiluteTransaction(ctx, batch.ID, newDepth)
if err != nil {
return err
}
for _, ev := range receipt.Logs {
if ev.Address == c.postageContractAddress && len(ev.Topics) > 0 && ev.Topics[0] == batchDiluteTopic {
return nil
}
}
return ErrBatchDilute
}
type batchCreatedEvent struct { type batchCreatedEvent struct {
BatchId [32]byte BatchId [32]byte
TotalAmount *big.Int TotalAmount *big.Int
......
...@@ -352,3 +352,137 @@ func newTopUpEvent(postageContractAddress common.Address, batch *postage.Batch) ...@@ -352,3 +352,137 @@ func newTopUpEvent(postageContractAddress common.Address, batch *postage.Batch)
BlockNumber: batch.Start + 1, BlockNumber: batch.Start + 1,
} }
} }
func TestDiluteBatch(t *testing.T) {
defer func(b uint8) {
postagecontract.BucketDepth = b
}(postagecontract.BucketDepth)
postagecontract.BucketDepth = 9
owner := common.HexToAddress("abcd")
postageStampAddress := common.HexToAddress("ffff")
bzzTokenAddress := common.HexToAddress("eeee")
ctx := context.Background()
t.Run("ok", func(t *testing.T) {
txHashDilute := common.HexToHash("c3a7")
batch := postagetesting.MustNewBatch(postagetesting.WithOwner(owner.Bytes()))
batch.Depth = uint8(10)
batch.BucketDepth = uint8(9)
batch.Value = big.NewInt(100)
newDepth := batch.Depth + 1
postageMock := postageMock.New(postageMock.WithIssuer(postage.NewStampIssuer(
"label",
"keyID",
batch.ID,
batch.Value,
batch.Depth,
batch.BucketDepth,
batch.Start,
batch.Immutable,
)))
batchStoreMock := postagestoreMock.New(postagestoreMock.WithBatch(batch))
expectedCallData, err := postagecontract.PostageStampABI.Pack("increaseDepth", common.BytesToHash(batch.ID), newDepth)
if err != nil {
t.Fatal(err)
}
contract := postagecontract.New(
owner,
postageStampAddress,
bzzTokenAddress,
transactionMock.New(
transactionMock.WithSendFunc(func(ctx context.Context, request *transaction.TxRequest) (txHash common.Hash, err error) {
if *request.To == postageStampAddress {
if !bytes.Equal(expectedCallData[:64], request.Data[:64]) {
return common.Hash{}, fmt.Errorf("got wrong call data. wanted %x, got %x", expectedCallData, request.Data)
}
return txHashDilute, nil
}
return common.Hash{}, errors.New("sent to wrong contract")
}),
transactionMock.WithWaitForReceiptFunc(func(ctx context.Context, txHash common.Hash) (receipt *types.Receipt, err error) {
if txHash == txHashDilute {
return &types.Receipt{
Logs: []*types.Log{
newDiluteEvent(postageStampAddress, batch),
},
Status: 1,
}, nil
}
return nil, errors.New("unknown tx hash")
}),
),
postageMock,
batchStoreMock,
)
err = contract.DiluteBatch(ctx, batch.ID, newDepth)
if err != nil {
t.Fatal(err)
}
si, err := postageMock.GetStampIssuer(batch.ID)
if err != nil {
t.Fatal(err)
}
if si == nil {
t.Fatal("stamp issuer not set")
}
})
t.Run("batch doesnt exist", func(t *testing.T) {
errNotFound := errors.New("not found")
contract := postagecontract.New(
owner,
postageStampAddress,
bzzTokenAddress,
transactionMock.New(),
postageMock.New(),
postagestoreMock.New(postagestoreMock.WithGetErr(errNotFound, 0)),
)
err := contract.DiluteBatch(ctx, postagetesting.MustNewID(), uint8(17))
if !errors.Is(err, errNotFound) {
t.Fatal("expected error on topup of non existent batch")
}
})
t.Run("invalid depth", func(t *testing.T) {
batch := postagetesting.MustNewBatch(postagetesting.WithOwner(owner.Bytes()))
batch.Depth = uint8(16)
batchStoreMock := postagestoreMock.New(postagestoreMock.WithBatch(batch))
contract := postagecontract.New(
owner,
postageStampAddress,
bzzTokenAddress,
transactionMock.New(),
postageMock.New(),
batchStoreMock,
)
err := contract.DiluteBatch(ctx, batch.ID, batch.Depth-1)
if !errors.Is(err, postagecontract.ErrInvalidDepth) {
t.Fatalf("expected error %v. got %v", postagecontract.ErrInvalidDepth, err)
}
})
}
func newDiluteEvent(postageContractAddress common.Address, batch *postage.Batch) *types.Log {
b, err := postagecontract.PostageStampABI.Events["BatchDepthIncrease"].Inputs.NonIndexed().Pack(
uint8(0),
big.NewInt(0),
)
if err != nil {
panic(err)
}
return &types.Log{
Address: postageContractAddress,
Data: b,
Topics: []common.Hash{postagecontract.BatchDiluteTopic, common.BytesToHash(batch.ID)},
BlockNumber: batch.Start + 1,
}
}
...@@ -8,4 +8,5 @@ var ( ...@@ -8,4 +8,5 @@ var (
PostageStampABI = postageStampABI PostageStampABI = postageStampABI
BatchCreatedTopic = batchCreatedTopic BatchCreatedTopic = batchCreatedTopic
BatchTopUpTopic = batchTopUpTopic BatchTopUpTopic = batchTopUpTopic
BatchDiluteTopic = batchDiluteTopic
) )
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
type contractMock struct { type contractMock struct {
createBatch func(ctx context.Context, initialBalance *big.Int, depth uint8, immutable bool, label string) ([]byte, error) createBatch func(ctx context.Context, initialBalance *big.Int, depth uint8, immutable bool, label string) ([]byte, error)
topupBatch func(ctx context.Context, id []byte, amount *big.Int) error topupBatch func(ctx context.Context, id []byte, amount *big.Int) error
diluteBatch func(ctx context.Context, id []byte, newDepth uint8) error
} }
func (c *contractMock) CreateBatch(ctx context.Context, initialBalance *big.Int, depth uint8, immutable bool, label string) ([]byte, error) { func (c *contractMock) CreateBatch(ctx context.Context, initialBalance *big.Int, depth uint8, immutable bool, label string) ([]byte, error) {
...@@ -24,6 +25,10 @@ func (c *contractMock) TopUpBatch(ctx context.Context, batchID []byte, amount *b ...@@ -24,6 +25,10 @@ func (c *contractMock) TopUpBatch(ctx context.Context, batchID []byte, amount *b
return c.topupBatch(ctx, batchID, amount) return c.topupBatch(ctx, batchID, amount)
} }
func (c *contractMock) DiluteBatch(ctx context.Context, batchID []byte, newDepth uint8) error {
return c.diluteBatch(ctx, batchID, newDepth)
}
// Option is a an option passed to New // Option is a an option passed to New
type Option func(*contractMock) type Option func(*contractMock)
...@@ -49,3 +54,9 @@ func WithTopUpBatchFunc(f func(ctx context.Context, batchID []byte, amount *big. ...@@ -49,3 +54,9 @@ func WithTopUpBatchFunc(f func(ctx context.Context, batchID []byte, amount *big.
m.topupBatch = f m.topupBatch = f
} }
} }
func WithDiluteBatchFunc(f func(ctx context.Context, batchID []byte, newDepth uint8) error) Option {
return func(m *contractMock) {
m.diluteBatch = f
}
}
...@@ -120,6 +120,21 @@ func (ps *service) HandleTopUp(batchID []byte, newValue *big.Int) { ...@@ -120,6 +120,21 @@ func (ps *service) HandleTopUp(batchID []byte, newValue *big.Int) {
} }
} }
func (ps *service) HandleDepthIncrease(batchID []byte, newDepth uint8, normalisedBalance *big.Int) {
ps.lock.Lock()
defer ps.lock.Unlock()
for _, v := range ps.issuers {
if bytes.Equal(batchID, v.data.BatchID) {
if newDepth > v.data.BatchDepth {
v.data.BatchDepth = newDepth
v.data.BatchAmount = normalisedBalance
}
return
}
}
}
// StampIssuers returns the currently active stamp issuers. // StampIssuers returns the currently active stamp issuers.
func (ps *service) StampIssuers() []*StampIssuer { func (ps *service) StampIssuers() []*StampIssuer {
ps.lock.Lock() ps.lock.Lock()
......
...@@ -130,4 +130,17 @@ func TestGetStampIssuer(t *testing.T) { ...@@ -130,4 +130,17 @@ func TestGetStampIssuer(t *testing.T) {
t.Fatalf("expected amount %d got %d", 10, ps.StampIssuers()[0].Amount().Int64()) t.Fatalf("expected amount %d got %d", 10, ps.StampIssuers()[0].Amount().Int64())
} }
}) })
t.Run("dilute", func(t *testing.T) {
ps.HandleDepthIncrease(ids[2], 17, big.NewInt(1))
_, err := ps.GetStampIssuer(ids[2])
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if ps.StampIssuers()[1].Amount().Cmp(big.NewInt(1)) != 0 {
t.Fatalf("expected amount %d got %d", 1, ps.StampIssuers()[1].Amount().Int64())
}
if ps.StampIssuers()[1].Depth() != 17 {
t.Fatalf("expected depth %d got %d", 17, ps.StampIssuers()[1].Depth())
}
})
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment