Skip to content

WaitGroup

1. 概述

WaitGroup 是 Go 语言中用于等待一组 Goroutine 完成的同步原语,它是 sync 包中的一个结构体。WaitGroup 提供了一种简单有效的方式来等待多个并发操作完成,而不需要手动管理通道或计数器。

在整个 Go 语言课程体系中,WaitGroup 是并发编程的基础组件之一,与 Goroutine、通道和上下文一起构成了 Go 语言并发模型的核心。掌握 WaitGroup 的使用和原理,对于构建可靠的并发系统至关重要。

2. 基本概念

2.1 语法

2.1.1 基本用法

go
import "sync"

// 创建 WaitGroup
var wg sync.WaitGroup

// 增加等待计数
wg.Add(n) // n 是要等待的 Goroutine 数量

// 减少等待计数(通常在 defer 中使用)
wg.Done()

// 等待所有 Goroutine 完成
wg.Wait()

2.1.2 示例代码

go
func main() {
    var wg sync.WaitGroup
    
    // 启动 3 个 Goroutine
    for i := 1; i <= 3; i++ {
        wg.Add(1) // 增加计数
        go func(id int) {
            defer wg.Done() // 减少计数
            fmt.Printf("Goroutine %d started\n", id)
            time.Sleep(time.Second)
            fmt.Printf("Goroutine %d completed\n", id)
        }(i)
    }
    
    fmt.Println("Waiting for all goroutines to complete...")
    wg.Wait() // 等待所有 Goroutine 完成
    fmt.Println("All goroutines completed")
}

2.2 语义

  • 计数器:WaitGroup 内部维护一个计数器,初始值为 0。
  • Add(n):将计数器增加 n,n 必须是正数。通常在启动 Goroutine 之前调用。
  • Done():将计数器减少 1,相当于 Add(-1)。通常在 Goroutine 结束时调用,最好使用 defer 确保即使发生错误也能正确减少计数。
  • Wait():阻塞当前 Goroutine,直到计数器变为 0。
  • 零值可用:WaitGroup 的零值是可用的,不需要初始化。

2.3 规范

  • 命名规范:WaitGroup 变量通常命名为 wg
  • 使用顺序
    1. 调用 Add() 设置需要等待的 Goroutine 数量。
    2. 启动 Goroutine。
    3. 在每个 Goroutine 中使用 defer wg.Done()
    4. 调用 Wait() 等待所有 Goroutine 完成。
  • 不要复制:WaitGroup 是一个结构体,不是引用类型,不要复制使用中的 WaitGroup。
  • 不要在 Wait() 之后调用 Add():在调用 Wait() 之后再调用 Add() 会导致不可预期的行为。
  • 合理设置计数:确保 Add() 的调用次数与 Done() 的调用次数匹配,避免计数器永远不为 0。

3. 原理深度解析

3.1 WaitGroup 结构体

WaitGroup 的底层实现是一个结构体,包含以下字段:

go
type WaitGroup struct {
    noCopy noCopy // 防止复制
    state1 [3]uint32 // 状态字段,包含计数器和等待者数量
}

其中 state1 字段是一个长度为 3 的 uint32 数组,用于存储:

  • 计数器值(counter):需要等待的 Goroutine 数量。
  • 等待者数量(waiter count):正在等待的 Goroutine 数量。
  • 信号量(semaphore):用于唤醒等待的 Goroutine。

3.2 核心方法实现

3.2.1 Add 方法

Add 方法的主要功能是增加计数器的值:

  1. 检查传入的 n 是否为负数,如果是则 panic。
  2. 原子地将 n 加到计数器上。
  3. 如果计数器变为负数,则 panic。
  4. 如果计数器大于 0 且之前计数器为 0,则重置等待者数量。

3.2.2 Done 方法

Done 方法的主要功能是减少计数器的值:

  1. 原子地将计数器减 1。
  2. 如果计数器变为负数,则 panic。
  3. 如果计数器变为 0 且有等待者,则唤醒所有等待的 Goroutine。

3.2.3 Wait 方法

Wait 方法的主要功能是等待计数器变为 0:

  1. 原子地检查计数器是否为 0,如果是则直接返回。
  2. 增加等待者数量。
  3. 阻塞当前 Goroutine,直到计数器变为 0。
  4. 减少等待者数量。

3.3 并发安全

WaitGroup 的所有方法都是并发安全的,使用原子操作来修改状态,确保在多 Goroutine 环境中安全使用。

3.4 内存模型

WaitGroup 遵循 Go 语言的内存模型,确保以下顺序:

  • 在调用 Done() 之前的所有操作,发生在 Wait() 返回之前。
  • 多个 Add() 调用的顺序不影响最终结果,因为它们是原子操作。
  • Wait() 调用会阻塞,直到所有 Add() 调用对应的 Done() 调用都完成。

4. 常见错误与踩坑点

4.1 忘记调用 Done()

错误表现Wait() 永远不会返回,导致程序永久阻塞。

产生原因:在 Goroutine 中忘记调用 wg.Done(),导致计数器永远不为 0。

解决方案:使用 defer wg.Done() 确保 Done() 被调用,即使发生错误。

go
// 错误示例
func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go func() {
        fmt.Println("Goroutine started")
        // 忘记调用 wg.Done()
    }()
    wg.Wait() // 永远不会返回
    fmt.Println("All done")
}

// 正确示例
func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go func() {
        defer wg.Done() // 使用 defer 确保调用
        fmt.Println("Goroutine started")
    }()
    wg.Wait() // 会正常返回
    fmt.Println("All done")
}

4.2 Add() 调用次数与 Done() 不匹配

错误表现

  • 如果 Add() 调用次数多于 Done()Wait() 永远不会返回。
  • 如果 Done() 调用次数多于 Add(),会导致 panic。

产生原因:没有正确管理计数器,导致计数器要么永远不为 0,要么变为负数。

解决方案:确保 Add() 的调用次数与 Done() 的调用次数完全匹配。

go
// 错误示例:Add 次数多于 Done
func main() {
    var wg sync.WaitGroup
    wg.Add(2) // 增加 2
    go func() {
        defer wg.Done() // 减少 1
        fmt.Println("Goroutine 1")
    }()
    // 忘记启动第二个 Goroutine
    wg.Wait() // 永远不会返回
}

// 错误示例:Done 次数多于 Add
func main() {
    var wg sync.WaitGroup
    wg.Add(1) // 增加 1
    go func() {
        defer wg.Done() // 减少 1
        defer wg.Done() // 错误:多调用了一次
        fmt.Println("Goroutine")
    }()
    wg.Wait() // 会 panic
}

4.3 在 Wait() 之后调用 Add()

错误表现:可能导致不可预期的行为,如 Wait() 立即返回但新的 Goroutine 还未完成。

产生原因:在调用 Wait() 之后再调用 Add(),破坏了 WaitGroup 的正常使用流程。

解决方案:确保在调用 Wait() 之前完成所有 Add() 调用。

go
// 错误示例
func main() {
    var wg sync.WaitGroup
    
    wg.Add(1)
    go func() {
        defer wg.Done()
        fmt.Println("Goroutine 1")
    }()
    
    wg.Wait() // 等待第一个 Goroutine 完成
    
    // 错误:在 Wait() 之后调用 Add()
    wg.Add(1)
    go func() {
        defer wg.Done()
        fmt.Println("Goroutine 2")
    }()
    
    wg.Wait() // 会等待第二个 Goroutine 完成,但这种用法不推荐
}

// 正确示例
func main() {
    var wg sync.WaitGroup
    
    // 所有 Add() 调用都在 Wait() 之前
    wg.Add(2)
    
    go func() {
        defer wg.Done()
        fmt.Println("Goroutine 1")
    }()
    
    go func() {
        defer wg.Done()
        fmt.Println("Goroutine 2")
    }()
    
    wg.Wait() // 等待所有 Goroutine 完成
}

4.4 复制 WaitGroup

错误表现:复制的 WaitGroup 与原 WaitGroup 状态不同步,导致不可预期的行为。

产生原因:WaitGroup 是结构体,不是引用类型,复制后会创建一个新的实例,与原实例状态无关。

解决方案:通过指针传递 WaitGroup,而不是复制它。

go
// 错误示例
func worker(wg sync.WaitGroup) { // 复制 WaitGroup
    defer wg.Done() // 这会操作副本,不会影响原 WaitGroup
    fmt.Println("Worker")
}

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go worker(wg) // 传递副本
    wg.Wait() // 永远不会返回
}

// 正确示例
func worker(wg *sync.WaitGroup) { // 通过指针传递
    defer wg.Done() // 操作原 WaitGroup
    fmt.Println("Worker")
}

func main() {
    var wg sync.WaitGroup
    wg.Add(1)
    go worker(&wg) // 传递指针
    wg.Wait() // 会正常返回
}

4.5 错误处理不当

错误表现:Goroutine 中的错误没有被捕获和处理,导致程序行为不可预期。

产生原因:在使用 WaitGroup 时,只关注了 Goroutine 的完成状态,没有处理其中的错误。

解决方案:使用通道或其他方式收集和处理 Goroutine 中的错误。

go
// 错误示例
func main() {
    var wg sync.WaitGroup
    
    for i := 0; i < 3; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            if id == 1 {
                // 错误没有被处理
                panic("Error in goroutine")
            }
            fmt.Printf("Goroutine %d completed\n", id)
        }(i)
    }
    
    wg.Wait()
    fmt.Println("All goroutines completed")
}

// 正确示例
func main() {
    var wg sync.WaitGroup
    errCh := make(chan error, 3)
    
    for i := 0; i < 3; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            if id == 1 {
                errCh <- fmt.Errorf("Error in goroutine %d", id)
                return
            }
            fmt.Printf("Goroutine %d completed\n", id)
        }(i)
    }
    
    // 等待所有 Goroutine 完成
    wg.Wait()
    close(errCh)
    
    // 处理错误
    for err := range errCh {
        fmt.Printf("Error: %v\n", err)
    }
    
    fmt.Println("All goroutines completed")
}

5. 常见应用场景

5.1 等待多个 Goroutine 完成

场景描述:需要启动多个 Goroutine 执行任务,然后等待所有任务完成后再继续执行。

使用方法:使用 WaitGroup 跟踪所有 Goroutine 的完成状态。

示例代码

go
func main() {
    var wg sync.WaitGroup
    tasks := []string{"task1", "task2", "task3", "task4", "task5"}
    
    for _, task := range tasks {
        wg.Add(1)
        go func(t string) {
            defer wg.Done()
            fmt.Printf("Processing %s\n", t)
            time.Sleep(time.Second) // 模拟处理时间
            fmt.Printf("Completed %s\n", t)
        }(task)
    }
    
    fmt.Println("Waiting for all tasks to complete...")
    wg.Wait()
    fmt.Println("All tasks completed")
}

5.2 并发下载文件

场景描述:需要从多个 URL 并发下载文件,等待所有下载完成后进行后续处理。

使用方法:为每个下载任务启动一个 Goroutine,使用 WaitGroup 等待所有下载完成。

示例代码

go
func downloadFile(url string, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Downloading %s\n", url)
    // 模拟下载
    time.Sleep(time.Second)
    fmt.Printf("Downloaded %s\n", url)
}

func main() {
    var wg sync.WaitGroup
    urls := []string{
        "https://example.com/file1.txt",
        "https://example.com/file2.txt",
        "https://example.com/file3.txt",
    }
    
    for _, url := range urls {
        wg.Add(1)
        go downloadFile(url, &wg)
    }
    
    fmt.Println("Waiting for all downloads to complete...")
    wg.Wait()
    fmt.Println("All downloads completed")
}

5.3 并行处理数据

场景描述:需要对大量数据进行并行处理,提高处理效率。

使用方法:将数据分成多个部分,每个部分由一个 Goroutine 处理,使用 WaitGroup 等待所有处理完成。

示例代码

go
func processBatch(data []int, wg *sync.WaitGroup) {
    defer wg.Done()
    for _, item := range data {
        // 处理数据
        fmt.Printf("Processing %d\n", item)
        time.Sleep(time.Millisecond * 100)
    }
}

func main() {
    var wg sync.WaitGroup
    data := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
    batchSize := 3
    
    // 分批处理
    for i := 0; i < len(data); i += batchSize {
        end := i + batchSize
        if end > len(data) {
            end = len(data)
        }
        batch := data[i:end]
        
        wg.Add(1)
        go processBatch(batch, &wg)
    }
    
    fmt.Println("Waiting for all batches to complete...")
    wg.Wait()
    fmt.Println("All batches processed")
}

5.4 并发测试

场景描述:需要对某个功能进行并发测试,模拟多个客户端同时访问。

使用方法:启动多个 Goroutine 模拟并发请求,使用 WaitGroup 等待所有测试完成。

示例代码

go
func testRequest(id int, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Request %d started\n", id)
    // 模拟请求
    time.Sleep(time.Millisecond * 500)
    fmt.Printf("Request %d completed\n", id)
}

func main() {
    var wg sync.WaitGroup
    concurrency := 10
    
    for i := 1; i <= concurrency; i++ {
        wg.Add(1)
        go testRequest(i, &wg)
    }
    
    fmt.Printf("Waiting for %d concurrent requests to complete...\n", concurrency)
    wg.Wait()
    fmt.Println("All requests completed")
}

5.5 后台任务管理

场景描述:需要启动多个后台任务,等待所有任务完成后再退出程序。

使用方法:为每个后台任务启动一个 Goroutine,使用 WaitGroup 等待所有任务完成。

示例代码

go
func backgroundTask(name string, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Background task %s started\n", name)
    // 执行任务
    time.Sleep(2 * time.Second)
    fmt.Printf("Background task %s completed\n", name)
}

func main() {
    var wg sync.WaitGroup
    tasks := []string{"cleanup", "backup", "sync", "report"}
    
    for _, task := range tasks {
        wg.Add(1)
        go backgroundTask(task, &wg)
    }
    
    fmt.Println("Waiting for all background tasks to complete...")
    wg.Wait()
    fmt.Println("All background tasks completed, exiting")
}

6. 企业级进阶应用场景

6.1 工作池管理

场景描述:在高并发系统中,需要限制并发数量,避免系统资源耗尽。

使用方法:结合 WaitGroup 和通道实现工作池,控制并发数量。

示例代码

go
type WorkerPool struct {
    tasks chan Task
    wg    sync.WaitGroup
    size  int
}

type Task func()

func NewWorkerPool(size int) *WorkerPool {
    return &WorkerPool{
        tasks: make(chan Task),
        size:  size,
    }
}

func (p *WorkerPool) Start() {
    for i := 0; i < p.size; i++ {
        p.wg.Add(1)
        go func() {
            defer p.wg.Done()
            for task := range p.tasks {
                task()
            }
        }()
    }
}

func (p *WorkerPool) Submit(task Task) {
    p.tasks <- task
}

func (p *WorkerPool) Close() {
    close(p.tasks)
    p.wg.Wait()
}

func main() {
    pool := NewWorkerPool(5) // 5 个工作协程
    pool.Start()
    defer pool.Close()
    
    // 提交 20 个任务
    for i := 0; i < 20; i++ {
        taskID := i
        pool.Submit(func() {
            fmt.Printf("Processing task %d\n", taskID)
            time.Sleep(time.Millisecond * 100)
        })
    }
    
    fmt.Println("All tasks submitted, waiting for completion...")
}

6.2 分布式任务协调

场景描述:在分布式系统中,需要协调多个节点的任务执行,等待所有节点完成后进行汇总。

使用方法:在每个节点使用 WaitGroup 等待本地任务完成,然后通过网络通信协调所有节点的完成状态。

示例代码

go
// 节点本地任务处理
func processLocalTasks(nodeID string, tasks []Task) error {
    var wg sync.WaitGroup
    errCh := make(chan error, len(tasks))
    
    for _, task := range tasks {
        wg.Add(1)
        go func(t Task) {
            defer wg.Done()
            if err := t(); err != nil {
                errCh <- err
            }
        }(task)
    }
    
    wg.Wait()
    close(errCh)
    
    // 收集错误
    for err := range errCh {
        if err != nil {
            return err
        }
    }
    
    fmt.Printf("Node %s completed all tasks\n", nodeID)
    return nil
}

// 协调多个节点
func coordinateNodes(nodes []string, tasksPerNode int) error {
    var wg sync.WaitGroup
    errCh := make(chan error, len(nodes))
    
    for _, node := range nodes {
        wg.Add(1)
        go func(n string) {
            defer wg.Done()
            // 生成节点任务
            tasks := generateTasks(tasksPerNode)
            if err := processLocalTasks(n, tasks); err != nil {
                errCh <- err
            }
        }(node)
    }
    
    wg.Wait()
    close(errCh)
    
    // 收集错误
    for err := range errCh {
        if err != nil {
            return err
        }
    }
    
    fmt.Println("All nodes completed tasks")
    return nil
}

6.3 批量 API 调用

场景描述:需要批量调用外部 API,等待所有调用完成后处理结果。

使用方法:为每个 API 调用启动一个 Goroutine,使用 WaitGroup 等待所有调用完成,同时收集结果和错误。

示例代码

go
type APIResult struct {
    URL   string
    Data  string
    Error error
}

func callAPI(url string, results chan<- APIResult, wg *sync.WaitGroup) {
    defer wg.Done()
    fmt.Printf("Calling API: %s\n", url)
    
    // 模拟 API 调用
    time.Sleep(time.Millisecond * 500)
    
    // 模拟结果
    if url == "https://example.com/api/error" {
        results <- APIResult{URL: url, Error: fmt.Errorf("API error")}
        return
    }
    
    results <- APIResult{URL: url, Data: fmt.Sprintf("Response from %s", url)}
}

func main() {
    var wg sync.WaitGroup
    urls := []string{
        "https://example.com/api/1",
        "https://example.com/api/2",
        "https://example.com/api/error",
        "https://example.com/api/3",
    }
    
    results := make(chan APIResult, len(urls))
    
    for _, url := range urls {
        wg.Add(1)
        go callAPI(url, results, &wg)
    }
    
    // 等待所有 API 调用完成
    wg.Wait()
    close(results)
    
    // 处理结果
    for result := range results {
        if result.Error != nil {
            fmt.Printf("Error calling %s: %v\n", result.URL, result.Error)
        } else {
            fmt.Printf("Success calling %s: %s\n", result.URL, result.Data)
        }
    }
}

6.4 实时数据处理

场景描述:需要实时处理数据流,同时等待所有处理完成后进行汇总。

使用方法:使用 WaitGroup 等待所有数据处理完成,同时使用通道接收处理结果。

示例代码

go
func processData(data int, results chan<- int, wg *sync.WaitGroup) {
    defer wg.Done()
    // 处理数据
    result := data * 2
    time.Sleep(time.Millisecond * 50)
    results <- result
}

func main() {
    var wg sync.WaitGroup
    data := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
    results := make(chan int, len(data))
    
    for _, item := range data {
        wg.Add(1)
        go processData(item, results, &wg)
    }
    
    // 等待所有数据处理完成
    wg.Wait()
    close(results)
    
    // 汇总结果
    total := 0
    for result := range results {
        total += result
    }
    
    fmt.Printf("Total: %d\n", total)
}

7. 行业最佳实践

7.1 始终使用 defer wg.Done()

实践内容:在每个 Goroutine 中使用 defer wg.Done() 确保 Done() 被调用,即使发生错误。

推荐理由defer 语句可以确保即使 Goroutine 中发生 panic,Done() 也能被调用,避免 WaitGroup 永远等待。

7.2 集中管理 Add() 调用

实践内容:在启动 Goroutine 之前集中调用 Add(),避免在多个地方分散调用。

推荐理由:集中管理 Add() 调用可以使代码更加清晰,减少遗漏或重复调用的风险。

7.3 避免在 Wait() 之后调用 Add()

实践内容:确保所有 Add() 调用都在 Wait() 之前完成。

推荐理由:在 Wait() 之后调用 Add() 会破坏 WaitGroup 的正常使用流程,可能导致不可预期的行为。

7.4 通过指针传递 WaitGroup

实践内容:当需要在函数间传递 WaitGroup 时,使用指针传递,而不是复制。

推荐理由:WaitGroup 是结构体,复制后会创建一个新的实例,与原实例状态无关,通过指针传递可以确保操作的是同一个实例。

7.5 结合通道处理错误

实践内容:使用通道收集 Goroutine 中的错误,在 Wait() 之后处理这些错误。

推荐理由:WaitGroup 只关注 Goroutine 的完成状态,不处理错误,结合通道可以有效地收集和处理错误。

7.6 合理设置并发数量

实践内容:根据系统资源和任务特性,合理设置并发 Goroutine 的数量。

推荐理由:过多的并发会导致系统资源耗尽,过少的并发会影响处理效率,需要根据实际情况进行调整。

7.7 监控和调试

实践内容:在生产环境中监控 WaitGroup 的使用情况,如等待时间、并发数量等。

推荐理由:监控可以帮助发现潜在的问题,如 Goroutine 泄漏、死锁等。

7.8 与 Context 结合使用

实践内容:结合 context 包使用 WaitGroup,实现更复杂的并发控制,如超时控制、取消操作等。

推荐理由:Context 提供了取消信号和超时控制,与 WaitGroup 结合使用可以构建更加健壮的并发系统。

8. 常见问题答疑(FAQ)

8.1 WaitGroup 和通道有什么区别?

问题描述:WaitGroup 和通道都可以用于等待 Goroutine 完成,它们有什么区别?

回答内容

  • WaitGroup:专门用于等待一组 Goroutine 完成,不传递数据,只关注完成状态。
  • 通道:不仅可以用于等待 Goroutine 完成,还可以传递数据和错误。
  • 使用场景
    • 当只需要等待 Goroutine 完成,不需要返回数据时,使用 WaitGroup。
    • 当需要在 Goroutine 之间传递数据或错误时,使用通道。
    • 当需要更复杂的控制,如超时、取消等时,结合 Context 使用。

示例代码

go
// 使用 WaitGroup
func withWaitGroup() {
    var wg sync.WaitGroup
    for i := 0; i < 3; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            fmt.Printf("Goroutine %d\n", id)
        }(i)
    }
    wg.Wait()
}

// 使用通道
func withChannel() {
    ch := make(chan bool, 3)
    for i := 0; i < 3; i++ {
        go func(id int) {
            fmt.Printf("Goroutine %d\n", id)
            ch <- true
        }(i)
    }
    for i := 0; i < 3; i++ {
        <-ch
    }
}

8.2 如何处理 WaitGroup 中的错误?

问题描述:当使用 WaitGroup 时,如何处理 Goroutine 中发生的错误?

回答内容

  • 使用通道收集错误:在每个 Goroutine 中,将错误发送到一个通道。
  • Wait() 之后,从通道中读取并处理错误。
  • 可以使用 errgroup 包,它结合了 WaitGroup 和错误处理。

示例代码

go
// 使用通道收集错误
func withErrorChannel() {
    var wg sync.WaitGroup
    errCh := make(chan error, 3)
    
    for i := 0; i < 3; i++ {
        wg.Add(1)
        go func(id int) {
            defer wg.Done()
            if id == 1 {
                errCh <- fmt.Errorf("error in goroutine %d", id)
                return
            }
            fmt.Printf("Goroutine %d completed\n", id)
        }(i)
    }
    
    wg.Wait()
    close(errCh)
    
    for err := range errCh {
        if err != nil {
            fmt.Printf("Error: %v\n", err)
        }
    }
}

// 使用 errgroup
func withErrGroup() {
    g, ctx := errgroup.WithContext(context.Background())
    
    for i := 0; i < 3; i++ {
        id := i
        g.Go(func() error {
            if id == 1 {
                return fmt.Errorf("error in goroutine %d", id)
            }
            fmt.Printf("Goroutine %d completed\n", id)
            return nil
        })
    }
    
    if err := g.Wait(); err != nil {
        fmt.Printf("Error: %v\n", err)
    }
}

8.3 如何避免 WaitGroup 的计数器错误?

问题描述:如何确保 WaitGroup 的 Add()Done() 调用次数匹配?

回答内容

  • 始终使用 defer wg.Done() 确保 Done() 被调用。
  • 在启动 Goroutine 之前集中调用 Add(),避免在多个地方分散调用。
  • 对于复杂的场景,使用计数器变量来跟踪需要启动的 Goroutine 数量,然后一次性调用 Add()
  • 代码审查时特别关注 WaitGroup 的使用,确保 Add()Done() 调用次数匹配。

示例代码

go
// 集中管理 Add() 调用
func集中管理Add() {
    var wg sync.WaitGroup
    tasks := []Task{task1, task2, task3}
    
    // 集中调用 Add()
    wg.Add(len(tasks))
    
    for _, task := range tasks {
        go func(t Task) {
            defer wg.Done()
            t()
        }(task)
    }
    
    wg.Wait()
}

8.4 WaitGroup 可以重用吗?

问题描述:WaitGroup 在调用 Wait() 之后可以重用吗?

回答内容

  • 从技术上讲,WaitGroup 在调用 Wait() 之后可以重用,只要重新调用 Add() 设置新的计数器。
  • 但是,这种做法不推荐,因为容易导致逻辑混乱和错误。
  • 推荐的做法是为每个需要等待的 Goroutine 组创建一个新的 WaitGroup。

示例代码

go
// 不推荐的重用方式
func reuseWaitGroup() {
    var wg sync.WaitGroup
    
    // 第一组 Goroutine
    wg.Add(2)
    go func() { defer wg.Done(); fmt.Println("Goroutine 1") }()
    go func() { defer wg.Done(); fmt.Println("Goroutine 2") }()
    wg.Wait()
    
    // 重用 WaitGroup
    wg.Add(2)
    go func() { defer wg.Done(); fmt.Println("Goroutine 3") }()
    go func() { defer wg.Done(); fmt.Println("Goroutine 4") }()
    wg.Wait()
}

// 推荐的方式:为每组 Goroutine 创建新的 WaitGroup
func newWaitGroupPerGroup() {
    // 第一组 Goroutine
    var wg1 sync.WaitGroup
    wg1.Add(2)
    go func() { defer wg1.Done(); fmt.Println("Goroutine 1") }()
    go func() { defer wg1.Done(); fmt.Println("Goroutine 2") }()
    wg1.Wait()
    
    // 第二组 Goroutine
    var wg2 sync.WaitGroup
    wg2.Add(2)
    go func() { defer wg2.Done(); fmt.Println("Goroutine 3") }()
    go func() { defer wg2.Done(); fmt.Println("Goroutine 4") }()
    wg2.Wait()
}

8.5 WaitGroup 与 sync.Mutex 有什么关系?

问题描述:WaitGroup 和 sync.Mutex 都是同步原语,它们有什么关系和区别?

回答内容

  • WaitGroup:用于等待一组 Goroutine 完成,管理的是 Goroutine 的生命周期。
  • sync.Mutex:用于保护共享资源,防止多个 Goroutine 同时访问导致的竞态条件。
  • 关系:两者都是 sync 包中的同步原语,经常一起使用。例如,在多个 Goroutine 访问共享资源时,使用 Mutex 保护资源,使用 WaitGroup 等待所有 Goroutine 完成。

示例代码

go
func withMutexAndWaitGroup() {
    var wg sync.WaitGroup
    var mu sync.Mutex
    var counter int
    
    for i := 0; i < 1000; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            mu.Lock()
            counter++
            mu.Unlock()
        }()
    }
    
    wg.Wait()
    fmt.Printf("Counter: %d\n", counter)
}

8.6 如何在 WaitGroup 中处理超时?

问题描述:当使用 WaitGroup 时,如何设置超时,避免无限等待?

回答内容

  • 结合 context.WithTimeouttime.After 使用,在超时后取消等待。
  • 注意:这种方式只能取消等待,不能取消正在执行的 Goroutine,需要结合 Context 来取消 Goroutine。

示例代码

go
func withTimeout() {
    var wg sync.WaitGroup
    
    // 启动一个会阻塞的 Goroutine
    wg.Add(1)
    go func() {
        defer wg.Done()
        time.Sleep(5 * time.Second) // 模拟长时间运行
        fmt.Println("Goroutine completed")
    }()
    
    // 设置超时
    timeout := time.After(2 * time.Second)
    done := make(chan struct{})
    
    go func() {
        wg.Wait()
        close(done)
    }()
    
    select {
    case <-done:
        fmt.Println("All goroutines completed")
    case <-timeout:
        fmt.Println("Timeout waiting for goroutines")
    }
}

9. 实战练习

9.1 基础练习:并发计数器

题目:使用 WaitGroup 和 Mutex 实现一个并发安全的计数器,支持多个 Goroutine 同时递增。

解题思路

  • 使用 Mutex 保护计数器变量,避免竞态条件。
  • 使用 WaitGroup 等待所有 Goroutine 完成。
  • 启动多个 Goroutine 同时递增计数器。

常见误区

  • 忘记使用 Mutex 保护共享变量,导致竞态条件。
  • 忘记调用 wg.Done(),导致 wg.Wait() 永远不会返回。

分步提示

  1. 定义计数器变量和 Mutex。
  2. 创建 WaitGroup。
  3. 启动多个 Goroutine,每个 Goroutine 递增计数器。
  4. 在每个 Goroutine 中使用 defer wg.Done()
  5. 使用 Mutex 保护计数器的递增操作。
  6. 调用 wg.Wait() 等待所有 Goroutine 完成。
  7. 输出最终的计数器值。

参考代码

go
package main

import (
    "fmt"
    "sync"
)

func main() {
    var wg sync.WaitGroup
    var mu sync.Mutex
    var counter int
    
    // 启动 1000 个 Goroutine 同时递增计数器
    for i := 0; i < 1000; i++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            mu.Lock()
            counter++
            mu.Unlock()
        }()
    }
    
    wg.Wait()
    fmt.Printf("Final counter value: %d\n", counter) // 应该输出 1000
}

9.2 进阶练习:并发文件处理

题目:使用 WaitGroup 实现并发文件处理,读取多个文件并统计总字数。

解题思路

  • 为每个文件启动一个 Goroutine 进行处理。
  • 使用 WaitGroup 等待所有文件处理完成。
  • 使用 Mutex 保护总字数变量,避免竞态条件。
  • 收集每个文件的处理结果和错误。

常见误区

  • 忘记处理文件读取错误。
  • 没有使用 Mutex 保护共享变量。
  • 文件路径处理错误。

分步提示

  1. 定义文件列表和总字数变量。
  2. 创建 WaitGroup 和 Mutex。
  3. 为每个文件启动一个 Goroutine。
  4. 在每个 Goroutine 中读取文件内容并计算字数。
  5. 使用 Mutex 保护总字数的更新。
  6. 处理文件读取错误。
  7. 调用 wg.Wait() 等待所有文件处理完成。
  8. 输出总字数。

参考代码

go
package main

import (
    "fmt"
    "io/ioutil"
    "sync"
)

func main() {
    var wg sync.WaitGroup
    var mu sync.Mutex
    var totalWords int
    
    files := []string{"file1.txt", "file2.txt", "file3.txt"}
    
    for _, file := range files {
        wg.Add(1)
        go func(f string) {
            defer wg.Done()
            
            // 读取文件内容
            content, err := ioutil.ReadFile(f)
            if err != nil {
                fmt.Printf("Error reading file %s: %v\n", f, err)
                return
            }
            
            // 简单计算字数(按空格分割)
            words := 0
            inWord := false
            for _, c := range content {
                if c == ' ' || c == '\n' || c == '\t' {
                    inWord = false
                } else if !inWord {
                    inWord = true
                    words++
                }
            }
            
            // 更新总字数
            mu.Lock()
            totalWords += words
            mu.Unlock()
            
            fmt.Printf("File %s: %d words\n", f, words)
        }(file)
    }
    
    wg.Wait()
    fmt.Printf("Total words: %d\n", totalWords)
}

9.3 挑战练习:并发 Web 爬虫

题目:使用 WaitGroup 实现一个并发 Web 爬虫,爬取多个网页并提取链接。

解题思路

  • 为每个 URL 启动一个 Goroutine 进行爬取。
  • 使用 WaitGroup 等待所有爬取任务完成。
  • 使用通道收集爬取结果和错误。
  • 使用互斥锁或 sync.Map 去重链接。

常见误区

  • 没有处理网络错误。
  • 没有去重链接,导致重复爬取。
  • 并发数量过高,导致被目标网站封禁。
  • 没有设置超时,导致某些请求永远阻塞。

分步提示

  1. 定义初始 URL 列表和已访问 URL 集合。
  2. 创建 WaitGroup 和通道用于收集结果。
  3. 为每个 URL 启动一个 Goroutine 进行爬取。
  4. 在每个 Goroutine 中发送 HTTP 请求,解析 HTML,提取链接。
  5. 去重链接,避免重复爬取。
  6. 处理网络错误和超时。
  7. 调用 wg.Wait() 等待所有爬取任务完成。
  8. 输出爬取结果。

参考代码

go
package main

import (
    "fmt"
    "net/http"
    "regexp"
    "sync"
    "time"
)

func main() {
    var wg sync.WaitGroup
    visited := sync.Map{}
    resultCh := make(chan string, 100)
    errorCh := make(chan error, 100)
    
    // 初始 URL 列表
    urls := []string{
        "https://example.com",
        "https://golang.org",
    }
    
    // 爬取函数
    var crawl func(url string)
    crawl = func(url string) {
        defer wg.Done()
        
        // 检查是否已访问
        if _, ok := visited.Load(url); ok {
            return
        }
        visited.Store(url, true)
        
        fmt.Printf("Crawling: %s\n", url)
        
        // 发送 HTTP 请求
        client := &http.Client{Timeout: 10 * time.Second}
        resp, err := client.Get(url)
        if err != nil {
            errorCh <- fmt.Errorf("Error crawling %s: %v", url, err)
            return
        }
        defer resp.Body.Close()
        
        // 读取响应体
        body, err := ioutil.ReadAll(resp.Body)
        if err != nil {
            errorCh <- fmt.Errorf("Error reading response: %v", err)
            return
        }
        
        // 提取链接
        var links []string
        re := regexp.MustCompile(`<a[^>]+href="([^"]+)"`)
        matches := re.FindAllStringSubmatch(string(body), -1)
        for _, match := range matches {
            if len(match) > 1 {
                link := match[1]
                // 处理相对链接
                if !strings.HasPrefix(link, "http") {
                    baseURL, err := url.Parse(url)
                    if err == nil {
                        relativeURL, err := baseURL.Parse(link)
                        if err == nil {
                            link = relativeURL.String()
                        }
                    }
                }
                links = append(links, link)
            }
        }
        
        // 发送结果
        for _, link := range links {
            resultCh <- link
        }
        
        // 递归爬取新链接
        for _, link := range links {
            if strings.HasPrefix(link, "http") {
                wg.Add(1)
                go crawl(link)
            }
        }
    }
    
    // 启动初始爬取
    for _, url := range urls {
        wg.Add(1)
        go crawl(url)
    }
    
    // 等待所有爬取任务完成
    go func() {
        wg.Wait()
        close(resultCh)
        close(errorCh)
    }()
    
    // 处理结果
    var collectedLinks []string
    for link := range resultCh {
        collectedLinks = append(collectedLinks, link)
    }
    
    // 处理错误
    for err := range errorCh {
        fmt.Printf("Error: %v\n", err)
    }
    
    fmt.Printf("Crawling completed. Collected %d links\n", len(collectedLinks))
    // 输出前 10 个链接
    for i, link := range collectedLinks[:min(10, len(collectedLinks))] {
        fmt.Printf("%d: %s\n", i+1, link)
    }
}

func min(a, b int) int {
    if a < b {
        return a
    }
    return b
}