Skip to content

Commit

Permalink
fix: shard queue panic
Browse files Browse the repository at this point in the history
  • Loading branch information
jayantxie committed Apr 25, 2023
1 parent 1d17d4d commit f98dffa
Showing 1 changed file with 73 additions and 47 deletions.
120 changes: 73 additions & 47 deletions mux/shard_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package mux
import (
"fmt"
"runtime"
"sync"
"sync/atomic"

"github.com/bytedance/gopkg/util/gopool"
Expand All @@ -43,16 +42,18 @@ func init() {
// NewShardQueue .
func NewShardQueue(size int, conn netpoll.Connection) (queue *ShardQueue) {
queue = &ShardQueue{
conn: conn,
size: int32(size),
getters: make([][]WriterGetter, size),
swap: make([]WriterGetter, 0, 64),
locks: make([]int32, size),
conn: conn,
size: int32(size),
getters: make([][]WriterGetter, size),
swap: make([]WriterGetter, 0, 64),
locks: make([]int32, size),
closeNotif: make(chan struct{}),
}
for i := range queue.getters {
queue.getters[i] = make([]WriterGetter, 0, 64)
}
queue.list = make([]int32, size)
// To avoid w equals to r when loop writing, make list larger than size.
queue.list = make([]int32, size+1)
return queue
}

Expand All @@ -69,6 +70,8 @@ type ShardQueue struct {
getters [][]WriterGetter // len(getters) = size
swap []WriterGetter // use for swap
locks []int32 // len(locks) = size

closeNotif chan struct{}
queueTrigger
}

Expand All @@ -81,17 +84,29 @@ const (

// here for trigger
type queueTrigger struct {
trigger int32
state int32 // 0: active, 1: closing, 2: closed
runNum int32
w, r int32 // ptr of list
list []int32 // record the triggered shard
listLock sync.Mutex // list total lock
bufNum int32
state int32 // 0: active, 1: closing, 2: closed
runNum int32
w, r int32 // ptr of list
list []int32 // record the triggered shard
}

func (q *queueTrigger) length() int {
w := int(atomic.LoadInt32(&q.w))
r := int(atomic.LoadInt32(&q.r))
if w < r {
w += len(q.list)
}
return w - r
}

// Add adds to q.getters[shard]
func (q *ShardQueue) Add(gts ...WriterGetter) {
atomic.AddInt32(&q.bufNum, 1)
if atomic.LoadInt32(&q.state) != active {
if atomic.AddInt32(&q.bufNum, -1) <= 0 {
close(q.closeNotif)
}
return
}
shard := atomic.AddInt32(&q.idx, 1) % q.size
Expand All @@ -109,90 +124,101 @@ func (q *ShardQueue) Close() error {
return fmt.Errorf("shardQueue has been closed")
}
// wait for all tasks finished
for atomic.LoadInt32(&q.state) != closed {
if atomic.LoadInt32(&q.trigger) == 0 {
atomic.StoreInt32(&q.state, closed)
return nil
if atomic.LoadInt32(&q.bufNum) == 0 {
atomic.StoreInt32(&q.state, closed)
} else {
select {
case <-q.closeNotif:
}
runtime.Gosched()
atomic.StoreInt32(&q.state, closed)
}
return nil
}

// triggering shard.
func (q *ShardQueue) triggering(shard int32) {
q.listLock.Lock()
q.w = (q.w + 1) % q.size
q.list[q.w] = shard
q.listLock.Unlock()

if atomic.AddInt32(&q.trigger, 1) > 1 {
return
for {
ow := atomic.LoadInt32(&q.w)
nw := (ow + 1) % int32(len(q.list))
if atomic.CompareAndSwapInt32(&q.w, ow, nw) {
q.list[nw] = shard
break
}
}
q.foreach()
}

// foreach swap r & w. It's not concurrency safe.
// foreach swap r & w.
func (q *ShardQueue) foreach() {
if atomic.AddInt32(&q.runNum, 1) > 1 {
return
}
gopool.CtxGo(nil, func() {
var negNum int32 // is negative number of triggerNum
for triggerNum := atomic.LoadInt32(&q.trigger); triggerNum > 0; {
q.r = (q.r + 1) % q.size
shared := q.list[q.r]
var negBufNum int32 // is negative number of bufNum
for q.length() > 0 {
nr := (atomic.LoadInt32(&q.r) + 1) % int32(len(q.list))
atomic.StoreInt32(&q.r, nr)
shard := q.list[nr]

// lock & swap
q.lock(shared)
tmp := q.getters[shared]
q.getters[shared] = q.swap[:0]
q.lock(shard)
tmp := q.getters[shard]
q.getters[shard] = q.swap[:0]
q.swap = tmp
q.unlock(shared)
q.unlock(shard)

// deal
q.deal(q.swap)
negNum--
if triggerNum+negNum == 0 {
triggerNum = atomic.AddInt32(&q.trigger, negNum)
negNum = 0
if err := q.deal(q.swap); err != nil {
close(q.closeNotif)
return
}
negBufNum -= int32(len(q.swap))
}
if negBufNum < 0 {
if err := q.flush(); err != nil {
close(q.closeNotif)
return
}
}
q.flush()

// MUST decrease bufNum first.
if atomic.AddInt32(&q.bufNum, negBufNum) <= 0 && atomic.LoadInt32(&q.state) != active {
close(q.closeNotif)
return
}

// quit & check again
atomic.StoreInt32(&q.runNum, 0)
if atomic.LoadInt32(&q.trigger) > 0 {
if q.length() > 0 {
q.foreach()
return
}
// if state is closing, change it to closed
atomic.CompareAndSwapInt32(&q.state, closing, closed)
})
}

// deal is used to get deal of netpoll.Writer.
func (q *ShardQueue) deal(gts []WriterGetter) {
func (q *ShardQueue) deal(gts []WriterGetter) error {
writer := q.conn.Writer()
for _, gt := range gts {
buf, isNil := gt()
if !isNil {
err := writer.Append(buf)
if err != nil {
q.conn.Close()
return
return err
}
}
}
return nil
}

// flush is used to flush netpoll.Writer.
func (q *ShardQueue) flush() {
func (q *ShardQueue) flush() error {
err := q.conn.Writer().Flush()
if err != nil {
q.conn.Close()
return
}
return err
}

// lock shard.
Expand Down

0 comments on commit f98dffa

Please sign in to comment.