LangGraph的stream_mode到底怎么选?我调了一下午终于搞明白了 原创

发布于 2025-8-22 09:05
浏览
0收藏

最近在重构我们的AI对话系统,从简单的请求-响应模式升级到实时流式处理。过程中发现LangGraph的stream_mode远比文档上写的复杂,今天把对应的实践经验分享出来。

stream_mode到底是什么

简单说,stream_mode就是控制你在流式处理时能拿到什么数据。简单理解就是你的Graph在执行时,每完成一个节点都会产生输出。stream_mode决定你能看到什么:

  • 是看到完整的状态快照?
  • 还是只看变化的部分?
  • 或者只关心LLM的输出?

4种模式

values

这是默认的,每次返回完整的状态。说实话,大部分时候用这个就够了:

from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated
import operator
import time


# 定义状态类型
class GraphState(TypedDict):
    """
    Graph的状态定义

    - messages: 存储对话消息列表
    - step_count: 记录执行步骤数
    - result: 最终结果
    """
    messages: Annotated[list, operator.add]  # 使用operator.add来合并列表
    step_count: int
    result: str


def step_1(state: GraphState) -> GraphState:
    """
    第一个处理步骤:初始化和数据准备
    """
    print("🔄 执行步骤1: 数据准备阶段")
    time.sleep(1)  # 模拟处理时间

    return {
        "messages": ["步骤1: 开始数据准备"],
        "step_count": state.get("step_count", 0) + 1,
        "result": "数据准备完成"
    }


def step_2(state: GraphState) -> GraphState:
    """
    第二个处理步骤:数据处理
    """
    print("🔄 执行步骤2: 数据处理阶段")
    time.sleep(1.5)  # 模拟处理时间

    return {
        "messages": ["步骤2: 正在处理数据"],
        "step_count": state.get("step_count", 0) + 1,
        "result": "数据处理完成,准备分析"
    }


def step_3(state: GraphState) -> GraphState:
    """
    第三个处理步骤:数据分析和生成结果
    """
    print("🔄 执行步骤3: 数据分析阶段")
    time.sleep(2)  # 模拟处理时间

    total_messages = len(state.get("messages", []))

    return {
        "messages": ["步骤3: 分析完成,生成最终结果"],
        "step_count": state.get("step_count", 0) + 1,
        "result": f"分析完成!总共处理了 {total_messages + 1} 条消息,执行了 {state.get('step_count', 0) + 1} 个步骤"
    }


def create_workflow():
    """
    创建LangGraph工作流
    """
    # 创建状态图
    workflow = StateGraph(GraphState)

    # 添加节点
    workflow.add_node("step_1", step_1)
    workflow.add_node("step_2", step_2)
    workflow.add_node("step_3", step_3)

    # 定义边:step_1 -> step_2 -> step_3 -> END
    workflow.set_entry_point("step_1")
    workflow.add_edge("step_1", "step_2")
    workflow.add_edge("step_2", "step_3")
    workflow.add_edge("step_3", END)

    # 编译图
    app = workflow.compile()
    return app


def demo_stream_values():
    """
    演示 stream_mode="values" 的用法

    stream_mode="values" 会返回每个步骤完成后的完整状态值
    这让我们可以实时监控整个graph的状态变化
    """
    print("=" * 60)
    print("🚀 LangGraph Stream Mode Demo - stream_mode='values'")
    print("=" * 60)

    # 创建工作流
    app = create_workflow()

    # 初始状态
    initial_state = {
        "messages": [],
        "step_count": 0,
        "result": ""
    }

    print("\n📊 开始流式执行,实时显示每个步骤的状态变化:")
    print("-" * 60)

    # 使用stream方法,设置stream_mode="values"
    for i, output in enumerate(app.stream(initial_state, stream_mode="values")):
        print(f"\n🔍 步骤 {i} 完成后的状态:")
        print(f"   📝 消息列表: {output.get('messages', [])}")
        print(f"   🔢 步骤计数: {output.get('step_count', 0)}")
        print(f"   ✨ 当前结果: {output.get('result', '')}")
        print("-" * 40)

    print("\n✅ 工作流执行完成!")

我一开始就是用的这个,结果发现数据量特别大。比如我们有个处理报表的流程,状态里存了一个几千行的DataFrame,每个节点都要传输这么大的数据,难怪客户端卡。

updates

后来改成updates模式,立马快了很多:

def demo_update_values():
    """
    对比不同stream_mode的效果
    """
    print("\n" + "=" * 60)
    print("🔍 Stream Mode 对比演示")
    print("=" * 60)

    app = create_workflow()
    initial_state = {"messages": [], "step_count": 0, "result": ""}

    # 演示 stream_mode="updates"
    print("\n🔄 stream_mode='updates' - 只显示每步的更新内容:")
    for output in app.stream(initial_state, stream_mode="updates"):
        for node_name, updates in output.items():
            print(f"{node_name} 更新了: {updates}")

这个模式特别适合生产环境。比如你的状态里有个huge_data字段一直不变,用values模式每次都传,用updates就只传真正变化的部分。

不过要注意,你拿到的是增量更新,需要自己维护完整状态:

# 自己维护状态
current_state = {}
for chunk in app.stream(input_data, stream_mode="updates"):
    for node_name, updates in chunk.items():
        current_state.update(updates)
        # 现在current_state是最新的完整状态

debug

这个模式我只在开发时用,信息特别详细:

for chunk in app.stream(input_data, stream_mode="debug"):
    print(f"Debug info: {chunk}")

会输出类似这样的信息:

  • 节点开始执行
  • 节点执行结束
  • 状态变化
  • 错误信息
  • 执行时间

有一次一个节点莫名其妙执行了两次,就是用debug模式发现的,原来是我的条件边写错了。

messages

如果你在做聊天机器人,这个模式能省很多事.

from typing import TypedDict, List
from langgraph.graph import StateGraph, START
from langchain_openai import ChatOpenAI
from langchain_core.messages import BaseMessage


class SimpleState(TypedDict):
    topic: str
    joke: str
    # 注意:这里没有 messages 字段!

model = ChatOpenAI(model="gpt-4o-mini")

def call_model(state: SimpleState):
    """调用 LLM 生成笑话"""
    # 这里调用了 LLM
    llm_response = model.invoke([
        {"role": "user", "content": f"Generate a joke about {state['topic']}"}
    ])
    # 返回的是 joke 字段,不是 messages
    return {"joke": llm_response.content}

graph1 = (
    StateGraph(SimpleState)
    .add_node("call_model", call_model)
    .add_edge(START, "call_model")
    .compile()
)

# stream_mode="messages" 仍然可以工作!
# 因为它拦截的是 model.invoke() 调用时产生的 tokens
for msg, metadata in graph1.stream({"topic": "cats"}, stream_mode="messages"):
    if msg.content:
        print(msg.content, end="|")
        
# 输出: Why| did| the| cat|...(流式输出)

你可能会觉得State里并没有messages字段,为什么stream_mode="messages" 仍旧能工作呢?这是因为:

当您使用 stream_mode="messages" 时,LangGraph 做了以下事情:

1. **Hook 机制**:
   - LangGraph 在底层使用回调(callbacks)系统
   - 当检测到 stream_mode="messages" 时,它会自动将 LLM 的 invoke 
     方法切换到 stream 模式

2. **事件监听**:
   - 监听所有 LangChain 模型的 on_llm_new_token 事件
   - 这些事件在 LLM 生成 tokens 时触发

3. **数据流**:

用户代码调用 model.invoke() ↓ LangGraph 检测到 stream_mode="messages" ↓ 自动将 invoke 转换为 stream 调用 ↓ 捕获 on_llm_new_token 事件 ↓ 将 tokens 作为 (message_chunk, metadata) 流式返回

4. **独立于 State**:
- stream_mode="messages" 工作在更底层
- 它不关心 State 的结构
- 只要有 LLM 调用,就能捕获 tokens
"""

# stream_mode="messages" 会捕获所有节点中的 LLM 调用
for msg, metadata in graph3.stream(
    {"input_text": "AI development"}, 
    stream_mode="messages"
):
    if msg.content:
        node = metadata.get("langgraph_node", "unknown")
        print(f"[{node}] {msg.content[:20]}...")

不同模式的区别如下:

print("\n不同 stream_mode 的区别:")

# 1. stream_mode="values" - 返回完整的 State
for chunk in graph1.stream({"topic": "cats"}, stream_mode="values"):
    print(f"Values mode - State: {chunk}")
    # 输出: {'topic': 'cats', 'joke': '完整的笑话内容'}

# 2. stream_mode="updates" - 返回 State 的更新
for chunk in graph1.stream({"topic": "dogs"}, stream_mode="updates"):
    print(f"Updates mode - Updates: {chunk}")
    # 输出: {'call_model': {'joke': '完整的笑话内容'}}

# 3. stream_mode="messages" - 返回 LLM tokens
for msg, metadata in graph1.stream({"topic": "birds"}, stream_mode="messages"):
    if msg.content:
        print(f"Messages mode - Token: {msg.content[:10]}...")
        # 输出: 流式的 tokens

消息增强的处理类

class EnhancedMessageProcessor:
    """增强的消息处理器"""
    
    def __init__(self, verbose: bool = True, show_tools: bool = True):
        self.verbose = verbose
        self.show_tools = show_tools
        self.message_buffer = []
        self.tool_calls_buffer = []
        self.current_node = None
        self.stats = {
            "total_messages": 0,
            "ai_messages": 0,
            "tool_messages": 0,
            "total_tokens": 0,
            "tool_calls": 0
        }
    
    def process(self, msg: BaseMessage, metadata: dict) -> None:
        """处理单个消息"""
        self.stats["total_messages"] += 1
        node = metadata.get("langgraph_node", "unknown")
        
        if node != self.current_node:
            if self.current_node:
                self._flush_buffer()
            self.current_node = node
            if self.verbose:
                print(f"\n📍 [{node}]", flush=True)
        
        # 处理不同类型的消息
        if isinstance(msg, AIMessageChunk):
            self._process_ai_chunk(msg, metadata)
        elif isinstance(msg, AIMessage):
            self._process_ai_message(msg, metadata)
        elif isinstance(msg, ToolMessage):
            self._process_tool_message(msg, metadata)
        elif isinstance(msg, HumanMessage):
            self._process_human_message(msg, metadata)
        else:
            self._process_other_message(msg, metadata)
    
    def _process_ai_chunk(self, msg: AIMessageChunk, metadata: dict):
        """处理 AI 消息块"""
        self.stats["ai_messages"] += 1
        
        # 处理文本内容
        if msg.content:
            self.message_buffer.append(msg.content)
            if self.verbose:
                print(msg.content, end="", flush=True)
            self.stats["total_tokens"] += len(msg.content.split())
        
        # 处理工具调用块
        if hasattr(msg, 'tool_call_chunks') and msg.tool_call_chunks:
            for chunk in msg.tool_call_chunks:
                self.tool_calls_buffer.append(chunk)
                if self.verbose and self.show_tools:
                    if chunk.get('name'):
                        print(f"\n🔧 准备调用: {chunk['name']}", end="")
                    if chunk.get('args'):
                        print(f" {chunk['args']}", end="")
        
        # 处理完整的工具调用
        if hasattr(msg, 'tool_calls') and msg.tool_calls:
            self.stats["tool_calls"] += len(msg.tool_calls)
            if self.verbose and self.show_tools:
                print(f"\n📞 工具调用检测到:")
                for tc in msg.tool_calls:
                    print(f"   • {tc['name']}: {tc.get('args', {})}")
    
    def _process_ai_message(self, msg: AIMessage, metadata: dict):
        """处理完整的 AI 消息"""
        if msg.content and self.verbose:
            print(f"\n✅ AI完整响应: {msg.content[:100]}...")
        
        if hasattr(msg, 'tool_calls') and msg.tool_calls and self.show_tools:
            print(f"\n🔨 即将执行工具:")
            for tc in msg.tool_calls:
                print(f"   • {tc['name']}({tc.get('args', {})})")
    
    def _process_tool_message(self, msg: ToolMessage, metadata: dict):
        """处理工具消息"""
        self.stats["tool_messages"] += 1
        if self.verbose and self.show_tools:
            try:
                # 尝试解析 JSON 结果
                result = json.loads(msg.content) if msg.content else {}
                print(f"\n📊 工具结果:")
                for key, value in result.items():
                    print(f"   • {key}: {value}")
            except:
                print(f"\n📊 工具结果: {msg.content}")
    
    def _process_human_message(self, msg: HumanMessage, metadata: dict):
        """处理人类消息"""
        if self.verbose:
            print(f"\n👤 用户: {msg.content}")
    
    def _process_other_message(self, msg: BaseMessage, metadata: dict):
        """处理其他类型消息"""
        if hasattr(msg, 'content') and msg.content and self.verbose:
            print(f"\n📝 {type(msg).__name__}: {msg.content}")
    
    def _flush_buffer(self):
        """清空缓冲区"""
        if self.message_buffer:
            full_message = "".join(self.message_buffer)
            self.message_buffer = []
        
        if self.tool_calls_buffer:
            self.tool_calls_buffer = []
    
    def get_stats(self) -> dict:
        """获取统计信息"""
        return self.stats

实际案例

分享一个真实的优化案例。我们有个数据分析的工作流:

class AnalysisState(TypedDict):
    raw_data: pd.DataFrame  # 原始数据,很大
    processed_data: dict    # 处理后的数据
    summary: str           # 分析总结
    step_info: str         # 当前步骤信息

# 之前的代码(慢)
asyncfor chunk in app.astream(initial_state):  # 默认values模式
    # 每次都传输完整的DataFrame
    print(f"当前步骤: {chunk.get('step_info')}")
    # 客户端:为啥这么卡?

# 优化后(快)
asyncfor chunk in app.astream(initial_state, stream_mode="updates"):
    for node_name, updates in chunk.items():
        # 只传输变化的部分
        if"step_info"in updates:
            print(f"当前步骤: {updates['step_info']}")
        if"summary"in updates:
            print(f"分析结果: {updates['summary']}")

效果立竿见影,传输的数据量少了90%。

选择建议

开发调试阶段:

  • 用debug模式,能看到所有细节
  • 出问题时方便定位

生产环境:

  • 优先用updates模式,性能最好
  • 只有真的需要完整状态时才用values

聊天应用:

  • 直接用messages模式,别自己解析了

性能敏感场景:

  • 一定要用updates
  • 我们测过,数据量大的时候updates比values快3-5倍

模式组合

最后贴个不同模式组合的例子:

for stream_mode, chunk in agent.stream(
            {"messages": [{"role": "user", "content": "book a hotel"}]},
            config,
            stream_mode=["messages", "updates"],
        ):

    if stream_mode == "messages":
        print(chunk)
        if isinstance(chunk, tuple) and len(chunk) == 2:
            message_chunk, metadata = chunk

            if hasattr(message_chunk, 'content') and message_chunk.content:
                print(message_chunk.content, end="", flush=True)
                # messages.append(message_chunk.content)

    elif stream_mode == "updates":
        # Check for interrupt signal in updates
        if isinstance(chunk, dict) and"__interrupt__"in chunk:
            is_interrupted = True
            interrupt_info = chunk["__interrupt__"]
            print(f"\n\n🛑 INTERRUPT DETECTED!")
            print(f"   Info: {interrupt_info}")
            # Don't break - let it finish streaming current content

        # Also check for tool calls that might trigger interrupts
        if isinstance(chunk, dict):
            for key, value in chunk.items():
                if isinstance(value, dict) and"messages"in value:
                    for msg in value.get("messages", []):
                        if hasattr(msg, "tool_calls") and msg.tool_calls:
                            print(f"\n🔧 Tool call detected: {msg.tool_calls[0].get('name', 'unknown')}")

总结

stream_mode这个参数看起来简单,但选对了能省很多事:

  • 别无脑用默认的values,根据场景选择
  • 生产环境首选updates,真的快很多
  • debug只在开发时用
  • messages是给聊天应用的特供

本文转载自​AI 博物院​​​ 作者:longyunfeigu

©著作权归作者所有,如需转载,请注明出处,否则将追究法律责任
收藏
回复
举报
回复
相关推荐