寻找符号间的连接:基于稀疏字典学习的回路发现理论与实践

若稀疏字典学习可以提取Transformer中(几乎所有)有意义的特征,我们能否据此逆向出Transformer内部的(几乎所有)回路?

Authors

Affiliations

Zhengfu He

MOSS Team

Tianxiang Sun

MOSS Team

Qiong Tang

MOSS Team

Xipeng Qiu

MOSS Team

Published

Dec. 23, 2023

DOI

No DOI yet.

引言

从最广义的角度来看,本文及其后续工作属于机械论可解释性(Mechanistic Interpretability) 的范畴。这类工作的最终目标是能够完全逆向工程任意训练好的神经网络,从而理解其内部特征与回路即使其没有任何先验的可解释性设计。具体而言,我们希望能选择一个合适的可解释性粒度, 在这一粒度下将神经网络的内部状态分解为可解释的组分。若可以理解每个组分的功能,并解释组分间的交互,我们便有机会完全理解神经网络的内部世界。

神经元作为最基本的运算单位,却并不能作为人类解释神经网络的基本组分,这主要是因为单个神经元通常会表示多种语义,因此无法建立神经元激活和人类概念的双射。这种多语义性(Polysemanticity) 可能的重要来源是特征重叠假说(Superposition hypothesis)。 这种假说猜想模型尝试在一个高维空间表示比空间维度更多的独立特征,以增强其表达能力。这组特征在空间内构成了一组可解释的、过完备的基。

稀疏字典学习(Sparse Dictionary Learning)是一种基于稀疏自编码器(Sparse Autoencoder)的无监督特征分解方法,可以从神经网络表征空间中提取可解释、单语义的独立特征,并将任意给定的表征近似分解为若干个(更加单语义的)特征线性加权和。 这类方法为重叠问题带来了重要的进展。 对稀疏字典学习的可解释性研究已具备一定的基础,本文基于现有的研究对以下问题做了尝试: 若字典学习可以提取Transformer中有意义的特征,我们如何据此逆向出Transformer内部的回路?我们在一个黑白棋模拟任务训练的较小的Transformer上进行了稀疏字典学习,并提出一套理论解释:

本文分为以下5个主要部分:基本设定、字典实验、回路理论、回路实验与总分析。我们首先介绍选用的黑白棋模拟任务,Transformer结构与训练结果;字典学习实验部分包括字典学习的特征提取位置选择与分析结果;回路分析理论中,我们引入一套理论,利用已有的特征提取结果寻找重要的内部回路; 回路分析实验介绍了用这套理论对Transformer应用的结果。

本文的核心结论列举如下:

从黑白棋模型入手

作为后续的实验基座,我们用一个1.2M参数的decoder-only Transformer以自监督预训练的方式学习了一个可编程解决的黑白棋预测任务。本章主要介绍任务描述,模型架构与训练细节。

任务介绍

黑白棋(Othello)是一个棋牌游戏,两玩家轮流在8x8的棋盘上落子,落子规则如下: 以一个空格为中心,若在上下左右及对角线共8个方向上的任意一个方向上,存在一个同色的棋子,且空格与该棋子连线上全都是异色的棋子,则该空格可以落子,并把所有满足前述条件的对方棋子翻转为自己的颜色。 因此这个游戏也被称为翻转棋。

上图展示了一局游戏的进程这里有一个plotly的陈年bug,当动画回到过去的帧时,画布上的部分内容不会被清除。。棋盘的开局有4个棋子交错排列,双方不断落子直至填满整个棋盘游戏结束存在棋盘未被填满便无棋可下的情况,我们不考虑这样的少见情况,双方的目的是在终局时在棋盘上占据更多棋子。

正如图中所示,棋盘中有60个空位,因此游戏总共进行60步。将每一步的落子位置记录下来,我们便可以用一个长度为60的序列表示一盘游戏:

[37, 29, 18, 45, 22, 26, 19, 12, 54, 53, 25, 44, 38, 21, 62, 34, 46, 63, 13, 55, 42, 30, 23, 17, 5, 39, 11, 15, 16, 60, 7, 20, 43, 24, 31, 61, 32, 4, 47, 2, 10, 52, 51, 9, 0, 14, 33, 58, 59, 41, 49, 50, 3, 6, 57, 48, 56, 1, 8, 40]

通过python代码描述游戏规则,并在每一步从符合规则的下法中采样一个位置落子,我们可以生成数以百万计的数据。我们的设定即对这样的序列进行自回归建模,就如同语言模型建模下一个词的概率,这个任务建模下一步的合法概率这个任务只要求每一步均是合法的下法,并不在乎输赢。

为何选择这个训练任务?

这个任务最初是在这篇ICLR2023 spotlight文章中提出的。原作的主要思路是在这个任务上验证世界模型(World Model)的形成。 作者通过巧妙的probing和intervention为世界模型的形成提供了一个很强的证据:模型首先会首先计算出棋盘的状态并据此做出决策,而非单纯地记忆表面的规律。 这样的任务比先前的模加法、除法更复杂,而又比语言模型更简单。

这个任务的一个巧妙之处在于输入的序列本身提供的信息很少:只有落子顺序,并不包括当前棋盘的状态。 此外,模型不具备任何有关棋盘或规则的知识,其不知道输入的序列是按照两位玩家交替的规律展开的,也不预先知道输入序列与棋盘位置的映射关系,在这样艰难的条件下,其能够完成这一任务便很令人惊奇。即使人类先前知道序列的现实意义,想要得到下一步的合法落子位置也需要通过递归地模拟这一过程并经过细心的考察。而Transformer的计算资源是固定的,它不能显式地完成这种递归过程。 因此,即使不以这个任务作为理解大语言模型的出发点,完全理解其中的原理也能使研究者对Transformer的内部机制有很大启发。

对于本文接下来的字典学习和回路分析理论来说,这个任务是一个较好的入手点:它不需要10-100M级别的模型即可较好地完成任务,这种极小的模型使字典训练简单许多,调试与分析的过程也会更容易。 此外,这是一个更可解释的任务,它很少存在歧义,内在的逻辑较明确。这两个原因帮助我们排除了任务和模型的可能干扰,将研究重心放在可解释性理论本身。对这种选择的更多讨论在黑白棋和语言模型一章有更详细的展开。

模型架构

本文主要关注decoder-only Transformer。架构与配置如下所示:

{ 'act_fn': 'gelu', 'd_head': 16, 'd_mlp': 512, 'd_model': 128, 'd_vocab': 61, 'd_vocab_out': 61, 'n_ctx': 60, 'n_heads': 8, 'n_layers': 6, 'n_params': 1179648, 'normalization_type': 'PreLayerNorm', 'positional_embedding_type': 'learned', 'p_dropout': 0.0 }

尽管先前的黑白棋可解释性工作已经开源模型架构和参数,本文对其做出了一定修改,理由如下:

模型训练

训练这个合成任务上的Transformer模型训的设置与朴素的自回归文本建模相同。由于棋盘的开局中已有四个棋子被占据,我们按照阅读顺序将其余空位标号0-59。随机采样的一盘游戏被视作一条序列数据,用于模型自回归建模。 我们采样了1e9条数据并做了去重处理,使用8卡NVIDIA A100训练了1个epoch。模型loss收敛至~2.02,正确率99.5%。除错误率相比原文99.9%略微下降,其余设置几乎完全相同。

用字典学习提取模型中可解释的特征

本章我们主要介绍以下三个主题:为何用稀疏字典学习解释模型结合Transformer架构的字典训练思路实验表现

稀疏字典学习简介

若Transformer的内部结构对人类理解友好一些,线性的可解释的特征恰好就是神经元,而其激活强度就是这个特征的强度。这样我们只需要分析每个神经元具体在何种情况下被激活以确定其可解释意义,并对特定的输出观察每个神经元的激活值,理解其内部思路就不难了。 这个假设具有较大的局限性,但仍可以在许多语言模型、视觉模型或语言-视觉模型上发现可解释的神经元。 从最直观的角度来看,即使不考虑任何高级的特征,只要词表的大小大于模型的隐层维度,模型就不得不用一个有限维度的空间容纳多于其维数个数的特征。特征重叠假说认为特征会以相互干预的方式“塞入”这个拥挤的空间内。 在这种情况下,必然会存在至少一个特征,其在表征空间中的方向并不与任何一个神经元重合。此外,先前对Priviledged Bases的研究认为模型表示特征的方式是以下两种方式的混合策略:

由于以上的原因,从神经元理解模型内部特征存在不少问题,因此需要一种更一般的方式寻找这些可解释的方向构成的“可解释基”,它们大概率是过完备的。此外,特征重叠的一个重要动力是特征的稀疏性,这也是现实任务常具备的重要性质。 因此利用稀疏字典学习提取特征,相当符合过完备与稀疏这两个重要性质。总体而言,稀疏字典学习的目标是通过一个autoencoder寻找一组过完备的基\mathbf{d},使得对于模型给定位置的任意激活\mathbf{x},都可以分解为这组基上的稀疏加权和:

\mathbf{x} \approx \sum_{i}{c_i \mathbf{d}_i};\quad \text{s.t.} \quad \min{\sum_{i}{c_i}}

其中,c_i是第i个特征的激活强度。通过约束字典特征的稀疏激活,字典被迫寻找表征中隐含的基本特征并计算这组特征用于复原给定表征的最小权重。

先前的工作已经对字典训练的很多细节做了探索,我们认为训练字典和依据字典发现回路这两个问题在实践上具有极强的相关性,但在理论上是相当正交的。 因此我们直接采用现有的训练设定,对模型中给定位置的激活集合\mathbf{X},采用如下的稀疏字典学习流程:

\bar{\mathbf{x}} = \mathbf{x} - \mathbf{b}_d \mathbf{c} = \text{ReLU}(W_e \bar{\mathbf{x}} + \mathbf{b}_e) \hat{\mathbf{x}} = W_d \mathbf{c} + \mathbf{b}_d \mathcal{L} = \frac{1}{\mathbf{X}}\sum_{x\in \mathbf{X}}{\lVert \mathbf{x} - \hat{\mathbf{x}} \rVert}_2^2 + \lambda \lVert \mathbf{c} \rVert_1

Decoder bias \mathbf{b}_d的主要作用是对\mathbf{X}做平移,在调整后的集合上做重构;W_eW_d是encoder和decoder权重,W_e的主要作用是确定某个特征的激活强度,W_d的每一行都被限制为单位向量,主要作用是表示可解释的过完备基。 因此,\mathbf{c}是特征激活值,\hat{\mathbf{x}}是重构值,损失函数\mathcal{L}由重构误差和特征稀疏性混合构成。对这些设置的详细分析见原文

如下图所示,对于给定模块的某处表征,我们可以用字典学习将其(尽可能无损地)分解为若干可解释特征的带权和:

字典学习应从何处提取特征?

Transformer的可解释性架构已经在A Mathematical Framework of Transformer Circuits中介绍的比较详尽。将残差流视作Transformer的内存管理中心,并将每个Attn和MLP层视作从中读与写的模块是其中重要的核心思想。 先前基于字典学习的工作通常研究的对象包括:词表征、残差流和MLP隐层。 我们对这些工作做简要的评述:

基于以上分析,我们认为用字典学习分解以下三个部分可能是有益的:词表征,每个Attn层输出和每个MLP层输出。尽管相比于MLP,针对Attn head的分析已经有相当多的Mech Interp工作,我们认为将其纳入字典学习的统一理论是很有必要的,这种设定对系统和规模化地理解Transformer很有帮助。

用字典学习分解黑白棋模型

作为回顾,我们训练的黑白棋模型总共包含6层,采用Pre-norm,激活函数为gelu。模型隐层d_model=128,注意力头数n_heads=8,MLP隐层d_mlp=512,层数n_layers=6。

模型总共有12个Attn/MLP模块,我们针对每个模块的输出训练一个字典我们没有分解词表征,这是因为这个任务比较特殊:词表大小甚至小于模型隐层维度,这是很反常的;此外,由于任务的设定,每个token的定义是相当明确的。至于pos embedding的解释,本模型采用的learned pos emb不具备太多参考的意义。但是在未来应用于语言建模任务时,分解词表征可能十分重要。,字典的输入维度总是d_in=128,隐层n_components=1024。 我们采样了4e8条序列用于训练字典,即4e8 * 60 = 2.4e10个token,对于每个token,我们将其带上前文一起输入Transformer模型,获取每个模块的输出并存下来。尽管每个token都与前文有关,我们在训练时将这些token获得的表征视作完全独立的,打乱并采样用于重构训练每个模块的字典。

每层的平均输入表征的2-范数如图所示,图中几乎不可见的误差棒表示了平均的重构误差(由MSE计算的2-范数)。字典几乎可以无损地重构每个模块的输出

训练完成后,我们有12个字典,从底至上对应第0层的Attn至第五层的MLP,我们用L0A-L5M分别表示第0-5层的Attn层和MLP层。

对于每一个字典,我们统计其中特征的活跃程度我们采样5万条数据,统计每个token对应的激活在每个字典上的分解结果,若某特征的激活次数大于64,则我们认为是活跃的特征(64 / (60 * 50000) ~= 2e-5)。。在字典学习可以提取有意义的即使不是人类可理解的特征这一假设下,在这个模型中,所有层的输出都存在特征重叠

由于这个任务需要建模的世界相比于语言建模的世界简单许多,其隐层维度d_model却与语言模型仅差1-2个数量级,因此在真实的语言模型中重叠现象会更加严重。此外,通过重采样字典激活值等技巧可以让提取出的特征数量更多,可解释意义更详细

上图描述了该模型内部特征的过完备性,模型内部特征另一个的重要性质是稀疏性。下图展现了每个token在每个字典上的平均激活特征数目,在这个模型中,所有层的输出都可以用比隐层维度更少的特征复原

由于同样的原因,现实语言模型中的特征应当更稀疏

字典特征概览

这个部分主要介绍这个模型中现有工作已发现的特征,并由此引出我们如何通过考察字典特征殊途同归地得到相同的或扩展的结论。

在了解我们的字典特征前,首先简介已有的对这个模型的认识,这有助于读者理解我们解释特征的框架。 具体而言,对于一个给定的输入序列,可以唯一确定一个棋盘状态。例如:

[37, 29, 18, 45, 22, 26, 19, 12, 54, 53, 25, 44, 38, 21, 62, 34, 46, 63, 13, 55, 42, 30, 23, 17]

这段输入可以唯一确定以下的棋盘状态:

原文中作者们对于棋盘的每个位置与模型的每层残差流单独训练了一个probe分类器,通过分类棋盘目前的颜色(黑、白或空)以判断模型内部是否含有对于当前棋盘状态的知识。 作者们发现,当分类器是线性时,分类的错误率在20%左右,错误率较高。而采用非线性分类器可以将错误率降低至1.7%。因此原作中的结论是棋盘内部存在非线性表示的世界知识。 后续的一篇工作发现原文中的结论并不充分,若probe的目标由(黑/白/空)三种状态转化为(与当前棋子同色/不同色/空),则线性分类器可以获得很好的效果。 这个观察对特征的线性性这一假设有深刻的启发,我们在字典学习与Probing一章中详细展开。

就该问题而言,模型内部应当不存在有关黑色与白色的线性表示,而是将“自己的”和“对方的”两种特征线性表示在隐层空间内。在特征的线性性这一假设下,我们便可以认为后者这种表示方式是更加基本的,而黑白之分的表示在模型中是某种复合特征,对该问题的讨论见基本特征与复合特征一章。 尽管这两种表示方法的表示能力是等价的,以“自己和对方”的视角思考或许有助于模型在不断交错落子的序列中可以始终保持相同的视角:不论当前是奇数步或偶数步,模型仅需以一种方式表示棋盘状态,而非维护两个(可能较独立的)子空间分别用于黑色和白色。

在这一前置发现的基础上,我们的特征解释框架展开如下。

对于字典学习分解出来的特征,我们仅考察那些激活频率大于一定标准的“活跃特征”。 我们对这些活跃的特征做如下记号:对于第X层的Attn/MLP输出分解得到的第Y个特征,我们将其命名为LX{A/M}Y例如,第3层Attn层字典的第728个特征称为L3A728,第0层MLP的第17个特征称为L0M17。

针对每个特征,我们分别考察它们激活程度最高的样本。下图是一个例子,展示了L0A622被激活的程度最高的64个输入:

直接观察这样的图像比较难以寻找规律,极容易造成幻觉。因此,我们设计了以下的解释接口:

对于给定的字典特征,考察对于1.2M token中top-k使其激活的输入。对这k个输入序列/棋盘状态计算如下统计量:

以上的统计图中,k取2048。每张这样的统计图反映了一个特征的行为。上图中的第一行第二列说明在2048个L0A622最感兴趣的输入中,当前步全部落子于f-1位置,因此我们可以认为这是一个“当前步=f-1”的特征。我们在如何解释字典学习的特征?一章详细讨论了我们解释特征的方法

我们利用这种工具分析字典特征,主要发现了以下几类特征:

相比于之前的工作,我们产生了一些新的发现:

以上两处不同之处其实并不矛盾。我们认为探测出的结果是一种复合特征

总体而言,我们发现相当多的特征都可以解释,尽管存在无法直接理解的特征,这些特征的激活值往往较小。

从特征发现回路:我们的回路发现理论

在机械论可解释性的宏观视角来看,即使我们对字典学习的认识和工程经验不断精进,能够更好地理解特征,我们仍然需要一种可规模化的方法回答一个重要问题:不同的特征是如何被激活或抑制的,即模型的内部回路是如何由底层表征计算得到高层表征的?

我们将Transformer decoder分为三个主要部分:QK回路,OV回路和MLP。对这种数学模型的分析在A Mathematical Framework of Transformer Circuits已有较详细的介绍。从直观的角度理解,我们需要具体地回答如下三个问题:

为了方便理解,我们按照以下顺序介绍我们的理论框架:残差流的一个重要数学性质,OV回路的可解释性,QK回路的可解释性与MLP的可解释性。

残差流中的重要性质

如上图所示,在残差流中,任何模块的输入x可线性表示为其底部所有模块的输出之和,例如:

x_{\text{L1M}} = \text{LN}(\textbf{Embed} + \textbf{Out}_{\text{L0A}} + \textbf{Out}_{\text{L0M}} + \textbf{Out}_{\text{L1A}})

处理LayerNorm

在主流的Transformer架构中,LayerNorm通常在每个模块从残差流读取一份复制后,对其执行LayerNorm操作并用于后续计算,即Pre-norm。尽管每个模块的输入可以线性分解为底部所有模块的输出之和,LayerNorm本身并非一个线性操作,这使得我们无法将相同的操作应用于每个线性分量上,解决这个问题这对后文的分析有重要影响。

x = x - x.mean() # Linear Operation x = x * (1 / x.std()) # Non-Linear Operation !! x = ln.w @ x + ln.b # Linear Operation

上式用一个伪代码描述了LayerNorm的计算过程,其中计算标准差这一步对输入x不是线性的。对此,我们将x的标准差视作常数而非x的函数,于是LayerNorm在不影响计算结果的前提下转化为x的线性函数,我们就可以对x的任意线性分解分别进行这种新的LayerNorm操作以估计每个分量对结果的影响。

特别地,当我们关注的LayerNorm是模型在lm_head之前的FinalLayerNorm时,这样的技术就可以用于分析每个模块对输出logit的影响,这种技术被称为直接输出归因(Direct Logit Attribution),这种技术在相当多的Mech Interp文章中有应用

OV回路分析理论

Transformer内部的每个注意力头需要通过如下过程将token j的输入迁移至token i

\textbf{OV}^h_{i \gets j} = \textbf{AttnPattern}^h_{ij}W_O^hW^h_Vx_j

其中上标h表示对应的参数/表征是每个头独立的。\textbf{AttnPattern}^h_{ij}表示信息迁移的权重系数,W^h_OW^h_V是该注意力头的OV权重矩阵,可直观理解信息迁移时做的额外加工。 x_j是该注意力模块的输入,对每个头而言均相同。

由于多头注意力的独立可加性,注意力模块在token i处的输出结果可表示为:

\textbf{Out}_{\text{LXA},i} = \sum_h \sum_j \textbf{OV}^h_{i \gets j} = \sum_h \sum_j \textbf{AttnPattern}^h_{ij}W_O^hW^h_Vx_j

因此,某个注意力模块在token i处的输出是每个头输出的加和,而每个头的输出又是每个其余token由该头计算出的OV输出加和。进一步地,我们还可以将其余每个token在该模块处的输入分解为其底部所有特征的加和:

\textbf{Out}_{\text{LXA},i} = \sum_h \sum_j \textbf{AttnPattern}^h_{ij}W_O^hW^h_V \text{LN}(\sum_{\text{Lower Components}}\textbf{Out}_{\text{Lower Components}})

再进一步,我们又可以将其他模块的输出近似为字典学习的分解结果:

\textbf{Out}_{\text{LXA},i} \approx \sum_h \sum_j \textbf{AttnPattern}^h_{ij}W_O^hW^h_V \text{LN}(\sum_{\text{L.C.}}(\sum_k {c^{\text{L.C.}}_k \mathbf{d}^{\text{L.C.}}_k}))

其中c^{\text{L.C.}}_k \mathbf{d}^{\text{L.C.}}_k是给定token的给定底部模块的某个字典特征方向与其在当前输出下的激活值之乘积。对于任意的\textbf{Out}_{\text{LXA},i},稀疏字典通过其encoder计算每个特征的激活程度。例如对于L1A622,我们通过以下方式计算其对\textbf{Out}_{\text{L1A},i}的激活程度贡献:

\mathbf{c}_{\text{L1A622}, i} = \text{ReLU}(W^{\text{L1A}}_{e, 622} \textbf{Out}_{\text{L1A}, i} + \mathbf{b}^{\text{L1A}}_{e, 622})

若在token i处,L1A622被激活(\mathbf{c}_{\text{L1A622}} > 0),我们便有:

\mathbf{c}_{\text{L1A622}, i} \approx W^{\text{L1A}}_{e, 622} \sum_h \sum_j \textbf{AttnPattern}^h_{ij}W_O^hW^h_V \text{LN}(\sum_{\text{L.C.}}(\sum_k {c^{\text{L.C.}}_k \mathbf{d}^{\text{L.C.}}_k})) + \mathbf{b}^{\text{L1A}}_{e, 622}

采用上一章对LayerNorm的处理,整个式子对任意的底部特征激活c_k \mathbf{d}_k都是线性函数。这种分解可以将某个注意力模块输出的某个被激活字典特征归因至其左下方的所有特征,我们便可以回答一个重要问题:若某个注意力模块的输出特征被激活,是之前的哪些token中对应的残差流中的哪些特征导致的?

QK回路分析理论

对于QK回路的理解与记号可以很大程度借鉴上一部分。Transformer内部的每个注意力头通过如下过程决定token i需要对token j分配多少比例的注意力:

\textbf{AttnPattern}_{ij} = \text{Softmax}(x_iW_QW^\mathrm{T}_Kx^\mathrm{T})_j

其中x\in \mathrm{R}^{L \times D}表示整个序列在该注意力模块的输入集合,x_i\in \mathrm{R}^{1 \times D}是token i处的输入, W_Q, W_K \in \mathrm{R}^{D \times d}是QK对应的参数矩阵,共同组成了QK回路。D, d分别表示模型隐层维度和每个注意力头的隐层维度。此处省略注意力头上标是因为该操作对每个头是独立的。

在下文中,我们将计算Softmax前的变量x_iW_QW^\mathrm{T}_Kx^\mathrm{T}称为注意力分数Attn Score,归一化完成的结果称为注意力强度Attn Pattern。 Softmax作为重要的非线性模块,将注意力强度进行了归一化处理,这种处理使得注意力分数中的任何元素变化都将使整个序列的注意力强度变化,对解释造成了困扰。因此我们很难完全理解整个序列的注意力分数,只得退而求其次求解更加定性的问题: 对于一对产生较强注意力(例如\textbf{AttnPattern}_{ij} > 3/L)的token,是各自残差流中的哪对特征产生了较强的“共鸣”?

具体而言,由于每个token在给定注意力模块的输入可以近似地用其底部的所有特征表示,这个过程与上一部分讨论的相同,即:

x_{\text{LXA}} = \text{LN}(\sum\limits_{\text{Lower Components}}\textbf{Out}_{\text{Lower Components}}) \approx \text{LN}(\sum_{\text{L.C.}}(\sum_k {c^{\text{L.C.}}_k \mathbf{d}^{\text{L.C.}}_k}))

在可以用线性操作近似LayerNorm的前提下,我们发现\textbf{AttnScore}_{ij}可以用以下的形式近似:

\textbf{AttnScore}_{ij} \approx \underbrace{\text{LN}(\sum_m c_m \mathbf{d}_m)}_{\text{Residual Stream of the } i\text{-th token}}W_QW^\mathrm{T}_K\underbrace{\text{LN}(\sum_n c_n \mathbf{d}_n)^\mathrm{T}}_{\text{Residual Stream of the } j\text{-th token}}

为便于表示,在上式中我们将两维的线性展开\sum_{\text{L.C.}}(\sum_k {c^{\text{L.C.}}_k \mathbf{d}^{\text{L.C.}}_k})简化为单纯的线性组合\sum_m c_m \mathbf{d}_m。在前文所述的LayerNorm的线性近似下,我们更容易发现这种双线性结构有一种更直观的理解方式:

\textbf{AttnScore}_{ij} \approx \sum_m \sum_n ((c_m \mathbf{d}_m)W_QW^\mathrm{T}_K(c_n \mathbf{d}_n)^\mathrm{T})

于是对任意给定的\textbf{AttnScore}_{ij},可以近似分解为两token间所有特征对的贡献之和。于是我们就可以回答一个重要问题:若某一注意力头由一个token注意到另一个token,是这两者对应的残差流中哪些特征对导致的?

MLP分析理论

MLP模块占据了Transformer模型中的逾半数参数,然而领域内对MLP的可解释性研究相比于注意力模块显著较少。这可能是因为MLP本身形式较简单,而基于神经元的研究又存在一定的局限。在字典学习的框架下,我们似乎可以理解更多有关MLP的内部特征,于是我们提出一个问题:这些特征是如何通过MLP计算得到的?

这部分的记号相比于之前两个部分简单,这是因为MLP不具备多头注意力的可分解性,同时也不涉及token间交互。因此我们的问题可以这样描述:

若通过字典学习将LXM模块的输入分解为字典特征的加权和,则LXM的输出可表示如下:

\textbf{Out}_{\text{LXM}} = \text{MLP}_{\text{X}}(x_\text{LXM}) \approx \text{MLP}_{\text{X}}(\text{LN}(\sum_{\text{L.C.}}(\sum_k {c^{\text{L.C.}}_k \mathbf{d}^{\text{L.C.}}_k})))

若同样地用处理LayerNorm的策略将其近似为线性函数此处重载了c_m \mathbf{d}_m的记号。经过LayerNorm线性近似的每个特征方向和激活值都会发生变化,为便于理解我们不再引入新的记号。,并简化模块输入的表示,单个MLP的特征激活值可表述为:

\mathbf{c}_{\text{LXMY}} = \text{ReLU}(W^{\text{LXM}}_{e, Y} \text{MLP}_{\text{X}}(\sum_m c_m \mathbf{d}_m) + \mathbf{b}^{\text{LXM}}_{e, Y})

其中W^{\text{LXM}}_{e, Y}是LXM字典encoder的第Y行,\mathbf{b}^{\text{LXM}}_{e, Y}是LXM字典的encoder bias。我们重点关注W^{\text{LXM}}_{e, Y} \text{MLP}_{\text{X}}(\sum_m c_m \mathbf{d}_m)这一部分,于是这个问题可以更一般化为:若一非线性函数的输入是一个变量集合的线性组合,则其输出在某一方向上的投影主要由输入的哪一个子集决定?

若把MLP看做任意的非线性函数,回答这个问题将是非常困难的。好在MLP内部仍然具有一定的归纳偏置可供我们分析,在常规设置的单层MLP中,经过激活函数的神经元表示对最终结果的贡献是线性的,而每个输入对神经元的贡献是非线性的。具体而言,某个神经元N的激活值是MLP的输入x的函数:

N=\text{Act\_fn}(W_{in}x), \textbf{Out} = W_{out}N

MLP隐层的激活函数Act_fn带来的非线性性质是影响可解释性的关键。我们发现流行的激活函数设置总是满足某种自门控的形式x\cdot\sigma(x):对ReLU而言\sigma(x)是符号函数,SiLU是Sigmoid函数,GELU是Gaussian分布的累积分布函数,这些函数均取值0-1且具有单调不减性。

\text{Act\_fn}(W_{in}x) = (W_{in}x) \cdot \sigma(W_{in}x)

基于这些观察,我们定义组成x的每个特征c_m \mathbf{d}_m对输出\textbf{Out}近似直接贡献\textbf{ADC}

\textbf{ADC}(c_m \mathbf{d}_m) = W_{out}(\underbrace{(W_{in}c_m \mathbf{d}_m)}_{\text{Dictionary Feature}}\cdot\overbrace{\sigma(W_{in}x)}^{\text{Leave MLP input unchanged for }\sigma})

于是MLP的某一输出特征可表示为底部所有特征的近似直接贡献之和:

c_{\text{LXMY}} = \text{ReLU}(W^{\text{LXM}}_{e, Y} \sum_m \textbf{ADC}(c_m \mathbf{d}_m) + \mathbf{b}^{\text{LXM}}_{e, Y})

至此我们终于将MLP的输出特征与输入特征建立了联系,并且可以通过近似直接贡献考察每个输出特征的激活来源。

用近似直接贡献评估每个特征对输出的影响,具有以下性质:

由于MLP中的隐层神经元对MLP输出的贡献是线性的,因此这些神经元对任意MLP输出激活的贡献也是线性的,这是因为MLP的输出操作和字典encoder都是线性映射。 于是我们若考虑底部特征c_m \textbf{d}_m对输出特征c_{\text{LXMY}}的贡献,其有且仅有两个主要来源:激活对c_{\text{LXMY}}有正向影响的神经元抑制对c_{\text{LXMY}}有负向影响的神经元。 近似直接贡献能够很好地反应前者的贡献,因为对于单调不减的非负自门控函数族\sigma(x)而言,\frac{\partial}{\partial x} x\sigma(x)是恒正的。 因此尽管近似直接贡献将底部特征对高层MLP特征的贡献做了线性假设,其与实际的贡献结果是正相关的。

然而近似直接贡献不能捕捉c_m \textbf{d}_m对负向影响c_{\text{LXMY}}的神经元的抑制。这是由于其抑制主要使\sigma(x)\to 0,而这会使得在该神经元上的近似直接贡献趋于0。 但是对这种贡献的捕捉缺失存在一个上界,对任意特征而言,即使所有与之负相关的神经元均被抑制,也无法将该特征激活至一常数以上,这个常数由字典decoder bias和MLP层的LayerNorm bias决定。

回路分析理论:总结

我们在本章将Transformer分为QK,OV和MLP三大模块,针对各自的特性讨论了一个核心问题:若字典学习可以较好地分解表征,我们如何理解Transformer是如何根据底层特征计算得到高层特征的?我们认为对OV回路的理解是最完善的;对QK回路的分析有可能无法捕捉到有关整个序列的特征,但仍足以理解许多行为;对MLP的分析理论仍需要更多研究以确保其可以可信地泛化与规模化到所有情况。

理解黑白棋模型中的回路

至此,我们已经准备好了我们的理论框架,现用它来分析我们的黑白棋模型。本章的主要篇幅是对模型各种具体行为的样例分析,应用字典分解结果和回路分析理论,我们似乎能理解这个模型的大多数行为。最终我们给出一个较宏观的回路发现结果。

我们对本章的每个主题选择1-2个棋盘状态对模型的行为具体分析,展示的所有样例均是随机选择的。

用Direct Logit Attribution理解模型的输出与特征激活的关系

通过应用Direct Logit Attribution,很多的Mech Interp工作将特定的logit归因到某个MLP或注意力头上,以帮助分析每个头的作用。我们在这里提出一个更详细的问题:若字典学习可以将每个模块的输出分解为可解释特征之和,哪些特征对给定的logit有较大贡献?

我们随机采样一组数据,并从中随机采样一步,对应的棋盘状态如下:

模型给出的预测结果如下:

模型成功预测了合法的位置,我们可以分析任意一个logit,比如棋盘上的33号位置,其logit约8.29。我们对每个特征计算Direct Logit Attribution,结果如下:

在上图中,我们移除了贡献绝对值小于0.1的特征。我们发现L5M499贡献相当突出,这个特征的行为描述如下:

我们重点关注第一行第三列描述合法位置的统计量,这张热力图代表在L5M499被激活最强的2048个输入中,有大部分描述的是“d-1,e-1,f-1”是合法的,而这恰好符合上述的棋盘状态。再结合我们尝试的其他例子,我们得到一个结论:棋盘的多数logit是由少数几个L5M特征主要激活的,这种影响具有直接的因果性,且这些特征的响应较专一。

用近似直接贡献理解MLP特征的激活

上一部分我们已经将模型输出与最高层的MLP特征建立了联系,接下来一个较为自然的问题就是,每个MLP特征是如何计算出来的?

我们再次随机选取一个棋盘状态,如下所示:

在上图对应的棋盘状态中,我们选取L2M中激活最高的特征L2M845,其激活值是1.4135。这个特征的行为比较好理解,指的是模型在b-4或b-5落子并翻转了c-4处的棋子:

我们列出了近似直接贡献绝对值小于0.05的特征:

我们发现有四个重要的贡献者:L0A837, L1M280, embedding和L1M49。它们对应的特征简要描述如下:

这四个特征的共同之处在于:它们均描述了列4的翻转情况,从人类可理解的角度来看,这些特征均是{L2M845:c-4被翻转}的充分条件,我们有一定的信心相信这揭示了模型从底部特征得出高级特征的一种模式。

理解OV回路的信息迁移

我们仍然随机采样一个棋盘状态:

我们发现L2A474重点描述了以c-2为中心的某种特定棋盘状态,与当前的棋盘状态有较好的对应,在当前棋盘状态的激活值是0.70:

我们利用OV回路的分析理论,将该所有token位于L2A之下的所有特征对该特征的贡献陈列如下:

上图中省略了贡献绝对值小于0.03的特征,其中PX表示来自第X个token残差流中的信息(该例中X取0-7)。我们发现贡献最大的3个特征均与该棋盘状态相关:

除此之外,我们发现P6L0A629对P7L2A474的激活有很强的副作用,我们发现这也有很强的可解释意义:L0A629主要描述c-3是己方棋子,而由于P6L0A629位于P7L2A474的前一个token,这与P7L2A474描述的c-3是己方棋子是相悖的,这也是因为相隔奇数步的残差流中对己方和对方的认知是相反的。 这种矛盾主要是因为P7的落子恰好翻转了c-3位置的棋子,而上一步的残差流并不包含未来的信息。我们猜想这种模型中存在某种非常微妙的平衡,故由于棋子翻转从过去的token中带来的负面影响会被描述棋子翻转的特征带来的正面影响消除甚至覆盖,从而在棋局的不断进行中始终保持尽可能正确的棋盘状态信息。

对读者而言,这部分的理解应当是相当困难的,因为黑白棋模型的特殊理解方式“己方vs对方”与繁杂的棋盘记号,对表述和理解都带来很大的困难。简而言之,我们发现OV回路迁移的信息都具有很强的可解释意义,某个残差流中的Attn从其它残差流带来的可解释特征能够很大程度上转化为当前残差流的可解释特征,这使我们对Transformer内部神奇的信息流动机制一再感到惊奇。同时我们也更加相信理解模型的行为并非一个极其复杂的问题,经过细心的观察,我们有很大机会理解这些复杂的信息流动。

理解注意力强度的形成

注意力强度是可解释性研究中最容易入手的点之一,通过注意力分布的热力图,我们很容易认识到每个token对其它的token倾注了多少注意力。在本章节中,我们对该模型中最广泛的一类注意力模式做分析,它们形如以下的交错式注意结构:

我们猜测,这种注意力模式的形成是为了分别从属于己方/对方落子对应的残差流中迁移信息,若某token与当前token相隔偶数步,那么过去的残差流中描述“属于自己的棋子”的特征应当能够增强当前步的相应特征,我们发现这种机制多是靠位置编码实现的。这里我们举一个例子, 在上图的注意力模式图中,最后一个token对对方落子的token均赋予了很强的注意力。

应用我们的回路分析理论,我们调查上图中token 14对token 13的注意力是哪些特征对引起的。我们发现其Softmax前的注意力分数为5.80,线性分解后的贡献如图所示:

我们先前发现模型自发学习的位置编码显著包含了当前落子颜色的信息,且模型底层有若干特征表示了类似的概念。我们发现,图中贡献最大的特征对中许多都包含位置编码,我们还发现其余的若干重要特征(例如P13L0M1015与P14L0A195)均是与当前落子颜色强相关的特征。 这说明我们的回路发现理论可以识别与位置编码相关的特征,且发现这些特征直接贡献了前文所述的注意力模式。

总结

这部分内容是本文的核心,我们通过实验更加确信了我们的回路发现理论可以在前所未有的粒度(即字典学习特征)上发现模型的回路,我们能够理解相当一部分的MLP特征与Attn特征是如何计算得到的,且实验结果和人类理解是相符的。 在我们的理论框架下,再加之细心的实验与总结,我们可以比较深入地理解Othello模型大致的工作流程:

本章的发现本身便足以使我们对Transformer内部的结构产生相当程度的认识,例如OV回路分析中,我们发现了特征间的可解释关系和贡献/抑制关系具有很强的正相关性,以及MLP中高层抽象特征的形成通过近似直接贡献常能按人类可理解的方式分解为底层基本特征。 这些结果让我们对理解语言模型产生了很多信心,同时我们也从中获取了许多关于Transformer模型的知识。我们在这章中仅展示了一些有代表性的结果,这套理论可以发现的现象远不止于此。有兴趣的读者可以在github repo下载并运行交互式的开源实现,自行复现或探索这个模型内部的回路。

总讨论

黑白棋和语言模型

我们的最终目标是能够逆向的Transformer语言模型,在黑白棋模型中仅做概念验证。黑白棋任务是一个不错的语言模型简易替代品,因为它的复杂度很难仅凭记忆完成,这套“黑白棋语言”的内部世界较丰富,且存在不少抽象的概念。 这种具有一定复杂性的任务使其有作为初步尝试的价值,相较下算术任务等与语言模型相差太大,可能无法得出可迁移的结论。 同时其又不至于过分复杂以增加训练字典或解释的难度,仅60个token的词表大小,与无重复输入的性质均一定程度上使我们的可解释更加明确,排除了很多易混淆的概念。

我们总结了两点黑白棋与语言模型的区别,我们认为这些区别在字典学习的框架下是值得注意的。

重叠假说和我们的字典学习流程均假设表征空间可以被分解为若干个可解释特征的加和。这些特征在空间中以单位向量的形式出现,其激活强度代表了这种特征出现的强弱。 这些特征的激活强度均是非负值,实现上,我们用字典学习隐层的ReLU函数将所有小于0的特征激活重新规约为0即出于这个原因。

在这一假设下,方向相反的特征一般具有较强的负相关性,而非一个特征的正激活和负激活。这两种情况其实并非完全互斥的,例如在下图所示的特征重叠结构中,绿蓝两特征为两个无关概念,因稀疏性被表示于相反方向。而模型在左下-右上方向上安放了两个完全互斥的特征,我们可以也等价地理解为这是一个特征的正激活与负激活。不论按何种方式理解,这个空间一定发生了特征重叠。

总体而言,我们认为这个任务上的结果有一定借鉴意义,其与真实语言模型的差异不会影响我们的基本结论,但仍是重要的细节。

如何解释字典学习的特征?

假设我们提取出字典学习的特征,并能够对每个特征获取激活较大的若干输入,我们该如何解释这种输入?从长远来看,我们必然需要某种自动化的方法,例如用大语言模型解释语言特征。 即便如此,人类解释仍需要验证和补充这些结果,因为解释特征的最终目标是为人类所理解,整个解释流程中至少需要一个环节采用人力介入。我们认为如何可规模化的解释这些特征是一个重要问题, 我们猜想可以采用强大的多模态大模型解释某个特征最感兴趣的棋局,并采用人力调整解释结果。但鉴于这不是我们的研究重点,我们没有使用这些方法,而是通过人类直观解释。

为了评估解释结果,我们在少量的特征中进行了验证,流程与现有的自动化解释思路类似。我们以一个例子说明:

通过可视化特征激活可以获取对激活专一性与敏感性的直观理解,但定量化的评估指标仍是必要的,我们认为这也是未来重要的研究方向。

字典学习与Probing

字典学习与Probing都可以发现模型隐层中可解释的方向,两者间一个重要的区别在于字典学习额外引入了重构误差,清晰了我们距离完全理解隐层特征的距离。理想情况下,如果字典学习可以得到0重构误差,且每个特征都高度单语义可解释,我们就总能完全理解这个隐层包含的所有特征集合。

另一个重要的区别在于,相比于基于probing的方法需要首先提出特征的定义或分类标签,字典学习可以无监督地提取特征。但监督信号仍可能是需要的,这是因为解释这些特征时仍可能需要许多先验知识。 本文所示的任务便是个极好的例子。在原文中作者们采用黑色与白色做线性probing取得的效果并不理想,后续的工作中从“己方”与“对方”的角度极大地提升了probing的表现,这说明已有的对字典学习内部特征族的理解很可能帮我们解释一大批原本并不清晰的特征。

我们认为可能最重要的问题是字典学习特征的基本性,我们在下章中单独讨论。

基本特征与复合特征

先前对OthelloGPT的研究中,有研究者用probe的方法发现模型中存在Flipped features,即指示某一棋盘位置是否在当前步被翻转,而在我们的字典学习结果中,我们没有直接发现这样的特征。 但与之类似的,我们识别到一系列特征,指示当前位置落子于某处且翻转了某个固定方向的例子。例如,在L0M中存在一对特征,均对当前步落子于c-2处,然而仅当该步落子翻转了右侧的棋子时,L0M195被激活;该步落子翻转右上方棋子时,L0M205才被激活。 这两个特征在表征空间中具有-0.13的余弦相似度,说明两者之间具有相当的独立性。

因此我们猜测,字典学习得到的特征集合可能是更加“基本”的特征。在上例中,probing得到的Flipped features可能是该棋盘位置被来自其周围8个方向对应翻转特征的组合,因此这是一个由基本特征线性组合而成的复合特征。 这是一个比较模糊的话题。我们认为“基本”是一种相对的概念,Anthropic的字典学习研究中表明,当字典隐层不断被扩大,得到的特征的描述粒度不断细分,而大字典中的特征子集可以构成一个“特征簇”,在小字典中以单个或更少数量的子集展现。 我们在这一观察的基础上认为,字典越大,分解得到的特征越基本,这种基本性是由于字典训练过程中为了用有限的神经元数量形成稀疏分解为动力的,而probing得到的特征方向则不具备任何有关基本性的先验。这也是我们得出上述结论的依据。

类似地,我们猜想现实语言模型中指示情感或事实性的特征也可能是一系列基本特征的线性组合,在未来的字典学习研究计划中,我们认为这个问题可能十分重要。

基本特征与复合特征

回路与随机性

在QK,OV和MLP三种回路中,我们均可以刻画某种贡献,在QK回路中这种贡献是两条残差流中任意特征对的双线性结果,在OV回路中是来自其它token残差流的特征经过OV回路的结果,在MLP中是近似直接贡献。 然而,由于特征重叠的原因,独立的特征对无法被表示在正交的方向上,因此它们之间的贡献应当满足某种随机分布。我们需要明确任意两个特征间的贡献是来源于随机性,还是Transformer学习到的回路使然。 例如在前文所述的近似直接贡献图中,我们无法确定应该以何处作为分界线,以明确这之前的强贡献特征是模型实现有意为之,而这之后只是特征重叠带来的微量噪声:

现实的情况更可能是不存在这一条分界线。特征重叠的重要动力是减少训练损失,因此将两个较弱正相关的特征建立某种微弱的正向重叠以求期望上的最小信息干扰,故我们观测到的贡献某种程度上反映了相关性的强弱关系。

这模糊了我们对回路理解的定义。结合现有的回路研究,我们更加相信模型内部的多数行为由多个或正或负的回路组合而成,正如同模型预测的logit是一种混合策略,其内部信息的流动也是混合而成的。 在这一猜想下,我们或许只能理解最主要的那一部分,并随时准备迎接可能的可解释性幻觉。但仅是如此,我们认为已经足够我们非常深入地了解Transformer内部的工作原理。

此外,由于真实的语言模型中特征更稀疏,隐层却更宽,且所有特征张成的语义集合(即世界)更广阔,我们猜测在语言模型中来自OV回路的信息流动会更明确。

本文理论框架的可规模化

最终,我们希望用类似的分析,加以许多工程上的改进,将本文涉及的方法应用于较大的语言模型。尽管本文采用的模型较小,任务也并非语言建模,但由于字典学习部分和回路分析部分均与模型大小或任务无关,我们相信现有的结果已经足以支撑语言模型上的直接应用。

一个重要的待解决问题是语言模型上的字典训练。在最基础的稀疏自编码器结构上,已有很多工作提出了算法上的改进与实践经验,更优的算法设计可以帮助我们找到计算资源、重建误差和特征可解释性三者的帕累托最优边界。这一问题在所有字典学习工作都至关重要。 例如稀疏约束优化的最终目标应是字典隐层的L0-norm,广泛采用的L1-norm有凸性等很好的性质,但是否存在更好的稀疏约束损失是一个可能的重要问题。 字典优化上,规整Adam动量与重采样沉睡特征等训练技巧也仍需要验证与补充。

对字典特征的解释已在前文讨论过,在语言模型中,用强大的语言模型自动解释特征至少可以为每个特征赋予一个有相当可信度的初值。此外,搭建交互式的可解释性界面很可能是人类微调或分析的核心接口。 我们相信这些工程问题有可能在不断优化的过程中整合,发展成较成熟的可解释性范式。

我们的回路分析理论本身不受规模化的影响,但需要对transformer的架构发展做出一定调整。 对QK回路的分析与流行的位置编码方式是完全兼容的;流行的Llama架构中采用的GLU模块采用了一部分参数用于执行数据依赖的自门控,这会使近似直接贡献丢失更多信息,我们需要更好地推广近似直接贡献以更好地解决MLP回路的解释问题。

回路分析的实践中,交互式的接口会对回路发现带来极大便捷。这些工程问题都可能转化为可解释性可规模化过程中的副产物。

总结

我们提出了一种非常一般性的可解释性理论,在字典学习可以提取每个MLP与注意力模块中尽可能多可解释特征的假设下,我们进一步提出了一种回路发现理论框架,可以将计算图中的所有特征联系起来,形成一张极其密布的连接关系图。 尽管这张理想的图中涉及非常多的连接,我们认为特征的稀疏性与特征间相当程度上的独立性对理解这些连接很有帮助,使得我们可以在人类可接受的复杂度下理解每个特征与其底部所有特征的联系。其中稀疏性保证了每个输入只激活少数的特征,且我们猜想每个特征的激活来源只有其底部的少数特征而非所有特征共同作用。

由于我们理论的一般性,我们的视野涉及了非常多先前Mech Interp的研究方式,我们猜想通过训练线性probe可以发现的特征也可以通过字典学习发现,且Activation Patching等方法在这个框架中也可以自然地应用。如果这个方法能够被应用在至少GPT2-small级别的语言模型中,我们相信我们可以在这个框架下“重新发现”许多已有的结论,包括注意力头(组)的局部作用、MLP中的知识与发现种类繁多的全局回路,并帮助我们发现新的现象。 可以用两种“一般性”来总结这套理论的重要意义。 首先是特征一般性,我们不需要有关模型内部信息的任何先验知识,仅需无监督地训练字典便可分解出许多可解释的特征; 其次是回路一般性,我们的回路分析弱化了对回路结构的先验认识,我们只需定位到与之相关的特征,并从之出发(或者更简单地,从输出出发),将特征的激活分解为其(或其它token的)底部特征贡献。这个过程可以很轻松地发现局部的回路,递归地应用这个过程有望帮助我们发现许多组合回路。

但正因为其一般性,想要在一篇这样的散文中讨论所有的细节是很不现实的,仅是处在核心的字典学习也需要许多篇幅将其中许多重要的细节阐述清楚。 同时,正因为这种一般性,我们能在此展示的结论只是模型的众多行为中非常小的一部分。若将我们展示出的部分视作“特征簇”与“回路簇”中的代表,我们可以稍微自信一些地说我们可能理解了部分模型内部的工作原理,但仍没有解释Othello-GPT的所有现象,例如终局回路。 尽管我们相信这种分析已经将很多有关Transformer的奥秘分解出来,如何系统化地为人类理解仍是一个问题。理解特定的logit或特定的特征的激活来源具有相当细的粒度,也正因此,这种理解就需要重复天文数字般的次数才能完全理解模型的每一个特征与行为。

此外,我们认为我们在字典学习与回路发现理论上均处在初级阶段,我们无法声称字典学习完全提取了所有的特征,或声称每一个特征都是单语义的或可解释的,也无法声称发现的每一个特征激活都能充分或无误地被分解为底层的来源。字典学习与回路发现理论的任何优化都可能为我们展开新的可解释性可能。 但不论我们对其中的哪些细节存在担忧,我们认为这套理论为后续的,尤其是基于字典学习的,可解释性研究打下了不错的基础,让我们对完全理解Transformer内部的终极目标又多了一些希望。

这篇工作的另一个重要意义是作为Open-MOSS Interpretability组的第一篇post,明确了我们对可解释性研究的一些思考。我们希望重点研究Mechanistic Interpretability的系统化发展,找到一种人类理解和高维表征空间的连接,并将数以亿计的模型参数抽丝剥茧以理解其智能。字典学习是目前我们比较寄予厚望的基础,我们也随时准备拥抱新的可能性。

Contributions

Zhengfu He提出了本文的理论框架,完成实验与初稿的撰写;Tianxiang Sun为本文提供了大量意见,并提出了总讨论中字典特征解释方法一章涉及的观点;Qiong Tang负责本文的展示与绘图,为第四部分的实验设计提供了建设性意见;Xipeng Qiu是团队的导师,对总讨论中黑白棋与语言模型一章与可规模化一章提出了重要的建议。

Acknowledgements

Kenneth Li与Neel Nanda在OthelloGPT上的开源研究对本工作影响很大,我们在已有的大量有关Othello游戏和棋盘状态可视化的基础上做了改动,才使字典特征的解释成本大大降低。 Neel Nanda对Mech Interp社区的贡献与精妙的想法分享对本文的构思起到重要影响。其开源的Transformer_lens库解决了Mech Interp领域相当一部分的工程难题,也是实现本文所需的重要框架。

本文使用的计算资源由复旦大学智能计算平台(CFFF)支持,CFFF工作人员的热情与专业为这篇工作的顺利进行提供了长足的保障。

缺失以上提及的任何一部分,本文的完成都将困难重重,我们对各位的支持十分感激。

Footnotes

  1. 即使其没有任何先验的可解释性设计[↩]
  2. 这里有一个plotly的陈年bug,当动画回到过去的帧时,画布上的部分内容不会被清除。[↩]
  3. 存在棋盘未被填满便无棋可下的情况,我们不考虑这样的少见情况[↩]
  4. 这个任务只要求每一步均是合法的下法,并不在乎输赢。[↩]
  5. 99.9%[↩]
  6. 例如互联网 = 用户 + 网线 + 社交软件 + ...[↩]
  7. (d_mlp / d_model)的平方。[↩]
  8. 我们没有分解词表征,这是因为这个任务比较特殊:词表大小甚至小于模型隐层维度,这是很反常的;此外,由于任务的设定,每个token的定义是相当明确的。至于pos embedding的解释,本模型采用的learned pos emb不具备太多参考的意义。但是在未来应用于语言建模任务时,分解词表征可能十分重要。[↩]
  9. 我们采样5万条数据,统计每个token对应的激活在每个字典上的分解结果,若某特征的激活次数大于64,则我们认为是活跃的特征(64 / (60 * 50000) ~= 2e-5)。[↩]
  10. 即使不是人类可理解的[↩]
  11. 例如,第3层Attn层字典的第728个特征称为L3A728,第0层MLP的第17个特征称为L0M17。[↩]
  12. 此处重载了c_m \mathbf{d}_m的记号。经过LayerNorm线性近似的每个特征方向和激活值都会发生变化,为便于理解我们不再引入新的记号。[↩]
  13. 此处与当前步不符。[↩]

References

  1. Thread: Circuits
    Cammarata, N., Carter, S., Goh, G., Olah, C., Petrov, M., Schubert, L., Voss, C., Egan, B. and Lim, S.K., 2020. Distill. DOI: 10.23915/distill.00024
  2. A Mathematical Framework for Transformer Circuits
    Elhage, N., Nanda, N., Olsson, C., Henighan, T., Joseph, N., Mann, B., Askell, A., Bai, Y., Chen, A., Conerly, T., DasSarma, N., Drain, D., Ganguli, D., Hatfield-Dodds, Z., Hernandez, D., Jones, A., Kernion, J., Lovitt, L., Ndousse, K., Amodei, D., Brown, T., Clark, J., Kaplan, J., McCandlish, S. and Olah, C., 2021. Transformer Circuits Thread.
  3. Linear Algebraic Structure of Word Senses, with Applications to Polysemy[link]
    Arora, S., Li, Y., Liang, Y., Ma, T. and Risteski, A., 2018. Trans. Assoc. Comput. Linguistics, Vol 6, pp. 483--495. DOI: 10.1162/TACL\_A\_00034
  4. Zoom In: An Introduction to Circuits
    Olah, C., Cammarata, N., Schubert, L., Goh, G., Petrov, M. and Carter, S., 2020. Distill. DOI: 10.23915/distill.00024.001
  5. Toy Models of Superposition
    Elhage, N., Hume, T., Olsson, C., Schiefer, N., Henighan, T., Kravec, S., Hatfield-Dodds, Z., Lasenby, R., Drain, D., Chen, C., Grosse, R., McCandlish, S., Kaplan, J., Amodei, D., Wattenberg, M. and Olah, C., 2022. Transformer Circuits Thread.
  6. Learning Sparse Overcomplete Word Vectors Without Intermediate Dense Representations[link]
    Chen, Y., Li, G. and Jin, Z., 2017. Knowledge Science, Engineering and Management - 10th International Conference, KSEM 2017, Melbourne, VIC, Australia, August 19-20, 2017, Proceedings, Vol 10412, pp. 3--15. Springer. DOI: 10.1007/978-3-319-63558-3\_1
  7. SPINE: SParse Interpretable Neural Embeddings[link]
    Subramanian, A., Pruthi, D., Jhamtani, H., Berg-Kirkpatrick, T. and Hovy, E.H., 2018. Proceedings of the Thirty-Second AAAI Conference on Artificial Intelligence, (AAAI-18), the 30th innovative Applications of Artificial Intelligence (IAAI-18), and the 8th AAAI Symposium on Educational Advances in Artificial Intelligence (EAAI-18), New Orleans, Louisiana, USA, February 2-7, 2018, pp. 4921--4928. AAAI Press. DOI: 10.1609/AAAI.V32I1.11935
  8. Word Embedding Visualization Via Dictionary Learning[PDF]
    Zhang, J., Chen, Y., Cheung, B. and Olshausen, B.A., 2019. CoRR, Vol abs/1910.03833.
  9. Word2Sense: Sparse Interpretable Word Embeddings[link]
    Panigrahi, A., Simhadri, H.V. and Bhattacharyya, C., 2019. Proceedings of the 57th Conference of the Association for Computational Linguistics, ACL 2019, Florence, Italy, July 28- August 2, 2019, Volume 1: Long Papers, pp. 5692--5705. Association for Computational Linguistics. DOI: 10.18653/V1/P19-1570
  10. Transformer visualization via dictionary learning: contextualized embedding as a linear superposition of transformer factors[link]
    Yun, Z., Chen, Y., Olshausen, B.A. and LeCun, Y., 2021. Proceedings of Deep Learning Inside Out: The 2nd Workshop on Knowledge Extraction and Integration for Deep Learning Architectures, DeeLIO@NAACL-HLT 2021, Online, June 10 2021, pp. 1--10. Association for Computational Linguistics. DOI: 10.18653/V1/2021.DEELIO-1.1
  11. Sparse autoencoders find highly interpretable features in language models
    Cunningham, H., Ewart, A., Riggs, L., Huben, R. and Sharkey, L., 2023. arXiv preprint arXiv:2309.08600.
  12. Emergent World Representations: Exploring a Sequence Model Trained on a Synthetic Task[link]
    Li, K., Hopkins, A.K., Bau, D., Vi\'egas, F.B., Pfister, H. and Wattenberg, M., 2023. The Eleventh International Conference on Learning Representations, ICLR 2023, Kigali, Rwanda, May 1-5, 2023. OpenReview.net.
  13. Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets[PDF]
    Power, A., Burda, Y., Edwards, H., Babuschkin, I. and Misra, V., 2022. CoRR, Vol abs/2201.02177.
  14. The Clock and the Pizza: Two Stories in Mechanistic Explanation of Neural Networks[link]
    Zhong, Z., Liu, Z., Tegmark, M. and Andreas, J., 2023. CoRR, Vol abs/2306.17844. DOI: 10.48550/ARXIV.2306.17844
  15. Progress measures for grokking via mechanistic interpretability[link]
    Nanda, N., Chan, L., Lieberum, T., Smith, J. and Steinhardt, J., 2023. The Eleventh International Conference on Learning Representations, ICLR 2023, Kigali, Rwanda, May 1-5, 2023. OpenReview.net.
  16. On Interpretability and Feature Representations: An Analysis of the Sentiment Neuron[link]
    Donnelly, J. and Roegiest, A., 2019. Advances in Information Retrieval - 41st European Conference on IR Research, ECIR 2019, Cologne, Germany, April 14-18, 2019, Proceedings, Part I, Vol 11437, pp. 795--802. Springer. DOI: 10.1007/978-3-030-15712-8\_55
  17. On the importance of single directions for generalization[link]
    Morcos, A.S., Barrett, D.G.T., Rabinowitz, N.C. and Botvinick, M.M., 2018. 6th International Conference on Learning Representations, ICLR 2018, Vancouver, BC, Canada, April 30 - May 3, 2018, Conference Track Proceedings. OpenReview.net.
  18. Visualizing and Understanding Recurrent Networks[PDF]
    Karpathy, A., Johnson, J. and Fei-Fei, L., 2015. CoRR, Vol abs/1506.02078.
  19. Network Dissection: Quantifying Interpretability of Deep Visual Representations[link]
    Bau, D., Zhou, B., Khosla, A., Oliva, A. and Torralba, A., 2017. 2017 IEEE Conference on Computer Vision and Pattern Recognition, CVPR 2017, Honolulu, HI, USA, July 21-26, 2017, pp. 3319--3327. IEEE Computer Society. DOI: 10.1109/CVPR.2017.354
  20. Multimodal Neurons in Artificial Neural Networks
    Goh, G., †, N.C., †, C.V., Carter, S., Petrov, M., Schubert, L., Radford, A. and Olah, C., 2021. Distill. DOI: 10.23915/distill.00030
  21. Privileged Bases in the Transformer Residual Stream
    Elhage, N., Lasenby, R. and Olah, C., 2023. Transformer Circuits Thread.
  22. Towards Monosemanticity: Decomposing Language Models With Dictionary Learning
    Bricken, T., Templeton, A., Batson, J., Chen, B., Jermyn, A., Conerly, T., Turner, N., Anil, C., Denison, C., Askell, A., Lasenby, R., Wu, Y., Kravec, S., Schiefer, N., Maxwell, T., Joseph, N., Hatfield-Dodds, Z., Tamkin, A., Nguyen, K., McLean, B., Burke, J.E., Hume, T., Carter, S., Henighan, T. and Olah, C., 2023. Transformer Circuits Thread.
  23. Emergent Linear Representations in World Models of Self-Supervised Sequence Models[link]
    Nanda, N., Lee, A. and Wattenberg, M., 2023. CoRR, Vol abs/2309.00941. DOI: 10.48550/ARXIV.2309.00941
  24. Interpretability in the Wild: a Circuit for Indirect Object Identification in GPT-2 Small[link]
    Wang, K.R., Variengien, A., Conmy, A., Shlegeris, B. and Steinhardt, J., 2023. The Eleventh International Conference on Learning Representations, ICLR 2023, Kigali, Rwanda, May 1-5, 2023. OpenReview.net.
  25. Successor Heads: Recurring, Interpretable Attention Heads In The Wild
    Gould, R., Ong, E., Ogden, G. and Conmy, A., 2023. arXiv preprint arXiv:2312.09230.
  26. Deep Learning using Rectified Linear Units (ReLU)[PDF]
    Agarap, A.F., 2018. CoRR, Vol abs/1803.08375.
  27. Sigmoid-weighted linear units for neural network function approximation in reinforcement learning[link]
    Elfwing, S., Uchibe, E. and Doya, K., 2018. Neural Networks, Vol 107, pp. 3--11. DOI: 10.1016/J.NEUNET.2017.12.012
  28. Bridging Nonlinearities and Stochastic Regularizers with Gaussian Error Linear Units[PDF]
    Hendrycks, D. and Gimpel, K., 2016. CoRR, Vol abs/1606.08415.
  29. Softmax Linear Units
    Elhage, N., Hume, T., Olsson, C., Nanda, N., Henighan, T., Johnston, S., ElShowk, S., Joseph, N., DasSarma, N., Mann, B., Hernandez, D., Askell, A., Ndousse, K., Jones, A., Drain, D., Chen, A., Bai, Y., Ganguli, D., Lovitt, L., Hatfield-Dodds, Z., Kernion, J., Conerly, T., Kravec, S., Fort, S., Kadavath, S., Jacobson, J., Tran-Johnson, E., Kaplan, J., Clark, J., Brown, T., McCandlish, S., Amodei, D. and Olah, C., 2022. Transformer Circuits Thread.
  30. Language models can explain neurons in language models
    Bills, S., Cammarata, N., Mossing, D., Tillman, H., Gao, L., Goh, G., Sutskever, I., Leike, J., Wu, J. and Saunders, W., 2023.
  31. Understanding intermediate layers using linear classifier probes[link]
    Alain, G. and Bengio, Y., 2017. 5th International Conference on Learning Representations, ICLR 2017, Toulon, France, April 24-26, 2017, Workshop Track Proceedings. OpenReview.net.
  32. Linear Representations of Sentiment in Large Language Models[link]
    Tigges, C., Hollinsworth, O.J., Geiger, A. and Nanda, N., 2023. CoRR, Vol abs/2310.15154. DOI: 10.48550/ARXIV.2310.15154
  33. Inference-Time Intervention: Eliciting Truthful Answers from a Language Model[link]
    Li, K., Patel, O., Vi\'egas, F.B., Pfister, H. and Wattenberg, M., 2023. CoRR, Vol abs/2306.03341. DOI: 10.48550/ARXIV.2306.03341
  34. Adam: A Method for Stochastic Optimization[PDF]
    Kingma, D.P. and Ba, J., 2015. 3rd International Conference on Learning Representations, ICLR 2015, San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings.
  35. LLaMA: Open and Efficient Foundation Language Models[link]
    Touvron, H., Lavril, T., Izacard, G., Martinet, X., Lachaux, M., Lacroix, T., Rozi\`ere, B., Goyal, N., Hambro, E., Azhar, F., Rodriguez, A., Joulin, A., Grave, E. and Lample, G., 2023. CoRR, Vol abs/2302.13971. DOI: 10.48550/ARXIV.2302.13971
  36. GLU Variants Improve Transformer[PDF]
    Shazeer, N., 2020. CoRR, Vol abs/2002.05202.