Commit 019ce89c authored by Michael de Hoog's avatar Michael de Hoog

Cond -> Semaphore

parent 74029794
......@@ -34,6 +34,7 @@ require (
github.com/urfave/cli/v2 v2.17.2-0.20221006022127-8f469abc00aa
golang.org/x/crypto v0.6.0
golang.org/x/exp v0.0.0-20230213192124-5e25df0256eb
golang.org/x/sync v0.1.0
golang.org/x/term v0.5.0
golang.org/x/time v0.0.0-20220922220347-f3bd1da661af
)
......@@ -176,7 +177,6 @@ require (
go.uber.org/zap v1.24.0 // indirect
golang.org/x/mod v0.8.0 // indirect
golang.org/x/net v0.7.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.5.0 // indirect
golang.org/x/text v0.7.0 // indirect
golang.org/x/tools v0.6.0 // indirect
......
......@@ -2,9 +2,12 @@ package txmgr
import (
"context"
"math"
"sync"
"sync/atomic"
"github.com/ethereum/go-ethereum/core/types"
"golang.org/x/sync/semaphore"
)
type TxReceipt[T any] struct {
......@@ -17,10 +20,9 @@ type TxFactory[T any] func(ctx context.Context) (*TxCandidate, T, error)
type Queue[T any] struct {
txMgr TxManager
maxPending uint64
pendingChanged func(uint64)
pending uint64
cond *sync.Cond
pending atomic.Uint64
semaphore *semaphore.Weighted
wg sync.WaitGroup
}
......@@ -29,11 +31,19 @@ type Queue[T any] struct {
// - pendingChanged: called whenever a job starts or finishes. The
// number of currently pending txs is passed as a parameter.
func NewQueue[T any](txMgr TxManager, maxPending uint64, pendingChanged func(uint64)) *Queue[T] {
if maxPending > math.MaxInt64 {
// ensure we don't overflow as semaphore only accepts int64; in reality this will never be an issue
maxPending = math.MaxInt64
}
var s *semaphore.Weighted
if maxPending > 0 {
// only create a semaphore for limited-size queues
s = semaphore.NewWeighted(int64(maxPending))
}
return &Queue[T]{
txMgr: txMgr,
maxPending: maxPending,
pendingChanged: pendingChanged,
cond: sync.NewCond(&sync.Mutex{}),
semaphore: s,
}
}
......@@ -46,10 +56,11 @@ func (q *Queue[T]) Wait() {
// and then send the next tx. The TxFactory should return `nil` if the next
// tx does not exist. Returns the error returned from the TxFactory (if any).
func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) error {
q.cond.L.Lock()
defer q.cond.L.Unlock()
for q.full() {
q.cond.Wait()
if q.semaphore != nil {
err := q.semaphore.Acquire(ctx, 1)
if err != nil {
return err
}
}
return q.trySend(ctx, factory, receiptCh)
}
......@@ -60,34 +71,37 @@ func (q *Queue[T]) Send(ctx context.Context, factory TxFactory[T], receiptCh cha
// The TxFactory should return `nil` if the next tx does not exist. Returns
// the error returned from the TxFactory (if any).
func (q *Queue[T]) TrySend(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) error {
q.cond.L.Lock()
defer q.cond.L.Unlock()
if q.semaphore != nil {
if !q.semaphore.TryAcquire(1) {
return nil
}
}
return q.trySend(ctx, factory, receiptCh)
}
func (q *Queue[T]) trySend(ctx context.Context, factory TxFactory[T], receiptCh chan TxReceipt[T]) error {
if q.full() {
return nil
}
candidate, data, err := factory(ctx)
release := func() {
if q.semaphore != nil {
q.semaphore.Release(1)
}
}
if err != nil {
release()
return err
}
if candidate == nil {
release()
return nil
}
q.pending++
q.pendingChanged(q.pending)
q.pendingChanged(q.pending.Add(1))
q.wg.Add(1)
go func() {
defer func() {
q.cond.L.Lock()
q.pending--
q.pendingChanged(q.pending)
release()
q.pendingChanged(q.pending.Add(^uint64(0))) // -1
q.wg.Done()
q.cond.L.Unlock()
q.cond.Broadcast()
}()
receipt, err := q.txMgr.Send(ctx, *candidate)
receiptCh <- TxReceipt[T]{
......@@ -98,7 +112,3 @@ func (q *Queue[T]) trySend(ctx context.Context, factory TxFactory[T], receiptCh
}()
return nil
}
func (q *Queue[T]) full() bool {
return q.maxPending > 0 && q.pending >= q.maxPending
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment