Commit 6a2d3381 authored by Michael de Hoog's avatar Michael de Hoog

Use errgroup instead of semaphore

parent 57bef610
...@@ -7,7 +7,7 @@ import ( ...@@ -7,7 +7,7 @@ import (
"sync/atomic" "sync/atomic"
"github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/core/types"
"golang.org/x/sync/semaphore" "golang.org/x/sync/errgroup"
) )
type TxReceipt[T any] struct { type TxReceipt[T any] struct {
...@@ -27,8 +27,9 @@ type Queue[T any] struct { ...@@ -27,8 +27,9 @@ type Queue[T any] struct {
txMgr TxManager txMgr TxManager
pendingChanged func(uint64) pendingChanged func(uint64)
pending atomic.Uint64 pending atomic.Uint64
semaphore *semaphore.Weighted lock sync.Mutex
wg sync.WaitGroup ctx context.Context
group *errgroup.Group
} }
// NewQueue creates a new transaction sending Queue, with the following parameters: // NewQueue creates a new transaction sending Queue, with the following parameters:
...@@ -36,25 +37,27 @@ type Queue[T any] struct { ...@@ -36,25 +37,27 @@ type Queue[T any] struct {
// - pendingChanged: called whenever a tx send starts or finishes. The // - pendingChanged: called whenever a tx send starts or finishes. The
// number of currently pending txs is passed as a parameter. // number of currently pending txs is passed as a parameter.
func NewQueue[T any](txMgr TxManager, maxPending uint64, pendingChanged func(uint64)) *Queue[T] { func NewQueue[T any](txMgr TxManager, maxPending uint64, pendingChanged func(uint64)) *Queue[T] {
if maxPending > math.MaxInt64 { if maxPending > math.MaxInt {
// ensure we don't overflow as semaphore only accepts int64; in reality this will never be an issue // ensure we don't overflow as errgroup only accepts int; in reality this will never be an issue
maxPending = math.MaxInt64 maxPending = math.MaxInt
} }
var s *semaphore.Weighted group, cxt := errgroup.WithContext(context.Background())
if maxPending > 0 { if maxPending > 0 {
// only create a semaphore for limited-size queues group.SetLimit(int(maxPending))
s = semaphore.NewWeighted(int64(maxPending)) } else {
group.SetLimit(-1)
} }
return &Queue[T]{ return &Queue[T]{
txMgr: txMgr, txMgr: txMgr,
pendingChanged: pendingChanged, pendingChanged: pendingChanged,
semaphore: s, ctx: cxt,
group: group,
} }
} }
// Wait waits for all pending txs to complete (or fail). // Wait waits for all pending txs to complete (or fail).
func (q *Queue[T]) Wait() { func (q *Queue[T]) Wait() {
q.wg.Wait() _ = q.group.Wait()
} }
// Send will wait until the number of pending txs is below the max pending, // Send will wait until the number of pending txs is below the max pending,
...@@ -64,12 +67,15 @@ func (q *Queue[T]) Wait() { ...@@ -64,12 +67,15 @@ func (q *Queue[T]) Wait() {
// The actual tx sending is non-blocking, with the receipt returned on the // The actual tx sending is non-blocking, with the receipt returned on the
// provided receipt channel. // provided receipt channel.
func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) error { func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) error {
if q.semaphore != nil { ctx, cancel := mergeContexts(ctx, q.ctx)
if err := q.semaphore.Acquire(ctx, 1); err != nil { defer cancel()
return err errChan := make(chan error)
} q.group.Go(func() error {
} sender, err := q.createTxSender(ctx, factory, receiptCh)
return q.trySend(ctx, factory, receiptCh) errChan <- err
return sender()
})
return <-errChan
} }
// TrySend sends the next tx, but only if the number of pending txs is below the // TrySend sends the next tx, but only if the number of pending txs is below the
...@@ -83,32 +89,35 @@ func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh cha ...@@ -83,32 +89,35 @@ func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh cha
// The actual tx sending is non-blocking, with the receipt returned on the // The actual tx sending is non-blocking, with the receipt returned on the
// provided receipt channel. // provided receipt channel.
func (q *Queue[T]) TrySend(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) (bool, error) { func (q *Queue[T]) TrySend(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) (bool, error) {
if q.semaphore != nil && !q.semaphore.TryAcquire(1) { ctx, cancel := mergeContexts(ctx, q.ctx)
defer cancel()
errChan := make(chan error)
queued := q.group.TryGo(func() error {
sender, err := q.createTxSender(ctx, factory, receiptCh)
errChan <- err
return sender()
})
if !queued {
return false, nil return false, nil
} }
err := q.trySend(ctx, factory, receiptCh) err := <-errChan
return err == nil, err return err != nil, err
} }
func (q *Queue[T]) trySend(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) error { func (q *Queue[T]) createTxSender(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) (func() error, error) {
// lock to prevent concurrent access to the tx factory
q.lock.Lock()
defer q.lock.Unlock()
candidate, id, err := factory(ctx) candidate, id, err := factory(ctx)
release := func() {
if q.semaphore != nil {
q.semaphore.Release(1)
}
}
if err != nil { if err != nil {
release() return nil, err
return err
} }
q.pendingChanged(q.pending.Add(1)) q.pendingChanged(q.pending.Add(1))
q.wg.Add(1) return func() error {
go func() {
defer func() { defer func() {
release()
q.pendingChanged(q.pending.Add(^uint64(0))) // -1 q.pendingChanged(q.pending.Add(^uint64(0))) // -1
q.wg.Done()
}() }()
receipt, err := q.txMgr.Send(ctx, candidate) receipt, err := q.txMgr.Send(ctx, candidate)
receiptCh <- TxReceipt[T]{ receiptCh <- TxReceipt[T]{
...@@ -116,6 +125,23 @@ func (q *Queue[T]) trySend(ctx context.Context, factory TxFactory[T], receiptCh ...@@ -116,6 +125,23 @@ func (q *Queue[T]) trySend(ctx context.Context, factory TxFactory[T], receiptCh
Receipt: receipt, Receipt: receipt,
Err: err, Err: err,
} }
return err
}, nil
}
// mergeContexts creates a new Context that is canceled if either of the two
// contexts are closed. The CancelFunc should be called once finished.
func mergeContexts(ctx1 context.Context, ctx2 context.Context) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx1)
cl := make(chan struct{})
go func() {
defer cancel()
select {
case <-ctx2.Done():
case <-cl:
}
}() }()
return nil return ctx, func() {
close(cl)
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment