
生产环境跑LangGraph半年了,我整理了这份避坑指南 原创 精华
今年3月开始用LangGraph重构我们的AI系统,到现在已经快6个月了。期间踩了一些坑,有些问题官方文档里根本没提到,今天把这些经验教训整理出来。
先说结论
如果你的系统符合以下任何一个条件,LangGraph可能适合你:
- 需要复杂的多步骤决策流程
- 有明确的状态管理需求
- 需要人工审核关键节点
- 要做多智能体协作
但如果只是简单的单轮对话或者纯粹的RAG,用LangChain就够了,别给自己找麻烦。
状态管理的坑
1. Checkpointer选择决定生死
刚开始组里同事用InMemorySaver做测试,一切正常。上线后服务一重启,所有对话历史全没了。
# ❌ 千万别在生产环境这么干
from langgraph.checkpoint.memory import InMemorySaver
checkpointer = InMemorySaver() # 服务重启就GG
# ✅ 生产环境的正确姿势
from langgraph.checkpoint.postgres import PostgresSaver
from psycopg_pool import AsyncConnectionPool
# 使用连接池,不要每次都创建新连接
asyncdef create_checkpointer():
pool = AsyncConnectionPool(
"postgresql://user:pass@localhost/db",
min_size=10,
max_size=100,
max_idle=300.0, # 连接最大空闲时间
max_lifetime=3600.0# 连接最大生命周期
)
asyncwith pool.connection() as conn:
return PostgresSaver(conn)
2. Thread ID管理
最开始我们用用户ID做thread_id,结果一个用户同时发起多个对话时状态就串了。后来改成UUID,又发现无法追踪用户历史。
最终方案:复合ID策略
import hashlib
from datetime import datetime
class ThreadManager:
@staticmethod
def generate_thread_id(user_id: str, session_type: str = "default"):
"""生成可追踪的thread_id"""
# 格式:用户ID_会话类型_时间戳_短hash
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
unique_str = f"{user_id}_{session_type}_{timestamp}"
short_hash = hashlib.md5(unique_str.encode()).hexdigest()[:8]
returnf"{user_id}_{session_type}_{timestamp}_{short_hash}"
@staticmethod
def parse_thread_id(thread_id: str):
"""解析thread_id获取元信息"""
parts = thread_id.split("_")
return {
"user_id": parts[0],
"session_type": parts[1],
"timestamp": parts[2],
"hash": parts[3]
}
3. 状态版本控制
LangGraph 存储每个 channel 值时都会进行版本控制,这样每个新的 checkpoint 只存储真正变化的值。但如果你的状态结构经常变,会遇到兼容性问题。
from typing import TypedDict, Optional
from pydantic import BaseModel
# 使用版本化的状态定义
class StateV1(TypedDict):
messages: list
context: dict
version: int # 始终包含版本号
class StateV2(TypedDict):
messages: list
context: dict
metadata: dict # V2新增字段
version: int
class StateMigrator:
"""状态迁移器"""
@staticmethod
def migrate(state: dict) -> dict:
version = state.get("version", 1)
if version == 1:
# V1 -> V2迁移
state["metadata"] = {}
state["version"] = 2
# 未来可以继续添加迁移逻辑
return state
@staticmethod
def load_state(thread_id: str, checkpointer):
"""加载并自动迁移状态"""
state = checkpointer.get({"thread_id": thread_id})
if state:
return StateMigrator.migrate(state)
returnNone
错误处理
1. 节点级重试机制
LangGraph 提供了 retry_policy 来重试失败的节点,只有失败的分支会被重试,不用担心重复执行工作。但默认的重试策略太简单了。
from langgraph.types import RetryPolicy
import httpx
from typing import Optional
class SmartRetryPolicy:
"""智能重试策略"""
@staticmethod
def create_policy(node_name: str) -> RetryPolicy:
# 根据节点类型设置不同的重试策略
if"llm"in node_name:
return RetryPolicy(
max_attempts=3,
backoff_factor=2.0,
max_interval=30.0,
retry_on=lambda e: SmartRetryPolicy.should_retry_llm(e)
)
elif"api"in node_name:
return RetryPolicy(
max_attempts=5,
backoff_factor=1.5,
max_interval=60.0,
retry_on=lambda e: SmartRetryPolicy.should_retry_api(e)
)
else:
# 默认策略
return RetryPolicy(max_attempts=2)
@staticmethod
def should_retry_llm(error: Exception) -> bool:
"""LLM调用是否需要重试"""
# 限流错误必须重试
if isinstance(error, httpx.HTTPStatusError):
return error.response.status_code in [429, 502, 503, 504]
# 网络错误重试
if isinstance(error, (httpx.ConnectError, httpx.TimeoutException)):
returnTrue
# 参数错误不重试
if"invalid"in str(error).lower():
returnFalse
returnTrue
@staticmethod
def should_retry_api(error: Exception) -> bool:
"""API调用是否需要重试"""
if isinstance(error, httpx.HTTPStatusError):
# 5xx都重试,429限流也重试
return error.response.status_code >= 500or error.response.status_code == 429
return isinstance(error, (httpx.ConnectError, httpx.TimeoutException))
# 使用示例
builder = StateGraph(State)
builder.add_node(
"llm_node",
process_llm, retry_policy=SmartRetryPolicy.create_policy("llm_node")
)
2. 全局错误恢复
节点重试解决不了所有问题,还需要全局的错误恢复机制:
from langgraph.errors import GraphRecursionError
import asyncio
from typing import Optional
class ResilientGraphRunner:
def __init__(self, graph, checkpointer):
self.graph = graph
self.checkpointer = checkpointer
self.dead_letter_queue = [] # 死信队列
asyncdef run_with_recovery(
self,
input_data: dict,
thread_id: str,
max_recovery_attempts: int = 3
):
"""带恢复机制的图执行"""
attempt = 0
last_error = None
while attempt < max_recovery_attempts:
try:
config = {
"configurable": {
"thread_id": thread_id,
"recursion_limit": 100# 防止无限循环
}
}
# 尝试执行
result = await self.graph.ainvoke(input_data, config)
return result
except GraphRecursionError as e:
# 递归深度超限,可能是死循环
await self.handle_recursion_error(thread_id, e)
break
except Exception as e:
last_error = e
attempt += 1
# 记录错误
await self.log_error(thread_id, e, attempt)
# 尝试从最后一个成功的checkpoint恢复
if attempt < max_recovery_attempts:
await self.recover_from_checkpoint(thread_id)
await asyncio.sleep(2 ** attempt) # 指数退避
# 所有重试都失败,进入死信队列
await self.send_to_dead_letter(thread_id, input_data, last_error)
raise last_error
asyncdef recover_from_checkpoint(self, thread_id: str):
"""从最后一个成功的checkpoint恢复"""
# 获取最后一个成功的状态
checkpoints = self.checkpointer.list(
{"configurable": {"thread_id": thread_id}},
limit=10
)
for checkpoint in checkpoints:
if checkpoint.metadata.get("status") == "success":
# 恢复到这个状态
self.checkpointer.put(
{"configurable": {"thread_id": thread_id}},
checkpoint.checkpoint,
checkpoint.metadata
)
break
3. 工具调用错误处理
工具节点现在会在tool call失败时返回带有error字段的ToolMessages,但默认处理太粗糙:
from langchain_core.messages import ToolMessage, AIMessage
from typing import List, Dict, Any
class SafeToolExecutor:
"""安全的工具执行器"""
def __init__(self, tools: List, fallback_model=None):
self.tools = {tool.name: tool for tool in tools}
self.fallback_model = fallback_model
self.execution_history = [] # 记录执行历史
asyncdef execute_with_fallback(
self,
tool_calls: List[Dict[str, Any]],
state: Dict
) -> List[ToolMessage]:
"""执行工具调用,失败时有降级策略"""
results = []
for tool_call in tool_calls:
tool_name = tool_call.get("name")
tool_args = tool_call.get("args", {})
# 验证工具是否存在
if tool_name notin self.tools:
results.append(ToolMessage(
content=f"Tool {tool_name} not found",
tool_call_id=tool_call.get("id"),
additional_kwargs={"error": "ToolNotFound"}
))
continue
# 执行工具
try:
result = await self.execute_single_tool(
tool_name,
tool_args,
state
)
results.append(ToolMessage(
content=str(result),
tool_call_id=tool_call.get("id")
))
except Exception as e:
# 记录错误
self.execution_history.append({
"tool": tool_name,
"args": tool_args,
"error": str(e),
"timestamp": datetime.now()
})
# 尝试降级策略
fallback_result = await self.try_fallback(
tool_name,
tool_args,
e,
state
)
results.append(ToolMessage(
content=fallback_result,
tool_call_id=tool_call.get("id"),
additional_kwargs={
"error": str(e),
"fallback_used": True
}
))
return results
asyncdef try_fallback(
self,
tool_name: str,
args: dict,
error: Exception,
state: dict
) -> str:
"""降级策略"""
# 策略1:使用备用工具
backup_tool = self.get_backup_tool(tool_name)
if backup_tool:
try:
returnawait backup_tool.arun(**args)
except:
pass
# 策略2:使用LLM模拟
if self.fallback_model:
prompt = f"""
工具 {tool_name} 执行失败。
参数:{args}
错误:{error}
请基于当前上下文提供一个合理的替代回答。
上下文:{state.get('context', '')}
"""
returnawait self.fallback_model.ainvoke(prompt)
# 策略3:返回有意义的错误信息
returnf"工具执行失败,请尝试其他方式:{error}"
性能优化
1. 并行执行的正确姿势
很多人不知道LangGraph支持自动并行执行:
from langgraph.graph import StateGraph, START, END
from typing import Literal
class OptimizedGraph:
@staticmethod
def build_parallel_graph():
builder = StateGraph(State)
# 这些节点会自动并行执行!
def route_parallel(state) -> List[str]:
"""返回多个节点名,它们会并行执行"""
tasks = []
if state.get("need_search"):
tasks.append("search_node")
if state.get("need_calculation"):
tasks.append("calc_node")
if state.get("need_validation"):
tasks.append("validate_node")
return tasks if tasks else ["default_node"]
# 添加条件边实现并行
builder.add_conditional_edges(
START,
route_parallel,
# 这些节点会并行执行
["search_node", "calc_node", "validate_node", "default_node"]
)
# Fan-in:所有并行节点完成后汇总
builder.add_edge(["search_node", "calc_node", "validate_node"], "aggregate_node")
return builder.compile()
2. 节点缓存机制
LangGraph 现在支持节点级缓存,可以缓存单个节点的结果,减少重复计算并加速执行:
from functools import lru_cache
import hashlib
import pickle
class NodeCache:
"""节点级缓存"""
def __init__(self, redis_client=None):
self.redis = redis_client
self.local_cache = {} # 本地缓存作为一级缓存
def cache_key(self, node_name: str, state: dict) -> str:
"""生成缓存键"""
# 只用关键字段生成key,忽略无关字段
relevant_fields = self.get_relevant_fields(node_name)
cache_data = {k: state.get(k) for k in relevant_fields}
# 生成稳定的hash
data_str = pickle.dumps(cache_data, protocol=pickle.HIGHEST_PROTOCOL)
returnf"node:{node_name}:{hashlib.md5(data_str).hexdigest()}"
def get_relevant_fields(self, node_name: str) -> List[str]:
"""获取节点相关的状态字段"""
# 不同节点关注不同字段
field_map = {
"search_node": ["query", "filters"],
"llm_node": ["messages", "temperature"],
"calc_node": ["formula", "variables"]
}
return field_map.get(node_name, ["messages"])
asyncdef get_or_compute(
self,
node_name: str,
state: dict,
compute_func,
ttl: int = 3600
):
"""获取缓存或计算"""
cache_key = self.cache_key(node_name, state)
# 一级缓存:内存
if cache_key in self.local_cache:
return self.local_cache[cache_key]
# 二级缓存:Redis
if self.redis:
cached = await self.redis.get(cache_key)
if cached:
result = pickle.loads(cached)
self.local_cache[cache_key] = result
return result
# 计算并缓存
result = await compute_func(state)
# 写入缓存
self.local_cache[cache_key] = result
if self.redis:
await self.redis.set(
cache_key,
pickle.dumps(result),
expire=ttl
)
return result
# 使用缓存装饰器
def cached_node(ttl=3600):
def decorator(func):
asyncdef wrapper(state: dict, cache: NodeCache):
returnawait cache.get_or_compute(
func.__name__,
state,
func,
ttl
)
return wrapper
return decorator
@cached_node(ttl=7200)
asyncdef expensive_search_node(state: dict):
"""昂贵的搜索操作,结果会被缓存"""
# 实际的搜索逻辑
results = await perform_search(state["query"])
return {"search_results": results}
3. 流式输出优化
前面提到的stream_mode选择很重要,但还有其他优化点:
class StreamOptimizer:
"""流式输出优化器"""
@staticmethod
asyncdef optimized_stream(graph, input_data, config):
"""优化的流式处理"""
# 使用updates模式减少传输量
asyncfor chunk in graph.astream(
input_data,
config,
stream_mode="updates"
):
# 只处理真正需要的更新
for node_name, updates in chunk.items():
# 过滤掉内部状态更新
filtered_updates = StreamOptimizer.filter_updates(updates)
if filtered_updates:
# 压缩大对象
compressed = StreamOptimizer.compress_if_needed(filtered_updates)
yield node_name, compressed
@staticmethod
def filter_updates(updates: dict) -> dict:
"""过滤不必要的更新"""
# 这些字段不需要传给客户端
internal_fields = [
"_raw_response",
"_checkpoint_data",
"_debug_info",
"tool_calls"# 客户端通常不需要看到具体的工具调用
]
return {
k: v for k, v in updates.items()
if k notin internal_fields andnot k.startswith("_")
}
@staticmethod
def compress_if_needed(data: dict) -> dict:
"""压缩大对象"""
import sys
import gzip
import base64
for key, value in data.items():
# 超过10KB的字符串进行压缩
if isinstance(value, str) and sys.getsizeof(value) > 10240:
compressed = gzip.compress(value.encode())
data[key] = {
"compressed": True,
"data": base64.b64encode(compressed).decode()
}
# 大列表只传摘要
elif isinstance(value, list) and len(value) > 100:
data[key] = {
"summary": f"List with {len(value)} items",
"preview": value[:10], # 只传前10个
"total": len(value)
}
return data
生产部署的细节
1. 多环境配置管理
from enum import Enum
from pydantic import BaseSettings
class Environment(Enum):
DEV = "dev"
STAGING = "staging"
PROD = "prod"
class LangGraphConfig(BaseSettings):
"""配置管理"""
environment: Environment
# 数据库配置
postgres_url: str
postgres_pool_size: int = 20
# Redis配置
redis_url: str
redis_pool_size: int = 50
# LLM配置
openai_api_key: str
openai_timeout: int = 30
openai_max_retries: int = 3
# Graph配置
max_recursion_depth: int = 100
default_thread_ttl: int = 86400# 24小时
# 监控配置
enable_tracing: bool = True
langsmith_api_key: str = None
class Config:
env_file = f".env.{Environment.PROD.value}"
def get_checkpointer_config(self):
"""根据环境返回不同的checkpointer配置"""
if self.environment == Environment.DEV:
# 开发环境用内存
return {"type": "memory"}
elif self.environment == Environment.STAGING:
# 测试环境用SQLite
return {
"type": "sqlite",
"path": "checkpoints.db"
}
else:
# 生产环境用PostgreSQL
return {
"type": "postgres",
"url": self.postgres_url,
"pool_size": self.postgres_pool_size
}
2. 监控和可观测性
from dataclasses import dataclass
from datetime import datetime
import json
@dataclass
class GraphMetrics:
"""图执行指标"""
thread_id: str
start_time: datetime
end_time: datetime
total_nodes_executed: int
failed_nodes: List[str]
retry_count: int
total_tokens: int
total_cost: float
class MetricsCollector:
"""指标收集器"""
def __init__(self, prometheus_client=None):
self.prometheus = prometheus_client
self.metrics_buffer = []
asyncdef track_node_execution(self, node_name: str, duration: float, success: bool):
"""追踪节点执行"""
if self.prometheus:
self.prometheus.histogram(
"langgraph_node_duration",
duration,
labels={"node": node_name, "success": str(success)}
)
self.prometheus.increment(
"langgraph_node_executions",
labels={"node": node_name, "status": "success"if success else"failure"}
)
asyncdef track_graph_execution(self, metrics: GraphMetrics):
"""追踪整个图的执行"""
# 发送到监控系统
if self.prometheus:
duration = (metrics.end_time - metrics.start_time).total_seconds()
self.prometheus.histogram(
"langgraph_graph_duration",
duration
)
self.prometheus.gauge(
"langgraph_graph_cost",
metrics.total_cost
)
# 存储详细日志用于分析
self.metrics_buffer.append(metrics)
# 定期批量写入
if len(self.metrics_buffer) >= 100:
await self.flush_metrics()
asyncdef flush_metrics(self):
"""批量写入指标"""
ifnot self.metrics_buffer:
return
# 写入数据仓库或日志系统
batch_data = [
json.dumps(m.__dict__, default=str)
for m in self.metrics_buffer
]
# 实际写入逻辑
await write_to_datawarehouse(batch_data)
self.metrics_buffer.clear()
3. 负载均衡和扩展
class GraphPoolManager:
"""图实例池管理"""
def __init__(self, min_instances=2, max_instances=10):
self.min_instances = min_instances
self.max_instances = max_instances
self.instances = []
self.current_index = 0
asyncdef get_instance(self) -> CompiledGraph:
"""轮询获取图实例"""
ifnot self.instances:
await self.initialize_pool()
# 简单轮询
instance = self.instances[self.current_index]
self.current_index = (self.current_index + 1) % len(self.instances)
return instance
asyncdef scale_based_on_load(self, current_qps: float):
"""基于负载动态扩缩容"""
target_instances = self.calculate_target_instances(current_qps)
current_count = len(self.instances)
if target_instances > current_count:
# 扩容
for _ in range(target_instances - current_count):
self.instances.append(await self.create_instance())
elif target_instances < current_count:
# 缩容
excess = current_count - target_instances
for _ in range(excess):
instance = self.instances.pop()
await self.destroy_instance(instance)
def calculate_target_instances(self, qps: float) -> int:
"""计算需要的实例数"""
# 每个实例处理100 QPS
target = int(qps / 100) + 1
return max(self.min_instances, min(target, self.max_instances))
踩坑总结
必须记住的点
- Checkpointer是必需的- 创建自定义checkpoint saver时,考虑实现异步版本以避免阻塞主线程
- 合理设置递归限制- 默认的递归限制可能不够,但设太高会导致死循环难以发现
- 工具错误要优雅处理- 工具调用失败是常态,不是异常
- 状态更新要原子化- 并行节点更新同一个字段会有竞态条件
- 监控要从第一天开始- 不要等出问题了才加监控
什么时候不该用LangGraph
- 简单的问答系统 - 直接用LangChain
- 纯流式生成 - 用Streaming API就够了
- 无状态的API调用 - FastAPI更合适
- 极度延迟敏感的场景 - 图遍历有开销
总结
LangGraph确实强大,但它不是银弹。用对了地方能让你的系统脱胎换骨,用错了地方就是过度设计。
最重要的经验:从简单开始,逐步复杂化。别一上来就搞20个节点的复杂图,先从3-5个节点开始,跑稳定了再加功能。
还有,LangGraph的graph-based架构确实提供了很大的灵活性,可以从完全开放的agent到完全确定的流程。
如果你也在用LangGraph,欢迎交流踩坑经验。生产环境的坑,只有真正跑过的人才知道有多深。
文中的代码都是从生产代码简化而来,直接复制可能需要调整。关键是理解思路,而不是照抄代码。
本文转载自AI 博物院 作者:longyunfeigu
