Commit 8ce2ad39 authored by Michael de Hoog's avatar Michael de Hoog

Small refactor

parent dcdeeaf9
......@@ -49,23 +49,6 @@ func NewQueue[T any](txMgr TxManager, maxPending uint64, pendingChanged func(uin
}
}
func (q *Queue[T]) groupContext() context.Context {
q.groupLock.Lock()
defer q.groupLock.Unlock()
if q.groupCtx != nil && q.groupCtx.Err() == nil {
// a group exists and has no errors, nothing to do
return q.groupCtx
}
q.Wait() // wait for existing processes in the group to stop (if any)
q.group, q.groupCtx = errgroup.WithContext(context.Background())
if q.maxPending > 0 {
q.group.SetLimit(int(q.maxPending))
} else {
q.group.SetLimit(-1)
}
return q.groupCtx
}
// Wait waits for all pending txs to complete (or fail).
func (q *Queue[T]) Wait() {
if q.group == nil {
......@@ -81,8 +64,7 @@ func (q *Queue[T]) Wait() {
// The actual tx sending is non-blocking, with the receipt returned on the
// provided receipt channel.
func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) error {
groupCtx := q.groupContext()
ctx, cancel := mergeContexts(ctx, groupCtx)
ctx, cancel := q.mergeWithGroupContext(ctx)
factoryErrCh := make(chan error)
q.group.Go(func() error {
defer cancel()
......@@ -102,8 +84,7 @@ func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh cha
// The actual tx sending is non-blocking, with the receipt returned on the
// provided receipt channel.
func (q *Queue[T]) TrySend(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) (bool, error) {
groupCtx := q.groupContext()
ctx, cancel := mergeContexts(ctx, groupCtx)
ctx, cancel := q.mergeWithGroupContext(ctx)
factoryErrCh := make(chan error)
started := q.group.TryGo(func() error {
defer cancel()
......@@ -122,7 +103,7 @@ func (q *Queue[T]) sendTx(ctx context.Context, factory TxFactory[T], factoryErro
factoryErrorCh <- err
if err != nil {
// Factory returned an error which was returned in the channel. This means
// there was no tx to send, so return nil.
// there is no tx to send, so return nil.
return nil
}
......@@ -139,15 +120,33 @@ func (q *Queue[T]) sendTx(ctx context.Context, factory TxFactory[T], factoryErro
return err
}
// mergeContexts creates a new Context that is canceled if either of the two
// contexts are closed. The CancelFunc should be called once finished.
func mergeContexts(ctx1 context.Context, ctx2 context.Context) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(ctx1)
// mergeWithGroupContext creates a new Context that is canceled if either the given context is
// Done, or the group context is canceled. The returned CancelFunc should be called once finished.
//
// If the group context doesn't exist or has already been canceled, a new one is created after
// waiting for existing group threads to complete.
func (q *Queue[T]) mergeWithGroupContext(ctx context.Context) (context.Context, context.CancelFunc) {
q.groupLock.Lock()
defer q.groupLock.Unlock()
if q.groupCtx == nil || q.groupCtx.Err() != nil {
// no group exists, or the existing context has an error, so we need to wait
// for existing group threads to complete (if any) and create a new group
q.Wait()
q.group, q.groupCtx = errgroup.WithContext(context.Background())
if q.maxPending > 0 {
q.group.SetLimit(int(q.maxPending))
} else {
q.group.SetLimit(-1)
}
}
ctx, cancel := context.WithCancel(ctx)
cl := make(chan struct{})
groupContext := q.groupCtx
go func() {
defer cancel()
select {
case <-ctx2.Done():
case <-groupContext.Done():
case <-cl:
}
}()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment