golang源代码分析之异步通信WaitGroup

Golang source code analysis asynchronous communication WaitGroup.

常用方法介绍

WaitGroup主要是阻塞主线程,等待一组goroutine执行完毕。通常我们会用到sync.WaitGroup中的三个函数

Add()       // 添加计数器
Done()      // 减掉计数器,等价于Add(-1)
Wait()      // 阻塞,知道计数器减为0

上面所述当计数器<=0之后,Add的参数不能为为负值,会出现如下错误panic("sync: negative WaitGroup counter")sync.WaitGroup的源码在 如下两个文件中,包含源码和单元测试的源码:

src/sync/waitgroup.go
src/sync/waitgroup_test.go

如何使用

sync.WaitGroup主要配置go func来使用,阻塞主线程,等到go func执行完毕。
如下代码其实是没有任何输出的,因为主线程执行完成先于go func。所以里面的代码是不会被执行的。

package main

func main() {

    go func() {
        for i := 0; i < 10; i++ {
            fmt.Println("i: ", i)
        }
    }()
}

那么,我们需要如何解决这个问题呢?这里我们就需要用到sync.WaitGroup这个解决方案了。

package main

func main() {
    var wg = sync.WaitGroup{}
    wg.Add(1)

    go func() {
        for i := 0; i < 10; i++ {
            fmt.Println("i: ", i)
        }
        wg.Done()
    }()

    wg.Wait()
}

源码分析

结构体WaitGroup有两个字段,noCopy和state1的一个数组

type WaitGroup struct {
    noCopy noCopy

    state1 [3]uint32
}

主要的方法是Add()Wait()。函数Done()的实现如下,其实就是wg.Add(-1)

func (wg *WaitGroup) Done() {
    wg.Add(-1)
}

state()

state()是一个私有方法,state()返回指向wg.state1中存储的state和sema字段的指针

// state returns pointers to the state and sema fields stored within wg.state1.
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
    if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
        return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
    } else {
        return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
    }
}

Add()

Add adds delta, which may be negative, to the WaitGroup counter.
If the counter becomes zero, all goroutines blocked on Wait are released.
If the counter goes negative, Add panics.

Add的参数是一个整数,可能是负数,如果计数器变为0了,所有的goroutines会被阻塞等待释放,如果计数器变成负值,则会报错。 从方法statep, semap := wg.state()中拿到statep,它是一个原子计数器。

func (wg *WaitGroup) Add(delta int) {
    statep, semap := wg.state()
    if race.Enabled {
        _ = *statep // trigger nil deref early
        if delta < 0 {
            // Synchronize decrements with Wait.
            race.ReleaseMerge(unsafe.Pointer(wg))
        }
        race.Disable()
        defer race.Enable()
    }
    state := atomic.AddUint64(statep, uint64(delta)<<32)
    v := int32(state >> 32)
    w := uint32(state)
    if race.Enabled && delta > 0 && v == int32(delta) {
        // The first increment must be synchronized with Wait.
        // Need to model this as a read, because there can be
        // several concurrent wg.counter transitions from 0.
        race.Read(unsafe.Pointer(semap))
    }
    if v < 0 {
        panic("sync: negative WaitGroup counter")
    }
    if w != 0 && delta > 0 && v == int32(delta) {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    if v > 0 || w == 0 {
        return
    }
    // This goroutine has set counter to 0 when waiters > 0.
    // Now there can't be concurrent mutations of state:
    // - Adds must not happen concurrently with Wait,
    // - Wait does not increment waiters if it sees counter == 0.
    // Still do a cheap sanity check to detect WaitGroup misuse.
    if *statep != state {
        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    }
    // Reset waiters count to 0.
    *statep = 0
    for ; w != 0; w-- {
        runtime_Semrelease(semap, false)
    }
}

Wait()

// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
    statep, semap := wg.state()
    if race.Enabled {
        _ = *statep // trigger nil deref early
        race.Disable()
    }
    for {
        state := atomic.LoadUint64(statep)
        v := int32(state >> 32)
        w := uint32(state)
        if v == 0 {
            // Counter is 0, no need to wait.
            if race.Enabled {
                race.Enable()
                race.Acquire(unsafe.Pointer(wg))
            }
            return
        }
        // Increment waiters count.
        if atomic.CompareAndSwapUint64(statep, state, state+1) {
            if race.Enabled && w == 0 {
                // Wait must be synchronized with the first Add.
                // Need to model this is as a write to race with the read in Add.
                // As a consequence, can do the write only for the first waiter,
                // otherwise concurrent Waits will race with each other.
                race.Write(unsafe.Pointer(semap))
            }
            runtime_Semacquire(semap)
            if *statep != 0 {
                panic("sync: WaitGroup is reused before previous Wait has returned")
            }
            if race.Enabled {
                race.Enable()
                race.Acquire(unsafe.Pointer(wg))
            }
            return
        }
    }
}

0 comments

To reply to the article, please Login or registered