Commit 44534433 authored by Michael de Hoog's avatar Michael de Hoog

Simplify context management

parent 6704c867
...@@ -293,14 +293,14 @@ func (l *BatchSubmitter) loop() { ...@@ -293,14 +293,14 @@ func (l *BatchSubmitter) loop() {
defer publishTicker.Stop() defer publishTicker.Stop()
receiptsCh := make(chan txmgr.TxReceipt[txData]) receiptsCh := make(chan txmgr.TxReceipt[txData])
queue := txmgr.NewQueue[txData](l.txMgr, l.MaxPendingTransactions, l.metr.RecordPendingTx) queue := txmgr.NewQueue[txData](l.killCtx, l.txMgr, l.MaxPendingTransactions, l.metr.RecordPendingTx)
for { for {
select { select {
case <-loadTicker.C: case <-loadTicker.C:
l.loadBlocksIntoState(l.shutdownCtx) l.loadBlocksIntoState(l.shutdownCtx)
case <-publishTicker.C: case <-publishTicker.C:
_, _ = queue.TrySend(l.killCtx, l.publishStateToL1Factory(), receiptsCh) _, _ = queue.TrySend(l.publishStateToL1Factory(), receiptsCh)
case r := <-receiptsCh: case r := <-receiptsCh:
l.handleReceipt(r) l.handleReceipt(r)
case <-l.shutdownCtx.Done(): case <-l.shutdownCtx.Done():
...@@ -333,7 +333,7 @@ func (l *BatchSubmitter) drainState(receiptsCh chan txmgr.TxReceipt[txData], que ...@@ -333,7 +333,7 @@ func (l *BatchSubmitter) drainState(receiptsCh chan txmgr.TxReceipt[txData], que
case <-l.killCtx.Done(): case <-l.killCtx.Done():
return return
default: default:
err := queue.Send(l.killCtx, l.publishStateToL1Factory(), receiptsCh) err := queue.Send(l.publishStateToL1Factory(), receiptsCh)
if err != nil { if err != nil {
if err != io.EOF { if err != io.EOF {
l.log.Error("error while publishing state on shutdown", "err", err) l.log.Error("error while publishing state on shutdown", "err", err)
......
...@@ -24,6 +24,7 @@ type TxReceipt[T any] struct { ...@@ -24,6 +24,7 @@ type TxReceipt[T any] struct {
type TxFactory[T any] func(ctx context.Context) (TxCandidate, T, error) type TxFactory[T any] func(ctx context.Context) (TxCandidate, T, error)
type Queue[T any] struct { type Queue[T any] struct {
ctx context.Context
txMgr TxManager txMgr TxManager
maxPending uint64 maxPending uint64
pendingChanged func(uint64) pendingChanged func(uint64)
...@@ -37,12 +38,13 @@ type Queue[T any] struct { ...@@ -37,12 +38,13 @@ type Queue[T any] struct {
// - maxPending: max number of pending txs at once (0 == no limit) // - maxPending: max number of pending txs at once (0 == no limit)
// - pendingChanged: called whenever a tx send starts or finishes. The // - pendingChanged: called whenever a tx send starts or finishes. The
// number of currently pending txs is passed as a parameter. // number of currently pending txs is passed as a parameter.
func NewQueue[T any](txMgr TxManager, maxPending uint64, pendingChanged func(uint64)) *Queue[T] { func NewQueue[T any](ctx context.Context, txMgr TxManager, maxPending uint64, pendingChanged func(uint64)) *Queue[T] {
if maxPending > math.MaxInt { if maxPending > math.MaxInt {
// ensure we don't overflow as errgroup only accepts int; in reality this will never be an issue // ensure we don't overflow as errgroup only accepts int; in reality this will never be an issue
maxPending = math.MaxInt maxPending = math.MaxInt
} }
return &Queue[T]{ return &Queue[T]{
ctx: ctx,
txMgr: txMgr, txMgr: txMgr,
maxPending: maxPending, maxPending: maxPending,
pendingChanged: pendingChanged, pendingChanged: pendingChanged,
...@@ -63,11 +65,10 @@ func (q *Queue[T]) Wait() { ...@@ -63,11 +65,10 @@ func (q *Queue[T]) Wait() {
// //
// The actual tx sending is non-blocking, with the receipt returned on the // The actual tx sending is non-blocking, with the receipt returned on the
// provided receipt channel. // provided receipt channel.
func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) error { func (q *Queue[T]) Send(factory TxFactory[T], receiptCh chan TxReceipt[T]) error {
ctx, cancel := q.mergeWithGroupContext(ctx) group, ctx := q.groupContext()
factoryErrCh := make(chan error) factoryErrCh := make(chan error)
q.group.Go(func() error { group.Go(func() error {
defer cancel()
return q.sendTx(ctx, factory, factoryErrCh, receiptCh) return q.sendTx(ctx, factory, factoryErrCh, receiptCh)
}) })
return <-factoryErrCh return <-factoryErrCh
...@@ -83,15 +84,13 @@ func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh cha ...@@ -83,15 +84,13 @@ func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh cha
// //
// The actual tx sending is non-blocking, with the receipt returned on the // The actual tx sending is non-blocking, with the receipt returned on the
// provided receipt channel. // provided receipt channel.
func (q *Queue[T]) TrySend(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) (bool, error) { func (q *Queue[T]) TrySend(factory TxFactory[T], receiptCh chan TxReceipt[T]) (bool, error) {
ctx, cancel := q.mergeWithGroupContext(ctx) group, ctx := q.groupContext()
factoryErrCh := make(chan error) factoryErrCh := make(chan error)
started := q.group.TryGo(func() error { started := group.TryGo(func() error {
defer cancel()
return q.sendTx(ctx, factory, factoryErrCh, receiptCh) return q.sendTx(ctx, factory, factoryErrCh, receiptCh)
}) })
if !started { if !started {
cancel()
return false, nil return false, nil
} }
err := <-factoryErrCh err := <-factoryErrCh
...@@ -125,30 +124,17 @@ func (q *Queue[T]) sendTx(ctx context.Context, factory TxFactory[T], factoryErro ...@@ -125,30 +124,17 @@ func (q *Queue[T]) sendTx(ctx context.Context, factory TxFactory[T], factoryErro
// //
// If the group context doesn't exist or has already been canceled, a new one is created after // If the group context doesn't exist or has already been canceled, a new one is created after
// waiting for existing group threads to complete. // waiting for existing group threads to complete.
func (q *Queue[T]) mergeWithGroupContext(ctx context.Context) (context.Context, context.CancelFunc) { func (q *Queue[T]) groupContext() (*errgroup.Group, context.Context) {
q.groupLock.Lock() q.groupLock.Lock()
defer q.groupLock.Unlock() defer q.groupLock.Unlock()
if q.groupCtx == nil || q.groupCtx.Err() != nil { if q.groupCtx == nil || q.groupCtx.Err() != nil {
// no group exists, or the existing context has an error, so we need to wait // no group exists, or the existing context has an error, so we need to wait
// for existing group threads to complete (if any) and create a new group // for existing group threads to complete (if any) and create a new group
q.Wait() q.Wait()
q.group, q.groupCtx = errgroup.WithContext(context.Background()) q.group, q.groupCtx = errgroup.WithContext(q.ctx)
if q.maxPending > 0 { if q.maxPending > 0 {
q.group.SetLimit(int(q.maxPending)) q.group.SetLimit(int(q.maxPending))
} }
} }
return q.group, q.groupCtx
ctx, cancel := context.WithCancel(ctx)
cl := make(chan struct{})
groupContext := q.groupCtx
go func() {
defer cancel()
select {
case <-groupContext.Done():
case <-cl:
}
}()
return ctx, func() {
close(cl)
}
} }
...@@ -18,15 +18,15 @@ import ( ...@@ -18,15 +18,15 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
type queueFunc func(ctx context.Context, factory TxFactory[int], receiptCh chan TxReceipt[int], q *Queue[int]) (bool, error) type queueFunc func(factory TxFactory[int], receiptCh chan TxReceipt[int], q *Queue[int]) (bool, error)
func sendQueueFunc(ctx context.Context, factory TxFactory[int], receiptCh chan TxReceipt[int], q *Queue[int]) (bool, error) { func sendQueueFunc(factory TxFactory[int], receiptCh chan TxReceipt[int], q *Queue[int]) (bool, error) {
err := q.Send(ctx, factory, receiptCh) err := q.Send(factory, receiptCh)
return err == nil, err return err == nil, err
} }
func trySendQueueFunc(ctx context.Context, factory TxFactory[int], receiptCh chan TxReceipt[int], q *Queue[int]) (bool, error) { func trySendQueueFunc(factory TxFactory[int], receiptCh chan TxReceipt[int], q *Queue[int]) (bool, error) {
return q.TrySend(ctx, factory, receiptCh) return q.TrySend(factory, receiptCh)
} }
type queueCall struct { type queueCall struct {
...@@ -237,7 +237,7 @@ func TestSend(t *testing.T) { ...@@ -237,7 +237,7 @@ func TestSend(t *testing.T) {
}) })
ctx := context.Background() ctx := context.Background()
queue := NewQueue[int](mgr, test.max, func(uint64) {}) queue := NewQueue[int](ctx, mgr, test.max, func(uint64) {})
// make all the queue calls given in the test case // make all the queue calls given in the test case
start := time.Now() start := time.Now()
...@@ -245,7 +245,7 @@ func TestSend(t *testing.T) { ...@@ -245,7 +245,7 @@ func TestSend(t *testing.T) {
msg := fmt.Sprintf("Call %d", i) msg := fmt.Sprintf("Call %d", i)
c := c c := c
receiptCh := make(chan TxReceipt[int], 1) receiptCh := make(chan TxReceipt[int], 1)
queued, err := c.call(ctx, factory, receiptCh, queue) queued, err := c.call(factory, receiptCh, queue)
require.Equal(t, c.queued, queued, msg) require.Equal(t, c.queued, queued, msg)
if c.callErr { if c.callErr {
require.Error(t, err, msg) require.Error(t, err, msg)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment