使用的go版本为 go1.21.2
首先我们写一个简单的WaitGroup的使用代码
package main
import (
"fmt"
"sync"
)
func main() {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
fmt.Println("xiaochuan")
}()
wg.Wait()
}
WaitGroup的基本使用场景就是等待子协程完毕后,执行主协程,比如我的api需要多个下游api支持开多个协程进行访问,等待耗时最高的api返回过来后执行,这种场景是比较适合WaitGroup的。
我们来看一下WaitGroup构造体相关的底层源码
WaitGroup结构体
//代码位于 GOROOT/src/sync/waitgroup.go L:23
type WaitGroup struct {
//防止WaitGroup被复制, 君子协议,编译可以通过,某些编辑器会报waring
//有兴趣可以看一下这里 https://github.com/golang/go/issues/8005#issuecomment-190753527
noCopy noCopy
// 高32位表示计数器,低32位表示等待的waiter数量。
// 低版本go的state字段类型是[3]uint32,需要进行位数对齐
state atomic.Uint64
// 信号量
sema uint32
}
编辑器的warning
Add函数
//代码位于 GOROOT/src/sync/waitgroup.go L:43
func (wg *WaitGroup) Add(delta int) {
if race.Enabled { //使用竞态检查
if delta < 0 { //如果传递的数值是负数,递减等待同步
// Synchronize decrements with Wait.
race.ReleaseMerge(unsafe.Pointer(wg))
}
race.Disable() //竞态检查 禁用
defer race.Enable() //竞态检查 启用
}
//计算我们要进行add的值,将其加入到比特位上
//<< 32 为二进制左位移 32位
state := wg.state.Add(uint64(delta) << 32)
v := int32(state >> 32) // state变量的高位是计数
w := uint32(state) // state变量的低位是waiter计数
//使用竞态检查,当前传入的值与v相同,说明当前是第一次调度add
if race.Enabled && delta > 0 && v == int32(delta) {
// The first increment must be synchronized with Wait.
// Need to model this as a read, because there can be
// several concurrent wg.counter transitions from 0.
race.Read(unsafe.Pointer(&wg.sema))
}
//如果 计数器小于0 说明了多进行了done操作或者add传递负数,业务代码的出现逻辑错误了
if v < 0 {
panic("sync: negative WaitGroup counter")
}
// 如果当前存在等待,而且计数器不为0
// 说明当前有地方调度了Wait后,又进行add操作了, 违反了官方的使用设计
if w != 0 && delta > 0 && v == int32(delta) {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// 计数大于0,没有等待,就是单纯的add直接返回
if v > 0 || w == 0 {
return
}
// 再做一次检测,防止有并发调度
// 比如我有两个goroutine A goroutine 在add, B goroutine 在调度 wait
// 刚刚好A加完了计数,B突然wait导致state更变就会触发这个panic
if wg.state.Load() != state {
panic("sync: WaitGroup misuse: Add called concurrently with Wait")
}
// 重置waiter为0
wg.state.Store(0)
for ; w != 0; w-- { // 逐步释放信号量
runtime_Semrelease(&wg.sema, false, 0)
}
}
Done函数
//代码位于 GOROOT/src/sync/waitgroup.go L:86
//这个很简单 调用了一下add函数传了一个-1
func (wg *WaitGroup) Done() {
wg.Add(-1)
}
Wait函数
//代码位于 GOROOT/src/sync/waitgroup.go L:91
func (wg *WaitGroup) Wait() {
if race.Enabled { //使用竞态检查
race.Disable() //竞态检查 禁用
}
for {
state := wg.state.Load() // 原子操作读取state字段
v := int32(state >> 32) // state变量的高位是计数
w := uint32(state) // state变量的低位是waiter计数
if v == 0 { // 如果当前计数器为0 就没必要等待直接返回了
if race.Enabled {
race.Enable() //竞态检查 启用
race.Acquire(unsafe.Pointer(wg))
}
return
}
// 将waiter计数+1 因为waiter处于低32位所以不需要位移直接加就行了
if wg.state.CompareAndSwap(state, state+1) {
if race.Enabled && w == 0 { // 使用竞态检查,第一次进行wait操作
// Wait must be synchronized with the first Add.
// Need to model this is as a write to race with the read in Add.
// As a consequence, can do the write only for the first waiter,
// otherwise concurrent Waits will race with each other.
race.Write(unsafe.Pointer(&wg.sema))
}
// 获取信号量,这行代码会进行G的阻塞
runtime_Semacquire(&wg.sema)
//重新获取一下state,正常来讲计数为0, waiter为0
//执行判断之前,又有一个协程进行了add操作,会触发panic
if wg.state.Load() != 0 {
panic("sync: WaitGroup is reused before previous Wait has returned")
}
if race.Enabled { //使用竞态检查
race.Enable() //竞态检查 启用
race.Acquire(unsafe.Pointer(wg))
}
return
}
}
}
总结
我们从上面的源码分析了解WaitGroup的数据结构、Add、Done和Wait这些基本操作原理,在项目中我们可以使用比特位来减少内存的占用,从源码分析我们得知Go官方设计不允许进行WaitGroup复制(君子协议)与并发调度同一个WaitGroup操作。