On this page

Python 多线程

Python 多线程编程全面指南

Python通过threading模块提供了多线程支持,允许程序同时执行多个任务。以下是Python多线程的详细说明。

1. 线程基础

创建线程

方法1:使用函数

import threading
import time

def task(name, delay):
    print(f"线程 {name} 开始")
    time.sleep(delay)
    print(f"线程 {name} 结束")

# 创建线程
t1 = threading.Thread(target=task, args=("A", 2))
t2 = threading.Thread(target=task, args=("B", 1))

# 启动线程
t1.start()
t2.start()

# 等待线程完成
t1.join()
t2.join()

print("所有线程执行完毕")

方法2:继承Thread类

class MyThread(threading.Thread):
    def __init__(self, name, delay):
        super().__init__()
        self.name = name
        self.delay = delay
    
    def run(self):
        print(f"线程 {self.name} 开始")
        time.sleep(self.delay)
        print(f"线程 {self.name} 结束")

# 使用自定义线程类
t1 = MyThread("A", 2)
t2 = MyThread("B", 1)

t1.start()
t2.start()

t1.join()
t2.join()

2. 线程同步

锁(Lock)

shared_counter = 0
lock = threading.Lock()

def increment_counter():
    global shared_counter
    for _ in range(100000):
        with lock:  # 自动获取和释放锁
            shared_counter += 1

threads = []
for _ in range(5):
    t = threading.Thread(target=increment_counter)
    threads.append(t)
    t.start()

for t in threads:
    t.join()

print(f"最终计数器值: {shared_counter}")  # 应该是500000

可重入锁(RLock)

rlock = threading.RLock()

def recursive_function(count):
    with rlock:  # 同一个线程可以多次获取
        if count > 0:
            print(f"进入层级 {count}")
            recursive_function(count - 1)
            print(f"退出层级 {count}")

t = threading.Thread(target=recursive_function, args=(3,))
t.start()
t.join()

条件变量(Condition)

from collections import deque

buffer = deque(maxlen=10)
condition = threading.Condition()

def producer():
    for i in range(20):
        with condition:
            if len(buffer) == buffer.maxlen:
                print("缓冲区满,生产者等待")
                condition.wait()
            buffer.append(i)
            print(f"生产: {i}")
            condition.notify_all()
        time.sleep(0.1)

def consumer():
    for _ in range(20):
        with condition:
            if not buffer:
                print("缓冲区空,消费者等待")
                condition.wait()
            item = buffer.popleft()
            print(f"消费: {item}")
            condition.notify_all()
        time.sleep(0.2)

producer_thread = threading.Thread(target=producer)
consumer_thread = threading.Thread(target=consumer)

producer_thread.start()
consumer_thread.start()

producer_thread.join()
consumer_thread.join()

信号量(Semaphore)

semaphore = threading.Semaphore(3)  # 最多3个线程同时访问

def limited_resource(user):
    with semaphore:
        print(f"{user} 正在使用资源")
        time.sleep(1)
        print(f"{user} 释放资源")

for i in range(10):
    threading.Thread(target=limited_resource, args=(f"User-{i}",)).start()

事件(Event)

event = threading.Event()

def waiter():
    print("等待事件触发")
    event.wait()  # 阻塞直到事件被设置
    print("事件已触发,继续执行")

def setter():
    time.sleep(3)
    print("设置事件")
    event.set()  # 触发事件

t1 = threading.Thread(target=waiter)
t2 = threading.Thread(target=setter)

t1.start()
t2.start()

t1.join()
t2.join()

3. 线程间通信

队列(Queue)

from queue import Queue

q = Queue(maxsize=5)

def producer():
    for i in range(10):
        q.put(i)
        print(f"生产: {i}")
        time.sleep(0.1)

def consumer():
    while True:
        item = q.get()
        if item is None:  # 哨兵值,结束循环
            break
        print(f"消费: {item}")
        q.task_done()

producer_thread = threading.Thread(target=producer)
consumer_thread = threading.Thread(target=consumer)

producer_thread.start()
consumer_thread.start()

producer_thread.join()
q.put(None)  # 发送结束信号
consumer_thread.join()

4. 线程局部数据

thread_local = threading.local()

def show_thread_data():
    print(f"{threading.current_thread().name}: {thread_local.data}")

def thread_function(value):
    thread_local.data = value
    show_thread_data()

threads = []
for i in range(3):
    t = threading.Thread(target=thread_function, args=(i,), name=f"Thread-{i}")
    threads.append(t)
    t.start()

for t in threads:
    t.join()

5. 线程池

使用concurrent.futures

from concurrent.futures import ThreadPoolExecutor
import urllib.request

urls = [
    'https://www.python.org',
    'https://www.google.com',
    'https://www.github.com'
]

def fetch_url(url):
    with urllib.request.urlopen(url) as response:
        return f"{url}: {response.getcode()}, {len(response.read())} bytes"

with ThreadPoolExecutor(max_workers=3) as executor:
    # 提交任务
    future_to_url = {executor.submit(fetch_url, url): url for url in urls}
    
    # 获取结果
    for future in concurrent.futures.as_completed(future_to_url):
        url = future_to_url[future]
        try:
            data = future.result()
            print(data)
        except Exception as e:
            print(f"{url} 生成异常: {e}")

6. 定时器线程

def hello():
    print("Hello, 定时器!")

# 5秒后执行
timer = threading.Timer(5.0, hello)
timer.start()

# 可以取消
# timer.cancel()

7. 守护线程

def daemon_task():
    while True:
        print("守护线程运行中...")
        time.sleep(1)

# 创建守护线程
daemon = threading.Thread(target=daemon_task, daemon=True)
daemon.start()

# 主线程工作
print("主线程开始工作")
time.sleep(3)
print("主线程结束")
# 程序退出,守护线程自动终止

8. GIL(全局解释器锁)的影响

Python的GIL限制:

  • 同一时间只有一个线程执行Python字节码
  • I/O密集型任务仍可从多线程受益
  • CPU密集型任务考虑多进程(multiprocessing模块)

I/O密集型 vs CPU密集型

def io_bound_task():
    time.sleep(1)  # 模拟I/O操作

def cpu_bound_task():
    sum(x * x for x in range(10**7))  # 大量计算

# 测试多线程对I/O密集型任务的加速
start = time.time()
threads = [threading.Thread(target=io_bound_task) for _ in range(5)]
for t in threads: t.start()
for t in threads: t.join()
print(f"I/O密集型任务耗时: {time.time() - start:.2f}秒")

# 测试多线程对CPU密集型任务的加速
start = time.time()
threads = [threading.Thread(target=cpu_bound_task) for _ in range(5)]
for t in threads: t.start()
for t in threads: t.join()
print(f"CPU密集型任务耗时: {time.time() - start:.2f}秒")

9. 线程安全数据结构

使用queue模块

from queue import Queue, LifoQueue, PriorityQueue

# 先进先出队列
q = Queue()
q.put(1)
q.put(2)
print(q.get())  # 1

# 后进先出队列
lq = LifoQueue()
lq.put(1)
lq.put(2)
print(lq.get())  # 2

# 优先队列
pq = PriorityQueue()
pq.put((3, "低优先级"))
pq.put((1, "高优先级"))
pq.put((2, "中优先级"))
print(pq.get()[1])  # "高优先级"

10. 实际应用示例

多线程Web爬虫

import requests
from queue import Queue
from threading import Thread

NUM_WORKERS = 5
url_queue = Queue()
results = []

def crawl():
    while True:
        url = url_queue.get()
        if url is None:  # 结束信号
            break
        try:
            response = requests.get(url, timeout=3)
            results.append((url, response.status_code, len(response.text)))
            print(f"爬取 {url} 成功,状态码: {response.status_code}")
        except Exception as e:
            print(f"爬取 {url} 失败: {e}")
        url_queue.task_done()

# 添加URL到队列
urls = [
    'https://www.python.org',
    'https://www.google.com',
    'https://www.github.com',
    'https://www.example.com',
    'https://www.stackoverflow.com'
]
for url in urls:
    url_queue.put(url)

# 启动工作线程
workers = []
for _ in range(NUM_WORKERS):
    t = Thread(target=crawl)
    t.start()
    workers.append(t)

# 等待队列清空
url_queue.join()

# 停止工作线程
for _ in range(NUM_WORKERS):
    url_queue.put(None)
for t in workers:
    t.join()

# 输出结果
print("\n爬取结果:")
for result in results:
    print(result)

多线程端口扫描器

import socket
from queue import Queue
import threading

target = "localhost"
port_queue = Queue()
open_ports = []

def port_scan(port):
    try:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(1)
        result = sock.connect_ex((target, port))
        if result == 0:
            open_ports.append(port)
        sock.close()
    except socket.error:
        pass

def worker():
    while True:
        port = port_queue.get()
        if port is None:
            break
        port_scan(port)
        port_queue.task_done()

# 将端口放入队列
for port in range(1, 1025):
    port_queue.put(port)

# 创建并启动线程
threads = []
for _ in range(100):  # 100个工作线程
    t = threading.Thread(target=worker)
    t.start()
    threads.append(t)

# 等待队列清空
port_queue.join()

# 停止工作线程
for _ in range(100):
    port_queue.put(None)
for t in threads:
    t.join()

# 输出结果
print(f"开放端口: {sorted(open_ports)}")

11. 常见问题与解决方案

1. 线程不安全的操作

# 错误示例
shared_list = []

def unsafe_append():
    for i in range(1000):
        shared_list.append(i)

threads = [threading.Thread(target=unsafe_append) for _ in range(5)]
for t in threads: t.start()
for t in threads: t.join()

print(len(shared_list))  # 可能小于5000

# 解决方案: 使用锁保护共享资源
lock = threading.Lock()
shared_list_safe = []

def safe_append():
    for i in range(1000):
        with lock:
            shared_list_safe.append(i)

threads = [threading.Thread(target=safe_append) for _ in range(5)]
for t in threads: t.start()
for t in threads: t.join()

print(len(shared_list_safe))  # 保证是5000

2. 死锁

# 错误示例
lock1 = threading.Lock()
lock2 = threading.Lock()

def thread_a():
    with lock1:
        time.sleep(1)
        with lock2:  # 可能被阻塞
            print("线程A完成")

def thread_b():
    with lock2:
        time.sleep(1)
        with lock1:  # 可能被阻塞
            print("线程B完成")

# 解决方案: 
# 1. 按固定顺序获取锁
# 2. 使用RLock
# 3. 设置超时
def safe_thread_a():
    with lock1:
        time.sleep(1)
        if lock2.acquire(timeout=2):  # 设置超时
            try:
                print("线程A完成")
            finally:
                lock2.release()
        else:
            print("线程A获取锁2超时")

3. 线程泄漏

# 错误示例
def leak_threads():
    for _ in range(100):
        t = threading.Thread(target=time.sleep, args=(10,))
        t.start()  # 创建后没有管理

# 解决方案: 使用线程池或管理线程生命周期
def safe_thread_management():
    threads = []
    for _ in range(100):
        t = threading.Thread(target=time.sleep, args=(10,))
        t.start()
        threads.append(t)
    
    for t in threads:
        t.join()  # 等待所有线程完成

12. 总结

Python多线程编程要点:

主题关键点
创建线程Thread类, target参数
线程同步Lock, RLock, Condition, Semaphore, Event
线程通信Queue线程安全队列
线程池ThreadPoolExecutor高效管理
GIL影响I/O密集型适用,CPU密集型考虑多进程
实际应用爬虫、端口扫描等并发任务

使用多线程时需要注意:

  1. 共享资源的线程安全
  2. 避免死锁
  3. 合理控制线程数量
  4. 考虑GIL对性能的影响
  5. 使用高级工具如线程池简化管理

对于更复杂的并发需求,可以结合asynciomultiprocessing模块使用。