当前位置:首页 > 编程笔记 > 正文
已解决

flash attention的CUDA编程和二维线程块实现softmax

来自网友在路上 174874提问 提问时间:2023-09-24 02:51:59阅读次数: 74

最佳答案 问答题库748位专家为你答疑解惑

本文参考了链接添加链接描述

flash attention介绍

flash attention的介绍可以参考论文:FlashAttention: Fast and Memory-Efficient Exact Attention
with IO-Awareness,具体的数学公式参考下面这个图片:其中注意关于矩阵S有两个维度,softmax的操作维度是dim=1,用pytorch表示就是torch.softmax(S, dim=1)
在这里插入图片描述
对于flash attention来说,里面有两次矩阵乘法,对于这样的二维数组矩阵乘法,一般来说都会考虑使用二维线程块,但是我们之前实现的softmax都是以一维线程块来处理,其中专门用到了一个cub库的函数BlockReduce,经过本人测试,发现这个函数只能针对一维线程块做线程块内部的规约,不能用于二维线程块内部针对某个维度规约,因此在实现flash attention之前,我们需要编写一个二维线程块实现softmax的算法,其中注意BLOCK_DIM_x和BLOCK_DIM_y都必须要选取2的幂次方。

二维线程块实现softmax

之前我们实现一维线程块处理softmax的时候,参考链接添加链接描述

查看全文

99%的人还看了

猜你感兴趣

版权申明

本文"flash attention的CUDA编程和二维线程块实现softmax":http://eshow365.cn/6-12472-0.html 内容来自互联网,请自行判断内容的正确性。如有侵权请联系我们,立即删除!