告别偏科,能玩转多模态、多任务、多领域的强化智能体终于来了
随着 Llama 3 发布,未来大模型的参数量已飙升至惊人的 4000 亿。尽管每周几乎都有一个声称性能超强的大模型出来炸场,但 AI 应用还在等待属于它们的「ChatGPT 时刻」。其中,AI 智能体无疑是最被看好的赛道。
就连吴恩达都说,GPT-4 加上 AI 智能体,可能提前达到 GPT-5 的效果。
不过,我们熟知的智能体往往有点「偏科」。例如,第一个 AI 软件工程师 Devin,专精于代码。会打游戏的智能体往往也只能在某一个游戏里秀操作。寻找一个能够同时擅长多个领域,并能在其中无缝切换的通用模型仍是机器学习研究中的一个关键目标。
为了解决这个问题,研究者们对于智能体如何结合计算机视觉(CV)和自然语言处理(NLP)任务进行了广泛探索,但将强化学习(RL)任务整合进来的研究相对较少。这是由于 RL 任务本质上是异质的,这使得将 RL 任务与对话和图像识别等其他任务结合起来更加困难。这要求智能体能融会贯通不同领域任务中的不同模态、任务复杂性和数据类型。要达到全能型智能体,主要需要解决以下问题:(1)如何设计一个能够处理多种数据类型和模态的统一模型结构?(2)如何有效地平衡不同任务的学习进度和优先级?(3)如何确保智能体制定合适的学习目标,以避免不同任务之间的干扰和负向迁移?
来自 Hugging Face、法国国家信息与自动化研究所(INRIA)和波尔多大学的四位研究者提出了智能体中的「六边形战士」——Jack of All Trades (JAT)。JAT 是一个基于 Transformer 的多模态通用强化学习智能体框架。在此框架下,智能体能够通过同一套参数应对不同复杂度的多种任务,化身既会打游戏,又能控制机器人的全能高手。论文同时发布了大量 RL 智能体与 JAT 数据集。这是首个用于通用智能体训练的数据集 JAT 数据集,包含了由专家智能体收集的数十万条轨迹。
- 论文名称:《Jack of All Trades, Master of Some, a Multi-Purpose Transformer Agent》
- 论文链接:https://huggingface.co/papers/2402.09844
- 代码链接:https://github.com/huggingface/jat
- 项目链接:https://huggingface.co/jat-project/jat
- 数据集:https://huggingface.co/datasets/jat-project/jat-dataset
模型架构
JAT 的核心结构基于 Transformer,使用了 EleutherAI 的 GPT-Neo 实现。JAT 最大的创新点在于其嵌入机制,从本质上解决了数据类型不同的问题。JAT 模型将观察嵌入与其对应的奖励值和动作嵌入交错排列,形成一个序列。
图 1.JAT 网络架构。对于序列中的决策任务,一方面输入观察嵌入与奖励值,另一方面行动嵌入被编码并被交错放置。模型使用因果掩码自回归地生成下一个嵌入,并根据预期的模态进行解码。
因此,每个嵌入要么对应一个与奖励相关联的观察嵌入,要么对应一个动作嵌入。JAT 如何进一步对这些信息进行编码呢?这要取决于数据的类型。如果观察嵌入或动作嵌入的数据类型是图像,那么 JAT 将使用 CNN。如果是连续向量,则使用线性层。如果是离散值,则使用线性投影层。模型的输出也遵循相同的逻辑,具体取决于预测目标的数据类型。预测基于因果推理进行,将观察嵌入向后移动一个时间步,确保智能体可以根据所有先前的观察和动作嵌入来预测下一个动作嵌入。
这种嵌入设计让研究团队在训练智能体执行 NLP 和 CV 任务时兴致盎然。对于和文本相关的任务,作者让 JAT 模型采用 GPT-2 的分词策略,将文本转换为一个整数序列,然后通过一个查找表映射到一个嵌入向量序列。对于和图像有关的任务,JAT 模型将选择 ViT 方法,将图像切割成小块后,通过线性层转换为嵌入向量序列。JAT 模型再将图像和文本的向量序列拼接在一起,形成一个统一的序列,输入到 Transformer 中。
考虑到数据的模态变来变去,JAT 如何计算损失函数呢?它将针对每种模态分别计算 loss。对于图像和连续值,它使用均方误差(MSE)损失。对于离散值,它使用交叉熵损失。最终的损失是序列中每种元素损失的平均值。那么,这是否意味着 JAT 在预测动作嵌入和观察嵌入时的权重是相同的呢?实际上不是,在此后的章节中将一步探讨这个问题。
实验结果
研究团队共采用了 157 个训练任务来 JAT 评估。他们将这些任务分为 10 类,并记录了 JAT 的总奖励值。
JAT 模型在最终的检查点上达到了 65.8% 的专家得分,说明 JAT 能够在非常广泛的任务上达到专家水平。以下具体列出了 JAT 在四个常见的智能体训练环境中的得分:
- 对于 Atari 57,应用 JAT 模型的智能体实现了专家分数的 14.1%,这相当于人类表现的 37.6%。Atari 视频游戏广泛被用作评估和开发强化学习算法的基准环境,其中《吃豆人》是一款标志性游戏。在这一系列的 21 款游戏中,JAT 智能体的表现已经超越了人类玩家。值得注意的是, JAT 只用了单一网络就在所有 Atari 视频游戏中达到了这种水平;
- 对于 BabyAI,应用 JAT 模型的智能体达到了专家分数的 99.0%,只有一个任务的表现未能超过专家水平的 50%;
- 对于 Meta-World,应用 JAT 模型的智能体达到了专家分数的 65.5%;
- 对于 MuJoCo,应用 JAT 模型的智能体达到了专家分数的 84.8%。
JAT 智能体在 Atari 57 基线上和人类表现的对比
这些 JAT 智能体都可以通过项目主页下载,进一步测试和体验。更多细节请参阅论文原文。
专家智能体和 JAT 数据集
专家策略
传统的强化学习往往在单一环境中寻找专家策略,即在一个特定任务中寻找让模型表现最优的方法。构建跨领域的多功能智能体,也离不开这种方法。论文作者选择了 Atari、BabyAI、Meta-World 和 MuJoCo 一系列性质不同,难度各异的训练环境,直到训练出表现最好的智能体。这一系列采用 JAT 框架的专家智能体已经在项目主页上发布。
JAT 数据集
论文作者随论文同步发布了 JAT 数据集,这是首个针对通用智能体训练的专项数据集。其中包含了数十万条由上述专家智能体收集的轨迹数据。使用起来也很方便,可以像加载 Hugging Face 平台上的其他数据集一样简单。以下是调用代码示例:
JAT 数据集不仅包含强化学习的数据,还整合了来自维基百科等文本数据集,以及 Oscar、OK-VQA、Conceptual Captions 等针对视觉任务的数据集,提供了更丰富的数据类型选择。
增加模型预测观察嵌入的能力
智能体学得更好更快了
在训练强化学习智能体时,主要目标是使其在未曾遇到的任务中实现奖励最大化。然而,如果要求智能体预测未来可能遇到的情境,这一额外任务会促进还是阻碍其学习过程呢?
关于这个问题存在两种相反的观点。一方面,学会预判可能会让智能体对环境有更深入的理解,从而学得更好更快。另一方面,这可能会分散智能体对其主要目标的注意力,导致在预测观察嵌入和行动嵌入时都表现平庸。
为了得到问题的答案,论文作者进行了一个实验,使用了一个结合了观察损失和行动损失的损失函数,并通过权重参数 k 来平衡这两种损失。
研究团队在 95% 的置信区间内,针对选定任务,测量了预判将如何影响模型学习。每项任务进行了 100 次评估,基于这些评估得到了 k 值的范围。结果表明,适当选择 k 值可以显著提升智能体的表现。
当 k 值过高(高于 0.5)时,预测观察嵌入的额外任务阻碍了学习过程。但当 k 值较低时,对学习的影响可以忽略不计,且智能体的表现与没有额外预判任务时的表现相似。
研究团队发现,当 k=0.005 时,存在一个最佳临界点。这意味着,只要平衡得当,为智能体增加预测观察嵌入的任务,实际上可以提高智能体的学习效率。这一发现对于设计类似的智能体具有重要意义,突显了辅助目标在提升智能体学习效率方面的潜在价值。
未来展望
JAT 项目为通用智能体研究领域开辟了全新的方向。研究团队表示目前只是初步探索,以下几点思路可供未来研究者深入挖掘:
改进数据的质量:尽管填补了之前少有通用智能体训练数据集的空缺,JAT 数据集仍处于初级阶段。其中的专家轨迹仅来自每个环境中的一名专家智能体,这可能导致一些误差。虽然研究团队已尽力让智能体达到最优表现,但某些环境仍具挑战性。在这些环境中,智能体仍有很大进步空间。收集到更多数据,训练更多的专家智能体,将在很大程度上解决这些问题。
使用离线强化学习:JAT 智能体是仿照基线一比一地训练出来的。这意味着,其一,智能体无法利用次优的轨迹;其二,JAT 智能体无法超越专家。论文选择了这种方法是因为它比较简单,但研究团队相信,使用离线强化学习可以提高智能体的性能,同时,实现起来也不会过于复杂。
发挥更智能的多任务采样策略的全部潜力:目前,JAT 智能体均匀地从所有任务中采样数据,但这种方法可能限制了它的全部潜力。通过动态调整采样率,专注于最具挑战性的任务,或许也可以加速智能体的学习过程,并解锁显著的性能提升。
本文转自 机器之心 ,作者:机器之心