Layer-Wise Sub-Model Interpretability for Transformers
A Novel Approach to Improving AI Interpretability
1. Introduction
Transformers have achieved remarkable success in NLP and other domains, but they operate as complex black-box models, making their decisions hard to interpret (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). This lack of transparency has raised concerns about safety, trust, and accountability when deploying AI systems in real-world applications (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). Researchers and practitioners increasingly recognize that we need better interpretability techniques to understand why a transformer makes a given prediction, not just what it predicts. Yet current interpretability methods provide only partial insights into these modelsâ inner workings.
Common post-hoc methods like saliency maps or feature attribution highlight which input tokens or features most influence the output (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). While useful for identifying important words (e.g. via Integrated Gradients or SHAP), such methods treat the model as a black-box and do not reveal the intermediate reasoning process. A major concern is that these attributions may not faithfully reflect the modelâs true decision-making â different attribution techniques often disagree, casting doubt on their reliability (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). Similarly, analyzing attention weights (the internal attention coefficients of transformers) has been popular, since attention seems to show what the model âfocusesâ on. However, attention is not a definitive explanation: Jain and Wallace (2019) demonstrated that one can alter a modelâs attention distributions without changing its outputs, and Serrano and Smith (2019) found that even ablating or randomizing attention weights often leaves model behavior unchanged (Is Attention Interpretable in Transformer-Based Large Language Models? Letâs Unpack the Hype). These findings imply that raw attention patterns, by themselves, do not provide a faithful explanation of the modelâs reasoning. In short, existing techniques like saliency maps and attention visualizations are insufficient â they offer surface-level insight but often miss the deeper computational patterns inside transformer networks.
The limitations of current interpretability methods motivate exploring new approaches. Recent progress in mechanistic interpretability attempts to open the black box by reverse-engineering model components into human-understandable algorithms (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). For example, researchers have identified individual neurons and attention heads that correspond to specific functions or concepts within language models (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). This work can yield mechanistic explanations for certain behaviors (e.g. a neuron that activates on negative words, an attention head that tracks coreference). However, such analyses are typically labor-intensive and done for isolated circuits. We lack a scalable, general way to explain a modelâs entire reasoning process on a given input. There is a clear need for interpretability methods that can provide step-by-step explanations of how a transformer processes information through its layers, bridging the gap between low-level circuit findings and high-level feature attributions.
In this paper, we propose a novel approach to transformer interpretability: Layer-Wise Sub-Model Interpretable Transformers. The core idea is to leverage the transformer’s own structure â its multiple layers of hidden representations â by attaching interpretable sub-models to each layer. Each sub-model acts as an explanatory module for that layerâs activations, decoding or analyzing the hidden state into a human-understandable form. By doing so for every layer, we obtain a progressive, layer-by-layer explanation of how the modelâs understanding evolves from input to output. In essence, the transformer is no longer a single monolithic black-box; it is augmented with âinner voicesâ or probes at each layer that narrate what that layer is doing or has learned. This approach promises fine-grained transparency: instead of only explaining the final prediction (as most methods do), we explain the intermediate thinking steps of the model. We hypothesize that layer-wise sub-models can address the shortcomings of saliency and attention-based methods by revealing the meaning of internal states, thus providing more faithful and detailed explanations of transformer decisions. The rest of this paper details this approach, including technical background, implementation examples, comparisons to existing methods, and implications for future research in AI interpretability.
2. Technical Background
Transformer Model Primer: Transformers are deep neural networks that process information in a sequence of layers, each producing a new representation of the input tokens. An input text is first tokenized into discrete tokens (subwords or words) and embedded into continuous vector representations. These token embeddings are then fed into the transformerâs stack of self-attention layers. Each transformer layer consists of a multi-head self-attention sublayer followed by a feed-forward network sublayer (with residual connections and layer normalization at each step). Through self-attention, the model blends information across all tokens: each tokenâs representation is updated by attending to other relevant tokens in the sequence. The feed-forward sublayer then further transforms these representations in a token-wise fashion. As the input passes through layer after layer, the token representations (often called hidden states) gradually encode more abstract and complex features. By the final layer, the transformer has distilled the input into a form from which it can easily produce an output â for example, a probability distribution over next words in language modeling, or a classification label in a classifier. In summary, a transformer processes text through iterative hidden state transformations: at layer 1 it has a rudimentary understanding, and at layer $N$ (the last layer) it has a task-specific understanding ready to drive the output.
Evolution of Meaning Across Layers: Importantly, the hidden states in each layer are not random â they contain structured information that evolves as depth increases. Empirical studies have shown that transformers tend to learn a hierarchy of linguistic and semantic features across their layers. For instance, in BERT, lower layers capture basic syntactic information (like part-of-speech tags), intermediate layers handle named entities and short-range dependencies, and higher layers encode high-level semantics and coreference links ([1905.05950] BERT Rediscovers the Classical NLP Pipeline). In other words, the model appears to rediscover the classic NLP pipeline internally: first morphology and syntax, then semantics and reasoning ([1905.05950] BERT Rediscovers the Classical NLP Pipeline). Similarly, analysis of feed-forward networks in transformers has found that early layers focus on âshallowâ patterns (e.g. local grammar), whereas later layers concentrate on more âsemanticâ or abstract patterns (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). For example, Geva et al. (2021) showed that the feed-forward sublayers can act as key-value memory storing factual knowledge; an early layer might respond to a common phrase or syntactic cue, while a late layer retrieves a specific fact or meaning associated with that cue (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). These findings suggest that as we ascend the stack, the representation of each token gets refined: noise and ambiguity are reduced, and the modelâs interpretation of the input becomes more specific to the task at hand.
To concretize, consider a transformer-based question answering model receiving the question “Who wrote Pride and Prejudice?”. In the first few layers, the modelâs representations may highlight that the question is asking for a personâs name (an author) and recognize the token “Pride and Prejudice” as a book title. Mid-level layers might bring in contextual or factual associations, e.g. activating features related to literary authors or the 19th century. By the final layers, the modelâs hidden state for the question likely strongly encodes the answer “Jane Austen” (the author) before it is even produced as output. Indeed, the concept of the âlogit lensâ in interpretability is built on this premise: if you take the hidden state from an intermediate layer and directly map it to the output vocabulary, you often find that the modelâs latent prediction is increasingly close to the final answer as you go deeper (Eliciting Latent Predictions from Transformers with the Tuned Lens). Nostalgebraist (2020) first observed that decoding each layer of GPT-2 with the output embedding matrix yields distributions that converge monotonically toward the correct final token (Eliciting Latent Predictions from Transformers with the Tuned Lens). This implies the model is accumulating evidence or partial guesses layer by layer. In our example, a logit lens on an early layer might produce a broad distribution over many authors, but by layer 10 it might have âJane Austenâ as a high-probability candidate, indicating the model has essentially figured out the answer internally even before the last layer. Such insights underline that intermediate hidden states carry significant meaning and often an implicit draft of the modelâs eventual output.
Mechanistic Interpretability and Neuron Analysis: Beyond layer-wise aggregate behavior, researchers have also probed the internal mechanics at the level of neurons and attention heads. Each transformer layer contains numerous neurons (the dimensions of the hidden state) and multiple attention heads. Mechanistic interpretability research seeks to identify what these individual components represent and how they combine to perform computations (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). Notable findings include the discovery of specialized attention heads that perform distinct roles. For example, certain attention heads in GPT-2 were found to track long-range dependencies like quote pairings or to perform simple arithmetic in context. In BERT, some attention heads clearly specialize in coreference resolution (linking pronouns to the correct entities) (Is Attention Interpretable in Transformer-Based Large Language Models? Letâs Unpack the Hype), while others attend mostly to next or previous words to encode positional relationships. There is evidence of a division of labor: attention heads have interpretable roles, albeit as âteam playersâ working in concert (Is Attention Interpretable in Transformer-Based Large Language Models? Letâs Unpack the Hype). At the neuron level, analysis is more challenging because transformers use distributed representations (each concept is encoded across many neurons, and each neuron participates in many concepts â a phenomenon known as superposition). Still, a few individual neurons have been identified with strong correlations to specific features. For instance, researchers found a neuron in GPT-2 that appears to track whether the text is inside a quotation (a quote-tracking neuron), and another that fires on negation words indicating a negation concept. Such neuron-level discoveries often come from exhaustive probing or even automated techniques. Recent work by OpenAI used GPT-4 to generate natural language explanations for the behavior of every neuron in GPT-2, producing a dataset of (imperfect) descriptions for what each neuron might be looking for (Language models can explain neurons in language models | OpenAI). This kind of approach treats neurons themselves as âunits of thoughtâ to be explained, and though not fully reliable yet, it represents a step toward scaling neuron interpretability. Overall, prior research in mechanistic interpretability has demonstrated that transformer internals are not inscrutable â there is structure and algorithmic behavior we can uncover. However, most such work either focuses on local interpretability (e.g. one neuron or one head at a time) or requires significant manual effort and intuition to piece together into a global explanation.
Summary: A transformer processes inputs through successive transformations, with each layerâs hidden state encoding a more refined understanding. There is a rich internal dynamics: early layers handle low-level patterns, later layers handle high-level meaning ([1905.05950] BERT Rediscovers the Classical NLP Pipeline) (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). While tools like attention maps and saliency highlight some aspects of this process, they fall short of explaining how each layer contributes to the final decision. Mechanistic studies give insight into components but are hard to apply comprehensively to each instance of model usage. This background sets the stage for our proposed solution: using layer-wise sub-models to systematically interpret each layer, effectively peeling back the transformer’s computations one layer at a time.
3. Proposed Approach: Layer-Wise Sub-Model Interpretable Transformers
We propose an approach that integrates interpretability into the very structure of transformer models by introducing layer-wise sub-models for interpretation. The key idea is to attach an auxiliary interpretable model to each transformer layer, which takes that layerâs hidden state as input and produces an explanatory output that is understandable to humans. These sub-models serve as lenses or translators that can read the state of the transformer at their respective layer and express what it represents (or what it is âthinkingâ) in a concise, meaningful way.
Sub-Models at Each Layer: Concretely, for each layer $\ell$ in the transformer (where $\ell = 1, 2, âŚ, N$), we introduce a sub-model $F_\ell$ that consumes the hidden activations $h_\ell$ (the set of token vectors after layer $\ell$) and generates an explanation or interpretation $E_\ell$. The form of $E_\ell$ can vary depending on the application â it could be a natural language sentence summarizing the layerâs understanding, a set of labels or features (e.g. âLayer 3 has identified this sentenceâs sentiment as positiveâ), or even a distribution over the final output vocabulary (as in the logit lens). The sub-model $F_\ell$ itself could be as simple as a linear classifier or as complex as a small neural network; the only requirement is that $F_\ell$ is interpretable by design or yields an output in an interpretable format. For example, one might use:
- A linear or decision-tree model that takes the hidden state and predicts a human-interpretable attribute (such as the part-of-speech tags of each token, the subject of the sentence, etc.).
- A small decoder network that cross-attends to $h_\ell$ and generates a natural-language description of what is encoded at layer $\ell$. This is akin to using a miniature language model to âtranslateâ the hidden vector into words (in spirit similar to the DecoderLens method which verbalizes encoder activations) (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers).
- A multi-label classifier that outputs which high-level concepts are active at this layer (for example, in an image transformer one might output concepts like âedges detectedâ at early layers and âobject shape = catâ at later layers; in language, early layers might output âsyntax parsedâ and later âcoreference resolvedâ or âtopic: sportsâ).
Crucially, each sub-model focuses only on the state of one layer. This modularity means the interpretation of each layer is localized and specific, rather than entangled with the entire network at once. By designing $F_\ell$ to be much simpler than the full transformer, we aim for these sub-models to be transparent. In essence, $F_\ell$ acts as an interpreter for layer $\ell$, extracting the content of $h_\ell$ in a form we can reason about.
Extracting and Explaining Layer-Wise Meaning: With sub-models in place for all layers, we can obtain a sequence of explanations $(E_1, E_2, âŚ, E_N)$ for a given input as it flows through the transformer. This sequence should reveal how the modelâs knowledge and hypotheses are refined layer by layer. For example, consider again a question-answering transformer. The explanation $E_1$ (after the first layer) might say something like: âIdentified question words; focusing on entity âPride and Prejudiceâ.â By layer 6, $E_6$ might read: âUnderstands the question is asking for an author of âPride and Prejudiceâ; likely referring to a book author.â By the penultimate layer, $E_{N-1}$ could be âModel strongly anticipates the answer is âJane Austenâ; high confidence in this author name.â Finally, the model outputs âJane Austen.â In this way, the sequence of $E_\ell$ provides a narrative of the modelâs reasoning, from initial comprehension to final answer. Each $E_\ell$ explains the incremental progress made by layer $\ell$. Early layers handle more generic or surface-level information, while later layers provide more detailed and task-specific information â exactly mirroring how the hidden state evolves, but now made explicit.
One can view this progressive explanation as analogous to how a human might break down their reasoning. Instead of jumping directly from question to answer, a human might think: âOkay, the question mentions Pride and Prejudice, thatâs a novel. Next, the question asks who wrote it â so itâs looking for an authorâs name. Pride and Prejudice was written by Jane Austen, I recall. Thus, the answer is Jane Austen.â Here we see distinct steps (identify topic, interpret question, recall fact, produce answer). Standard transformers do all this internally without explicit delineation. Our layer-wise sub-model approach aims to externalize these implicit steps, mapping them to the networkâs layered processing.
Progressive Refinement and Consistency: An important aspect of our approach is that the explanations $E_\ell$ should show progressive refinement. That is, $E_{\ell+1}$ should generally build upon or update the information in $E_\ell$. If we observe drastic or contradictory changes in the explanations from one layer to the next, that might indicate an interesting phenomenon (e.g., the model revising an earlier hypothesis). In many cases, however, we expect a coherent story. This progressive property can be enforced or encouraged when designing the sub-model outputs. For instance, sub-models could be built to condition on previous explanations (though that introduces complexity and inter-dependence) or we can simply observe after the fact that $E_1, E_2, âŚ$ align well. A concrete way to ensure consistency is to have each sub-model predict the final output as well (as an auxiliary task) â effectively each layer tries to guess the answer, as in early-exit models (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). If all layers are predicting the same final answer by the end, their intermediate explanations will likely be consistent with that answer. In our QA example, if by layer 6 the sub-model already predicts âJane Austenâ as the answer, it will shape its explanation accordingly. This doesnât force the layers to agree on everything (they might differ in confidence or reasoning details) but it tethers them to a common end goal.
Design and Architecture Considerations: There are multiple ways to implement layer-wise interpretable sub-models, each with trade-offs:
- Probing Approach (Post-hoc): We can train each sub-model after the transformer is trained, in a post-hoc probing fashion. Here we treat the transformer as fixed and train $F_\ell$ to map $h_\ell$ to a desired explanation output. For example, $F_\ell$ could be a simple classifier trained on a labeled dataset to predict some known property of the input or output that we believe layer $\ell$ should capture. If no labeled data for intermediate concepts is available, we might train $F_\ell$ to mimic the transformer’s own final output (as a distillation target) or to predict a proxy like the logit lens distribution (Eliciting Latent Predictions from Transformers with the Tuned Lens). Post-hoc training has the advantage that it doesnât interfere with the original modelâs training or performance. We can probe for many different types of information from the same $h_\ell$. The downside is that $F_\ell$ might extract information that is present in $h_\ell$ but not necessarily used by the model, raising questions of faithfulness. However, techniques like the Tuned Lens (which learns a small transformation to better align intermediate hidden states with final logits) show that even a simple learned mapping can faithfully reveal the modelâs latent predictions (Eliciting Latent Predictions from Transformers with the Tuned Lens) (Eliciting Latent Predictions from Transformers with the Tuned Lens).
- Integrated Training (Auxiliary Losses): Alternatively, we can integrate the sub-models into the transformerâs training process. In this setup, when training the main model, we add auxiliary objectives for each layer: the sub-modelâs output $E_\ell$ should correctly predict some target (e.g. part-of-speech tags for layer 1, semantic roles for layer 6, final answer for layer $N$, etc.). This is reminiscent of âdeep supervisionâ or early-exit networks, where intermediate layers are directly trained to produce outputs (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). By providing an explicit learning signal to each layer, we encourage the model to organize information in that layer in a human-interpretable way. For example, if layer 5 is trained to output a rough answer prediction, the model might ensure that by layer 5 it has already gathered most information needed for the answer, making layer 5âs hidden state more interpretable. Integrated training can thus bake interpretability into the modelâs representations. The trade-off is additional complexity in training and potential interference with the primary task optimization â the model might sacrifice a bit of final accuracy to satisfy intermediate explanations. Careful weighting of losses and ensuring $F_\ell$ models are lightweight can mitigate this.
- Choice of Sub-model: The simplest sub-model is a linear probe (a single linear layer mapping hidden state to some predefined labels or vocabulary). Linear probes are easy to interpret (weights highlight which dimensions correlate with which features) and cheap to train (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). They have been widely used to analyze transformers, e.g. testing which layers encode grammatical number or gender. However, linear probes might be too restrictive for rich explanations. More complex sub-models like a two-layer MLP or a mini-transformer decoder can capture nonlinear relations in $h_\ell$ and produce more flexible outputs (like sentences). We propose experimenting with a small transformer decoder as $F_\ell$ that generates a textual explanation from the hidden state. Such a decoder can be thought of as âasking the model to explain itself in English.â It could be trained on a corpus where explanations are available, or even jointly trained with a language modeling objective using the main modelâs hidden state as context. Prior work (DecoderLens) has shown that using the modelâs own decoder (in an encoder-decoder architecture) to interpret encoder layers yields meaningful word sequences explaining those layersâ content (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). In our approach, if using a decoder sub-model, we might feed $h_\ell$ through a cross-attention mechanism to an explanation decoder that has learned to output sentences like the ones we hypothesized (âLayer $\ell$ thinksâŚâ). This is an exciting direction because it directly produces human-language rationales for each layerâs computations.
- Computational Trade-offs: Adding sub-models naturally increases the computational load. If we have $N$ layers and attach even a small network to each, inference time could grow (especially if we run all sub-models for every input to get a full explanation). There are ways to manage this: we could run sub-models only when needed (for research/debugging, not every production inference), or design them to be very efficient (e.g. linear heads). Another consideration is memory â storing all intermediate activations $h_\ell$ is already common (if
output_hidden_states=True
in frameworks), so thatâs not a big issue, but running a large decoder on each layerâs output would be heavy. A compromise is to attach sub-models to a subset of layers (e.g. every 2nd or 3rd layer) if fine granularity is not needed. We will demonstrate with code that extracting hidden states is straightforward; training and using sub-models can be scaled according to the desired interpretability depth and resources available. In Section 5, we discuss how our approach compares in cost-benefit to other interpretability techniques.
In summary, our proposed approach transforms a standard transformer into an interpretable multi-stage model. Each layerâs computation is accompanied by a transparent âexplanationâ via a dedicated sub-model. This yields a high-resolution view of the modelâs internal decision process, addressing the black-box criticism. We expect layer-wise explanations to help answer questions like: What does the model know after this layer? What intermediate decision did it make? By making the invisible visible at every layer, we aim to greatly enhance the interpretability of transformer-based AI systems.
4. Implementation and Example Code
Implementing layer-wise interpretability involves two main steps: (1) extracting the intermediate activations from a transformer, and (2) designing/training sub-models to interpret those activations. Modern deep learning frameworks make it easy to access hidden states, and we can leverage that to build our interpretability pipeline. In this section, we provide example code snippets and explanations to illustrate how one might realize the proposed approach in practice. The examples are in Python using the Hugging Face Transformers library and PyTorch, which are common tools for transformer models.
Extracting Hidden States: We first need to obtain the hidden state outputs for each layer of a transformer. In Hugging Faceâs API, many models allow this by enabling an output_hidden_states
flag. For instance, to get hidden states from BERT or GPT-2:
from transformers import AutoTokenizer, AutoModel
# Load a pretrained transformer (e.g., BERT base) with hidden state output enabled
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
# Encode an example input
sentence = "The capital of France is [MASK]."
inputs = tokenizer(sentence, return_tensors='pt')
# Forward pass to get outputs (including hidden states)
outputs = model(**inputs)
hidden_states = outputs.hidden_states # tuple of length (num_layers + 1)
print(f"Number of layers (including embedding output): {len(hidden_states)}")
print(f"Shape of hidden state tensor for layer 0 (embeddings): {hidden_states[0].shape}")
print(f"Shape for layer 1: {hidden_states[1].shape}")
This code loads a pre-trained BERT model and runs a sample sentence through it. By specifying output_hidden_states=True
, the model returns a tuple of hidden states. Typically, hidden_states[0]
is the embedding output (layer 0 before any self-attention), and hidden_states[i]
for i=1,...,12
correspond to the outputs of each of BERTâs 12 layers. Each hidden state is a tensor of shape (batch_size, sequence_length, hidden_dim)
. In our example, sequence_length
would be the number of tokens in the input (after tokenization), and hidden_dim
is 768 for BERT-base. The above code would print something like: 13 layers (including embeddings), and shapes such as torch.Size([1, 8, 768]) for each, meaning 1 sentence, 8 tokens, 768-dim vector per token.
With these hidden states in hand, we can begin to attach interpretation. Suppose we want to see what each layer would predict for the masked word in the example sentence (a fill-in-the-blank task). We can use a Logit Lens-style approach: take the hidden state at the mask token for each layer, and project it into the vocabulary space using BERTâs output layer. In BERT, predicting [MASK] is done via a classification head tied to the token embeddings, but for simplicity, we can use the pre-trained decoder (if available) or just a linear mapping via the embedding matrix. For demonstration, letâs assume a simpler scenario with GPT-2 (a causal language model), where the final linear layer is directly tied to token prediction. We can retrieve GPT-2âs hidden states and see the modelâs top predictions at each layer for the next token:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Load GPT-2 small with output of all hidden states
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2", output_hidden_states=True)
text = "The capital of France is"
inputs = tokenizer(text, return_tensors='pt')
outputs = model(**inputs)
hidden_states = outputs.hidden_states # tuple of hidden states
# Get the last token's hidden state from each layer and compute logits
last_token_index = inputs['input_ids'].size(1) - 1 # index of last input token
vocab_matrix = model.transformer.wte.weight.T # GPT-2 uses tied input embeddings as output weights
for layer_idx, h in enumerate(hidden_states[1:], start=1): # skip hidden_states[0] which is embeddings
# hidden_states[n] has shape (batch, seq_len, hidden_dim)
last_token_hidden = h[0, last_token_index, :] # vector for last token
logits = last_token_hidden @ vocab_matrix # project to vocab (approximate logits)
top5_indices = logits.topk(5).indices.tolist()
top5_tokens = tokenizer.convert_ids_to_tokens(top5_indices)
print(f"Layer {layer_idx} top predictions: {top5_tokens}")
This snippet uses GPT-2 for simplicity (since it directly predicts next tokens). We take the hidden vector corresponding to the last token in the input sequence (âisâ in the text “The capital of France is”) for each layer, and multiply by the transpose of the embedding matrix to get pseudo-logits for the next token. Then we find the top 5 predicted tokens. Running this code would output something like:
Layer 1 top predictions: [' of', ' in', ' the', ',', ' and']
Layer 2 top predictions: [' the', ' of', ' in', ',', ' The']
Layer 3 top predictions: [' Paris', ' Lyon', 'The', ' Rome', ' London']
...
Layer 12 top predictions: [' Paris', ' Lyon', ' London', ' Rome', ' Berlin']
These are illustrative results (not exact), but they demonstrate a trend: in early layers, GPT-2 hasnât narrowed down the next word â it predicts common continuations like â ofâ or â inâ. By layer 3 or 4, location names (Paris, Lyon, etc.) start appearing as likely candidates for âThe capital of France is _â. By the final layer (12), â Parisâ is the top prediction (the correct completion). This is exactly the latent knowledge refinement we expect, now made explicit per layer. Such intermediate predictions are a simple form of explanation: they tell us what the model is thinking at each layer about the next word. In our framework, one could consider this a very basic sub-model at each layer: a linear map to the vocabulary. Our approach generalizes this idea to more complex and descriptive interpretations, rather than just next-token guesses.
Training Interpretable Sub-Models: Next, consider training a small model to decode some aspect of the hidden state. As a toy example, let’s train a probe to predict whether a sentence is question or statement from each layerâs CLS token in BERT. We generate some labeled data (e.g., sentences ending with question mark vs period), feed them through BERT to get hidden states, then train a logistic regression on the hidden features. Although we cannot fully execute training here, pseudocode would be:
import torch
from sklearn.linear_model import LogisticRegression
# Suppose we have data: list of sentences and labels (1 if question, 0 if statement)
sentences = ["What is your name?", "This is a cat.", ...]
labels = [1, 0, ...]
# Get BERT CLS hidden states for each sentence at a certain layer (say layer 8)
cls_features = []
for sent in sentences:
inputs = tokenizer(sent, return_tensors='pt')
outputs = model(**inputs)
h_states = outputs.hidden_states
# CLS token is usually at position 0 for BERT input (with [CLS] token)
cls_vector = h_states[8][0, 0, :].detach().numpy() # layer 8, first token
cls_features.append(cls_vector)
# Train a simple logistic regression on these features
clf = LogisticRegression().fit(cls_features, labels)
print("Training accuracy:", clf.score(cls_features, labels))
# Inspect coefficients (optional interpretability of probe itself)
important_dims = np.argsort(np.abs(clf.coef_))[::-1][:10]
print("Top contributing hidden dimensions for question detection:", important_dims)
This code sketch shows how one might post-hoc train a probe on layer 8âs CLS token to classify sentence type. If BERT layer 8 encodes question-ness (which is plausible, as middle layers often capture sentence intent or syntax), the logistic regression should achieve good accuracy. The classifierâs weights also tell us which dimensions of the 768-dim hidden state were most useful for this task, giving some insight into the representation. In practice, one could similarly train sub-models to predict a variety of properties: sentiment, tense, topic, presence of certain keywords, etc., at different layers. Each such probe is effectively an interpretation function applied to the hidden state.
From Probing to Explaining: The ultimate goal, however, is to produce human-readable explanations of the modelâs âthought process.â Instead of just predicting properties, we might want each sub-model to output text explaining what the layer is doing. One way to achieve this is to use a language-modeling approach: train a small generative model that takes the hidden state as input and outputs an explanatory sentence. This could be done by creating a synthetic dataset: run a variety of inputs through the transformer, record the hidden states, and pair them with explanations. But who writes the explanations? If we have an existing interpretability tool or heuristic, we could generate them, or we could even use a large language model (like GPT-4) to label what each layer might be doing (as OpenAI did for neurons) (Language models can explain neurons in language models | OpenAI). This is an open-ended task, but as a proof of concept, one could manually label a few cases to train a prototype.
For example, we might take a set of sentences and manually note: at layer 2, the model likely identifies proper nouns; at layer 5, it resolves pronouns; at layer 10, it integrates context to decide sentiment. Training on such data (even if small) can allow a decoder to generate similar explanations for new inputs. A pseudo-training loop for a decoder sub-model might look like:
# Pseudo-code for training an explanation decoder for a given layer
explanation_model = SmallTransformerDecoder() # a small seq-to-seq model
for input_text, gold_explanations in training_data:
# get hidden states from main model
inputs = tokenizer(input_text, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
h_state = outputs.hidden_states[layer_of_interest][0] # take layer's full sequence hidden state
# Use h_state as context to the explanation model
# In practice, h_state could be fed through a cross-attention mechanism or concatenated.
expl_pred = explanation_model.generate(h_state)
loss = explanation_model.compute_loss(expl_pred, gold_explanations)
loss.backward()
optimizer.step()
Here, gold_explanations
would be a target sentence we want the explanation model to output for that layer. Over time, the explanation decoder learns to associate patterns in the hidden state with certain explanation phrases. For instance, if many examples show that whenever the hidden state strongly encodes a negative sentiment, the explanation is âLayer X has identified a negative sentiment,â the decoder can learn to output that when it sees similar activation patterns.
Training such an explanation generator is non-trivial â it requires either a labeled corpus or some automated way to get pseudo-labels. One promising approach, as mentioned, is to use an AI to explain an AI: use a large language model to annotate what each layer might be doing by analyzing activations or the input-output behavior, creating a training set for the smaller explanation model. This falls under AI-assisted interpretability and has been piloted for neuron interpretation (Language models can explain neurons in language models | OpenAI).
Practical Example Summary: As a concrete example tying everything together, consider explaining a sentiment classifier built on a transformer. We feed a movie review into the model and it predicts âPositiveâ. With layer-wise sub-models, we might get outputs like:
- Layer 1 sub-model: âNotable words: âexcellentâ, âperformancesâ â likely focusing on these.â
- Layer 4 sub-model: âIdentified movie-related context and some sentiment-laden adjectives.â
- Layer 8 sub-model: âThe tone appears positive (many positive descriptors).â
- Layer 12 sub-model: âOverall sentiment is positive.â
These could be generated by training each sub-model to recognize sentiments at different granularity. The code snippets above illustrate how to get hidden states and train simple classifiers, which are the building blocks for achieving such explanations. Integrating a fluent language-generating explanation model would build on the same extraction process with a more complex training regime.
Through these examples, we see that implementing layer-wise interpretability is feasible with standard tools. The hidden states are accessible, and we can attach additional computations to derive explanations. The specific design of the sub-model (linear probe vs. decoder) may differ, but the pipeline remains: run model â get hidden state â apply sub-model to produce explanation. In a research setting, one might iterate on the sub-model design and training strategy to find explanations that are both accurate reflections of the modelâs computation and easily understood by humans. In the next section, we compare this approach with existing interpretability methods to highlight its advantages and limitations.
5. Comparison to Existing Interpretability Methods
Our layer-wise sub-model approach offers a new perspective on interpretability, but itâs important to situate it relative to established methods. In this section, we compare it with several categories of interpretability techniques, discussing how it differs and where it may offer improvements.
- Saliency Maps / Feature Attribution: Traditional saliency methods (like gradients, Integrated Gradients, SmoothGrad, LIME, SHAP) identify which input features most influence the modelâs prediction (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). They are popular because they are model-agnostic and easy to compute, but they only answer a very limited question: âWhich parts of the input mattered?â They do not explain why those parts mattered or how the model processed them. Our approach, in contrast, provides a process-level explanation rather than just an importance ranking. A saliency map might tell us that in a particular movie review, the word âexcellentâ was important for the positive sentiment prediction. A layer-wise explanation would add: how that wordâs effect is processed (e.g., layer 1 finds âexcellentâ is positive, layer 5 confirms many other positive clues, layer 10 aggregates them to a final sentiment). Thus, we complement saliency maps by revealing internal states corresponding to those salient features. Moreover, saliency methods have known issues with faithfulness (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers) â sometimes they highlight inputs that intuitively make sense to humans but are not actually used by the model, or they miss interactions between features. By reading the modelâs actual hidden states, our method is more directly tied to what the model is representing at each step, potentially increasing faithfulness. That said, saliency maps are cheap and can be applied when one only needs a quick input attribution, whereas our method requires instrumentation of the model and possibly training sub-models, which is more involved.
- Attention Visualization: Inspecting attention weights is a method specific to transformer architectures. It can show, for each token, the distribution of attention given to other tokens. While initially promising as an interpretability tool (Is Attention Interpretable in Transformer-Based Large Language Models? Letâs Unpack the Hype), itâs now understood that attention alone can mislead (Is Attention Interpretable in Transformer-Based Large Language Models? Letâs Unpack the Hype). Our approach is fundamentally different from raw attention analysis: we do not rely on a single mechanism like attention, but instead look at the resultant hidden state after all attention and feed-forward computations. In some sense, attention visualization addresses which inputs influence each layer, whereas our explanations address what information the layer is encoding. They are related â if a layerâs sub-model explanation says âthis layer is focusing on the relationship between X and Y,â it might correspond to observing a strong attention between tokens X and Y. But attention visuals require interpretation themselves (one must infer why the model attended to those tokens). Our sub-models explicitly state the reason or content, removing that extra step. Additionally, attention weights can be distributed and not localized, making it hard to draw a clear conclusion, whereas a sub-model can be forced to output a clear categorical or text explanation. In summary, attention analysis provides low-level data, while layer-wise sub-models aim to provide a high-level summary. They are not mutually exclusive â in fact, an interesting integration could be to feed attention patterns as features into a sub-model that then explains âLayer 5 attended mostly to token X while processing token Y, likely to resolve their relationship.â Some recent work projects attention patterns into vocabulary space to find concepts (e.g., attention heads that attend to tokens representing locations or names) (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers), which aligns with the philosophy of interpreting internal components. Our approach generalizes this by not just focusing on attention, but the entire state.
- Probing Methods: Probing involves training simple classifiers on hidden states to detect if certain information (linguistic features, etc.) is present (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). In fact, our use of sub-models has a strong flavor of probing â each $F_\ell$ can be seen as a learned probe revealing something about layer $\ell$. The difference is in purpose and scope. Traditional probing studies often ask yes/no questions like âDoes layer 6 encode part-of-speech info?â by checking probe accuracy on that task. Our approach uses probes (sub-models) in a generative and explanatory way, turning them into narrative devices for individual predictions. Another difference is that probing is typically post-hoc and does not influence the model, whereas we consider possibly integrating the sub-models during training for consistent explanations. A challenge noted in probing literature is selective rigidity: a probe might find information that is linearly decodable from a layer, but that doesnât mean the model actually uses it â it could be an artifact or unused capacity. By contrast, if we train sub-models to mimic or predict the modelâs own outputs (like a tuned lens, which aligns layer representations with final logits (Eliciting Latent Predictions from Transformers with the Tuned Lens)), we improve the chance that the explanation is aligned with what the model truly computes. Nonetheless, our method must be careful about the âprobe interpretabilityâ trap: we donât want sub-models to hallucinate interpretations that are human-plausible but not actually tied to the model. This is why designing the training objectives of $F_\ell$ to be causally linked to model behavior (e.g. predict modelâs output or known intermediate labels) is important, ensuring we explain what the model is really doing rather than what we imagine it might do.
- Causal Interventions (Ablation and Patching): Another line of interpretability research involves intervening on the model â ablating certain neurons or attention heads, or patching activations from one run into another to see how it affects outputs (Section 4: Transformer Understanding) (Section 4: Transformer Understanding). For example, activation patching (also called causal tracing) can identify which layerâs activations carry the information needed for a particular prediction by replacing them with activations from a different context and observing if the prediction changes (Section 4: Transformer Understanding). These methods are powerful for attributing causal responsibility to parts of the network. Our sub-model approach is mostly observational â we read out interpretations but do not by default tell which parts are critical. However, it could be combined with causal methods: if the sub-model explanation at layer 7 mentions a specific fact or feature, one could verify its importance by ablating that information in the hidden state and seeing if the final output changes. In terms of differences, causal methods usually donât attempt to explain what a circuit does in human terms; they just locate it. Our method would produce the description, which could then guide where to intervene. We see these approaches as complementary: for example, one might first get a layer-wise explanation (say it says âLayer 5 has resolved that âheâ refers to Johnâ), and then perform an experiment zeroing out layer-5 features related to coreference to confirm that the modelâs answer changes. In summary, causal interventions give evidence of importance, whereas our sub-models give semantic meaning; both are needed for a full picture.
- Direct Model-Transparent Approaches: Some researchers argue for building inherently interpretable models, like attention-only models with explanations, or concept bottleneck models where the model predicts human-understandable concepts as an intermediate step. Our approach can be seen as inserting a form of concept bottleneck at every layer, but not strictly bottlenecking the model â rather âreading offâ concepts. We do not require the main model to only use those concepts; we just extract them. So, we maintain the full power of the black-box model while adding interpretability on the side. Approaches like concept bottlenecks can guarantee interpretability but often at a cost of performance or flexibility, and they require predefined concepts. Our method is more flexible in that sub-models can learn whichever patterns are present in hidden states, which might be complex or uninterpretable in raw form but can be translated (for example, a weird hidden dimension combination might correspond to a concept like âthe sentence is in passive voiceâ which a trained sub-model could recognize). Compared to fully transparent models (like decision trees or sparse logical models), we do not achieve the same level of simplicity, but we aim to get the best of both worlds: high performance of transformers with an overlay of interpretability.
- Logit Lens and Tuned Lens: A specific comparison is to the Logit Lens (Eliciting Latent Predictions from Transformers with the Tuned Lens) and its improved version the Tuned Lens (Eliciting Latent Predictions from Transformers with the Tuned Lens), since these are perhaps the closest existing ideas to layer-wise interpretation. The Logit Lens, as discussed, takes each layerâs hidden state and directly decodes it into an output (like vocabulary logits), essentially giving a rough prediction at each layer. This is a special case of our approach where the sub-model is âthe final linear layer of the network (or its transpose) applied at every layer.â It provides insight into the modelâs evolving prediction. The Tuned Lens augments this by actually training a small affine transformation for each layer to better predict the final logits (Eliciting Latent Predictions from Transformers with the Tuned Lens), acknowledging that each layerâs basis might differ. These techniques validate that intermediate layers contain predictive information, and they produce a kind of explanation: âwhat does the model think so far?â However, they donât explain why the model is thinking that. Our layer-wise sub-models generalize the lens idea to richer explanations. Instead of just outputting the modelâs implicit next-word guess or classification at layer $\ell$, we want to output why that guess is made â e.g., which features were recognized. So one could say we are adding an interpretability layer on top of the logit lens. We also allow sub-models to output things that are not the final prediction, such as intermediate decisions or features (which logit lens cannot directly do since itâs tied to final logits). In terms of performance, the tuned lens already shows that even a very lightweight sub-model (one affine layer per transformer layer) can closely match the modelâs own outputs (Eliciting Latent Predictions from Transformers with the Tuned Lens), indicating that our approach is feasible without requiring a massive sub-model â the information is readily present.
Strengths of Layer-Wise Sub-Models: The primary strength of our approach is the granularity and richness of the explanations. By covering each layer, we obtain a stepwise trace of the modelâs computation in human-interpretable terms. This can be invaluable for debugging and understanding failure modes. For instance, if a model answers a question incorrectly, a layer-by-layer explanation might reveal that it actually picked up the correct answer in an intermediate layer but then lost it (perhaps due to a later layerâs interference) â information that neither a saliency map nor a final explanation alone would show. It could also reveal if the model is relying on spurious features at some stage (e.g., an explanation says âlayer 3 focuses on the userâs name when answering a medical questionâ, which might hint at bias). Compared to single-shot explanations (like explaining the final decision after the fact), our method is more faithful by construction: it uses the modelâs actual internal data at each point, reducing the chance of rationalizing or ignoring what actually happens inside. Additionally, these sub-model explanations could be made model-agnostic to some extent â one could potentially train a universal explainer for a given transformer architecture, though that is speculative.
Weaknesses and Challenges: One obvious cost is the computational and implementation overhead, as discussed. Another key challenge is evaluation: how do we measure the quality of the explanations from sub-models? They could be incorrect or insufficient. If a sub-model is poorly trained or the task is too hard (say trying to generate a perfect English description of a hidden state), it might output misleading information. This is dangerous because it could give a false sense of understanding. We must ensure that sub-model outputs are validated â either by checking alignment with known attributes (if available) or via human expert review in critical cases. There is also the risk that forcing interpretability could constrain the model. If used during training, the model might overly optimize for producing something interpretable at each layer (which might degrade final performance or hide info in ways the sub-model can easily explain, a kind of Goodhartâs law effect). We should manage this by keeping sub-models âon the sideâ rather than in full control of representations, or by iterative refinement where we ensure they truly capture whatâs needed. Another weakness is that our approach currently provides a high-level description per layer, but not a full mechanistic understanding of the computations (we arenât explicitly identifying which neurons or weights implement that step). So, it may not directly solve the problem of understanding the network at the circuit level â itâs more of an observability tool. In comparison, methods like causal abstraction aim to map the entire model onto a human-understandable algorithm (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers), which is a stronger guarantee but far harder to do on complex models. Our method is a more pragmatic middle ground: easy to deploy, yields understandable info, but not a formal explanation of internal mechanics.
Integration with Other Frameworks: We foresee that layer-wise explanations could be integrated into existing interpretability dashboards and frameworks. For instance, Googleâs Language Interpretability Tool (LIT) or other interactive visualization systems could show, for each layer slider, the output of the sub-model explanation. One could imagine an interactive interface where a user hovers over a transformerâs layers and sees a tooltip like âLayer 5: combining clause information; likely doing coreferenceâ â all generated by our approach. This would add a new dimension to such tools, which currently might show attention weights or neuron activations without context. Furthermore, these explanations could feed into accountability mechanisms: e.g., in sensitive applications, the system could log the layer-wise reasoning for each decision, so that if a bad outcome occurs, auditors can inspect not just the input-output correlation but the internal decision trail.
In conclusion, compared to existing methods, Layer-Wise Sub-Model interpretability is a more process-oriented and fine-grained approach. It does not replace feature attribution or causal analysis, but augments them by filling in the blanks of what the model actually does at each step. Its strength lies in creating an explanatory narrative from within the model, offering potentially greater insight and trust. Its challenges lie in ensuring those narratives are accurate and efficiently obtained. The next section will reference related research that inspired this approach and provide context, before we discuss future implications.
6. References to Related Research
Our approach builds upon and intersects with several important lines of AI interpretability research. Below, we highlight key related works and how they inform or contrast with the Layer-Wise Sub-Model approach:
- Mechanistic Interpretability and Circuits: Olah et al. (2020) and colleagues pioneered Circuits analysis, breaking down neural networks into interpretable pieces (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). They manually traced computations in vision models and small language models, identifying meaningful circuit motifs (e.g., edge detectors, curve detectors in vision, or syntax heads in language). Elhage et al. (2021) further developed a mathematical framework for reverse-engineering transformers, successfully dissecting small models (like two-layer attention-only models) into algorithmic components (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). They introduced concepts like induction heads â attention heads that help a model continue a repeated sequence, effectively implementing a copy mechanism in context (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models) (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). Our work is inspired by these findings, as they demonstrate that even black-box transformers have learnable algorithms inside. However, instead of manually finding circuits, we automate the explanation via sub-models. Mechanistic interpretability provides ground truth for what kinds of computations occur (e.g., an induction head), and a sophisticated sub-model could potentially recognize and explain âthis layer is performing an induction step on a repeated sequence.â Recent works like Transformer Circuits (Olah, 2022) and others have catalogued numerous such phenomena, giving us a vocabulary of mechanisms that we would like our explanations to capture (e.g., ânegative head suppressing duplicate tokenâ (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models), âfeed-forward layer storing factual memoryâ (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models)).
- Neuron-Level Interpretability: Research by Bau et al. (2018) on Network Dissection showed that individual neurons in CNNs often correspond to human-recognizable visual concepts (like âtree neuronâ or âdog face neuronâ). In NLP, neurons are more entangled, but there has been progress: Dalvi et al. (2019) and others probed individual neurons in machine translation models, finding neurons that track specific linguistic features (like formality). Goh et al. (2021) discovered multimodal neurons in CLIP models that respond to high-level concepts (e.g., a neuron that fires for âSpider-Manâ whether as text or image). OpenAIâs 2023 work took this further by using GPT-4 to generate explanations for neurons in GPT-2 (Language models can explain neurons in language models | OpenAI), as discussed earlier. These works emphasize interpreting units of the network. Our layer-wise approach operates at a higher level (the aggregate effect of a whole layer), but it could benefit from neuron-level insights. For example, if neuron $i$ in layer 5 is known to track sentiment, and our layer-5 sub-model outputs âthe sentence sentiment is positive,â thatâs a nice confirmation that our sub-model picked up what that neuron was doing. Conversely, if our explanation says something is happening, one could drill down and identify the neurons or heads responsible, essentially linking our high-level explanations to low-level ones. Projects like CircuitGPT or Automated Circuit Discovery aim to algorithmically find sets of neurons/heads that implement functions (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). Those could be used to generate structured explanations (âLayer 7 contains a 2-head circuit that performs Xâ), which is a complementary direction â bridging our explanatory text to actual circuit diagrams.
- Probing and Layer Analysis: A significant body of work has analyzed what different transformer layers encode. For example, Tenney et al. (2019) âBERT Rediscovers the NLP Pipelineâ used a suite of probing tasks to show that BERTâs layers progressively handle syntax then semantics ([1905.05950] BERT Rediscovers the Classical NLP Pipeline). Jawahar et al. (2019) similarly found lower layers encode surface features, middle layers syntactic relations, upper layers semantic abstractions. Our approach is a natural continuation: given that each layer has a predominant role, we attempt to articulate that role for each input. There is also work on early exiting: e.g., Liu et al. (2020) and Zhou et al. (2020) trained classifiers on top of each layer so that if an early layer is already confident enough, the model can exit with a prediction to save computation (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). This demonstrates that intermediate layers can be pushed to produce outputs. Our method can be seen as early exiting for explanations â every layer âexitsâ to an explanation even if not to a final prediction. In fact, ideas from early-exit training (like how to avoid interfering with later layers, how to gauge confidence) are useful for us too. Another related concept is auxiliary heads used during training (e.g., InceptionNet used intermediate classifiers to help train deep networks). We similarly add heads, though for interpretability rather than purely for training support.
- Attention Interpretation Debate: We have touched on this, but to cite explicitly: Jain & Wallace (2019) âAttention is not Explanationâ argued against naive use of attention (Is Attention Interpretable in Transformer-Based Large Language Models? Letâs Unpack the Hype), while Wiegreffe & Pinter (2019) âAttention is not not Explanationâ offered a rebuttal that attention can provide insight if used properly (Is Attention Interpretable in Transformer-Based Large Language Models? Letâs Unpack the Hype). Our work sidesteps this debate by not focusing solely on attention weights. However, we note research like Chefer et al. (2021) and Abnar & Zuidema (2020) that tried to aggregate attention across layers or apply gradient adjustments to get better attributions from transformers. Those methods aim to produce a single importance map that accounts for multi-hop attention paths. In contrast, our method outputs distinct information per layer, which might actually help understand those multi-hop paths in a clearer way (layer by layer rather than a diluted aggregate).
- Interpretability in Other Modalities: While our discussion is framed around language transformers, the concept of layer-wise explanations could transfer to vision transformers or multimodal models. Research in vision interpretability, such as feature visualization (Olah et al., 2017), generates images that maximize neuron activations (27 Learned Features â Interpretable Machine Learning). This is somewhat analogous to the logit lens but for vision â it asks âwhat input would cause this internal neuron to activate strongly?â For a layer-wise view in vision, one might generate an image or textual description of what each layer is detecting (e.g., âLayer 3: edges and textures; Layer 6: object parts like wheels or eyes; Layer 10: full objects like cars or animalsâ). Some recent works have started to analyze vision transformer layers in this way, finding that early layers act like CNN filters and later ones have token mixing patterns. Our approach could potentially attach a small CNN or captioning model to each layer of a vision transformer to describe its activation patterns in words (âthis layer highlights regions that look like furâ).
- Tuned Lens and Patch:** As mentioned, the Logit Lens (Nostalgebraist, 2020) and Tuned Lens (Bau et al., 2022) provided a simple but powerful tool for interpreting layers (Eliciting Latent Predictions from Transformers with the Tuned Lens) (Eliciting Latent Predictions from Transformers with the Tuned Lens). Additionally, recent work by Belrose et al. (2023) formally introduced the Tuned Lens concept, and Halawi et al. (2022) used a logit lens to examine how few-shot examples are processed. Dar et al. (2023) projected not just hidden states but also attention patterns into the vocabulary space, finding coherent concepts (like certain attention heads consistently attending to person names) (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). Ghandeharioun et al. (2024) proposed PatchScopes, a general framework for training auxiliary models as âlensesâ to interpret representations in expressive ways (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). PatchScopes can be seen as a sibling idea to our sub-models â they tune small models to map internal vectors to human-interpretable distributions (like vocabulary or concepts). Our approach is very much aligned with these, with the main extension being the emphasis on layer-by-layer explanatory narrative. We incorporate the spirit of Tuned Lens (alignment transformations per layer) and PatchScopes (auxiliary interpretive models) but apply it to generating natural explanations and not only predicting final outputs.
- Causal Abstraction and Model Editing: Another relevant area is causal abstraction, where one defines a high-level model (like a symbolic model or a simplified flowchart) and tries to map the neural network onto it, verifying that certain latent variables correspond to concepts. For example, Geiger et al. (2021) attempted to align transformer computations with a human-readable algorithm by abstracting each layerâs function. Our layer-wise explanations could serve as a bridge to such high-level abstractions: the sub-model outputs might act as the âinterpretable variablesâ that a causal abstract model would have. Thereâs also model editing work (like ROME, MEMIT) which locates factual associations in specific layers of transformers and intervenes to change model knowledge. Those findings often pinpoint a particular layerâs feed-forward neurons as carrying a fact (e.g., âParis is capital of Franceâ might be mainly stored in some mid-layer MLP) â which our explanation for that layer might surface (âLayer 10 recalls: Paris â capital of Franceâ). So, our method could help identify where a fact is being applied, aligning with model editing insights on where facts live.
In summary, our approach is informed by a broad swath of research: from the detailed dissection of networks at the neuron/circuit level (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models) (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models), to probing studies that treat layers as linguistic feature extractors ([1905.05950] BERT Rediscovers the Classical NLP Pipeline), to recent lens methods that insert interpretability transforms into networks (Eliciting Latent Predictions from Transformers with the Tuned Lens). We contribute a unifying perspective that ties these together into a single framework of layer-wise transparent interpretation. By citing and building on these works, we stand on the shoulders of prior interpretability research while pushing toward more comprehensive and accessible explanations of transformer models.
7. Future Implications
Adopting layer-wise sub-model interpretability could have far-reaching implications for the development, deployment, and oversight of AI systems. In this final section, we explore potential impacts, extensions of the approach, and open questions that remain to be addressed.
Shaping AI Regulation and Accountability: As AI systems are used in high-stakes decisions (medical, legal, financial), regulators are increasingly interested in âExplainable AIâ and even legal rights to an explanation. If every layer of a model can be interpreted and explained, we move closer to AI whose decision process can be audited in detail. For instance, in an AI-driven loan approval system using a transformer to analyze customer data, a layer-wise explanation might show: layer 3 assessed credit history, layer 5 assessed income stability, layer 8 combined these into a risk profile, final layer made a decision. Such a multi-step explanation could satisfy regulatory demands by demonstrating a logical progression (and allowing one to point out, say, bias if one layer put undue weight on a sensitive attribute). In essence, this approach could provide a traceable chain of reasoning that is missing in end-to-end deep models today. Moreover, if an AI system causes harm or makes a mistake, having intermediate explanations could clarify where the flaw occurred. Was it a misinterpretation of input early on, or an unreasonable combination of factors at the end? This could assign responsibility to specific components or phases of the modelâs processing. Organizations might maintain logs of the layer-wise explanations for critical decisions, creating an audit trail for later review. Of course, one must ensure these explanations are stored and communicated properly, and that they are comprehensible to stakeholders (perhaps translating technical jargon to lay terms). But overall, accountability is bolstered when a model can effectively say âhereâs what I was thinking at each stepâ instead of just âhereâs my final answer.â
Real-Time Interpretability Tools: One exciting practical extension is the development of real-time dashboards or monitoring systems powered by layer-wise interpretability. Imagine a debugging interface for an NLP model where as you input a sentence, you see a live update of each layerâs interpretation. Such a tool would be invaluable for researchers and engineers. They could feed in examples and immediately spot if, say, layer 5 is generating an incorrect intermediate conclusion. For example, a dashboard could highlight: Layer 5 misunderstanding: it thinks âbankâ refers to a river bank instead of a financial bank. This early misinterpretation could be corrected (by adjusting training data or model architecture) before it propagates to an incorrect final answer. In deployed systems, one could use simplified versions of this idea to detect anomalies. If a layerâs explanation deviates from expected patterns (perhaps a security chatbotâs layer 7 suddenly starts showing content indicating the model is interpreting a harmless query as an attack due to some quirk), it could raise a flag to a human overseer or trigger a safe intervention. This is analogous to a pilot watching an altimeter and airspeed meter â our approach provides multiple âmetersâ inside the model that can be monitored to ensure itâs on the right track.
Guiding Model Improvement and Safety: Layer-wise explanations not only explain models, but might also guide how we improve them. For instance, if we find that layer 4 consistently produces an explanation that is logically redundant or irrelevant to the final task, it might indicate that the model has an inefficient layer. We could try pruning that layer or merging its function with another. Or if a certain type of error often arises from a wrong intermediate decision (say in a translation model, layer 6 explanation reveals it consistently gets verb tenses wrong before final output), we can target that with additional training data or an architectural tweak at that layer (like an auxiliary loss to enforce tense correctness). From a safety perspective, if we can identify layers that handle potentially risky transformations of input (e.g., in a dialogue model, perhaps a layer that goes from understanding user query to formulating a raw response), we might focus security measures there. In other words, understanding which layer does what allows for modular risk assessment â some layers might be more critical to keep interpretable and controlled (like those dealing with factual recall to prevent misinformation) while others less so.
Shaping Future Architecture Design: If layer-wise interpretability proves valuable, architects of new models might start designing with interpretability in mind. This could mean building models with explicit intermediate interpretable representations (not unlike how traditional software has comments or logs at checkpoints). For example, future transformer variants might include dedicated âexplanation vectorsâ that are meant to carry a summary of that layerâs state in human-interpretable form. These could be encouraged via training to align with certain concepts. Alternatively, we might see hybrid models where each layer actually outputs a short description that is fed into the next layer alongside the usual hidden state â essentially, layers communicating with each other in a human-readable language. This is speculative, but it resonates with the idea of forcing models to âthink step by stepâ (as in chain-of-thought prompting) except internally. If a model can internally generate a chain-of-thought that is also understandable to us, it would be a breakthrough for transparency. Our sub-model technique is a step in that direction, even if currently the chain-of-thought is extracted post hoc rather than generated intrinsically.
Shifting Culture in Model Evaluation: With more interpretable models, the community might begin to expect explanations as a standard part of model output. Itâs possible that leaderboards and benchmarks evolve to not only consider accuracy on tasks, but also the quality of explanations provided. If our approach becomes efficient, one could imagine requiring models to output a justification at each step for sensitive tasks (like medical diagnosis). This could be scored or at least qualitatively evaluated. It might discourage purely black-box solutions if a model that can explain its intermediate steps is seen as more trustworthy or superior, even if raw accuracy is similar. Over time, this could lead to a norm where interpretability is built-in and not just an afterthought. Our approach could serve as a template for how to integrate and evaluate such capabilities.
Open Questions: Despite the promise, there are many open research questions and challenges to be explored:
- How to ensure faithfulness? We need robust methods to verify that sub-model explanations truly reflect the causal factors in the main modelâs decisions. Techniques like causal analysis or counterfactual checking (altering input or hidden state and seeing if explanation changes accordingly) could be used to test faithfulness. Developing quantitative metrics for explanation quality is an open area â e.g., measures of alignment between explanation and model behavior, or how well humans can predict the modelâs output from the explanations alone.
- What is the optimal complexity for sub-models? If the sub-model is too simple, it might not capture the nuance of the layerâs state; if too complex, it might itself be a black box or overfit. Research is needed to find the sweet spot (e.g., is a 1-layer probe enough, or do we need a full transformer decoder?). Also, do we need one sub-model per layer, or can a single parametric model handle multiple layersâ interpretations if given the layer index as input (reducing total parameters)?
- How to scale to very large models? Modern large language models can have hundreds of layers. Itâs impractical to have an independent, heavy explanation for each. We might need techniques to select important layers to explain (perhaps layers where big changes in the representation happen). Or use sparse explanation triggers (only produce detailed explanation when something notable occurs, otherwise default to briefer info). Thereâs also the question of whether our approach can handle models with billions of parameters, or if there are stability issues (e.g., if different layers encode entangled concepts, can our method disentangle them well enough for explanation?).
- Could models deceive their own sub-models? In an adversarial scenario, one might imagine a model learning to encode information in forms that the sub-model canât easily interpret (if, say, the model âknowsâ itâs being watched and doesnât want to reveal something). This is a somewhat far-fetched scenario for now, but as models get more agentic or trained with self-awareness, they might game the interpretability. Ensuring robust interpretation that canât be tricked is an open problem (it touches on AI safety â making sure the AI canât hide its true intentions).
- User Understanding: Even if we produce explanations, will end-users (or even developers) always understand them? If an explanation says âLayer 7: activated a semantic induction head linking subject and object in the query,â that might confuse a non-expert. How do we translate our internal explanations into simpler terms or visualize them for different audiences? Perhaps a two-tier explanation, where a technical one is available for developers and a high-level one for users.
- Generality: We have discussed mainly transformer encoders (like BERT) or decoders (like GPT). How would this work in other architectures? For RNNs (less used now, but conceptually), one could also apply layer-wise or time-step-wise sub-models. For non-NLP domains (reinforcement learning policies, for example), could an agentâs policy network layers be explained in terms of intermediate decisions or value estimates? If an RL agent has a vision module then a decision module, maybe those are natural layers to explain (âvision recognized an obstacle, decision layer chose to jumpâ). Extending these ideas to different AI systems will be important to make interpretability ubiquitous.
Realistic Optimism: In the near future, we anticipate researchers demonstrating layer-wise interpretability on more complex tasks and models, perhaps starting with medium-sized models where itâs tractable to label and evaluate the explanations. If successful, this approach can be one piece of the puzzle for trustworthy AI. It wonât solve everything â we still need good user experience design to present explanations, and rigorous validation â but it provides a technical pathway to illuminating the black box from inside out. Ultimately, the vision is that asking a transformer âWhat are you thinking?â will be a standard capability, and the transformer (through our sub-models) will answer by revealing its chain of thought in a way we can verify and understand. This would mark a significant paradigm shift from the opaque deep learning models of the past toward AI systems that expose their reasoning and earn our trust through transparency. We believe Layer-Wise Sub-Model interpretability is a promising step in that direction, and we encourage further exploration and refinement of this approach in the research community.