Commit 75838eaf authored by Viktor Trón's avatar Viktor Trón Committed by GitHub

trojan,pss: optimise trojan mining (#695)

* trojan,pss: optimise trojan mining
 - introduce context to Wrap and eliminate MinerTimeout
 - simplify Wrap
 - add mine function that finds a nonce that makes a function true
 - introduce mining benchmarks

* trojan: redo mine function to work with chunk not interface
parent cffa61f9
...@@ -67,7 +67,7 @@ func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Top ...@@ -67,7 +67,7 @@ func (p *pss) Send(ctx context.Context, targets trojan.Targets, topic trojan.Top
return err return err
} }
var tc swarm.Chunk var tc swarm.Chunk
tc, err = m.Wrap(targets) tc, err = m.Wrap(ctx, targets)
if err != nil { if err != nil {
return err return err
......
...@@ -123,7 +123,7 @@ func TestDeliver(t *testing.T) { ...@@ -123,7 +123,7 @@ func TestDeliver(t *testing.T) {
// test chunk // test chunk
target := trojan.Target([]byte{1}) // arbitrary test target target := trojan.Target([]byte{1}) // arbitrary test target
targets := trojan.Targets([]trojan.Target{target}) targets := trojan.Targets([]trojan.Target{target})
c, err := msg.Wrap(targets) c, err := msg.Wrap(ctx, targets)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
...@@ -15,9 +15,6 @@ var ( ...@@ -15,9 +15,6 @@ var (
// ErrVarLenTargets is returned when the given target list for a trojan chunk has addresses of different lengths // ErrVarLenTargets is returned when the given target list for a trojan chunk has addresses of different lengths
ErrVarLenTargets = errors.New("target list cannot have targets of different length") ErrVarLenTargets = errors.New("target list cannot have targets of different length")
// ErrUnMarshallingTrojanMessage is returned when a trojan message could not be de-serialized // ErrUnmarshal is returned when a trojan message could not be de-serialized
ErrUnmarshal = errors.New("trojan message unmarshall error") ErrUnmarshal = errors.New("trojan message unmarshall error")
// ErrMinerTimeout is returned when mining a new nonce takes more time than swarm.TrojanMinerTimeout seconds
ErrMinerTimeout = errors.New("miner timeout error")
) )
...@@ -2,6 +2,4 @@ package trojan ...@@ -2,6 +2,4 @@ package trojan
var ( var (
Contains = contains Contains = contains
HashBytes = hashBytes
PadBytes = padBytesLeft
) )
...@@ -6,14 +6,13 @@ package trojan ...@@ -6,14 +6,13 @@ package trojan
import ( import (
"bytes" "bytes"
"context"
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"math/big" random "math/rand"
"time"
bmtlegacy "github.com/ethersphere/bmt/legacy"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
bmtlegacy "github.com/ethersphere/bmt/legacy"
) )
// Topic is an alias for a 32 byte fixed-size array which contains an encoding of a message topic // Topic is an alias for a 32 byte fixed-size array which contains an encoding of a message topic
...@@ -39,13 +38,14 @@ const ( ...@@ -39,13 +38,14 @@ const (
// MaxPayloadSize + Topic + Length + Nonce = Default ChunkSize // MaxPayloadSize + Topic + Length + Nonce = Default ChunkSize
// (4030) + (32) + (2) + (32) = 4096 Bytes // (4030) + (32) + (2) + (32) = 4096 Bytes
MaxPayloadSize = swarm.ChunkSize - NonceSize - LengthSize - TopicSize MaxPayloadSize = swarm.ChunkSize - NonceSize - LengthSize - TopicSize
// NonceSize is a hash bit sequence
NonceSize = 32 NonceSize = 32
// LengthSize is the byte length to represent message
LengthSize = 2 LengthSize = 2
// TopicSize is a hash bit sequence
TopicSize = 32 TopicSize = 32
) )
var minerTimeout = 20 * time.Second
// NewTopic creates a new Topic variable with the given input string // NewTopic creates a new Topic variable with the given input string
// the input string is taken as a byte slice and hashed // the input string is taken as a byte slice and hashed
func NewTopic(topic string) Topic { func NewTopic(topic string) Topic {
...@@ -89,14 +89,33 @@ func NewMessage(topic Topic, payload []byte) (Message, error) { ...@@ -89,14 +89,33 @@ func NewMessage(topic Topic, payload []byte) (Message, error) {
// Wrap creates a new trojan chunk for the given targets and Message // Wrap creates a new trojan chunk for the given targets and Message
// a trojan chunk is a content-addressed chunk made up of span, a nonce, and a payload which contains the Message // a trojan chunk is a content-addressed chunk made up of span, a nonce, and a payload which contains the Message
// the chunk address will have one of the targets as its prefix and thus will be forwarded to the neighbourhood of the recipient overlay address the target is derived from // the chunk address will have one of the targets as its prefix and thus will be forwarded to the neighbourhood of the recipient overlay address the target is derived from
func (m *Message) Wrap(targets Targets) (swarm.Chunk, error) { // this is done by iteratively enumerating different nonces until the BMT hash of the serialization of the trojan chunk fields results in a chunk address that has one of the targets as its prefix
func (m *Message) Wrap(ctx context.Context, targets Targets) (swarm.Chunk, error) {
if err := checkTargets(targets); err != nil { if err := checkTargets(targets); err != nil {
return nil, err return nil, err
} }
targetsLen := len(targets[0])
// serialize message
b, err := m.MarshalBinary() // TODO: this should be encrypted
if err != nil {
return nil, err
}
span := make([]byte, 8) span := make([]byte, 8)
binary.LittleEndian.PutUint64(span, swarm.ChunkSize) binary.LittleEndian.PutUint64(span, uint64(len(b)+NonceSize))
return m.toChunk(targets, span) h := hasher(span, b)
f := func(nonce []byte) (swarm.Chunk, error) {
hash, err := h(nonce)
if err != nil {
return nil, err
}
if !contains(targets, hash[:targetsLen]) {
return nil, nil
}
chunk := swarm.NewChunk(swarm.NewAddress(hash), append(span, append(nonce, b...)...))
return chunk, nil
}
return mine(ctx, f)
} }
// Unwrap creates a new trojan message from the given chunk payload // Unwrap creates a new trojan message from the given chunk payload
...@@ -141,81 +160,19 @@ func checkTargets(targets Targets) error { ...@@ -141,81 +160,19 @@ func checkTargets(targets Targets) error {
return nil return nil
} }
// toChunk finds a nonce so that when the given trojan chunk fields are hashed, the result will fall in the neighbourhood of one of the given targets func hasher(span, b []byte) func([]byte) ([]byte, error) {
// this is done by iteratively enumerating different nonces until the BMT hash of the serialization of the trojan chunk fields results in a chunk address that has one of the targets as its prefix
// the function returns a new chunk, with the found matching hash to be used as its address,
// and its data set to the serialization of the trojan chunk fields which correctly hash into the matching address
func (m *Message) toChunk(targets Targets, span []byte) (swarm.Chunk, error) {
// start out with random nonce
nonce := make([]byte, NonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
nonceInt := new(big.Int).SetBytes(nonce)
targetsLen := len(targets[0])
// serialize message
b, err := m.MarshalBinary() // TODO: this should be encrypted
if err != nil {
return nil, err
}
errC := make(chan error)
var hash, s []byte
go func() {
defer close(errC)
// mining operation: hash chunk fields with different nonces until an acceptable one is found
for {
s = append(append(span, nonce...), b...) // serialize chunk fields
hash, err = hashBytes(s)
if err != nil {
errC <- err
return
}
// take as much of the hash as the targets are long
if contains(targets, hash[:targetsLen]) {
// if nonce found, stop loop and return chunk
errC <- nil
return
}
// else, add 1 to nonce and try again
nonceInt.Add(nonceInt, big.NewInt(1))
// loop around in case of overflow after 256 bits
if nonceInt.BitLen() > (NonceSize * swarm.SpanSize) {
nonceInt = big.NewInt(0)
}
nonce = padBytesLeft(nonceInt.Bytes()) // pad in case Bytes call is not 32 bytes long
}
}()
// checks whether the mining is completed or times out
select {
case err := <-errC:
if err == nil {
return swarm.NewChunk(swarm.NewAddress(hash), s), nil
}
return nil, err
case <-time.After(minerTimeout):
return nil, ErrMinerTimeout
}
}
// hashBytes hashes the given serialization of chunk fields with the hashing func
func hashBytes(s []byte) ([]byte, error) {
hashPool := bmtlegacy.NewTreePool(swarm.NewHasher, swarm.Branches, bmtlegacy.PoolSize) hashPool := bmtlegacy.NewTreePool(swarm.NewHasher, swarm.Branches, bmtlegacy.PoolSize)
return func(nonce []byte) ([]byte, error) {
s := append(nonce, b...) // serialize chunk fields
hasher := bmtlegacy.New(hashPool) hasher := bmtlegacy.New(hashPool)
hasher.Reset() if err := hasher.SetSpanBytes(span); err != nil {
span := binary.LittleEndian.Uint64(s[:8])
err := hasher.SetSpan(int64(span))
if err != nil {
return nil, err return nil, err
} }
if _, err := hasher.Write(s[8:]); err != nil { if _, err := hasher.Write(s); err != nil {
return nil, err return nil, err
} }
return hasher.Sum(nil), nil return hasher.Sum(nil), nil
}
} }
// contains returns whether the given collection contains the given element // contains returns whether the given collection contains the given element
...@@ -228,19 +185,6 @@ func contains(col Targets, elem []byte) bool { ...@@ -228,19 +185,6 @@ func contains(col Targets, elem []byte) bool {
return false return false
} }
// padBytesLeft adds 0s to the given byte slice as left padding,
// returning this as a new byte slice with a length of exactly 32
// given param is assumed to be at most 32 bytes long
func padBytesLeft(b []byte) []byte {
l := len(b)
if l == 32 {
return b
}
bb := make([]byte, 32)
copy(bb[32-l:], b)
return bb
}
// MarshalBinary serializes a message struct // MarshalBinary serializes a message struct
func (m *Message) MarshalBinary() (data []byte, err error) { func (m *Message) MarshalBinary() (data []byte, err error) {
data = append(m.length[:], m.Topic[:]...) data = append(m.length[:], m.Topic[:]...)
...@@ -269,3 +213,50 @@ func (m *Message) UnmarshalBinary(data []byte) (err error) { ...@@ -269,3 +213,50 @@ func (m *Message) UnmarshalBinary(data []byte) (err error) {
m.padding = data[payloadEnd:] m.padding = data[payloadEnd:]
return nil return nil
} }
func mine(ctx context.Context, f func(nonce []byte) (swarm.Chunk, error)) (swarm.Chunk, error) {
seeds := make([]uint32, 8)
for i := range seeds {
seeds[i] = random.Uint32()
}
initnonce := make([]byte, 32)
for i := 0; i < 8; i++ {
binary.LittleEndian.PutUint32(initnonce[i*4:i*4+4], seeds[i])
}
quit := make(chan struct{})
// make both errs and result channels buffered so they never block
result := make(chan swarm.Chunk, 8)
errs := make(chan error, 8)
for i := 0; i < 8; i++ {
go func(j int) {
nonce := make([]byte, 32)
copy(nonce, initnonce)
for seed := seeds[j]; ; seed++ {
binary.LittleEndian.PutUint32(nonce[j*4:j*4+4], seed)
res, err := f(nonce)
if err != nil {
errs <- err
return
}
if res != nil {
result <- res
return
}
select {
case <-quit:
return
default:
}
}
}(i)
}
defer close(quit)
select {
case <-ctx.Done():
return nil, ctx.Err()
case err := <-errs:
return nil, err
case res := <-result:
return res, nil
}
}
...@@ -5,11 +5,12 @@ ...@@ -5,11 +5,12 @@
package trojan_test package trojan_test
import ( import (
"bytes" "context"
"crypto/rand"
"encoding/binary" "encoding/binary"
"errors"
"reflect" "reflect"
"testing" "testing"
"time"
chunktesting "github.com/ethersphere/bee/pkg/storage/testing" chunktesting "github.com/ethersphere/bee/pkg/storage/testing"
"github.com/ethersphere/bee/pkg/swarm" "github.com/ethersphere/bee/pkg/swarm"
...@@ -69,7 +70,7 @@ func TestNewMessage(t *testing.T) { ...@@ -69,7 +70,7 @@ func TestNewMessage(t *testing.T) {
// its resulting data should have a hash that matches its address exactly // its resulting data should have a hash that matches its address exactly
func TestWrap(t *testing.T) { func TestWrap(t *testing.T) {
m := newTestMessage(t) m := newTestMessage(t)
c, err := m.Wrap(testTargets) c, err := m.Wrap(context.Background(), testTargets)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -98,21 +99,14 @@ func TestWrap(t *testing.T) { ...@@ -98,21 +99,14 @@ func TestWrap(t *testing.T) {
t.Fatalf("chunk span set to %d, but rest of chunk data is of size %d", span, remainingDataLen) t.Fatalf("chunk span set to %d, but rest of chunk data is of size %d", span, remainingDataLen)
} }
dataHash, err := trojan.HashBytes(data)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(addr.Bytes(), dataHash) {
t.Fatal("chunk address does not match its data hash")
}
} }
// TestWrapError tests that the creation of a chunk fails when given targets are invalid // TestWrapError tests that the creation of a chunk fails when given targets are invalid
func TestWrapError(t *testing.T) { func TestWrapError(t *testing.T) {
m := newTestMessage(t) m := newTestMessage(t)
ctx := context.Background()
emptyTargets := trojan.Targets([]trojan.Target{}) emptyTargets := trojan.Targets([]trojan.Target{})
if _, err := m.Wrap(emptyTargets); err != trojan.ErrEmptyTargets { if _, err := m.Wrap(ctx, emptyTargets); err != trojan.ErrEmptyTargets {
t.Fatalf("expected error when creating chunk for empty targets to be %q, but got %v", trojan.ErrEmptyTargets, err) t.Fatalf("expected error when creating chunk for empty targets to be %q, but got %v", trojan.ErrEmptyTargets, err)
} }
...@@ -120,62 +114,30 @@ func TestWrapError(t *testing.T) { ...@@ -120,62 +114,30 @@ func TestWrapError(t *testing.T) {
t2 := trojan.Target([]byte{25, 120}) t2 := trojan.Target([]byte{25, 120})
t3 := trojan.Target([]byte{180, 18, 255}) t3 := trojan.Target([]byte{180, 18, 255})
varLenTargets := trojan.Targets([]trojan.Target{t1, t2, t3}) varLenTargets := trojan.Targets([]trojan.Target{t1, t2, t3})
if _, err := m.Wrap(varLenTargets); err != trojan.ErrVarLenTargets { if _, err := m.Wrap(ctx, varLenTargets); err != trojan.ErrVarLenTargets {
t.Fatalf("expected error when creating chunk for variable-length targets to be %q, but got %v", trojan.ErrVarLenTargets, err) t.Fatalf("expected error when creating chunk for variable-length targets to be %q, but got %v", trojan.ErrVarLenTargets, err)
} }
} }
// TestWrapTimeout tests for mining timeout and avoid forever loop // TestWrapTimeout tests for mining timeout and avoid forever loop
func TestWrapTimeout(t *testing.T) { func TestWrapTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
m := newTestMessage(t) m := newTestMessage(t)
// a large target will take more than MinerTimeout seconds, so timeout error will be triggered // a large target will take more than MinerTimeout seconds, so timeout error will be triggered
buf := make([]byte, swarm.SectionSize) buf := make([]byte, 16)
_, err := rand.Read(buf)
if err != nil {
t.Fatal(err)
}
target := trojan.Target(buf) target := trojan.Target(buf)
targets := trojan.Targets([]trojan.Target{target}) targets := trojan.Targets([]trojan.Target{target})
if _, err := m.Wrap(targets); err != trojan.ErrMinerTimeout { if _, err := m.Wrap(ctx, targets); err == nil || !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("expected error when having lengthy target to be %q, but got %v", trojan.ErrMinerTimeout, err) t.Fatalf("expected context timeout, got %v", err)
}
}
// TestPadBytes tests that different types of byte slices are correctly padded with leading 0s
// all slices are interpreted as big-endian
func TestPadBytes(t *testing.T) {
s := make([]byte, 32)
// empty slice should be unchanged
p := trojan.PadBytes(s)
if !bytes.Equal(p, s) {
t.Fatalf("expected byte padding to result in %x, but is %x", s, p)
}
// slice of length 3
s = []byte{255, 128, 64}
p = trojan.PadBytes(s)
e := append(make([]byte, 29), s...) // 29 zeros plus the 3 original bytes
if !bytes.Equal(p, e) {
t.Fatalf("expected byte padding to result in %x, but is %x", e, p)
}
// simulate toChunk behavior
s = []byte{1, 0, 0, 0}
p = trojan.PadBytes(s)
e = append(make([]byte, 28), s...) // 28 zeros plus the 4 original bytes
if !bytes.Equal(p, e) {
t.Fatalf("expected byte padding to result in %x, but is %x", e, p)
} }
} }
// TestUnwrap tests the correct unwrapping of a trojan chunk to obtain a message // TestUnwrap tests the correct unwrapping of a trojan chunk to obtain a message
func TestUnwrap(t *testing.T) { func TestUnwrap(t *testing.T) {
m := newTestMessage(t) m := newTestMessage(t)
c, err := m.Wrap(testTargets) c, err := m.Wrap(context.Background(), testTargets)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -213,7 +175,7 @@ func TestIsPotential(t *testing.T) { ...@@ -213,7 +175,7 @@ func TestIsPotential(t *testing.T) {
// valid potential trojan // valid potential trojan
m := newTestMessage(t) m := newTestMessage(t)
c, err := m.Wrap(testTargets) c, err := m.Wrap(context.Background(), testTargets)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
......
// Copyright 2020 The Swarm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package trojan_test
import (
"context"
"encoding/binary"
"fmt"
"testing"
"github.com/ethersphere/bee/pkg/trojan"
)
func newTargets(length, depth int) trojan.Targets {
targets := make([]trojan.Target, length)
for i := 0; i < length; i++ {
buf := make([]byte, 8)
binary.LittleEndian.PutUint64(buf, uint64(i))
targets[i] = trojan.Target(buf[:depth])
}
return trojan.Targets(targets)
}
func BenchmarkWrap(b *testing.B) {
payload := []byte("foopayload")
m, err := trojan.NewMessage(testTopic, payload)
if err != nil {
b.Fatal(err)
}
cases := []struct {
length int
depth int
}{
{1, 1},
{4, 1},
{16, 1},
{16, 2},
{64, 2},
{256, 2},
{256, 3},
{4096, 3},
{16384, 3},
}
for _, c := range cases {
name := fmt.Sprintf("length:%d,depth:%d", c.length, c.depth)
b.Run(name, func(b *testing.B) {
targets := newTargets(c.length, c.depth)
for i := 0; i < b.N; i++ {
if _, err := m.Wrap(context.Background(), targets); err != nil {
b.Fatal(err)
}
}
})
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment