Skip to content

工作池模式

1. 概述

工作池模式是一种经典的并发设计模式,用于限制并发执行的任务数量,避免系统资源过载。在 Go 语言中,工作池模式通过固定数量的 goroutine(工作协程)来处理任务队列中的任务,实现了任务的高效分配和执行。

工作池模式在以下场景中尤为重要:

  • 处理大量并发请求,如 HTTP 服务器
  • 执行 CPU 密集型任务,如数据处理、图像处理
  • 执行 I/O 密集型任务,如文件操作、网络请求
  • 任何需要限制并发度的场景

2. 基本概念

2.1 语法

在 Go 语言中,实现工作池模式的核心语法元素包括:

go
// 创建任务通道和结果通道
taskChan := make(chan Task, taskBufferSize)
resultChan := make(chan Result, resultBufferSize)

// 启动工作协程
for i := 0; i < workerCount; i++ {
    go worker(i, taskChan, resultChan)
}

// 发送任务到通道
for _, task := range tasks {
    taskChan <- task
}
close(taskChan) // 任务发送完成后关闭通道

// 收集结果
for i := 0; i < len(tasks); i++ {
    result := <-resultChan
    // 处理结果
}
close(resultChan)

2.2 语义

  • 工作池:由固定数量的工作协程组成的集合
  • 任务:需要执行的工作单元
  • 任务通道:用于向工作池提交任务的通道
  • 结果通道:用于从工作池获取执行结果的通道
  • 工作协程:执行任务的 goroutine
  • 并发度:同时执行任务的工作协程数量

2.3 规范

  • 合理设置工作协程数量:根据系统资源和任务特性设置合适的并发度
  • 正确管理通道生命周期:任务发送完成后关闭任务通道
  • 处理所有任务和结果:确保所有任务都被执行,所有结果都被处理
  • 错误处理:在任务执行过程中妥善处理错误
  • 资源管理:确保所有 goroutine 能够正确退出,避免资源泄漏

3. 原理深度解析

3.1 工作原理

工作池模式的工作原理基于以下流程:

  1. 初始化:创建固定数量的工作协程
  2. 任务提交:将任务发送到任务通道
  3. 任务分配:工作协程从任务通道获取任务
  4. 任务执行:工作协程执行任务
  5. 结果收集:工作协程将执行结果发送到结果通道
  6. 结果处理:主协程从结果通道获取并处理结果
  7. 清理:所有任务执行完成后,关闭通道,工作协程退出

3.2 并发控制

工作池模式通过以下方式实现并发控制:

  • 固定工作协程数量:限制同时执行的任务数量
  • 通道阻塞:当任务通道为空时,工作协程会阻塞等待新任务
  • 缓冲区:使用带缓冲的通道可以减少阻塞,提高任务分配效率
  • 同步机制:使用 sync.WaitGroup 等待所有工作协程完成

3.3 任务调度

工作池模式中的任务调度主要通过通道的特性实现:

  • 公平分配:多个工作协程从同一个任务通道获取任务,Go 运行时会公平地分配任务
  • 负载均衡:任务会自动分配给空闲的工作协程
  • 顺序无关:任务的执行顺序可能与提交顺序不同,适合处理独立任务

4. 常见错误与踩坑点

4.1 错误表现

在使用工作池模式时,常见的错误包括:

  1. 死锁:工作协程和主协程之间相互等待,导致程序卡住
  2. 资源泄漏:工作协程没有正确退出,导致资源无法释放
  3. 任务丢失:任务通道关闭过早,导致部分任务未被处理
  4. 结果丢失:结果通道关闭过早,导致部分结果未被收集
  5. 并发度过高:工作协程数量设置过多,导致系统资源过载
  6. 错误处理不当:任务执行过程中的错误没有被妥善处理

4.2 产生原因

  • 通道操作不当:如关闭通道的时机不正确
  • goroutine 管理不当:如没有正确处理工作协程的退出条件
  • 并发度设置不合理:如工作协程数量过多或过少
  • 错误处理不完善:如忽略任务执行过程中的错误
  • 资源管理不当:如没有正确使用同步原语

4.3 解决方案

  1. 正确关闭通道:在所有任务发送完成后关闭任务通道
  2. 使用 for range 遍历通道:自动处理通道关闭的情况
  3. 合理设置并发度:根据系统资源和任务特性设置合适的工作协程数量
  4. 使用 sync.WaitGroup:等待所有工作协程完成
  5. 妥善处理错误:在任务执行过程中捕获和处理错误
  6. 监控系统资源:根据系统资源使用情况调整并发度

5. 常见应用场景

5.1 HTTP 服务器

场景描述:需要处理大量 HTTP 请求,限制并发连接数。

使用方法:创建固定数量的工作协程,每个协程处理一个 HTTP 请求。

示例代码

go
package main

import (
    "fmt"
    "net/http"
    "sync"
    "time"
)

type RequestTask struct {
    w    http.ResponseWriter
    r    *http.Request
    id   int
}

func worker(id int, tasks <-chan RequestTask, wg *sync.WaitGroup) {
    defer wg.Done()
    for task := range tasks {
        fmt.Printf("Worker %d handling request %d\n", id, task.id)
        // 模拟处理请求
        time.Sleep(time.Millisecond * 100)
        fmt.Fprintf(task.w, "Hello from worker %d!", id)
    }
}

func main() {
    const workerCount = 5
    const bufferSize = 100
    
    tasks := make(chan RequestTask, bufferSize)
    var wg sync.WaitGroup
    
    // 启动工作协程
    for i := 0; i < workerCount; i++ {
        wg.Add(1)
        go worker(i, tasks, &wg)
    }
    
    // HTTP 处理器
    http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
        taskID := time.Now().UnixNano()
        tasks <- RequestTask{w: w, r: r, id: int(taskID)}
    })
    
    fmt.Println("Server started on :8080")
    // 启动服务器(非阻塞)
    go func() {
        if err := http.ListenAndServe(":8080", nil); err != nil {
            fmt.Printf("Server error: %v\n", err)
        }
    }()
    
    // 运行一段时间
    time.Sleep(time.Minute)
    
    close(tasks)
    wg.Wait()
    fmt.Println("Server stopped")
}

5.2 数据处理

场景描述:需要处理大量数据,如文件处理、数据转换等。

使用方法:创建固定数量的工作协程,每个协程处理一部分数据。

示例代码

go
package main

import (
    "fmt"
    "sync"
    "time"
)

type DataTask struct {
    id   int
    data []int
}

type DataResult struct {
    taskID int
    sum    int
}

func worker(id int, tasks <-chan DataTask, results chan<- DataResult, wg *sync.WaitGroup) {
    defer wg.Done()
    for task := range tasks {
        fmt.Printf("Worker %d processing task %d\n", id, task.id)
        // 模拟数据处理
        sum := 0
        for _, num := range task.data {
            sum += num
            time.Sleep(time.Millisecond * 10)
        }
        results <- DataResult{taskID: task.id, sum: sum}
    }
}

func main() {
    const workerCount = 4
    const taskCount = 10
    
    tasks := make(chan DataTask, taskCount)
    results := make(chan DataResult, taskCount)
    var wg sync.WaitGroup
    
    // 启动工作协程
    for i := 0; i < workerCount; i++ {
        wg.Add(1)
        go worker(i, tasks, results, &wg)
    }
    
    // 生成任务
    for i := 0; i < taskCount; i++ {
        data := make([]int, 100)
        for j := range data {
            data[j] = j
        }
        tasks <- DataTask{id: i, data: data}
        fmt.Printf("Submitted task %d\n", i)
    }
    close(tasks)
    
    // 收集结果
    go func() {
        wg.Wait()
        close(results)
    }()
    
    var totalSum int
    for result := range results {
        fmt.Printf("Task %d sum: %d\n", result.taskID, result.sum)
        totalSum += result.sum
    }
    
    fmt.Printf("Total sum: %d\n", totalSum)
}

5.3 网络请求

场景描述:需要发送大量网络请求,限制并发请求数。

使用方法:创建固定数量的工作协程,每个协程发送一个网络请求。

示例代码

go
package main

import (
    "fmt"
    "net/http"
    "sync"
    "time"
)

type RequestTask struct {
    id  int
    url string
}

type RequestResult struct {
    taskID    int
    statusCode int
    err       error
}

func worker(id int, tasks <-chan RequestTask, results chan<- RequestResult, wg *sync.WaitGroup) {
    defer wg.Done()
    client := &http.Client{Timeout: 10 * time.Second}
    
    for task := range tasks {
        fmt.Printf("Worker %d requesting %s\n", id, task.url)
        resp, err := client.Get(task.url)
        var statusCode int
        if resp != nil {
            statusCode = resp.StatusCode
            resp.Body.Close()
        }
        results <- RequestResult{taskID: task.id, statusCode: statusCode, err: err}
        time.Sleep(time.Millisecond * 100) // 避免请求过快
    }
}

func main() {
    const workerCount = 3
    const taskCount = 10
    
    tasks := make(chan RequestTask, taskCount)
    results := make(chan RequestResult, taskCount)
    var wg sync.WaitGroup
    
    // 启动工作协程
    for i := 0; i < workerCount; i++ {
        wg.Add(1)
        go worker(i, tasks, results, &wg)
    }
    
    // 生成任务
    urls := []string{
        "https://www.google.com",
        "https://www.github.com",
        "https://www.go.dev",
        "https://www.amazon.com",
        "https://www.microsoft.com",
    }
    
    for i := 0; i < taskCount; i++ {
        url := urls[i%len(urls)]
        tasks <- RequestTask{id: i, url: url}
        fmt.Printf("Submitted task %d: %s\n", i, url)
    }
    close(tasks)
    
    // 收集结果
    go func() {
        wg.Wait()
        close(results)
    }()
    
    for result := range results {
        if result.err != nil {
            fmt.Printf("Task %d error: %v\n", result.taskID, result.err)
        } else {
            fmt.Printf("Task %d status code: %d\n", result.taskID, result.statusCode)
        }
    }
}

5.4 文件处理

场景描述:需要处理大量文件,如读取、写入、转换等。

使用方法:创建固定数量的工作协程,每个协程处理一个文件。

示例代码

go
package main

import (
    "fmt"
    "os"
    "sync"
    "time"
)

type FileTask struct {
    id   int
    path string
}

type FileResult struct {
    taskID int
    size   int64
    err    error
}

func worker(id int, tasks <-chan FileTask, results chan<- FileResult, wg *sync.WaitGroup) {
    defer wg.Done()
    for task := range tasks {
        fmt.Printf("Worker %d processing file %s\n", id, task.path)
        // 模拟文件处理
        time.Sleep(time.Millisecond * 100)
        
        // 获取文件大小
        info, err := os.Stat(task.path)
        var size int64
        if info != nil {
            size = info.Size()
        }
        results <- FileResult{taskID: task.id, size: size, err: err}
    }
}

func main() {
    const workerCount = 3
    
    tasks := make(chan FileTask, 10)
    results := make(chan FileResult, 10)
    var wg sync.WaitGroup
    
    // 启动工作协程
    for i := 0; i < workerCount; i++ {
        wg.Add(1)
        go worker(i, tasks, results, &wg)
    }
    
    // 生成任务
    files := []string{
        "./worker-pool.md",
        "./patterns.md",
        "./producer-consumer.md",
        "./fanout-fanin.md",
        "./pipeline.md",
    }
    
    for i, file := range files {
        tasks <- FileTask{id: i, path: file}
        fmt.Printf("Submitted task %d: %s\n", i, file)
    }
    close(tasks)
    
    // 收集结果
    go func() {
        wg.Wait()
        close(results)
    }()
    
    for result := range results {
        if result.err != nil {
            fmt.Printf("Task %d error: %v\n", result.taskID, result.err)
        } else {
            fmt.Printf("Task %d file size: %d bytes\n", result.taskID, result.size)
        }
    }
}

5.5 数据库操作

场景描述:需要执行大量数据库操作,限制并发连接数。

使用方法:创建固定数量的工作协程,每个协程执行一个数据库操作。

示例代码

go
package main

import (
    "database/sql"
    "fmt"
    "sync"
    "time"
    
    _ "github.com/mattn/go-sqlite3"
)

type DBTask struct {
    id   int
    name string
    age  int
}

type DBResult struct {
    taskID int
    id     int64
    err    error
}

func worker(id int, db *sql.DB, tasks <-chan DBTask, results chan<- DBResult, wg *sync.WaitGroup) {
    defer wg.Done()
    for task := range tasks {
        fmt.Printf("Worker %d inserting %s\n", id, task.name)
        // 执行数据库插入
        result, err := db.Exec("INSERT INTO users (name, age) VALUES (?, ?)", task.name, task.age)
        var lastID int64
        if result != nil {
            lastID, _ = result.LastInsertId()
        }
        results <- DBResult{taskID: task.id, id: lastID, err: err}
        time.Sleep(time.Millisecond * 50) // 模拟处理时间
    }
}

func main() {
    const workerCount = 2
    const taskCount = 10
    
    // 连接数据库
    db, err := sql.Open("sqlite3", ":memory:")
    if err != nil {
        fmt.Printf("Database error: %v\n", err)
        return
    }
    defer db.Close()
    
    // 创建表
    _, err = db.Exec("CREATE TABLE users (id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT, age INTEGER)")
    if err != nil {
        fmt.Printf("Create table error: %v\n", err)
        return
    }
    
    tasks := make(chan DBTask, taskCount)
    results := make(chan DBResult, taskCount)
    var wg sync.WaitGroup
    
    // 启动工作协程
    for i := 0; i < workerCount; i++ {
        wg.Add(1)
        go worker(i, db, tasks, results, &wg)
    }
    
    // 生成任务
    for i := 0; i < taskCount; i++ {
        task := DBTask{
            id:   i,
            name: fmt.Sprintf("User %d", i),
            age:  20 + i%30,
        }
        tasks <- task
        fmt.Printf("Submitted task %d: %s\n", i, task.name)
    }
    close(tasks)
    
    // 收集结果
    go func() {
        wg.Wait()
        close(results)
    }()
    
    for result := range results {
        if result.err != nil {
            fmt.Printf("Task %d error: %v\n", result.taskID, result.err)
        } else {
            fmt.Printf("Task %d inserted with ID: %d\n", result.taskID, result.id)
        }
    }
    
    // 验证数据
    rows, err := db.Query("SELECT id, name, age FROM users")
    if err != nil {
        fmt.Printf("Query error: %v\n", err)
        return
    }
    defer rows.Close()
    
    fmt.Println("\nAll users:")
    for rows.Next() {
        var id int64
        var name string
        var age int
        if err := rows.Scan(&id, &name, &age); err != nil {
            fmt.Printf("Scan error: %v\n", err)
            continue
        }
        fmt.Printf("ID: %d, Name: %s, Age: %d\n", id, name, age)
    }
}

6. 企业级进阶应用场景

6.1 微服务架构中的工作池

场景描述:在微服务架构中,需要处理大量的服务间调用,限制并发请求数。

使用方法:为每个服务调用类型创建工作池,根据服务的性能和容量设置合适的并发度。

示例代码

go
package main

import (
    "context"
    "fmt"
    "sync"
    "time"
)

type ServiceTask struct {
    id      int
    service string
    payload map[string]interface{}
}

type ServiceResult struct {
    taskID  int
    result  interface{}
    err     error
}

type WorkerPool struct {
    tasks     chan ServiceTask
    results   chan ServiceResult
    wg        sync.WaitGroup
    workerCount int
}

func NewWorkerPool(workerCount, bufferSize int) *WorkerPool {
    return &WorkerPool{
        tasks:     make(chan ServiceTask, bufferSize),
        results:   make(chan ServiceResult, bufferSize),
        workerCount: workerCount,
    }
}

func (p *WorkerPool) Start() {
    for i := 0; i < p.workerCount; i++ {
        p.wg.Add(1)
        go p.worker(i)
    }
}

func (p *WorkerPool) worker(id int) {
    defer p.wg.Done()
    for task := range p.tasks {
        fmt.Printf("Worker %d calling service %s\n", id, task.service)
        // 模拟服务调用
        time.Sleep(time.Millisecond * 100)
        
        // 模拟服务响应
        result := map[string]interface{}{
            "task_id": task.id,
            "service": task.service,
            "data":    "result data",
        }
        
        p.results <- ServiceResult{
            taskID: task.id,
            result: result,
            err:    nil,
        }
    }
}

func (p *WorkerPool) Submit(task ServiceTask) {
    p.tasks <- task
}

func (p *WorkerPool) Results() <-chan ServiceResult {
    return p.results
}

func (p *WorkerPool) Close() {
    close(p.tasks)
    go func() {
        p.wg.Wait()
        close(p.results)
    }()
}

func main() {
    // 创建工作池
    pool := NewWorkerPool(5, 100)
    pool.Start()
    
    // 提交任务
    services := []string{"user-service", "order-service", "payment-service", "inventory-service"}
    for i := 0; i < 20; i++ {
        task := ServiceTask{
            id:      i,
            service: services[i%len(services)],
            payload: map[string]interface{}{
                "id": i,
                "data": fmt.Sprintf("payload %d", i),
            },
        }
        pool.Submit(task)
        fmt.Printf("Submitted task %d to %s\n", i, task.service)
    }
    
    // 关闭任务通道
    pool.Close()
    
    // 收集结果
    for result := range pool.Results() {
        fmt.Printf("Task %d result: %v\n", result.taskID, result.result)
    }
    
    fmt.Println("All tasks completed")
}

6.2 实时数据处理系统

场景描述:需要处理实时数据流,如传感器数据、用户行为数据等,限制并发处理数。

使用方法:创建工作池处理数据,根据数据速率和处理能力设置合适的并发度。

示例代码

go
package main

import (
    "fmt"
    "sync"
    "time"
)

type SensorData struct {
    id        string
    value     float64
    timestamp time.Time
}

type ProcessedData struct {
    sensorID  string
    value     float64
    average   float64
    timestamp time.Time
}

type DataTask struct {
    id   int
    data SensorData
}

type DataResult struct {
    taskID int
    data   ProcessedData
    err    error
}

func worker(id int, tasks <-chan DataTask, results chan<- DataResult, wg *sync.WaitGroup) {
    defer wg.Done()
    // 模拟滑动窗口计算平均值
    window := make([]float64, 0, 10)
    
    for task := range tasks {
        fmt.Printf("Worker %d processing sensor %s\n", id, task.data.id)
        
        // 处理数据
        window = append(window, task.data.value)
        if len(window) > 10 {
            window = window[1:]
        }
        
        // 计算平均值
        var sum float64
        for _, v := range window {
            sum += v
        }
        avg := sum / float64(len(window))
        
        processed := ProcessedData{
            sensorID:  task.data.id,
            value:     task.data.value,
            average:   avg,
            timestamp: time.Now(),
        }
        
        results <- DataResult{taskID: task.id, data: processed, err: nil}
        time.Sleep(time.Millisecond * 50) // 模拟处理时间
    }
}

func main() {
    const workerCount = 4
    const bufferSize = 100
    
    tasks := make(chan DataTask, bufferSize)
    results := make(chan DataResult, bufferSize)
    var wg sync.WaitGroup
    
    // 启动工作协程
    for i := 0; i < workerCount; i++ {
        wg.Add(1)
        go worker(i, tasks, results, &wg)
    }
    
    // 模拟数据生成
    go func() {
        sensors := []string{"temp-1", "temp-2", "humid-1", "humid-2"}
        for i := 0; i < 50; i++ {
            sensor := sensors[i%len(sensors)]
            data := SensorData{
                id:        sensor,
                value:     float64(20 + i%20),
                timestamp: time.Now(),
            }
            tasks <- DataTask{id: i, data: data}
            fmt.Printf("Generated data from sensor %s: %f\n", sensor, data.value)
            time.Sleep(time.Millisecond * 20) // 模拟数据生成速率
        }
        close(tasks)
    }()
    
    // 收集结果
    go func() {
        wg.Wait()
        close(results)
    }()
    
    for result := range results {
        fmt.Printf("Task %d processed: sensor=%s, value=%.2f, average=%.2f\n", 
            result.taskID, result.data.sensorID, result.data.value, result.data.average)
    }
    
    fmt.Println("All data processed")
}

6.3 批量任务处理系统

场景描述:需要处理大量批量任务,如数据导入、报表生成等,限制并发执行数。

使用方法:创建工作池处理批量任务,根据任务大小和系统资源设置合适的并发度。

示例代码

go
package main

import (
    "fmt"
    "sync"
    "time"
)

type BatchTask struct {
    id       int
    name     string
    itemCount int
}

type BatchResult struct {
    taskID    int
    processed int
    err       error
}

func worker(id int, tasks <-chan BatchTask, results chan<- BatchResult, wg *sync.WaitGroup) {
    defer wg.Done()
    for task := range tasks {
        fmt.Printf("Worker %d processing batch %s (%d items)\n", id, task.name, task.itemCount)
        
        // 模拟批量处理
        processed := 0
        for i := 0; i < task.itemCount; i++ {
            // 模拟处理单个 item
            time.Sleep(time.Millisecond * 10)
            processed++
        }
        
        results <- BatchResult{taskID: task.id, processed: processed, err: nil}
        fmt.Printf("Worker %d completed batch %s\n", id, task.name)
    }
}

func main() {
    const workerCount = 3
    const bufferSize = 20
    
    tasks := make(chan BatchTask, bufferSize)
    results := make(chan BatchResult, bufferSize)
    var wg sync.WaitGroup
    
    // 启动工作协程
    for i := 0; i < workerCount; i++ {
        wg.Add(1)
        go worker(i, tasks, results, &wg)
    }
    
    // 生成批量任务
    batches := []BatchTask{
        {id: 1, name: "Import Users", itemCount: 1000},
        {id: 2, name: "Generate Reports", itemCount: 500},
        {id: 3, name: "Update Inventory", itemCount: 2000},
        {id: 4, name: "Process Orders", itemCount: 1500},
        {id: 5, name: "Cleanup Data", itemCount: 800},
    }
    
    for _, batch := range batches {
        tasks <- batch
        fmt.Printf("Submitted batch %s\n", batch.name)
    }
    close(tasks)
    
    // 收集结果
    go func() {
        wg.Wait()
        close(results)
    }()
    
    var totalProcessed int
    for result := range results {
        fmt.Printf("Batch %d processed %d items\n", result.taskID, result.processed)
        totalProcessed += result.processed
    }
    
    fmt.Printf("Total processed items: %d\n", totalProcessed)
}

7. 行业最佳实践

7.1 实践内容

  1. 根据任务类型设置并发度

    • CPU 密集型任务:并发度不宜过高,一般设置为 CPU 核心数
    • I/O 密集型任务:并发度可以设置得较高,如 CPU 核心数的 2-4 倍
  2. 使用带缓冲的通道

    • 任务通道:根据任务生成速率设置合适的缓冲区大小
    • 结果通道:根据结果处理速率设置合适的缓冲区大小
  3. 实现优雅关闭

    • 正确关闭任务通道,确保所有任务都被处理
    • 使用 sync.WaitGroup 等待所有工作协程完成
    • 处理所有结果,确保结果不丢失
  4. 错误处理

    • 在任务执行过程中捕获和处理错误
    • 将错误作为结果的一部分返回,而不是直接 panic
    • 实现错误重试机制,提高系统的可靠性
  5. 监控和度量

    • 监控工作池的状态,如活跃工作协程数、任务队列长度等
    • 度量任务执行时间、成功率等指标
    • 根据监控数据动态调整并发度
  6. 动态调整并发度

    • 根据系统负载和任务队列长度动态调整工作协程数量
    • 实现自动扩缩容机制,提高系统的弹性
  7. 使用上下文管理

    • 使用 context 包管理工作协程的生命周期
    • 支持任务取消和超时控制
  8. 任务优先级

    • 实现任务优先级队列,优先处理高优先级任务
    • 确保重要任务能够及时得到处理

7.2 推荐理由

  • 提高系统吞吐量:通过合理的并发度设置,充分利用系统资源,提高任务处理速度
  • 避免系统过载:限制并发度,防止系统资源被耗尽
  • 提高系统可靠性:通过错误处理和监控机制,提高系统的稳定性和可靠性
  • 改善用户体验:快速处理任务,减少用户等待时间
  • 便于维护和扩展:模块化的设计,便于系统的维护和扩展

8. 常见问题答疑(FAQ)

8.1 问题描述:如何确定工作池的大小?

回答内容:工作池的大小应根据以下因素确定:

  • 任务类型(CPU 密集型或 I/O 密集型)
  • 系统资源(CPU 核心数、内存大小)
  • 任务处理时间
  • 系统负载

一般来说,对于 CPU 密集型任务,工作池大小设置为 CPU 核心数;对于 I/O 密集型任务,工作池大小可以设置为 CPU 核心数的 2-4 倍。

示例代码

go
import "runtime"

// 获取 CPU 核心数
cpuCount := runtime.NumCPU()

// CPU 密集型任务
workerCount := cpuCount

// I/O 密集型任务
workerCount := cpuCount * 2

8.2 问题描述:如何处理任务执行过程中的错误?

回答内容:可以通过以下方式处理任务执行过程中的错误:

  • 将错误作为结果的一部分返回
  • 实现错误重试机制
  • 记录错误日志
  • 对严重错误进行告警

示例代码

go
type TaskResult struct {
    taskID int
    result interface{}
    err    error
}

func worker(tasks <-chan Task, results chan<- TaskResult) {
    for task := range tasks {
        result, err := processTask(task)
        results <- TaskResult{taskID: task.id, result: result, err: err}
    }
}

8.3 问题描述:如何实现工作池的动态扩缩容?

回答内容:可以通过以下方式实现工作池的动态扩缩容:

  • 监控任务队列长度和系统负载
  • 根据监控数据调整工作协程数量
  • 实现添加和移除工作协程的机制

示例代码

go
func (p *WorkerPool) adjustWorkerCount() {
    queueLength := len(p.tasks)
    currentWorkers := p.workerCount
    
    // 根据队列长度调整工作协程数量
    if queueLength > currentWorkers*2 && currentWorkers < p.maxWorkers {
        // 添加工作协程
        for i := 0; i < 2; i++ {
            p.workerCount++
            p.wg.Add(1)
            go p.worker(p.workerCount - 1)
        }
    } else if queueLength < currentWorkers/2 && currentWorkers > p.minWorkers {
        // 移除工作协程(通过发送终止信号)
        for i := 0; i < 2 && p.workerCount > p.minWorkers; i++ {
            p.terminateChan <- struct{}{}
            p.workerCount--
        }
    }
}

8.4 问题描述:如何处理长时间运行的任务?

回答内容:对于长时间运行的任务,可以采取以下措施:

  • 设置任务超时机制
  • 实现任务中断和恢复功能
  • 监控任务执行时间,对超时任务进行处理
  • 将长时间运行的任务拆分为多个小任务

示例代码

go
func worker(tasks <-chan Task, results chan<- TaskResult) {
    for task := range tasks {
        // 创建带超时的上下文
        ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
        defer cancel()
        
        // 在 goroutine 中执行任务
        resultChan := make(chan TaskResult, 1)
        go func() {
            result, err := processTask(task)
            resultChan <- TaskResult{taskID: task.id, result: result, err: err}
        }()
        
        // 等待任务完成或超时
        select {
        case result := <-resultChan:
            results <- result
        case <-ctx.Done():
            results <- TaskResult{taskID: task.id, err: fmt.Errorf("task timeout")}
        }
    }
}

8.5 问题描述:如何实现任务优先级?

回答内容:可以通过以下方式实现任务优先级:

  • 使用优先级队列存储任务
  • 工作协程从优先级队列中获取任务
  • 确保高优先级任务先被处理

示例代码

go
type PriorityTask struct {
    task     Task
    priority int
}

type PriorityQueue []*PriorityTask

func (pq PriorityQueue) Len() int { return len(pq) }
func (pq PriorityQueue) Less(i, j int) bool { return pq[i].priority > pq[j].priority } // 高优先级优先
func (pq PriorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] }

func (pq *PriorityQueue) Push(x interface{}) {
    *pq = append(*pq, x.(*PriorityTask))
}

func (pq *PriorityQueue) Pop() interface{} {
    old := *pq
    n := len(old)
    item := old[n-1]
    *pq = old[0 : n-1]
    return item
}

// 工作协程从优先级队列获取任务
func worker(pq *PriorityQueue, mutex *sync.Mutex, cond *sync.Cond) {
    for {
        mutex.Lock()
        for pq.Len() == 0 {
            cond.Wait()
        }
        task := heap.Pop(pq).(*PriorityTask)
        mutex.Unlock()
        
        // 处理任务
        processTask(task.task)
    }
}

8.6 问题描述:如何测试工作池的性能?

回答内容:可以通过以下方式测试工作池的性能:

  • 测量任务处理时间
  • 测量系统资源使用情况
  • 测试不同并发度下的性能
  • 测试边界情况,如任务突增、系统负载高等

示例代码

go
func benchmarkWorkerPool(workerCount, taskCount int) time.Duration {
    start := time.Now()
    
    // 创建工作池
    pool := NewWorkerPool(workerCount, taskCount)
    pool.Start()
    
    // 提交任务
    for i := 0; i < taskCount; i++ {
        pool.Submit(Task{id: i})
    }
    
    // 关闭工作池
    pool.Close()
    
    // 收集结果
    for range pool.Results() {
        // 处理结果
    }
    
    return time.Since(start)
}

func main() {
    taskCount := 1000
    for workerCount := 1; workerCount <= 8; workerCount++ {
        duration := benchmarkWorkerPool(workerCount, taskCount)
        fmt.Printf("Worker count: %d, Duration: %v\n", workerCount, duration)
    }
}

9. 实战练习

9.1 基础练习:实现一个简单的工作池

解题思路:创建固定数量的工作协程,从任务通道获取任务并执行,将结果发送到结果通道。

常见误区:忘记关闭通道,导致工作协程无法退出。

分步提示

  1. 创建任务通道和结果通道
  2. 启动固定数量的工作协程
  3. 向任务通道发送任务
  4. 从结果通道接收结果
  5. 关闭通道,等待所有工作协程完成

参考代码

go
package main

import (
    "fmt"
    "sync"
    "time"
)

type Task struct {
    ID   int
    Data int
}

type Result struct {
    TaskID int
    Value  int
}

func worker(id int, tasks <-chan Task, results chan<- Result, wg *sync.WaitGroup) {
    defer wg.Done()
    for task := range tasks {
        fmt.Printf("Worker %d processing task %d\n", id, task.ID)
        // 模拟任务处理
        time.Sleep(time.Millisecond * 100)
        results <- Result{TaskID: task.ID, Value: task.Data * 2}
    }
}

func main() {
    const workerCount = 3
    const taskCount = 10
    
    tasks := make(chan Task, taskCount)
    results := make(chan Result, taskCount)
    var wg sync.WaitGroup
    
    // 启动工作协程
    for i := 0; i < workerCount; i++ {
        wg.Add(1)
        go worker(i, tasks, results, &wg)
    }
    
    // 发送任务
    for i := 0; i < taskCount; i++ {
        task := Task{ID: i, Data: i * 10}
        tasks <- task
        fmt.Printf("Submitted task %d: %d\n", i, task.Data)
    }
    close(tasks)
    
    // 收集结果
    go func() {
        wg.Wait()
        close(results)
    }()
    
    for result := range results {
        fmt.Printf("Result: Task %d = %d\n", result.TaskID, result.Value)
    }
    
    fmt.Println("All tasks completed")
}

9.2 进阶练习:实现带错误处理的工作池

解题思路:在工作池基础上添加错误处理功能,确保任务执行过程中的错误能够被妥善处理。

常见误区:错误处理不当,导致整个工作池崩溃。

分步提示

  1. 定义包含错误信息的结果结构
  2. 在任务执行过程中捕获错误
  3. 将错误作为结果的一部分返回
  4. 在主协程中处理错误

参考代码

go
package main

import (
    "fmt"
    "sync"
    "time"
)

type Task struct {
    ID   int
    Data int
}

type Result struct {
    TaskID int
    Value  int
    Err    error
}

func worker(id int, tasks <-chan Task, results chan<- Result, wg *sync.WaitGroup) {
    defer wg.Done()
    for task := range tasks {
        fmt.Printf("Worker %d processing task %d\n", id, task.ID)
        
        // 模拟任务处理,可能产生错误
        var err error
        var value int
        
        if task.Data%3 == 0 {
            err = fmt.Errorf("error processing task %d", task.ID)
        } else {
            // 模拟处理时间
            time.Sleep(time.Millisecond * 100)
            value = task.Data * 2
        }
        
        results <- Result{TaskID: task.ID, Value: value, Err: err}
    }
}

func main() {
    const workerCount = 3
    const taskCount = 10
    
    tasks := make(chan Task, taskCount)
    results := make(chan Result, taskCount)
    var wg sync.WaitGroup
    
    // 启动工作协程
    for i := 0; i < workerCount; i++ {
        wg.Add(1)
        go worker(i, tasks, results, &wg)
    }
    
    // 发送任务
    for i := 0; i < taskCount; i++ {
        task := Task{ID: i, Data: i * 10}
        tasks <- task
        fmt.Printf("Submitted task %d: %d\n", i, task.Data)
    }
    close(tasks)
    
    // 收集结果
    go func() {
        wg.Wait()
        close(results)
    }()
    
    var successCount, errorCount int
    for result := range results {
        if result.Err != nil {
            fmt.Printf("Task %d error: %v\n", result.TaskID, result.Err)
            errorCount++
        } else {
            fmt.Printf("Task %d result: %d\n", result.TaskID, result.Value)
            successCount++
        }
    }
    
    fmt.Printf("Processing completed: %d success, %d error\n", successCount, errorCount)
}

9.3 挑战练习:实现带动态扩缩容的工作池

解题思路:实现一个可以根据任务队列长度动态调整工作协程数量的工作池。

常见误区:动态扩缩容逻辑不正确,导致工作协程数量失控。

分步提示

  1. 实现工作池的基本功能
  2. 添加监控任务队列长度的机制
  3. 根据队列长度动态调整工作协程数量
  4. 实现工作协程的添加和移除

参考代码

go
package main

import (
    "fmt"
    "sync"
    "time"
)

type Task struct {
    ID   int
    Data int
}

type Result struct {
    TaskID int
    Value  int
}

type WorkerPool struct {
    tasks         chan Task
    results       chan Result
    wg            sync.WaitGroup
    workerCount   int
    minWorkers    int
    maxWorkers    int
    mu            sync.Mutex
    terminateChan chan struct{}
}

func NewWorkerPool(minWorkers, maxWorkers, bufferSize int) *WorkerPool {
    return &WorkerPool{
        tasks:         make(chan Task, bufferSize),
        results:       make(chan Result, bufferSize),
        minWorkers:    minWorkers,
        maxWorkers:    maxWorkers,
        workerCount:   0,
        terminateChan: make(chan struct{}),
    }
}

func (p *WorkerPool) Start() {
    // 启动最小数量的工作协程
    for i := 0; i < p.minWorkers; i++ {
        p.addWorker()
    }
    
    // 启动监控协程,动态调整工作协程数量
    go p.monitor()
}

func (p *WorkerPool) addWorker() {
    p.mu.Lock()
    id := p.workerCount
    p.workerCount++
    p.mu.Unlock()
    
    p.wg.Add(1)
    go p.worker(id)
    fmt.Printf("Added worker %d, total: %d\n", id, p.workerCount)
}

func (p *WorkerPool) removeWorker() {
    select {
    case p.terminateChan <- struct{}{}:
        p.mu.Lock()
        p.workerCount--
        fmt.Printf("Removed worker, total: %d\n", p.workerCount)
        p.mu.Unlock()
    default:
        // 没有可移除的工作协程
    }
}

func (p *WorkerPool) worker(id int) {
    defer p.wg.Done()
    
    for {
        select {
        case <-p.terminateChan:
            fmt.Printf("Worker %d exiting\n", id)
            return
        case task, ok := <-p.tasks:
            if !ok {
                fmt.Printf("Worker %d exiting (tasks channel closed)\n", id)
                return
            }
            
            // 处理任务
            fmt.Printf("Worker %d processing task %d\n", id, task.ID)
            time.Sleep(time.Millisecond * 100)
            p.results <- Result{TaskID: task.ID, Value: task.Data * 2}
        }
    }
}

func (p *WorkerPool) monitor() {
    ticker := time.NewTicker(time.Second)
    defer ticker.Stop()
    
    for {
        select {
        case <-ticker.C:
            p.mu.Lock()
            queueLength := len(p.tasks)
            currentWorkers := p.workerCount
            p.mu.Unlock()
            
            fmt.Printf("Monitoring: queue length=%d, workers=%d\n", queueLength, currentWorkers)
            
            // 动态调整工作协程数量
            if queueLength > currentWorkers*2 && currentWorkers < p.maxWorkers {
                // 添加工作协程
                p.addWorker()
            } else if queueLength < currentWorkers/2 && currentWorkers > p.minWorkers {
                // 移除工作协程
                p.removeWorker()
            }
        }
    }
}

func (p *WorkerPool) Submit(task Task) {
    p.tasks <- task
}

func (p *WorkerPool) Results() <-chan Result {
    return p.results
}

func (p *WorkerPool) Close() {
    close(p.tasks)
    go func() {
        p.wg.Wait()
        close(p.results)
        close(p.terminateChan)
    }()
}

func main() {
    // 创建工作池,最小 2 个工作协程,最大 5 个工作协程
    pool := NewWorkerPool(2, 5, 100)
    pool.Start()
    
    // 批量提交任务
    for i := 0; i < 30; i++ {
        task := Task{ID: i, Data: i * 10}
        pool.Submit(task)
        fmt.Printf("Submitted task %d\n", i)
        time.Sleep(time.Millisecond * 50) // 模拟任务生成速率
    }
    
    // 关闭工作池
    pool.Close()
    
    // 收集结果
    var results []Result
    for result := range pool.Results() {
        results = append(results, result)
    }
    
    fmt.Printf("Processed %d tasks\n", len(results))
}

10. 知识点总结

10.1 核心要点

  • 工作池模式:一种限制并发数量的设计模式,通过固定数量的工作协程处理任务队列
  • 并发控制:通过固定工作协程数量限制并发度,避免系统资源过载
  • 任务分配:通过通道实现任务的公平分配和负载均衡
  • 错误处理:在任务执行过程中妥善处理错误,确保系统的稳定性
  • 资源管理:确保所有工作协程能够正确退出,避免资源泄漏
  • 动态调整:根据系统负载和任务队列长度动态调整工作协程数量

10.2 易错点回顾

  • 死锁:工作协程和主协程之间相互等待,导致程序卡住
  • 资源泄漏:工作协程没有正确退出,导致资源无法释放
  • 任务丢失:任务通道关闭过早,导致部分任务未被处理
  • 结果丢失:结果通道关闭过早,导致部分结果未被收集
  • 并发度过高:工作协程数量设置过多,导致系统资源过载
  • 错误处理不当:任务执行过程中的错误没有被妥善处理
  • 动态扩缩容逻辑错误:导致工作协程数量失控

11. 拓展参考资料

11.1 官方文档链接

11.2 进阶学习路径建议

  • 并发编程进阶:深入学习 Go 语言的并发原语和调度器
  • 性能优化:学习如何优化工作池的性能
  • 分布式系统:学习如何在分布式环境中使用工作池模式
  • 容器化:学习如何在容器环境中部署和管理工作池
  • 监控和可观测性:学习如何监控工作池的状态和性能

11.3 相关学习资源