Skip to content

Instantly share code, notes, and snippets.

@pirogoeth
Created May 31, 2024 22:16
Show Gist options
  • Save pirogoeth/c91492edcbf18b26f349bbfe3458d27d to your computer and use it in GitHub Desktop.
Save pirogoeth/c91492edcbf18b26f349bbfe3458d27d to your computer and use it in GitHub Desktop.

Revisions

  1. pirogoeth created this gist May 31, 2024.
    61 changes: 61 additions & 0 deletions limit.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,61 @@
    package goro

    import (
    "context"
    "errors"
    "sync"
    )

    type LimitGroup struct {
    sema *SlottedSemaphore
    fns []errWrapper
    lock sync.Mutex
    }

    func NewLimitGroup(limit int) *LimitGroup {
    return &LimitGroup{
    sema: NewSlottedSemaphore(limit),
    fns: make([]errWrapper, 0),
    lock: sync.Mutex{},
    }
    }

    func (lg *LimitGroup) Add(fn errWrapper) error {
    if ok := lg.lock.TryLock(); !ok {
    return errors.New("cannot add to a LimitGroup while running")
    }

    lg.fns = append(lg.fns, fn)
    lg.lock.Unlock()

    return nil
    }

    func (lg *LimitGroup) Run(parentCtx context.Context) error {
    lg.lock.Lock()

    ctx, cancel := context.WithCancel(parentCtx)
    errCh := make(chan error, len(lg.fns))
    wg := sync.WaitGroup{}

    for _, fn := range lg.fns {
    wg.Add(1)
    go func(fn errWrapper) {
    slot := lg.sema.AcquireBlocking()
    defer slot.Release()

    errCh <- fn(ctx)
    wg.Done()
    }(fn)
    }
    wg.Wait()
    close(errCh)

    errs := make([]error, 0)
    for err := range errCh {
    errs = append(errs, err)
    }

    cancel()
    return errors.Join(errs...)
    }
    65 changes: 65 additions & 0 deletions limit_test.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,65 @@
    package goro

    import (
    "context"
    "crypto/rand"
    "fmt"
    "math/big"
    "sync/atomic"
    "testing"
    "time"
    )

    func TestLimit(t *testing.T) {
    ctx := context.Background()

    limit := NewLimitGroup(1)
    limit.Add(func(ctx context.Context) error {
    return nil
    })
    err := limit.Run(ctx)
    if err != nil {
    t.Fail()
    }
    }

    func TestLimitSingleConcurrent(t *testing.T) {
    ctx := context.Background()

    executing := atomic.Bool{}

    nestedWork := func(_ context.Context) error {
    if executing.Load() {
    return fmt.Errorf("concurrent execution detected")
    }

    executing.Store(true)
    defer executing.Store(false)

    sleepTime, err := rand.Int(rand.Reader, big.NewInt(5))
    if err != nil {
    return fmt.Errorf("could not get random int: %w", err)
    }

    time.Sleep(time.Duration(sleepTime.Int64()) * time.Millisecond)
    return nil
    }

    limit := NewLimitGroup(1)
    limit.Add(func(ctx context.Context) error {
    t.Log("run 1")
    return nestedWork(ctx)
    })
    limit.Add(func(ctx context.Context) error {
    t.Log("run 2")
    return nestedWork(ctx)
    })
    limit.Add(func(ctx context.Context) error {
    t.Log("run 3")
    return nestedWork(ctx)
    })

    if err := limit.Run(ctx); err != nil {
    t.Error(err)
    }
    }
    118 changes: 118 additions & 0 deletions slotsema.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,118 @@
    package goro

    import (
    "errors"
    "sync"

    "github.com/sirupsen/logrus"
    )

    var (
    ErrNoSlotsAvailable = errors.New("no slots available")
    ErrSlotAlreadyReleased = errors.New("slot already released")
    )

    type slot struct {
    parentSema *SlottedSemaphore
    }

    func (s *slot) Release() error {
    if s.parentSema == nil {
    logrus.Errorf("slot already released")
    return ErrSlotAlreadyReleased
    }

    if err := s.parentSema.Release(s); err != nil {
    logrus.Fatalf("failed to release slot: %s", err.Error())
    }

    logrus.Tracef("released slot %v", s)
    s.parentSema = nil

    return nil
    }

    func (s *slot) IsReleased() bool {
    return s.parentSema == nil
    }

    type SlottedSemaphore struct {
    sema []*slot
    lock *sync.Mutex
    acquisitionCond *sync.Cond
    }

    func NewSlottedSemaphore(limit int) *SlottedSemaphore {
    lock := &sync.Mutex{}
    return &SlottedSemaphore{
    sema: make([]*slot, limit),
    lock: lock,
    acquisitionCond: sync.NewCond(lock),
    }
    }

    func (ss *SlottedSemaphore) findFreeSlot() int {
    for i, slot := range ss.sema {
    if slot == nil {
    logrus.Tracef("first free slot at %d", i)
    return i
    }
    }

    return -1
    }

    func (ss *SlottedSemaphore) Acquire() (*slot, error) {
    ss.lock.Lock()
    defer ss.lock.Unlock()

    slotIdx := ss.findFreeSlot()
    if slotIdx == -1 {
    return nil, ErrNoSlotsAvailable
    }

    slot := &slot{
    parentSema: ss,
    }
    ss.sema[slotIdx] = slot
    logrus.Tracef("slot %d acquired", slotIdx)

    return slot, nil
    }

    func (ss *SlottedSemaphore) AcquireBlocking() *slot {
    ss.lock.Lock()

    for {
    slotIdx := ss.findFreeSlot()
    if slotIdx != -1 {
    slot := &slot{
    parentSema: ss,
    }
    ss.sema[slotIdx] = slot
    logrus.Tracef("slot %d acquired (blocking)", slotIdx)
    ss.lock.Unlock()

    return slot
    }

    logrus.Trace("no slots available, sleeping for release")
    ss.acquisitionCond.Wait()
    }
    }

    func (ss *SlottedSemaphore) Release(s *slot) error {
    ss.lock.Lock()
    defer ss.lock.Unlock()

    for i, slot := range ss.sema {
    if slot == s {
    logrus.Tracef("releasing slot at index %d", i)
    ss.sema[i] = nil
    ss.acquisitionCond.Signal()
    return nil
    }
    }

    return ErrSlotAlreadyReleased
    }
    94 changes: 94 additions & 0 deletions slotsema_test.go
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,94 @@
    package goro

    import (
    "testing"
    "time"
    )

    func TestSlottedSemaphore(t *testing.T) {
    ss := NewSlottedSemaphore(1)
    results := make([]int, 2)
    doneCh := make(chan bool)

    go func() {
    t.Log("Acquiring slot 1")
    slot := ss.AcquireBlocking()
    t.Logf("Acquired slot 1: %v", slot)
    defer slot.Release()

    time.Sleep(1 * time.Second)
    results[0] = 1
    doneCh <- true
    }()

    go func() {
    t.Log("Acquiring slot 2")
    slot := ss.AcquireBlocking()
    t.Logf("Acquired slot 2: %v", slot)
    defer slot.Release()

    results[1] = 2
    doneCh <- true
    }()

    <-doneCh
    <-doneCh

    if results[0] != 1 || results[1] != 2 {
    t.Fail()
    }
    }

    func TestSlottedSemaphoreDoubleRelease(t *testing.T) {
    ss := NewSlottedSemaphore(1)
    slot := ss.AcquireBlocking()
    slot.Release()
    err := slot.Release()
    if err != ErrSlotAlreadyReleased {
    t.Fail()
    }
    }

    func TestSlottedSemaphoreAsyncAcquire(t *testing.T) {
    ss := NewSlottedSemaphore(1)
    slot, err := ss.Acquire()
    if err != nil {
    t.Fail()
    }

    go func() {
    time.Sleep(1 * time.Second)
    slot.Release()
    }()

    for {
    if slot.IsReleased() {
    break
    }
    }
    }

    func TestSlottedSemaphoreGoroWait(t *testing.T) {
    ss := NewSlottedSemaphore(1)
    slot := ss.AcquireBlocking()
    done := make(chan bool)

    go func() {
    for i := 0; i < 10; i++ {
    _, err := ss.Acquire()
    if err == nil {
    t.Errorf("should not be able to acquire a slot")
    }

    t.Log("Waiting for slot")
    time.Sleep(500 * time.Millisecond)
    }

    done <- true
    }()

    <-done
    slot.Release()

    t.Log("Waiting goro closed")
    }