Commit 7df426d9 authored by 李伟@五瓣科技's avatar 李伟@五瓣科技

BatchElem

parent 60913c29
package multisend
import (
"fmt"
"github.com/ethereum/go-ethereum/core/types"
)
// ClientFactory produces load testing clients.
type ClientFactory interface {
// ValidateConfig must check whether the given configuration is valid for
// our specific client factory.
ValidateConfig(cfg Config) error
// NewClient must instantiate a new load testing client, or produce an error
// if that process fails.
NewClient(cfg Config) (Client, error)
}
// Client generates transactions to be sent to a specific endpoint.
type Client interface {
// GenerateTx must generate a raw transaction to be sent to the relevant
// broadcast_tx method for a given endpoint.
GenerateTx() (*types.Transaction, error)
}
// Our global registry of client factories
var clientFactories = map[string]ClientFactory{}
// RegisterClientFactory allows us to programmatically register different client
// factories to easily switch between different ones at runtime.
func RegisterClientFactory(name string, factory ClientFactory) error {
if _, exists := clientFactories[name]; exists {
return fmt.Errorf("client factory with the specified name already exists: %s", name)
}
clientFactories[name] = factory
return nil
}
package multisend
import (
"encoding/json"
"fmt"
)
const (
SelectSuppliedEndpoints = "supplied" // Select only the supplied endpoint(s) for load testing (the default).
SelectDiscoveredEndpoints = "discovered" // Select newly discovered endpoints only (excluding supplied endpoints).
SelectAnyEndpoints = "any" // Select from any of supplied and/or discovered endpoints.
)
var validEndpointSelectMethods = map[string]interface{}{
SelectSuppliedEndpoints: nil,
SelectDiscoveredEndpoints: nil,
SelectAnyEndpoints: nil,
}
// Config represents the configuration for a single client (i.e. standalone or
// worker).
type Config struct {
ClientFactory string `json:"client_factory"` // Which client factory should we use for load testing?
Connections int `json:"connections"` // The number of WebSockets connections to make to each target endpoint.
Time int `json:"time"` // The total time, in seconds, for which to handle the load test.
SendPeriod int `json:"send_period"` // The period (in seconds) at which to send batches of transactions.
Rate int `json:"rate"` // The number of transactions to generate, per send period.
Size int `json:"size"` // The desired size of each generated transaction, in bytes.
Count int `json:"count"` // The maximum number of transactions to send. Set to -1 for unlimited.
BroadcastTxMethod string `json:"broadcast_tx_method"` // The broadcast_tx method to use (can be "sync", "async" or "commit").
Endpoints []string `json:"endpoints"` // A list of the Tendermint node endpoints to which to connect for this load test.
EndpointSelectMethod string `json:"endpoint_select_method"` // The method by which to select endpoints for load testing.
ExpectPeers int `json:"expect_peers"` // The minimum number of peers to expect before starting a load test. Set to 0 by default (no minimum).
MaxEndpoints int `json:"max_endpoints"` // The maximum number of endpoints to use for load testing. Set to 0 by default (no maximum).
MinConnectivity int `json:"min_connectivity"` // The minimum number of peers to which each peer must be connected before starting the load test. Set to 0 by default (no minimum).
PeerConnectTimeout int `json:"peer_connect_timeout"` // The maximum time to wait (in seconds) for all peers to connect, if ExpectPeers > 0.
StatsOutputFile string `json:"stats_output_file"` // Where to store the final aggregate statistics file (in CSV format).
NoTrapInterrupts bool `json:"no_trap_interrupts"` // Should we avoid trapping Ctrl+Break? Only relevant for standalone execution mode.
}
var validBroadcastTxMethods = map[string]interface{}{
"async": nil,
"sync": nil,
"commit": nil,
}
func (c Config) Validate() error {
if len(c.ClientFactory) == 0 {
return fmt.Errorf("client factory name must be specified")
}
factory, factoryExists := clientFactories[c.ClientFactory]
if !factoryExists {
return fmt.Errorf("client factory \"%s\" does not exist", c.ClientFactory)
}
// client factory-specific configuration validation
if err := factory.ValidateConfig(c); err != nil {
return fmt.Errorf("invalid configuration for client factory \"%s\": %v", c.ClientFactory, err)
}
if c.Connections < 1 {
return fmt.Errorf("expected connections to be >= 1, but was %d", c.Connections)
}
if c.Time < 1 {
return fmt.Errorf("expected load test time to be >= 1 second, but was %d", c.Time)
}
if c.SendPeriod < 1 {
return fmt.Errorf("expected transaction send period to be >= 1 second, but was %d", c.SendPeriod)
}
if c.Rate < 1 {
return fmt.Errorf("expected transaction rate to be >= 1, but was %d", c.Rate)
}
if c.Count < 1 && c.Count != -1 {
return fmt.Errorf("expected max transaction count to either be -1 or >= 1, but was %d", c.Count)
}
if _, ok := validBroadcastTxMethods[c.BroadcastTxMethod]; !ok {
return fmt.Errorf("expected broadcast_tx method to be one of \"sync\", \"async\" or \"commit\", but was %s", c.BroadcastTxMethod)
}
if len(c.Endpoints) == 0 {
return fmt.Errorf("expected at least one endpoint to conduct load test against, but found none")
}
if _, ok := validEndpointSelectMethods[c.EndpointSelectMethod]; !ok {
return fmt.Errorf("invalid endpoint-select-method: %s", c.EndpointSelectMethod)
}
if c.ExpectPeers < 0 {
return fmt.Errorf("expect-peers must be at least 0, but got %d", c.ExpectPeers)
}
if c.ExpectPeers > 0 && c.PeerConnectTimeout < 1 {
return fmt.Errorf("peer-connect-timeout must be at least 1 if expect-peers is non-zero, but got %d", c.PeerConnectTimeout)
}
if c.MaxEndpoints < 0 {
return fmt.Errorf("invalid value for max-endpoints: %d", c.MaxEndpoints)
}
if c.MinConnectivity < 0 {
return fmt.Errorf("invalid value for min-peer-connectivity: %d", c.MinConnectivity)
}
return nil
}
// MaxTxsPerEndpoint estimates the maximum number of transactions that this
// configuration would generate for a single endpoint.
func (c Config) MaxTxsPerEndpoint() uint64 {
if c.Count > -1 {
return uint64(c.Count)
}
return uint64(c.Rate) * uint64(c.Time)
}
func (c Config) ToJSON() string {
b, err := json.Marshal(c)
if err != nil {
return fmt.Sprintf("%v", c)
}
return string(b)
}
//WorkerConfig is the configuration options specific to a worker node.
// type WorkerConfig struct {
// ID string `json:"id"` // A unique ID for this worker instance. Will show up in the metrics reported by the coordinator for this worker.
// CoordAddr string `json:"coord_addr"` // The address at which to find the coordinator node.
// CoordConnectTimeout int `json:"connect_timeout"` // The maximum amount of time, in seconds, to allow for the coordinator to become available.
// }
// func (c WorkerConfig) Validate() error {
// if len(c.ID) > 0 && !isValidWorkerID(c.ID) {
// return fmt.Errorf("Invalid worker ID \"%s\": worker IDs can only be lowercase alphanumeric characters", c.ID)
// }
// if len(c.CoordAddr) == 0 {
// return fmt.Errorf("coordinator address must be specified")
// }
// if c.CoordConnectTimeout < 1 {
// return fmt.Errorf("expected connect-timeout to be >= 1, but was %d", c.CoordConnectTimeout)
// }
// return nil
// }
// func (c WorkerConfig) ToJSON() string {
// b, err := json.Marshal(c)
// if err != nil {
// return fmt.Sprintf("%v", c)
// }
// return string(b)
// }
package multisend
import (
"context"
"fmt"
"net"
"net/url"
"os"
"os/signal"
"syscall"
"time"
"code.wuban.net.cn/multisend/internal/logging"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/ethereum/go-ethereum/rpc"
)
type ethPeerInfo struct {
Addr string
Client ethclient.Client
PeerAddrs []string
SuccessfullyQueried bool
}
func waitForEthNetworkPeers(
endpoints []string,
selectionMethod string,
minDiscoveredPeers int,
minPeerConnectivity int,
maxReturnedPeers int,
timeout time.Duration,
logger logging.Logger,
) ([]string, error) {
logger.Info("waiting for eth public node to connect",
"endpoints", endpoints,
"selectionMethod", selectionMethod,
"minDiscoveredPeers", minDiscoveredPeers,
"minPeerConnectivity", minPeerConnectivity,
"maxReturnedPeers", maxReturnedPeers,
"timeout", fmt.Sprintf("%.2f seconds", timeout.Seconds()))
cancelc := make(chan struct{}, 1)
cancelTrap := trapInterrupts(func() { close(cancelc) }, logger)
defer close(cancelTrap)
startTime := time.Now()
suppliedPeers := make(map[string]*ethPeerInfo)
for _, peerURL := range endpoints {
u, err := url.Parse(peerURL)
if err != nil {
return nil, fmt.Errorf("failed to parse peer URL %s: %s", peerURL, err)
}
peerIP, err := lookupFirstIPv4Addr(u.Hostname())
if err != nil {
return nil, fmt.Errorf("failed to resolve IP address for endpoint %s: %s", peerURL, err)
}
peerAddr := fmt.Sprintf("http://%s:8546", peerIP)
client, err := rpc.DialWebsocket(context.Background(), peerAddr, "")
if err != nil {
return nil, err
}
suppliedPeers[peerAddr] = &ethPeerInfo{
Addr: peerAddr,
Client: client,
PeerAddrs: make([]string, 0),
}
}
peers := make(map[string]*ethPeerInfo)
for a, p := range suppliedPeers {
pc := *p
peers[a] = &pc
}
for {
remainingTimeout := timeout - time.Since(startTime)
if remainingTimeout < 0 {
return nil, fmt.Errorf("timed out waiting for Tendermint peer crawl to complete")
}
newPeers, err := getEthNetworkPeers(peers, remainingTimeout, cancelc, logger)
if err != nil {
return nil, err
}
// we only care if we've discovered more peers than in the previous attempt
if len(newPeers) > len(peers) {
peers = newPeers
}
peerCount := len(peers)
peerConnectivity := getMinPeerConnectivity(peers)
if peerCount >= minDiscoveredPeers && peerConnectivity >= minPeerConnectivity {
logger.Info("All required peers connected", "count", peerCount, "minConnectivity", minPeerConnectivity)
// we're done here
return filterEthPeerMap(suppliedPeers, peers, selectionMethod, maxReturnedPeers), nil
} else {
logger.Debug(
"Peers discovered so far",
"count", peerCount,
"minConnectivity", peerConnectivity,
"remainingTimeout", timeout-time.Since(startTime),
)
time.Sleep(1 * time.Second)
}
}
return nil, nil
}
func trapInterrupts(onKill func(), logger logging.Logger) chan struct{} {
sigc := make(chan os.Signal, 1)
cancelTrap := make(chan struct{})
signal.Notify(sigc, os.Interrupt, syscall.SIGTERM)
go func() {
select {
case <-sigc:
logger.Info("Caught kill signal")
onKill()
case <-cancelTrap:
logger.Debug("Interrupt trap cancelled")
}
}()
return cancelTrap
}
func lookupFirstIPv4Addr(hostname string) (string, error) {
ipRecords, err := net.LookupIP(hostname)
if err != nil {
return "", err
}
for _, ipRecord := range ipRecords {
ipv4 := ipRecord.To4()
if ipv4 != nil {
return ipv4.String(), nil
}
}
return "", fmt.Errorf("no IPv4 records for hostname: %s", hostname)
}
// Queries the given peers (in parallel) to construct a unique set of known
// peers across the entire network.
func getEthNetworkPeers(
peers map[string]*ethPeerInfo, // Any existing peers we know about already
timeout time.Duration, // Maximum timeout for the entire operation
cancelc chan struct{}, // Allows us to cancel the polling operations
logger logging.Logger,
) (map[string]*ethPeerInfo, error) {
startTime := time.Now()
peerInfoc := make(chan *ethPeerInfo, len(peers))
errc := make(chan error, len(peers))
logger.Debug("Querying peers for more peers", "count", len(peers), "peers", getPeerAddrs(peers))
// parallelize querying all the Tendermint nodes' peers
for _, peer := range peers {
// reset this every time
peer.SuccessfullyQueried = false
go func(peer_ *ethPeerInfo) {
netInfo, err := peer_.Client.NetInfo(context.Background())
if err != nil {
logger.Debug("Failed to query peer - skipping", "addr", peer_.Addr, "err", err)
errc <- err
return
}
peerAddrs := make([]string, 0)
for _, peerInfo := range netInfo.Peers {
peerAddrs = append(peerAddrs, fmt.Sprintf("http://%s:8546", peerInfo.RemoteIP))
}
peerInfoc <- &ethPeerInfo{
Addr: peer_.Addr,
Client: peer_.Client,
PeerAddrs: peerAddrs,
SuccessfullyQueried: true,
}
}(peer)
}
result := make(map[string]*ethPeerInfo)
expectedNetInfoResults := len(peers)
receivedNetInfoResults := 0
for {
remainingTimeout := timeout - time.Since(startTime)
if remainingTimeout < 0 {
return nil, fmt.Errorf("timed out waiting for all peer network info to be returned")
}
select {
case <-cancelc:
return nil, fmt.Errorf("cancel signal received")
case peerInfo := <-peerInfoc:
result[peerInfo.Addr] = peerInfo
receivedNetInfoResults++
case <-errc:
receivedNetInfoResults++
case <-time.After(remainingTimeout):
return nil, fmt.Errorf("timed out while waiting for all peer network info to be returned")
}
if receivedNetInfoResults >= expectedNetInfoResults {
return resolveTendermintPeerMap(result), nil
} else {
// wait a little before polling again
time.Sleep(1 * time.Second)
}
}
}
func resolveTendermintPeerMap(peers map[string]*ethPeerInfo) map[string]*ethPeerInfo {
result := make(map[string]*ethPeerInfo)
for addr, peer := range peers {
result[addr] = peer
for _, peerAddr := range peer.PeerAddrs {
client, err := rpc.DialWebsocket(context.Background(), peerAddr, "")
if err != nil {
return nil
}
if _, exists := result[peerAddr]; !exists {
result[peerAddr] = &ethPeerInfo{
Addr: peerAddr,
Client: client,
PeerAddrs: make([]string, 0),
}
}
}
}
return result
}
func getPeerAddrs(peers map[string]*ethPeerInfo) []string {
results := make([]string, 0)
for _, peer := range peers {
results = append(results, peer.Addr)
}
return results
}
func getMinPeerConnectivity(peers map[string]*ethPeerInfo) int {
minPeers := len(peers)
for _, peer := range peers {
// we only care about peers we've successfully queried so far
if !peer.SuccessfullyQueried {
continue
}
peerCount := len(peer.PeerAddrs)
if peerCount > 0 && peerCount < minPeers {
minPeers = peerCount
}
}
return minPeers
}
func filterEthPeerMap(suppliedPeers, newPeers map[string]*ethPeerInfo, selectionMethod string, maxCount int) []string {
result := make([]string, 0)
for peerAddr := range newPeers {
u, err := url.Parse(peerAddr)
if err != nil {
continue
}
addr := fmt.Sprintf("ws://%s:8546", u.Hostname())
switch selectionMethod {
case SelectSuppliedEndpoints:
// only add it to the result if it was in the original list
if _, ok := suppliedPeers[peerAddr]; ok {
result = append(result, addr)
}
case SelectDiscoveredEndpoints:
// only add it to the result if it wasn't in the original list
if _, ok := suppliedPeers[peerAddr]; !ok {
result = append(result, addr)
}
default:
// otherwise, always add it
result = append(result, addr)
}
if len(result) >= maxCount {
break
}
}
return result
}
module code.wuban.net.cn/multisend
go 1.17
require (
github.com/StackExchange/wmi v0.0.0-20180116203802-5d049714c4a6 // indirect
github.com/btcsuite/btcd v0.20.1-beta // indirect
github.com/deckarep/golang-set v1.8.0 // indirect
github.com/ethereum/go-ethereum v1.10.16 // indirect
github.com/go-ole/go-ole v1.2.1 // indirect
github.com/go-stack/stack v1.8.0 // indirect
github.com/gorilla/websocket v1.4.2 // indirect
github.com/shirou/gopsutil v3.21.4-0.20210419000835-c7a38de76ee5+incompatible // indirect
github.com/sirupsen/logrus v1.8.1 // indirect
github.com/tklauser/go-sysconf v0.3.5 // indirect
github.com/tklauser/numcpus v0.2.2 // indirect
golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 // indirect
golang.org/x/sys v0.0.0-20210816183151-1e6c022a8912 // indirect
gopkg.in/natefinch/npipe.v2 v2.0.0-20160621034901-c1b8fa8bdcce // indirect
)
This diff is collapsed.
package logging
import (
"sync"
"github.com/sirupsen/logrus"
)
// Logger is the interface to our internal logger.
type Logger interface {
Debug(msg string, kvpairs ...interface{})
Info(msg string, kvpairs ...interface{})
Error(msg string, kvpairs ...interface{})
SetField(key string, val interface{})
PushFields()
PopFields()
}
// LogrusLogger is a thread-safe logger whose properties persist and can be modified.
type LogrusLogger struct {
mtx sync.Mutex
logger *logrus.Entry
ctx string
fields map[string]interface{}
pushedFieldSets []map[string]interface{}
}
// NoopLogger implements Logger, but does nothing.
type NoopLogger struct{}
// LogrusLogger implements Logger
var _ Logger = (*LogrusLogger)(nil)
var _ Logger = (*NoopLogger)(nil)
//
// LogrusLogger
//
// NewLogrusLogger will instantiate a logger with the given context.
func NewLogrusLogger(ctx string, kvpairs ...interface{}) Logger {
var logger *logrus.Entry
if len(ctx) > 0 {
logger = logrus.WithField("ctx", ctx)
} else {
logger = logrus.NewEntry(logrus.New())
}
return &LogrusLogger{
logger: logger,
ctx: ctx,
fields: serializeKVPairs(kvpairs),
pushedFieldSets: []map[string]interface{}{},
}
}
func (l *LogrusLogger) withFields() *logrus.Entry {
if len(l.fields) > 0 {
return l.logger.WithFields(l.fields)
}
return l.logger
}
func serializeKVPairs(kvpairs ...interface{}) map[string]interface{} {
res := make(map[string]interface{})
if (len(kvpairs) % 2) == 0 {
for i := 0; i < len(kvpairs); i += 2 {
res[kvpairs[i].(string)] = kvpairs[i+1]
}
}
return res
}
func (l *LogrusLogger) withKVPairs(kvpairs ...interface{}) *logrus.Entry {
fields := serializeKVPairs(kvpairs...)
if len(fields) > 0 {
return l.withFields().WithFields(fields)
}
return l.withFields()
}
func (l *LogrusLogger) Debug(msg string, kvpairs ...interface{}) {
l.mtx.Lock()
defer l.mtx.Unlock()
l.withKVPairs(kvpairs...).Debugln(msg)
}
func (l *LogrusLogger) Info(msg string, kvpairs ...interface{}) {
l.mtx.Lock()
defer l.mtx.Unlock()
l.withKVPairs(kvpairs...).Infoln(msg)
}
func (l *LogrusLogger) Error(msg string, kvpairs ...interface{}) {
l.mtx.Lock()
defer l.mtx.Unlock()
l.withKVPairs(kvpairs...).Errorln(msg)
}
func (l *LogrusLogger) SetField(key string, val interface{}) {
l.mtx.Lock()
defer l.mtx.Unlock()
l.fields[key] = val
}
func (l *LogrusLogger) PushFields() {
l.mtx.Lock()
defer l.mtx.Unlock()
l.pushedFieldSets = append(l.pushedFieldSets, l.fields)
}
func (l *LogrusLogger) PopFields() {
l.mtx.Lock()
defer l.mtx.Unlock()
pfsLen := len(l.pushedFieldSets)
if pfsLen > 0 {
l.fields = l.pushedFieldSets[pfsLen-1]
l.pushedFieldSets = l.pushedFieldSets[:pfsLen-1]
}
}
//
// NoopLogger
//
// NewNoopLogger will instantiate a logger that does nothing when called.
func NewNoopLogger() Logger {
return &NoopLogger{}
}
func (l *NoopLogger) Debug(msg string, kvpairs ...interface{}) {}
func (l *NoopLogger) Info(msg string, kvpairs ...interface{}) {}
func (l *NoopLogger) Error(msg string, kvpairs ...interface{}) {}
func (l *NoopLogger) SetField(key string, val interface{}) {}
func (l *NoopLogger) PushFields() {}
func (l *NoopLogger) PopFields() {}
package logging
import (
"reflect"
"testing"
)
func TestKVPairSerialization(t *testing.T) {
testCases := []struct {
kvpairs []interface{}
expected map[string]interface{}
}{
{
[]interface{}{"a", 1, "b", "v"},
map[string]interface{}{
"a": 1,
"b": "v",
},
},
{
[]interface{}{"a"},
map[string]interface{}{},
},
{
[]interface{}{"a", 1, "b"},
map[string]interface{}{},
},
}
for i, tc := range testCases {
actual := serializeKVPairs(tc.kvpairs...)
if !reflect.DeepEqual(actual, tc.expected) {
t.Errorf("Test case %d: Expected result %v, but got %v", i, tc.expected, actual)
}
}
}
package multisend
import (
"time"
"code.wuban.net.cn/multisend/internal/logging"
)
func ExecuteStandalone(cfg Config) error {
logger := logging.NewLogrusLogger("loadtest")
if cfg.ExpectPeers > 0 {
peers, err := waitForEthNetworkPeers(
cfg.Endpoints,
cfg.EndpointSelectMethod,
cfg.ExpectPeers,
cfg.MinConnectivity,
cfg.MaxEndpoints,
time.Duration(cfg.PeerConnectTimeout)*time.Second,
logger,
)
if err != nil {
logger.Error("Failed while waiting for peers to connect", "err", err)
return err
}
cfg.Endpoints = peers
}
logger.Info("Connecting to remote endpoints")
tg := NewTransactorGroup()
if err := tg.AddAll(&cfg); err != nil {
return err
}
logger.Info("Initiating load test")
tg.Start()
var cancelTrap chan struct{}
if !cfg.NoTrapInterrupts {
// we want to know if the user hits Ctrl+Break
cancelTrap = trapInterrupts(func() { tg.Cancel() }, logger)
defer close(cancelTrap)
} else {
logger.Debug("Skipping trapping of interrupts (e.g. Ctrl+Break)")
}
if err := tg.Wait(); err != nil {
logger.Error("Failed to execute load test", "err", err)
return err
}
logger.Info("Load test complete!")
return nil
}
package multisend
// A generic message to/from a worker.
type workerMsg struct {
ID string `json:"id,omitempty"` // A UUID for this worker.
State workerState `json:"state,omitempty"` // The worker's desired or actual state.
TxCount int `json:"tx_count,omitempty"` // The total number of transactions sent thus far by this worker.
TotalTxBytes int64 `json:"total_tx_bytes,omitempty"` // The total number of transaction bytes sent thus far by this worker.
Error string `json:"error,omitempty"` // If the worker has failed somehow, a descriptive error message as to why.
Config *Config `json:"config,omitempty"` // The load testing configuration, if relevant.
}
type workerState string
// Remote worker possible states
const (
workerConnected workerState = "connected"
workerAccepted workerState = "accepted"
workerRejected workerState = "rejected"
workerTesting workerState = "testing"
workerFailed workerState = "failed"
workerCompleted workerState = "completed"
)
This diff is collapsed.
package multisend
import (
"sync"
"time"
)
// TransactorGroup allows us to encapsulate the management of a group of
// transactors.
type TransactorGroup struct {
transactors []*Transactor
statsMtx sync.RWMutex
startTime time.Time
txCounts map[int]int // The counts of all of the total transactions per transactor.
txBytes map[int]int64 // The total number of transaction bytes sent per transactor.
progressCallbackMtx sync.RWMutex
progressCallbackInterval time.Duration
progressCallback func(g *TransactorGroup, txCount int, txBytes int64)
stopProgressReporter chan struct{} // Close this to stop the progress reporter.
progressReporterStopped chan struct{} // Closed when the progress reporter goroutine has completely stopped.
}
func NewTransactorGroup() *TransactorGroup {
return &TransactorGroup{
transactors: make([]*Transactor, 0),
txCounts: make(map[int]int),
txBytes: make(map[int]int64),
progressCallbackInterval: defaultProgressCallbackInterval,
stopProgressReporter: make(chan struct{}, 1),
progressReporterStopped: make(chan struct{}, 1),
}
}
// Add will instantiate a new Transactor with the given parameters. If
// instantiation fails it'll automatically shut down and close all other
// transactors, returning the error.
func (g *TransactorGroup) Add(remoteAddr string, config *Config) error {
t, err := NewTransactor(remoteAddr, config)
if err != nil {
g.close()
return err
}
id := len(g.transactors)
t.SetProgressCallback(id, g.getProgressCallbackInterval()/2, g.trackTransactorProgress)
g.transactors = append(g.transactors, t)
return nil
}
func (g *TransactorGroup) AddAll(cfg *Config) error {
for _, endpoint := range cfg.Endpoints {
for c := 0; c < cfg.Connections; c++ {
if err := g.Add(endpoint, cfg); err != nil {
return err
}
}
}
return nil
}
func (g *TransactorGroup) SetProgressCallback(interval time.Duration, callback func(*TransactorGroup, int, int64)) {
g.progressCallbackMtx.Lock()
g.progressCallbackInterval = interval
g.progressCallback = callback
g.progressCallbackMtx.Unlock()
}
// Start will handle through all transactors and start them.
func (g *TransactorGroup) Start() {
go g.progressReporter()
for _, t := range g.transactors {
t.Start()
}
g.setStartTime(time.Now())
}
// Cancel signals to all transactors to stop their operations.
func (g *TransactorGroup) Cancel() {
for _, t := range g.transactors {
t.Cancel()
}
}
// Wait will wait for all transactors to complete, returning the first error
// we encounter.
func (g *TransactorGroup) Wait() error {
defer func() {
close(g.stopProgressReporter)
<-g.progressReporterStopped
}()
var wg sync.WaitGroup
var err error
errc := make(chan error, len(g.transactors))
for i, t := range g.transactors {
wg.Add(1)
go func(_i int, _t *Transactor) {
errc <- _t.Wait()
defer wg.Done()
// get the final tx count
g.trackTransactorProgress(_i, _t.GetTxCount(), _t.GetTxBytes())
}(i, t)
}
wg.Wait()
// collect the results
for i := 0; i < len(g.transactors); i++ {
if e := <-errc; e != nil {
err = e
break
}
}
return err
}
// func (g *TransactorGroup) WriteAggregateStats(filename string) error {
// stats := AggregateStats{
// TotalTxs: g.totalTxs(),
// TotalTimeSeconds: time.Since(g.getStartTime()).Seconds(),
// TotalBytes: g.totalBytes(),
// }
// return writeAggregateStats(filename, stats)
// }
func (g *TransactorGroup) progressReporter() {
defer close(g.progressReporterStopped)
ticker := time.NewTicker(g.getProgressCallbackInterval())
defer ticker.Stop()
for {
select {
case <-ticker.C:
g.reportProgress()
case <-g.stopProgressReporter:
return
}
}
}
func (g *TransactorGroup) setStartTime(startTime time.Time) {
g.statsMtx.Lock()
g.startTime = startTime
g.statsMtx.Unlock()
}
func (g *TransactorGroup) getStartTime() time.Time {
g.statsMtx.RLock()
defer g.statsMtx.RUnlock()
return g.startTime
}
func (g *TransactorGroup) trackTransactorProgress(id int, txCount int, txBytes int64) {
g.statsMtx.Lock()
g.txCounts[id] = txCount
g.txBytes[id] = txBytes
g.statsMtx.Unlock()
}
func (g *TransactorGroup) getProgressCallbackInterval() time.Duration {
g.progressCallbackMtx.RLock()
defer g.progressCallbackMtx.RUnlock()
return g.progressCallbackInterval
}
func (g *TransactorGroup) reportProgress() {
totalTxs := g.totalTxs()
totalBytes := g.totalBytes()
g.progressCallbackMtx.RLock()
if g.progressCallback != nil {
g.progressCallback(g, totalTxs, totalBytes)
}
g.progressCallbackMtx.RUnlock()
}
func (g *TransactorGroup) totalTxs() int {
g.statsMtx.RLock()
defer g.statsMtx.RUnlock()
total := 0
for _, txCount := range g.txCounts {
total += txCount
}
return total
}
func (g *TransactorGroup) totalBytes() int64 {
g.statsMtx.RLock()
defer g.statsMtx.RUnlock()
total := int64(0)
for _, txBytes := range g.txBytes {
total += txBytes
}
return total
}
func (g *TransactorGroup) close() {
for _, t := range g.transactors {
t.close()
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment