go的WaitGroup使用及源码分析

源码使用的是1.9版本;sync 包里的WaitGroup主要用于线程的同步;计数主线程创建的子线程(WaitGoup.Add(i)); 调用清除标记方法(WaitGroup.Done()); 使用WaitGroup.Wait()来阻塞,直到所有子线程(标记=0)执行完毕。
例子:

package mainimport ( "sync" "fmt" )func main(){ var swg sync.WaitGroup for i:=0; i<3; i++{ //增加一个计数器 swg.Add(1) go func(wg *sync.WaitGroup,mark int){ //减去计数器 defer wg.Done()//等价于 wg.Add(-1) fmt.Printf("%d goroutine finish \n",mark) }(&swg,i) } //等待所有go程结束 swg.Wait() }

结果:
2 goroutine finish 1 goroutine finish 0 goroutine finish

注意!如果将代码改成下面这样(子线程函数,传入的参数是waitgroup的值拷贝),会出现什么情况呢?
func main(){ var swg sync.WaitGroup for i:=0; i<3; i++{ swg.Add(1) go func(wg sync.WaitGroup,mark int){ defer wg.Done() fmt.Printf("%d goroutine finish \n",mark) }(swg,i) } swg.Wait() }

结果:
2 goroutine finish fatal error: all goroutines are asleep - deadlock! 1 goroutine finish 0 goroutine finish goroutine 1 [semacquire]: sync.runtime_Semacquire(0xc0420080dc) C:/Go/src/runtime/sema.go:56 +0x40 sync.(*WaitGroup).Wait(0xc0420080d0) C:/Go/src/sync/waitgroup.go:131 +0x79

【go的WaitGroup使用及源码分析】出现死锁,因为子协程传入的waitGroup对象是一份新值拷贝,主协程的waitGroup并没有调用Done()方法,导致标志位无法被释放;各位童鞋在使用的时候,记得传入waitGroup的引用拷贝。
WaitGroup源码分析(精简了无关主要逻辑的代码)
1、首先查看WaitGroup的数据结构:
type WaitGroup struct { noCopy noCopy //共12个字节,低4字节用于记录wait等待次数,高8字节是计数器(64位机器是高8字节,32机器是中间4个字节,因为64位机器的原子操作需要64位的对齐,但是32位的编译器不能确保。) state1 [12]byte //用于唤醒go程的信号量 semauint32 }

2、WaitGroup.Add()方法
func (wg *WaitGroup) Add(delta int) { statep := wg.state() //将标记为加delta state := atomic.AddUint64(statep, uint64(delta)<<32) //获得计数器数值 v := int32(state >> 32) //获得wait()等待次数 w := uint32(state) //标记位不能小于0(done过多或者Add()负值太多) if v < 0 { panic("sync: negative WaitGroup counter") } //不能并发的Add() 和Done() if w != 0 && delta > 0 && v == int32(delta) { panic("sync: WaitGroup misuse: Add called concurrently with Wait") } //Add 完毕 if v > 0 || w == 0 { return } //执行到这,此时计数器V=0;那么等待计数器肯定和整个state的值相等,不然只有一个情况:有人调了Add(),并且是并发调用的。 if *statep != state { panic("sync: WaitGroup misuse: Add called concurrently with Wait") } //所有状态位清零 *statep = 0 //唤醒等待的go程 for ; w != 0; w-- { runtime_Semrelease(&wg.sema, false) } } //根据编译器位数,获得标志位和等待次数的数据域 func (wg *WaitGroup) state() *uint64 { if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 { return (*uint64)(unsafe.Pointer(&wg.state1)) } else { return (*uint64)(unsafe.Pointer(&wg.state1[4])) } } // Done方法其实就是Add(-1) func (wg *WaitGroup) Done() { wg.Add(-1) }

3、Wait方法
func (wg *WaitGroup) Wait() { statep := wg.state() //循环检查计数器V啥时候等于0 for { state := atomic.LoadUint64(statep) v := int32(state >> 32) w := uint32(state) //v==0说明go程执行结束 if v == 0 { return } //尚有未执行完的go程,等待标志位+1(直接在低位处理,无需移位) if atomic.CompareAndSwapUint64(statep, state, state+1) { runtime_Semacquire(&wg.sema) if *statep != 0 { panic("sync: WaitGroup is reused before previous Wait has returned") } return } } }

    推荐阅读