Commit 29180941 authored by Michael de Hoog's avatar Michael de Hoog

Ensure new errgroup is created after cancelation

parent 3172ccfe
...@@ -25,10 +25,11 @@ type TxFactory[T any] func(ctx context.Context) (TxCandidate, T, error) ...@@ -25,10 +25,11 @@ type TxFactory[T any] func(ctx context.Context) (TxCandidate, T, error)
type Queue[T any] struct { type Queue[T any] struct {
txMgr TxManager txMgr TxManager
maxPending uint64
pendingChanged func(uint64) pendingChanged func(uint64)
pending atomic.Uint64 pending atomic.Uint64
lock sync.Mutex groupLock sync.Mutex
ctx context.Context groupCtx context.Context
group *errgroup.Group group *errgroup.Group
} }
...@@ -41,22 +42,35 @@ func NewQueue[T any](txMgr TxManager, maxPending uint64, pendingChanged func(uin ...@@ -41,22 +42,35 @@ func NewQueue[T any](txMgr TxManager, maxPending uint64, pendingChanged func(uin
// ensure we don't overflow as errgroup only accepts int; in reality this will never be an issue // ensure we don't overflow as errgroup only accepts int; in reality this will never be an issue
maxPending = math.MaxInt maxPending = math.MaxInt
} }
group, cxt := errgroup.WithContext(context.Background())
if maxPending > 0 {
group.SetLimit(int(maxPending))
} else {
group.SetLimit(-1)
}
return &Queue[T]{ return &Queue[T]{
txMgr: txMgr, txMgr: txMgr,
maxPending: maxPending,
pendingChanged: pendingChanged, pendingChanged: pendingChanged,
ctx: cxt,
group: group,
} }
} }
func (q *Queue[T]) groupContext() context.Context {
q.groupLock.Lock()
defer q.groupLock.Unlock()
if q.groupCtx != nil && q.groupCtx.Err() == nil {
// a group exists and has no errors, nothing to do
return q.groupCtx
}
q.Wait() // wait for existing processes in the group to stop (if any)
q.group, q.groupCtx = errgroup.WithContext(context.Background())
if q.maxPending > 0 {
q.group.SetLimit(int(q.maxPending))
} else {
q.group.SetLimit(-1)
}
return q.groupCtx
}
// Wait waits for all pending txs to complete (or fail). // Wait waits for all pending txs to complete (or fail).
func (q *Queue[T]) Wait() { func (q *Queue[T]) Wait() {
if q.group == nil {
return
}
_ = q.group.Wait() _ = q.group.Wait()
} }
...@@ -67,10 +81,11 @@ func (q *Queue[T]) Wait() { ...@@ -67,10 +81,11 @@ func (q *Queue[T]) Wait() {
// The actual tx sending is non-blocking, with the receipt returned on the // The actual tx sending is non-blocking, with the receipt returned on the
// provided receipt channel. // provided receipt channel.
func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) error { func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) error {
ctx, cancel := mergeContexts(ctx, q.ctx) groupCtx := q.groupContext()
defer cancel() ctx, cancel := mergeContexts(ctx, groupCtx)
factoryErrCh := make(chan error) factoryErrCh := make(chan error)
q.group.Go(func() error { q.group.Go(func() error {
defer cancel()
return q.sendTx(ctx, factory, factoryErrCh, receiptCh) return q.sendTx(ctx, factory, factoryErrCh, receiptCh)
}) })
return <-factoryErrCh return <-factoryErrCh
...@@ -87,13 +102,15 @@ func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh cha ...@@ -87,13 +102,15 @@ func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh cha
// The actual tx sending is non-blocking, with the receipt returned on the // The actual tx sending is non-blocking, with the receipt returned on the
// provided receipt channel. // provided receipt channel.
func (q *Queue[T]) TrySend(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) (bool, error) { func (q *Queue[T]) TrySend(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) (bool, error) {
ctx, cancel := mergeContexts(ctx, q.ctx) groupCtx := q.groupContext()
defer cancel() ctx, cancel := mergeContexts(ctx, groupCtx)
factoryErrCh := make(chan error) factoryErrCh := make(chan error)
started := q.group.TryGo(func() error { started := q.group.TryGo(func() error {
defer cancel()
return q.sendTx(ctx, factory, factoryErrCh, receiptCh) return q.sendTx(ctx, factory, factoryErrCh, receiptCh)
}) })
if !started { if !started {
cancel()
return false, nil return false, nil
} }
err := <-factoryErrCh err := <-factoryErrCh
...@@ -101,10 +118,6 @@ func (q *Queue[T]) TrySend(ctx context.Context, factory TxFactory[T], receiptCh ...@@ -101,10 +118,6 @@ func (q *Queue[T]) TrySend(ctx context.Context, factory TxFactory[T], receiptCh
} }
func (q *Queue[T]) sendTx(ctx context.Context, factory TxFactory[T], factoryErrorCh chan error, receiptCh chan TxReceipt[T]) error { func (q *Queue[T]) sendTx(ctx context.Context, factory TxFactory[T], factoryErrorCh chan error, receiptCh chan TxReceipt[T]) error {
// lock to prevent concurrent access to the tx factory
q.lock.Lock()
defer q.lock.Unlock()
candidate, id, err := factory(ctx) candidate, id, err := factory(ctx)
factoryErrorCh <- err factoryErrorCh <- err
if err != nil { if err != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment