Commit 29180941 authored by Michael de Hoog's avatar Michael de Hoog

Ensure new errgroup is created after cancelation

parent 3172ccfe
......@@ -25,10 +25,11 @@ type TxFactory[T any] func(ctx context.Context) (TxCandidate, T, error)
type Queue[T any] struct {
txMgr TxManager
maxPending uint64
pendingChanged func(uint64)
pending atomic.Uint64
lock sync.Mutex
ctx context.Context
groupLock sync.Mutex
groupCtx context.Context
group *errgroup.Group
}
......@@ -41,22 +42,35 @@ func NewQueue[T any](txMgr TxManager, maxPending uint64, pendingChanged func(uin
// ensure we don't overflow as errgroup only accepts int; in reality this will never be an issue
maxPending = math.MaxInt
}
group, cxt := errgroup.WithContext(context.Background())
if maxPending > 0 {
group.SetLimit(int(maxPending))
} else {
group.SetLimit(-1)
}
return &Queue[T]{
txMgr: txMgr,
maxPending: maxPending,
pendingChanged: pendingChanged,
ctx: cxt,
group: group,
}
}
func (q *Queue[T]) groupContext() context.Context {
q.groupLock.Lock()
defer q.groupLock.Unlock()
if q.groupCtx != nil && q.groupCtx.Err() == nil {
// a group exists and has no errors, nothing to do
return q.groupCtx
}
q.Wait() // wait for existing processes in the group to stop (if any)
q.group, q.groupCtx = errgroup.WithContext(context.Background())
if q.maxPending > 0 {
q.group.SetLimit(int(q.maxPending))
} else {
q.group.SetLimit(-1)
}
return q.groupCtx
}
// Wait waits for all pending txs to complete (or fail).
func (q *Queue[T]) Wait() {
if q.group == nil {
return
}
_ = q.group.Wait()
}
......@@ -67,10 +81,11 @@ func (q *Queue[T]) Wait() {
// The actual tx sending is non-blocking, with the receipt returned on the
// provided receipt channel.
func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) error {
ctx, cancel := mergeContexts(ctx, q.ctx)
defer cancel()
groupCtx := q.groupContext()
ctx, cancel := mergeContexts(ctx, groupCtx)
factoryErrCh := make(chan error)
q.group.Go(func() error {
defer cancel()
return q.sendTx(ctx, factory, factoryErrCh, receiptCh)
})
return <-factoryErrCh
......@@ -87,13 +102,15 @@ func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh cha
// The actual tx sending is non-blocking, with the receipt returned on the
// provided receipt channel.
func (q *Queue[T]) TrySend(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) (bool, error) {
ctx, cancel := mergeContexts(ctx, q.ctx)
defer cancel()
groupCtx := q.groupContext()
ctx, cancel := mergeContexts(ctx, groupCtx)
factoryErrCh := make(chan error)
started := q.group.TryGo(func() error {
defer cancel()
return q.sendTx(ctx, factory, factoryErrCh, receiptCh)
})
if !started {
cancel()
return false, nil
}
err := <-factoryErrCh
......@@ -101,10 +118,6 @@ func (q *Queue[T]) TrySend(ctx context.Context, factory TxFactory[T], receiptCh
}
func (q *Queue[T]) sendTx(ctx context.Context, factory TxFactory[T], factoryErrorCh chan error, receiptCh chan TxReceipt[T]) error {
// lock to prevent concurrent access to the tx factory
q.lock.Lock()
defer q.lock.Unlock()
candidate, id, err := factory(ctx)
factoryErrorCh <- err
if err != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment