Commit 44534433 authored by Michael de Hoog's avatar Michael de Hoog

Simplify context management

parent 6704c867
......@@ -293,14 +293,14 @@ func (l *BatchSubmitter) loop() {
defer publishTicker.Stop()
receiptsCh := make(chan txmgr.TxReceipt[txData])
queue := txmgr.NewQueue[txData](l.txMgr, l.MaxPendingTransactions, l.metr.RecordPendingTx)
queue := txmgr.NewQueue[txData](l.killCtx, l.txMgr, l.MaxPendingTransactions, l.metr.RecordPendingTx)
for {
select {
case <-loadTicker.C:
l.loadBlocksIntoState(l.shutdownCtx)
case <-publishTicker.C:
_, _ = queue.TrySend(l.killCtx, l.publishStateToL1Factory(), receiptsCh)
_, _ = queue.TrySend(l.publishStateToL1Factory(), receiptsCh)
case r := <-receiptsCh:
l.handleReceipt(r)
case <-l.shutdownCtx.Done():
......@@ -333,7 +333,7 @@ func (l *BatchSubmitter) drainState(receiptsCh chan txmgr.TxReceipt[txData], que
case <-l.killCtx.Done():
return
default:
err := queue.Send(l.killCtx, l.publishStateToL1Factory(), receiptsCh)
err := queue.Send(l.publishStateToL1Factory(), receiptsCh)
if err != nil {
if err != io.EOF {
l.log.Error("error while publishing state on shutdown", "err", err)
......
......@@ -24,6 +24,7 @@ type TxReceipt[T any] struct {
type TxFactory[T any] func(ctx context.Context) (TxCandidate, T, error)
type Queue[T any] struct {
ctx context.Context
txMgr TxManager
maxPending uint64
pendingChanged func(uint64)
......@@ -37,12 +38,13 @@ type Queue[T any] struct {
// - maxPending: max number of pending txs at once (0 == no limit)
// - pendingChanged: called whenever a tx send starts or finishes. The
// number of currently pending txs is passed as a parameter.
func NewQueue[T any](txMgr TxManager, maxPending uint64, pendingChanged func(uint64)) *Queue[T] {
func NewQueue[T any](ctx context.Context, txMgr TxManager, maxPending uint64, pendingChanged func(uint64)) *Queue[T] {
if maxPending > math.MaxInt {
// ensure we don't overflow as errgroup only accepts int; in reality this will never be an issue
maxPending = math.MaxInt
}
return &Queue[T]{
ctx: ctx,
txMgr: txMgr,
maxPending: maxPending,
pendingChanged: pendingChanged,
......@@ -63,11 +65,10 @@ func (q *Queue[T]) Wait() {
//
// The actual tx sending is non-blocking, with the receipt returned on the
// provided receipt channel.
func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) error {
ctx, cancel := q.mergeWithGroupContext(ctx)
func (q *Queue[T]) Send(factory TxFactory[T], receiptCh chan TxReceipt[T]) error {
group, ctx := q.groupContext()
factoryErrCh := make(chan error)
q.group.Go(func() error {
defer cancel()
group.Go(func() error {
return q.sendTx(ctx, factory, factoryErrCh, receiptCh)
})
return <-factoryErrCh
......@@ -83,15 +84,13 @@ func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh cha
//
// The actual tx sending is non-blocking, with the receipt returned on the
// provided receipt channel.
func (q *Queue[T]) TrySend(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) (bool, error) {
ctx, cancel := q.mergeWithGroupContext(ctx)
func (q *Queue[T]) TrySend(factory TxFactory[T], receiptCh chan TxReceipt[T]) (bool, error) {
group, ctx := q.groupContext()
factoryErrCh := make(chan error)
started := q.group.TryGo(func() error {
defer cancel()
started := group.TryGo(func() error {
return q.sendTx(ctx, factory, factoryErrCh, receiptCh)
})
if !started {
cancel()
return false, nil
}
err := <-factoryErrCh
......@@ -125,30 +124,17 @@ func (q *Queue[T]) sendTx(ctx context.Context, factory TxFactory[T], factoryErro
//
// If the group context doesn't exist or has already been canceled, a new one is created after
// waiting for existing group threads to complete.
func (q *Queue[T]) mergeWithGroupContext(ctx context.Context) (context.Context, context.CancelFunc) {
func (q *Queue[T]) groupContext() (*errgroup.Group, context.Context) {
q.groupLock.Lock()
defer q.groupLock.Unlock()
if q.groupCtx == nil || q.groupCtx.Err() != nil {
// no group exists, or the existing context has an error, so we need to wait
// for existing group threads to complete (if any) and create a new group
q.Wait()
q.group, q.groupCtx = errgroup.WithContext(context.Background())
q.group, q.groupCtx = errgroup.WithContext(q.ctx)
if q.maxPending > 0 {
q.group.SetLimit(int(q.maxPending))
}
}
ctx, cancel := context.WithCancel(ctx)
cl := make(chan struct{})
groupContext := q.groupCtx
go func() {
defer cancel()
select {
case <-groupContext.Done():
case <-cl:
}
}()
return ctx, func() {
close(cl)
}
return q.group, q.groupCtx
}
......@@ -18,15 +18,15 @@ import (
"golang.org/x/exp/slices"
)
type queueFunc func(ctx context.Context, factory TxFactory[int], receiptCh chan TxReceipt[int], q *Queue[int]) (bool, error)
type queueFunc func(factory TxFactory[int], receiptCh chan TxReceipt[int], q *Queue[int]) (bool, error)
func sendQueueFunc(ctx context.Context, factory TxFactory[int], receiptCh chan TxReceipt[int], q *Queue[int]) (bool, error) {
err := q.Send(ctx, factory, receiptCh)
func sendQueueFunc(factory TxFactory[int], receiptCh chan TxReceipt[int], q *Queue[int]) (bool, error) {
err := q.Send(factory, receiptCh)
return err == nil, err
}
func trySendQueueFunc(ctx context.Context, factory TxFactory[int], receiptCh chan TxReceipt[int], q *Queue[int]) (bool, error) {
return q.TrySend(ctx, factory, receiptCh)
func trySendQueueFunc(factory TxFactory[int], receiptCh chan TxReceipt[int], q *Queue[int]) (bool, error) {
return q.TrySend(factory, receiptCh)
}
type queueCall struct {
......@@ -237,7 +237,7 @@ func TestSend(t *testing.T) {
})
ctx := context.Background()
queue := NewQueue[int](mgr, test.max, func(uint64) {})
queue := NewQueue[int](ctx, mgr, test.max, func(uint64) {})
// make all the queue calls given in the test case
start := time.Now()
......@@ -245,7 +245,7 @@ func TestSend(t *testing.T) {
msg := fmt.Sprintf("Call %d", i)
c := c
receiptCh := make(chan TxReceipt[int], 1)
queued, err := c.call(ctx, factory, receiptCh, queue)
queued, err := c.call(factory, receiptCh, queue)
require.Equal(t, c.queued, queued, msg)
if c.callErr {
require.Error(t, err, msg)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment