Commit 6a2d3381 authored by Michael de Hoog's avatar Michael de Hoog

Use errgroup instead of semaphore

parent 57bef610
......@@ -7,7 +7,7 @@ import (
"sync/atomic"
"github.com/ethereum/go-ethereum/core/types"
"golang.org/x/sync/semaphore"
"golang.org/x/sync/errgroup"
)
type TxReceipt[T any] struct {
......@@ -27,8 +27,9 @@ type Queue[T any] struct {
txMgr TxManager
pendingChanged func(uint64)
pending atomic.Uint64
semaphore *semaphore.Weighted
wg sync.WaitGroup
lock sync.Mutex
ctx context.Context
group *errgroup.Group
}
// NewQueue creates a new transaction sending Queue, with the following parameters:
......@@ -36,25 +37,27 @@ type Queue[T any] struct {
// - pendingChanged: called whenever a tx send starts or finishes. The
// number of currently pending txs is passed as a parameter.
func NewQueue[T any](txMgr TxManager, maxPending uint64, pendingChanged func(uint64)) *Queue[T] {
if maxPending > math.MaxInt64 {
// ensure we don't overflow as semaphore only accepts int64; in reality this will never be an issue
maxPending = math.MaxInt64
if maxPending > math.MaxInt {
// ensure we don't overflow as errgroup only accepts int; in reality this will never be an issue
maxPending = math.MaxInt
}
var s *semaphore.Weighted
group, cxt := errgroup.WithContext(context.Background())
if maxPending > 0 {
// only create a semaphore for limited-size queues
s = semaphore.NewWeighted(int64(maxPending))
group.SetLimit(int(maxPending))
} else {
group.SetLimit(-1)
}
return &Queue[T]{
txMgr: txMgr,
pendingChanged: pendingChanged,
semaphore: s,
ctx: cxt,
group: group,
}
}
// Wait waits for all pending txs to complete (or fail).
func (q *Queue[T]) Wait() {
q.wg.Wait()
_ = q.group.Wait()
}
// Send will wait until the number of pending txs is below the max pending,
......@@ -64,12 +67,15 @@ func (q *Queue[T]) Wait() {
// The actual tx sending is non-blocking, with the receipt returned on the
// provided receipt channel.
func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) error {
if q.semaphore != nil {
if err := q.semaphore.Acquire(ctx, 1); err != nil {
return err
}
}
return q.trySend(ctx, factory, receiptCh)
ctx, cancel := mergeContexts(ctx, q.ctx)
defer cancel()
errChan := make(chan error)
q.group.Go(func() error {
sender, err := q.createTxSender(ctx, factory, receiptCh)
errChan <- err
return sender()
})
return <-errChan
}
// TrySend sends the next tx, but only if the number of pending txs is below the
......@@ -83,32 +89,35 @@ func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh cha
// The actual tx sending is non-blocking, with the receipt returned on the
// provided receipt channel.
func (q *Queue[T]) TrySend(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) (bool, error) {
if q.semaphore != nil && !q.semaphore.TryAcquire(1) {
ctx, cancel := mergeContexts(ctx, q.ctx)
defer cancel()
errChan := make(chan error)
queued := q.group.TryGo(func() error {
sender, err := q.createTxSender(ctx, factory, receiptCh)
errChan <- err
return sender()
})
if !queued {
return false, nil
}
err := q.trySend(ctx, factory, receiptCh)
return err == nil, err
err := <-errChan
return err != nil, err
}
func (q *Queue[T]) trySend(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) error {
func (q *Queue[T]) createTxSender(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) (func() error, error) {
// lock to prevent concurrent access to the tx factory
q.lock.Lock()
defer q.lock.Unlock()
candidate, id, err := factory(ctx)
release := func() {
if q.semaphore != nil {
q.semaphore.Release(1)
}
}
if err != nil {
release()
return err
return nil, err
}
q.pendingChanged(q.pending.Add(1))
q.wg.Add(1)
go func() {
return func() error {
defer func() {
release()
q.pendingChanged(q.pending.Add(^uint64(0))) // -1
q.wg.Done()
}()
receipt, err := q.txMgr.Send(ctx, candidate)
receiptCh <- TxReceipt[T]{
......@@ -116,6 +125,23 @@ func (q *Queue[T]) trySend(ctx context.Context, factory TxFactory[T], receiptCh
Receipt: receipt,
Err: err,
}
return err
}, nil
}
// mergeContexts creates a new Context that is canceled if either of the two
// contexts are closed. The CancelFunc should be called once finished.
func mergeContexts(ctx1 context.Context, ctx2 context.Context) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx1)
cl := make(chan struct{})
go func() {
defer cancel()
select {
case <-ctx2.Done():
case <-cl:
}
}()
return nil
return ctx, func() {
close(cl)
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment