From f98dffad87c3902e76d001ec9c714fb0a692fb39 Mon Sep 17 00:00:00 2001 From: xiezhengyao Date: Mon, 24 Apr 2023 21:58:27 +0800 Subject: [PATCH] fix: shard queue panic --- mux/shard_queue.go | 120 +++++++++++++++++++++++++++------------------ 1 file changed, 73 insertions(+), 47 deletions(-) diff --git a/mux/shard_queue.go b/mux/shard_queue.go index 364fabae..ec078f6d 100644 --- a/mux/shard_queue.go +++ b/mux/shard_queue.go @@ -17,7 +17,6 @@ package mux import ( "fmt" "runtime" - "sync" "sync/atomic" "github.com/bytedance/gopkg/util/gopool" @@ -43,16 +42,18 @@ func init() { // NewShardQueue . func NewShardQueue(size int, conn netpoll.Connection) (queue *ShardQueue) { queue = &ShardQueue{ - conn: conn, - size: int32(size), - getters: make([][]WriterGetter, size), - swap: make([]WriterGetter, 0, 64), - locks: make([]int32, size), + conn: conn, + size: int32(size), + getters: make([][]WriterGetter, size), + swap: make([]WriterGetter, 0, 64), + locks: make([]int32, size), + closeNotif: make(chan struct{}), } for i := range queue.getters { queue.getters[i] = make([]WriterGetter, 0, 64) } - queue.list = make([]int32, size) + // To avoid w equals to r when loop writing, make list larger than size. + queue.list = make([]int32, size+1) return queue } @@ -69,6 +70,8 @@ type ShardQueue struct { getters [][]WriterGetter // len(getters) = size swap []WriterGetter // use for swap locks []int32 // len(locks) = size + + closeNotif chan struct{} queueTrigger } @@ -81,17 +84,29 @@ const ( // here for trigger type queueTrigger struct { - trigger int32 - state int32 // 0: active, 1: closing, 2: closed - runNum int32 - w, r int32 // ptr of list - list []int32 // record the triggered shard - listLock sync.Mutex // list total lock + bufNum int32 + state int32 // 0: active, 1: closing, 2: closed + runNum int32 + w, r int32 // ptr of list + list []int32 // record the triggered shard +} + +func (q *queueTrigger) length() int { + w := int(atomic.LoadInt32(&q.w)) + r := int(atomic.LoadInt32(&q.r)) + if w < r { + w += len(q.list) + } + return w - r } // Add adds to q.getters[shard] func (q *ShardQueue) Add(gts ...WriterGetter) { + atomic.AddInt32(&q.bufNum, 1) if atomic.LoadInt32(&q.state) != active { + if atomic.AddInt32(&q.bufNum, -1) <= 0 { + close(q.closeNotif) + } return } shard := atomic.AddInt32(&q.idx, 1) % q.size @@ -109,70 +124,80 @@ func (q *ShardQueue) Close() error { return fmt.Errorf("shardQueue has been closed") } // wait for all tasks finished - for atomic.LoadInt32(&q.state) != closed { - if atomic.LoadInt32(&q.trigger) == 0 { - atomic.StoreInt32(&q.state, closed) - return nil + if atomic.LoadInt32(&q.bufNum) == 0 { + atomic.StoreInt32(&q.state, closed) + } else { + select { + case <-q.closeNotif: } - runtime.Gosched() + atomic.StoreInt32(&q.state, closed) } return nil } // triggering shard. func (q *ShardQueue) triggering(shard int32) { - q.listLock.Lock() - q.w = (q.w + 1) % q.size - q.list[q.w] = shard - q.listLock.Unlock() - - if atomic.AddInt32(&q.trigger, 1) > 1 { - return + for { + ow := atomic.LoadInt32(&q.w) + nw := (ow + 1) % int32(len(q.list)) + if atomic.CompareAndSwapInt32(&q.w, ow, nw) { + q.list[nw] = shard + break + } } q.foreach() } -// foreach swap r & w. It's not concurrency safe. +// foreach swap r & w. func (q *ShardQueue) foreach() { if atomic.AddInt32(&q.runNum, 1) > 1 { return } gopool.CtxGo(nil, func() { - var negNum int32 // is negative number of triggerNum - for triggerNum := atomic.LoadInt32(&q.trigger); triggerNum > 0; { - q.r = (q.r + 1) % q.size - shared := q.list[q.r] + var negBufNum int32 // is negative number of bufNum + for q.length() > 0 { + nr := (atomic.LoadInt32(&q.r) + 1) % int32(len(q.list)) + atomic.StoreInt32(&q.r, nr) + shard := q.list[nr] // lock & swap - q.lock(shared) - tmp := q.getters[shared] - q.getters[shared] = q.swap[:0] + q.lock(shard) + tmp := q.getters[shard] + q.getters[shard] = q.swap[:0] q.swap = tmp - q.unlock(shared) + q.unlock(shard) // deal - q.deal(q.swap) - negNum-- - if triggerNum+negNum == 0 { - triggerNum = atomic.AddInt32(&q.trigger, negNum) - negNum = 0 + if err := q.deal(q.swap); err != nil { + close(q.closeNotif) + return + } + negBufNum -= int32(len(q.swap)) + } + if negBufNum < 0 { + if err := q.flush(); err != nil { + close(q.closeNotif) + return } } - q.flush() + + // MUST decrease bufNum first. + if atomic.AddInt32(&q.bufNum, negBufNum) <= 0 && atomic.LoadInt32(&q.state) != active { + close(q.closeNotif) + return + } // quit & check again atomic.StoreInt32(&q.runNum, 0) - if atomic.LoadInt32(&q.trigger) > 0 { + if q.length() > 0 { q.foreach() return } - // if state is closing, change it to closed - atomic.CompareAndSwapInt32(&q.state, closing, closed) }) } // deal is used to get deal of netpoll.Writer. -func (q *ShardQueue) deal(gts []WriterGetter) { +func (q *ShardQueue) deal(gts []WriterGetter) error { writer := q.conn.Writer() for _, gt := range gts { buf, isNil := gt() @@ -180,19 +205,20 @@ func (q *ShardQueue) deal(gts []WriterGetter) { err := writer.Append(buf) if err != nil { q.conn.Close() - return + return err } } } + return nil } // flush is used to flush netpoll.Writer. -func (q *ShardQueue) flush() { +func (q *ShardQueue) flush() error { err := q.conn.Writer().Flush() if err != nil { q.conn.Close() - return } + return err } // lock shard.