Created
May 31, 2024 22:16
-
-
Save pirogoeth/c91492edcbf18b26f349bbfe3458d27d to your computer and use it in GitHub Desktop.
Revisions
-
pirogoeth created this gist
May 31, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,61 @@ package goro import ( "context" "errors" "sync" ) type LimitGroup struct { sema *SlottedSemaphore fns []errWrapper lock sync.Mutex } func NewLimitGroup(limit int) *LimitGroup { return &LimitGroup{ sema: NewSlottedSemaphore(limit), fns: make([]errWrapper, 0), lock: sync.Mutex{}, } } func (lg *LimitGroup) Add(fn errWrapper) error { if ok := lg.lock.TryLock(); !ok { return errors.New("cannot add to a LimitGroup while running") } lg.fns = append(lg.fns, fn) lg.lock.Unlock() return nil } func (lg *LimitGroup) Run(parentCtx context.Context) error { lg.lock.Lock() ctx, cancel := context.WithCancel(parentCtx) errCh := make(chan error, len(lg.fns)) wg := sync.WaitGroup{} for _, fn := range lg.fns { wg.Add(1) go func(fn errWrapper) { slot := lg.sema.AcquireBlocking() defer slot.Release() errCh <- fn(ctx) wg.Done() }(fn) } wg.Wait() close(errCh) errs := make([]error, 0) for err := range errCh { errs = append(errs, err) } cancel() return errors.Join(errs...) } This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,65 @@ package goro import ( "context" "crypto/rand" "fmt" "math/big" "sync/atomic" "testing" "time" ) func TestLimit(t *testing.T) { ctx := context.Background() limit := NewLimitGroup(1) limit.Add(func(ctx context.Context) error { return nil }) err := limit.Run(ctx) if err != nil { t.Fail() } } func TestLimitSingleConcurrent(t *testing.T) { ctx := context.Background() executing := atomic.Bool{} nestedWork := func(_ context.Context) error { if executing.Load() { return fmt.Errorf("concurrent execution detected") } executing.Store(true) defer executing.Store(false) sleepTime, err := rand.Int(rand.Reader, big.NewInt(5)) if err != nil { return fmt.Errorf("could not get random int: %w", err) } time.Sleep(time.Duration(sleepTime.Int64()) * time.Millisecond) return nil } limit := NewLimitGroup(1) limit.Add(func(ctx context.Context) error { t.Log("run 1") return nestedWork(ctx) }) limit.Add(func(ctx context.Context) error { t.Log("run 2") return nestedWork(ctx) }) limit.Add(func(ctx context.Context) error { t.Log("run 3") return nestedWork(ctx) }) if err := limit.Run(ctx); err != nil { t.Error(err) } } This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,118 @@ package goro import ( "errors" "sync" "github.com/sirupsen/logrus" ) var ( ErrNoSlotsAvailable = errors.New("no slots available") ErrSlotAlreadyReleased = errors.New("slot already released") ) type slot struct { parentSema *SlottedSemaphore } func (s *slot) Release() error { if s.parentSema == nil { logrus.Errorf("slot already released") return ErrSlotAlreadyReleased } if err := s.parentSema.Release(s); err != nil { logrus.Fatalf("failed to release slot: %s", err.Error()) } logrus.Tracef("released slot %v", s) s.parentSema = nil return nil } func (s *slot) IsReleased() bool { return s.parentSema == nil } type SlottedSemaphore struct { sema []*slot lock *sync.Mutex acquisitionCond *sync.Cond } func NewSlottedSemaphore(limit int) *SlottedSemaphore { lock := &sync.Mutex{} return &SlottedSemaphore{ sema: make([]*slot, limit), lock: lock, acquisitionCond: sync.NewCond(lock), } } func (ss *SlottedSemaphore) findFreeSlot() int { for i, slot := range ss.sema { if slot == nil { logrus.Tracef("first free slot at %d", i) return i } } return -1 } func (ss *SlottedSemaphore) Acquire() (*slot, error) { ss.lock.Lock() defer ss.lock.Unlock() slotIdx := ss.findFreeSlot() if slotIdx == -1 { return nil, ErrNoSlotsAvailable } slot := &slot{ parentSema: ss, } ss.sema[slotIdx] = slot logrus.Tracef("slot %d acquired", slotIdx) return slot, nil } func (ss *SlottedSemaphore) AcquireBlocking() *slot { ss.lock.Lock() for { slotIdx := ss.findFreeSlot() if slotIdx != -1 { slot := &slot{ parentSema: ss, } ss.sema[slotIdx] = slot logrus.Tracef("slot %d acquired (blocking)", slotIdx) ss.lock.Unlock() return slot } logrus.Trace("no slots available, sleeping for release") ss.acquisitionCond.Wait() } } func (ss *SlottedSemaphore) Release(s *slot) error { ss.lock.Lock() defer ss.lock.Unlock() for i, slot := range ss.sema { if slot == s { logrus.Tracef("releasing slot at index %d", i) ss.sema[i] = nil ss.acquisitionCond.Signal() return nil } } return ErrSlotAlreadyReleased } This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,94 @@ package goro import ( "testing" "time" ) func TestSlottedSemaphore(t *testing.T) { ss := NewSlottedSemaphore(1) results := make([]int, 2) doneCh := make(chan bool) go func() { t.Log("Acquiring slot 1") slot := ss.AcquireBlocking() t.Logf("Acquired slot 1: %v", slot) defer slot.Release() time.Sleep(1 * time.Second) results[0] = 1 doneCh <- true }() go func() { t.Log("Acquiring slot 2") slot := ss.AcquireBlocking() t.Logf("Acquired slot 2: %v", slot) defer slot.Release() results[1] = 2 doneCh <- true }() <-doneCh <-doneCh if results[0] != 1 || results[1] != 2 { t.Fail() } } func TestSlottedSemaphoreDoubleRelease(t *testing.T) { ss := NewSlottedSemaphore(1) slot := ss.AcquireBlocking() slot.Release() err := slot.Release() if err != ErrSlotAlreadyReleased { t.Fail() } } func TestSlottedSemaphoreAsyncAcquire(t *testing.T) { ss := NewSlottedSemaphore(1) slot, err := ss.Acquire() if err != nil { t.Fail() } go func() { time.Sleep(1 * time.Second) slot.Release() }() for { if slot.IsReleased() { break } } } func TestSlottedSemaphoreGoroWait(t *testing.T) { ss := NewSlottedSemaphore(1) slot := ss.AcquireBlocking() done := make(chan bool) go func() { for i := 0; i < 10; i++ { _, err := ss.Acquire() if err == nil { t.Errorf("should not be able to acquire a slot") } t.Log("Waiting for slot") time.Sleep(500 * time.Millisecond) } done <- true }() <-done slot.Release() t.Log("Waiting goro closed") }