面向长代码序列的 Transformer 模型优化方法,提升长代码场景性能

人工智能 新闻
SASA方法将self-attention的计算稀疏化,同时结合了代码的结构特性,从而提升了长序列任务的性能,也降低了内存和计算复杂度。

阿里云机器学习平台PAI与华东师范大学高明教授团队合作在SIGIR2022上发表了结构感知的稀疏注意力Transformer模型SASA,这是面向长代码序列的Transformer模型优化方法,致力于提升长代码场景下的效果和性能。由于self-attention模块的复杂度随序列长度呈次方增长,多数编程预训练语言模型(Programming-based Pretrained Language Models, PPLM)采用序列截断的方式处理代码序列。SASA方法将self-attention的计算稀疏化,同时结合了代码的结构特性,从而提升了长序列任务的性能,也降低了内存和计算复杂度。

论文:Tingting Liu, Chengyu Wang, Cen Chen, Ming Gao, and Aoying Zhou. Understanding Long Programming Languages with Structure-Aware Sparse Attention. SIGIR 2022

模型框架

下图展示了SASA的整体框架:

其中,SASA主要包含两个阶段:预处理阶段和Sparse Transformer训练阶段。在预处理阶段得到两个token之间的交互矩阵,一个是top-k frequency矩阵,一个是AST pattern矩阵。Top-k frequency矩阵是利用代码预训练语言模型在CodeSearchNet语料上学习token之间的attention交互频率,AST pattern矩阵是解析代码的抽象语法树(Abstract Syntax Tree,AST ),根据语法树的连接关系得到token之间的交互信息。Sparse Transformer训练阶段以Transformer Encoder作为基础框架,将full self-attention替换为structure-aware sparse self-attention,在符合特定模式的token pair之间进行attention计算,从而降低计算复杂度。

SASA稀疏注意力一共包括如下四个模块:

  • Sliding window attention:仅在滑动窗口内的token之间计算self-attention,保留局部上下文的特征,计算复杂度为,为序列长度,是滑动窗口大小。
  • Global attention:设置一定的global token,这些token将与序列中所有token进行attention计算,从而获取序列的全局信息,计算复杂度为,为global token个数。
  • Top-k sparse attention:Transformer模型中的attention交互是稀疏且长尾的,对于每个token,仅与其attention交互最高的top-k个token计算attention,复杂度为。
  • AST-aware structure attention:代码不同于自然语言序列,有更强的结构特性,通过将代码解析成抽象语法树(AST),然后根据语法树中的连接关系确定attention计算的范围。

为了适应现代硬件的并行计算特性,我们将序列划分为若干block,而非以token为单位进行计算,每个query block与

个滑动窗口blocks和

个global blocks以及

个top-k和AST blocks计算attention,总体的计算复杂度为

b为block size。

每个sparse attention pattern 对应一个attention矩阵,以sliding window attention为例,其attention矩阵的计算为:

ASA伪代码:

实验结果

我们采用CodeXGLUE[1]提供的四个任务数据集进行评测,分别为code clone detection,defect detection,code search,code summarization。我们提取其中的序列长度大于512的数据组成长序列数据集,实验结果如下:

从实验结果可以看出,SASA在三个数据集上的性能明显超过所有Baseline。其中Roberta-base[2],CodeBERT[3],GraphCodeBERT[4]是采用截断的方式处理长序列,这将损失一部分的上下文信息。Longformer[5]和BigBird[6]是在自然语言处理中用于处理长序列的方法,但未考虑代码的结构特性,直接迁移到代码任务上效果不佳。

为了验证top-k sparse attention和AST-aware sparse attention模块的效果,我们在BigCloneBench和Defect Detection数据集上做了消融实验,结果如下:

sparse attention模块不仅对于长代码的任务性能有提升,还可以大幅减少显存使用,在同样的设备下,SASA可以设置更大的batch size,而full self-attention的模型则面临out of memory的问题,具体显存使用情况如下图:

SASA作为一个sparse attention的模块,可以迁移到基于Transformer的其他预训练模型上,用于处理长序列的自然语言处理任务,后续将集成到开源框架EasyNLP(https://github.com/alibaba/EasyNLP)中,贡献给开源社区。

论文链接:
https://arxiv.org/abs/2205.13730

责任编辑:张燕妮 来源: 阿里云云栖号
相关推荐

2024-01-30 01:12:37

自然语言时间序列预测Pytorch

2023-05-03 20:59:51

ChatGPTRMT模型

2023-09-19 10:31:09

算法数据

2023-08-24 16:42:29

Sample聊天实例应用

2021-06-16 07:05:02

gRPC 网关HTTP

2021-08-19 15:19:16

代码开发模型

2022-04-07 07:51:40

代码结构设计

2021-03-24 09:06:01

MySQL长连接短连接

2012-07-23 10:22:15

Python性能优化优化技巧

2018-06-11 16:20:22

数据不平衡数据集算法

2012-12-24 09:23:27

ASP.NETC#IIS

2009-11-27 13:24:20

PHP代码性能优化

2019-11-05 14:37:24

Java性能优化编程语言

2018-10-22 13:23:29

MySQL主从延时线程

2011-06-14 11:14:10

性能优化代码

2020-03-26 12:38:15

代码节点数据

2020-08-16 20:53:15

JavaScript代码开发

2015-09-16 14:47:14

Android性能优化代码

2015-11-05 09:02:05

Java代码性能优化

2016-10-13 10:57:55

phptcp专栏
点赞
收藏

51CTO技术栈公众号