Commit 1236d45d authored by acud's avatar acud Committed by GitHub

all: use bmt hasher pool (#823)

* use bmtpool
parent 61c70837
......@@ -8,7 +8,7 @@ require (
github.com/coreos/go-semver v0.3.0
github.com/davidlazar/go-crypto v0.0.0-20200604182044-b73af7476f6c // indirect
github.com/ethereum/go-ethereum v1.9.20
github.com/ethersphere/bmt v0.1.2
github.com/ethersphere/bmt v0.1.4
github.com/ethersphere/langos v1.0.0
github.com/ethersphere/manifest v0.3.2
github.com/ethersphere/sw3-bindings/v2 v2.1.0
......
......@@ -163,8 +163,8 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7
github.com/ethereum/go-ethereum v1.9.14/go.mod h1:oP8FC5+TbICUyftkTWs+8JryntjIJLJvWvApK3z2AYw=
github.com/ethereum/go-ethereum v1.9.20 h1:kk/J5OIoaoz3DRrCXznz3RGi212mHHXwzXlY/ZQxcj0=
github.com/ethereum/go-ethereum v1.9.20/go.mod h1:JSSTypSMTkGZtAdAChH2wP5dZEvPGh3nUTuDpH+hNrg=
github.com/ethersphere/bmt v0.1.2 h1:FEuvQY9xuK+rDp3VwDVyde8T396Matv/u9PdtKa2r9Q=
github.com/ethersphere/bmt v0.1.2/go.mod h1:fqRBDmYwn3lX2MH4lkImXQgFWeNP8ikLkS/hgi/HRws=
github.com/ethersphere/bmt v0.1.4 h1:+rkWYNtMgDx6bkNqGdWu+U9DgGI1rRZplpSW3YhBr1Q=
github.com/ethersphere/bmt v0.1.4/go.mod h1:Yd8ft1U69WDuHevZc/rwPxUv1rzPSMpMnS6xbU53aY8=
github.com/ethersphere/langos v1.0.0 h1:NBtNKzXTTRSue95uOlzPN4py7Aofs0xWPzyj4AI1Vcc=
github.com/ethersphere/langos v1.0.0/go.mod h1:dlcN2j4O8sQ+BlCaxeBu43bgr4RQ+inJ+pHwLeZg5Tw=
github.com/ethersphere/manifest v0.3.2 h1:IusNNfpqde2F7uWZ2DE9eyo9PMwUAMop3Ws1NBcdMyM=
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bmtpool
import (
"github.com/ethersphere/bee/pkg/swarm"
bmtlegacy "github.com/ethersphere/bmt/legacy"
"github.com/ethersphere/bmt/pool"
)
var instance pool.Pooler
func init() {
instance = pool.New(8, swarm.BmtBranches)
}
// Get a bmt Hasher instance.
// Instances are reset before being returned to the caller.
func Get() *bmtlegacy.Hasher {
return instance.Get()
}
// Put a bmt Hasher back into the pool
func Put(h *bmtlegacy.Hasher) {
instance.Put(h)
}
......@@ -10,8 +10,8 @@ import (
"errors"
"fmt"
"github.com/ethersphere/bee/pkg/bmtpool"
"github.com/ethersphere/bee/pkg/swarm"
bmtlegacy "github.com/ethersphere/bmt/legacy"
)
// NewChunk creates a new content-addressed single-span chunk.
......@@ -29,8 +29,8 @@ func NewChunkWithSpan(data []byte, span int64) (swarm.Chunk, error) {
return nil, fmt.Errorf("single-span chunk size mismatch; span is %d, chunk data length %d", span, len(data))
}
bmtPool := bmtlegacy.NewTreePool(swarm.NewHasher, swarm.Branches, bmtlegacy.PoolSize)
hasher := bmtlegacy.New(bmtPool)
hasher := bmtpool.Get()
defer bmtpool.Put(hasher)
// execute hash, compare and return result
spanBytes := make([]byte, 8)
......@@ -53,8 +53,8 @@ func NewChunkWithSpan(data []byte, span int64) (swarm.Chunk, error) {
// NewChunkWithSpanBytes deserializes a content-addressed chunk from separate
// data and span byte slices.
func NewChunkWithSpanBytes(data, spanBytes []byte) (swarm.Chunk, error) {
bmtPool := bmtlegacy.NewTreePool(swarm.NewHasher, swarm.Branches, bmtlegacy.PoolSize)
hasher := bmtlegacy.New(bmtPool)
hasher := bmtpool.Get()
defer bmtpool.Put(hasher)
// execute hash, compare and return result
err := hasher.SetSpanBytes(spanBytes)
......
......@@ -6,13 +6,10 @@ package bmt
import (
"errors"
"hash"
"github.com/ethersphere/bee/pkg/bmtpool"
"github.com/ethersphere/bee/pkg/file/pipeline"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bmt"
bmtlegacy "github.com/ethersphere/bmt/legacy"
"golang.org/x/crypto/sha3"
)
var (
......@@ -20,15 +17,13 @@ var (
)
type bmtWriter struct {
b bmt.Hash
next pipeline.ChainWriter
}
// NewBmtWriter returns a new bmtWriter. Partial writes are not supported.
// Note: branching factor is the BMT branching factor, not the merkle trie branching factor.
func NewBmtWriter(branches int, next pipeline.ChainWriter) pipeline.ChainWriter {
func NewBmtWriter(next pipeline.ChainWriter) pipeline.ChainWriter {
return &bmtWriter{
b: bmtlegacy.New(bmtlegacy.NewTreePool(hashFunc, branches, bmtlegacy.PoolSize)),
next: next,
}
}
......@@ -39,16 +34,20 @@ func (w *bmtWriter) ChainWrite(p *pipeline.PipeWriteArgs) error {
if len(p.Data) < swarm.SpanSize {
return errInvalidData
}
w.b.Reset()
err := w.b.SetSpanBytes(p.Data[:swarm.SpanSize])
hasher := bmtpool.Get()
err := hasher.SetSpanBytes(p.Data[:swarm.SpanSize])
if err != nil {
bmtpool.Put(hasher)
return err
}
_, err = w.b.Write(p.Data[swarm.SpanSize:])
_, err = hasher.Write(p.Data[swarm.SpanSize:])
if err != nil {
bmtpool.Put(hasher)
return err
}
p.Ref = w.b.Sum(nil)
p.Ref = hasher.Sum(nil)
bmtpool.Put(hasher)
return w.next.ChainWrite(p)
}
......@@ -56,7 +55,3 @@ func (w *bmtWriter) ChainWrite(p *pipeline.PipeWriteArgs) error {
func (w *bmtWriter) Sum() ([]byte, error) {
return w.next.Sum()
}
func hashFunc() hash.Hash {
return sha3.NewLegacyKeccak256()
}
......@@ -47,7 +47,7 @@ func TestBmtWriter(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
mockChainWriter := mock.NewChainWriter()
writer := bmt.NewBmtWriter(128, mockChainWriter)
writer := bmt.NewBmtWriter(mockChainWriter)
var data []byte
......@@ -81,7 +81,7 @@ func TestBmtWriter(t *testing.T) {
// TestSum tests that calling Sum on the writer calls the next writer's Sum.
func TestSum(t *testing.T) {
mockChainWriter := mock.NewChainWriter()
writer := bmt.NewBmtWriter(128, mockChainWriter)
writer := bmt.NewBmtWriter(mockChainWriter)
_, err := writer.Sum()
if err != nil {
t.Fatal(err)
......
......@@ -34,7 +34,7 @@ func NewPipelineBuilder(ctx context.Context, s storage.Storer, mode storage.Mode
func newPipeline(ctx context.Context, s storage.Storer, mode storage.ModePut) pipeline.Interface {
tw := hashtrie.NewHashTrieWriter(swarm.ChunkSize, swarm.Branches, swarm.HashSize, newShortPipelineFunc(ctx, s, mode))
lsw := store.NewStoreWriter(ctx, s, mode, tw)
b := bmt.NewBmtWriter(128, lsw)
b := bmt.NewBmtWriter(lsw)
return feeder.NewChunkFeederWriter(swarm.ChunkSize, b)
}
......@@ -43,7 +43,7 @@ func newPipeline(ctx context.Context, s storage.Storer, mode storage.ModePut) pi
func newShortPipelineFunc(ctx context.Context, s storage.Storer, mode storage.ModePut) func() pipeline.ChainWriter {
return func() pipeline.ChainWriter {
lsw := store.NewStoreWriter(ctx, s, mode, nil)
return bmt.NewBmtWriter(128, lsw)
return bmt.NewBmtWriter(lsw)
}
}
......@@ -55,7 +55,7 @@ func newShortPipelineFunc(ctx context.Context, s storage.Storer, mode storage.Mo
func newEncryptionPipeline(ctx context.Context, s storage.Storer, mode storage.ModePut) pipeline.Interface {
tw := hashtrie.NewHashTrieWriter(swarm.ChunkSize, 64, swarm.HashSize+encryption.KeyLength, newShortEncryptionPipelineFunc(ctx, s, mode))
lsw := store.NewStoreWriter(ctx, s, mode, tw)
b := bmt.NewBmtWriter(128, lsw)
b := bmt.NewBmtWriter(lsw)
enc := enc.NewEncryptionWriter(encryption.NewChunkEncrypter(), b)
return feeder.NewChunkFeederWriter(swarm.ChunkSize, enc)
}
......@@ -65,7 +65,7 @@ func newEncryptionPipeline(ctx context.Context, s storage.Storer, mode storage.M
func newShortEncryptionPipelineFunc(ctx context.Context, s storage.Storer, mode storage.ModePut) func() pipeline.ChainWriter {
return func() pipeline.ChainWriter {
lsw := store.NewStoreWriter(ctx, s, mode, nil)
b := bmt.NewBmtWriter(128, lsw)
b := bmt.NewBmtWriter(lsw)
return enc.NewEncryptionWriter(encryption.NewChunkEncrypter(), b)
}
}
......
......@@ -9,15 +9,13 @@ import (
"encoding/binary"
"errors"
"fmt"
"hash"
"github.com/ethersphere/bee/pkg/bmtpool"
"github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/file"
"github.com/ethersphere/bee/pkg/sctx"
"github.com/ethersphere/bee/pkg/swarm"
"github.com/ethersphere/bee/pkg/tags"
"github.com/ethersphere/bmt"
bmtlegacy "github.com/ethersphere/bmt/legacy"
"golang.org/x/crypto/sha3"
)
......@@ -29,11 +27,6 @@ type Putter interface {
// (128 ^ (9 - 1)) * 4096 = 295147905179352825856 bytes
const levelBufferLimit = 9
// hashFunc is a hasher factory used by the bmt hasher
func hashFunc() hash.Hash {
return sha3.NewLegacyKeccak256()
}
// SimpleSplitterJob encapsulated a single splitter operation, accepting blockwise
// writes of data whose length is defined in advance.
//
......@@ -50,7 +43,6 @@ type SimpleSplitterJob struct {
length int64 // number of bytes written to the data level of the hasher
sumCounts []int // number of sums performed, indexed per level
cursors []int // section write position, indexed per level
hasher bmt.Hash // underlying hasher used for hashing the tree
buffer []byte // keeps data and hashes, indexed by cursors
tag *tags.Tag
toEncrypt bool // to encryrpt the chunks or not
......@@ -66,7 +58,6 @@ func NewSimpleSplitterJob(ctx context.Context, putter Putter, spanLength int64,
if toEncrypt {
refSize += encryption.KeyLength
}
p := bmtlegacy.NewTreePool(hashFunc, swarm.Branches, bmtlegacy.PoolSize)
return &SimpleSplitterJob{
ctx: ctx,
......@@ -74,7 +65,6 @@ func NewSimpleSplitterJob(ctx context.Context, putter Putter, spanLength int64,
spanLength: spanLength,
sumCounts: make([]int, levelBufferLimit),
cursors: make([]int, levelBufferLimit),
hasher: bmtlegacy.New(p),
buffer: make([]byte, swarm.ChunkWithSpanSize*levelBufferLimit*2), // double size as temp workaround for weak calculation of needed buffer space
tag: sctx.GetTag(ctx),
toEncrypt: toEncrypt,
......@@ -167,16 +157,21 @@ func (s *SimpleSplitterJob) sumLevel(lvl int) ([]byte, error) {
}
}
s.hasher.Reset()
err = s.hasher.SetSpanBytes(c[:8])
hasher := bmtpool.Get()
err = hasher.SetSpanBytes(c[:8])
if err != nil {
bmtpool.Put(hasher)
return nil, err
}
_, err = s.hasher.Write(c[8:])
_, err = hasher.Write(c[8:])
if err != nil {
bmtpool.Put(hasher)
return nil, err
}
ref := s.hasher.Sum(nil)
ref := hasher.Sum(nil)
bmtpool.Put(hasher)
addr = swarm.NewAddress(ref)
// Add tag to the chunk if tag is valid
......
......@@ -345,8 +345,8 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
chunkvalidator := swarm.NewChunkValidator(content.NewValidator(), soc.NewValidator())
retrieve := retrieval.New(swarmAddress, storer, p2ps, kad, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), chunkvalidator, tracer)
tagg := tags.NewTags(stateStore, logger)
b.tagsCloser = tagg
tagService := tags.NewTags(stateStore, logger)
b.tagsCloser = tagService
if err = p2ps.AddProtocol(retrieve.Protocol()); err != nil {
return nil, fmt.Errorf("retrieval service: %w", err)
......@@ -358,22 +358,22 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
return nil, fmt.Errorf("swarm key: %w", err)
}
psss := pss.New(swarmPrivateKey, logger)
b.pssCloser = psss
pssService := pss.New(swarmPrivateKey, logger)
b.pssCloser = pssService
var ns storage.Storer
if o.GlobalPinningEnabled {
// create recovery callback for content repair
recoverFunc := recovery.NewRecoveryHook(psss)
recoverFunc := recovery.NewRecoveryHook(pssService)
ns = netstore.New(storer, recoverFunc, retrieve, logger, chunkvalidator)
} else {
ns = netstore.New(storer, nil, retrieve, logger, chunkvalidator)
}
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagg, psss.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), tracer)
pushSyncProtocol := pushsync.New(p2ps, storer, kad, tagService, pssService.TryUnwrap, logger, acc, accounting.NewFixedPricer(swarmAddress, 10), tracer)
// set the pushSyncer in the PSS
psss.SetPushSyncer(pushSyncProtocol)
pssService.SetPushSyncer(pushSyncProtocol)
if err = p2ps.AddProtocol(pushSyncProtocol.Protocol()); err != nil {
return nil, fmt.Errorf("pushsync service: %w", err)
......@@ -382,10 +382,10 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
if o.GlobalPinningEnabled {
// register function for chunk repair upon receiving a trojan message
chunkRepairHandler := recovery.NewRepairHandler(ns, logger, pushSyncProtocol)
b.recoveryHandleCleanup = psss.Register(recovery.RecoveryTopic, chunkRepairHandler)
b.recoveryHandleCleanup = pssService.Register(recovery.RecoveryTopic, chunkRepairHandler)
}
pushSyncPusher := pusher.New(storer, kad, pushSyncProtocol, tagg, logger, tracer)
pushSyncPusher := pusher.New(storer, kad, pushSyncProtocol, tagService, logger, tracer)
b.pusherCloser = pushSyncPusher
pullStorage := pullstorage.New(storer)
......@@ -410,7 +410,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
var apiService api.Service
if o.APIAddr != "" {
// API server
apiService = api.New(tagg, ns, multiResolver, psss, logger, tracer, api.Options{
apiService = api.New(tagService, ns, multiResolver, pssService, logger, tracer, api.Options{
CORSAllowedOrigins: o.CORSAllowedOrigins,
GatewayMode: o.GatewayMode,
WsPingPeriod: 60 * time.Second,
......@@ -441,7 +441,7 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
if o.DebugAPIAddr != "" {
// Debug API server
debugAPIService := debugapi.New(swarmAddress, publicKey, overlayEthAddress, p2ps, pingPong, kad, storer, logger, tracer, tagg, acc, settlement, o.SwapEnable, swapService, chequebookService)
debugAPIService := debugapi.New(swarmAddress, publicKey, overlayEthAddress, p2ps, pingPong, kad, storer, logger, tracer, tagService, acc, settlement, o.SwapEnable, swapService, chequebookService)
// register metrics from components
debugAPIService.MustRegisterMetrics(p2ps.Metrics()...)
debugAPIService.MustRegisterMetrics(pingPong.Metrics()...)
......@@ -452,8 +452,8 @@ func NewBee(addr string, swarmAddress swarm.Address, publicKey ecdsa.PublicKey,
debugAPIService.MustRegisterMetrics(pushSyncPusher.Metrics()...)
debugAPIService.MustRegisterMetrics(pullSync.Metrics()...)
if pssService, ok := psss.(metrics.Collector); ok {
debugAPIService.MustRegisterMetrics(pssService.Metrics()...)
if pssServiceMetrics, ok := pssService.(metrics.Collector); ok {
debugAPIService.MustRegisterMetrics(pssServiceMetrics.Metrics()...)
}
if apiService != nil {
......
......@@ -15,11 +15,11 @@ import (
random "math/rand"
"github.com/btcsuite/btcd/btcec"
"github.com/ethersphere/bee/pkg/bmtpool"
"github.com/ethersphere/bee/pkg/crypto"
"github.com/ethersphere/bee/pkg/encryption"
"github.com/ethersphere/bee/pkg/encryption/elgamal"
"github.com/ethersphere/bee/pkg/swarm"
bmtlegacy "github.com/ethersphere/bmt/legacy"
)
var (
......@@ -176,10 +176,10 @@ func checkTargets(targets Targets) error {
}
func hasher(span, b []byte) func([]byte) ([]byte, error) {
hashPool := bmtlegacy.NewTreePool(swarm.NewHasher, swarm.Branches, bmtlegacy.PoolSize)
return func(nonce []byte) ([]byte, error) {
s := append(nonce, b...)
hasher := bmtlegacy.New(hashPool)
hasher := bmtpool.Get()
defer bmtpool.Put(hasher)
if err := hasher.SetSpanBytes(span); err != nil {
return nil, err
}
......
......@@ -18,6 +18,7 @@ const (
SpanSize = 8
SectionSize = 32
Branches = 128
BmtBranches = 128
ChunkSize = SectionSize * Branches
HashSize = 32
MaxPO uint8 = 15
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment