Layer-Wise Sub-Model Interpretability for Transformers

John CraftsInformational Blog

A Novel Approach to Improving AI Interpretability

1. Introduction

Transformers have achieved remarkable success in NLP and other domains, but they operate as complex black-box models, making their decisions hard to interpret (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). This lack of transparency has raised concerns about safety, trust, and accountability when deploying AI systems in real-world applications (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). Researchers and practitioners increasingly recognize that we need better interpretability techniques to understand why a transformer makes a given prediction, not just what it predicts. Yet current interpretability methods provide only partial insights into these models’ inner workings.

Common post-hoc methods like saliency maps or feature attribution highlight which input tokens or features most influence the output (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). While useful for identifying important words (e.g. via Integrated Gradients or SHAP), such methods treat the model as a black-box and do not reveal the intermediate reasoning process. A major concern is that these attributions may not faithfully reflect the model’s true decision-making – different attribution techniques often disagree, casting doubt on their reliability (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). Similarly, analyzing attention weights (the internal attention coefficients of transformers) has been popular, since attention seems to show what the model “focuses” on. However, attention is not a definitive explanation: Jain and Wallace (2019) demonstrated that one can alter a model’s attention distributions without changing its outputs, and Serrano and Smith (2019) found that even ablating or randomizing attention weights often leaves model behavior unchanged (Is Attention Interpretable in Transformer-Based Large Language Models? Let’s Unpack the Hype). These findings imply that raw attention patterns, by themselves, do not provide a faithful explanation of the model’s reasoning. In short, existing techniques like saliency maps and attention visualizations are insufficient – they offer surface-level insight but often miss the deeper computational patterns inside transformer networks.

The limitations of current interpretability methods motivate exploring new approaches. Recent progress in mechanistic interpretability attempts to open the black box by reverse-engineering model components into human-understandable algorithms (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). For example, researchers have identified individual neurons and attention heads that correspond to specific functions or concepts within language models (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). This work can yield mechanistic explanations for certain behaviors (e.g. a neuron that activates on negative words, an attention head that tracks coreference). However, such analyses are typically labor-intensive and done for isolated circuits. We lack a scalable, general way to explain a model’s entire reasoning process on a given input. There is a clear need for interpretability methods that can provide step-by-step explanations of how a transformer processes information through its layers, bridging the gap between low-level circuit findings and high-level feature attributions.

In this paper, we propose a novel approach to transformer interpretability: Layer-Wise Sub-Model Interpretable Transformers. The core idea is to leverage the transformer’s own structure – its multiple layers of hidden representations – by attaching interpretable sub-models to each layer. Each sub-model acts as an explanatory module for that layer’s activations, decoding or analyzing the hidden state into a human-understandable form. By doing so for every layer, we obtain a progressive, layer-by-layer explanation of how the model’s understanding evolves from input to output. In essence, the transformer is no longer a single monolithic black-box; it is augmented with “inner voices” or probes at each layer that narrate what that layer is doing or has learned. This approach promises fine-grained transparency: instead of only explaining the final prediction (as most methods do), we explain the intermediate thinking steps of the model. We hypothesize that layer-wise sub-models can address the shortcomings of saliency and attention-based methods by revealing the meaning of internal states, thus providing more faithful and detailed explanations of transformer decisions. The rest of this paper details this approach, including technical background, implementation examples, comparisons to existing methods, and implications for future research in AI interpretability.

2. Technical Background

Transformer Model Primer: Transformers are deep neural networks that process information in a sequence of layers, each producing a new representation of the input tokens. An input text is first tokenized into discrete tokens (subwords or words) and embedded into continuous vector representations. These token embeddings are then fed into the transformer’s stack of self-attention layers. Each transformer layer consists of a multi-head self-attention sublayer followed by a feed-forward network sublayer (with residual connections and layer normalization at each step). Through self-attention, the model blends information across all tokens: each token’s representation is updated by attending to other relevant tokens in the sequence. The feed-forward sublayer then further transforms these representations in a token-wise fashion. As the input passes through layer after layer, the token representations (often called hidden states) gradually encode more abstract and complex features. By the final layer, the transformer has distilled the input into a form from which it can easily produce an output – for example, a probability distribution over next words in language modeling, or a classification label in a classifier. In summary, a transformer processes text through iterative hidden state transformations: at layer 1 it has a rudimentary understanding, and at layer $N$ (the last layer) it has a task-specific understanding ready to drive the output.

Evolution of Meaning Across Layers: Importantly, the hidden states in each layer are not random – they contain structured information that evolves as depth increases. Empirical studies have shown that transformers tend to learn a hierarchy of linguistic and semantic features across their layers. For instance, in BERT, lower layers capture basic syntactic information (like part-of-speech tags), intermediate layers handle named entities and short-range dependencies, and higher layers encode high-level semantics and coreference links ([1905.05950] BERT Rediscovers the Classical NLP Pipeline). In other words, the model appears to rediscover the classic NLP pipeline internally: first morphology and syntax, then semantics and reasoning ([1905.05950] BERT Rediscovers the Classical NLP Pipeline). Similarly, analysis of feed-forward networks in transformers has found that early layers focus on “shallow” patterns (e.g. local grammar), whereas later layers concentrate on more “semantic” or abstract patterns (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). For example, Geva et al. (2021) showed that the feed-forward sublayers can act as key-value memory storing factual knowledge; an early layer might respond to a common phrase or syntactic cue, while a late layer retrieves a specific fact or meaning associated with that cue (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). These findings suggest that as we ascend the stack, the representation of each token gets refined: noise and ambiguity are reduced, and the model’s interpretation of the input becomes more specific to the task at hand.

To concretize, consider a transformer-based question answering model receiving the question “Who wrote Pride and Prejudice?”. In the first few layers, the model’s representations may highlight that the question is asking for a person’s name (an author) and recognize the token “Pride and Prejudice” as a book title. Mid-level layers might bring in contextual or factual associations, e.g. activating features related to literary authors or the 19th century. By the final layers, the model’s hidden state for the question likely strongly encodes the answer “Jane Austen” (the author) before it is even produced as output. Indeed, the concept of the “logit lens” in interpretability is built on this premise: if you take the hidden state from an intermediate layer and directly map it to the output vocabulary, you often find that the model’s latent prediction is increasingly close to the final answer as you go deeper (Eliciting Latent Predictions from Transformers with the Tuned Lens). Nostalgebraist (2020) first observed that decoding each layer of GPT-2 with the output embedding matrix yields distributions that converge monotonically toward the correct final token (Eliciting Latent Predictions from Transformers with the Tuned Lens). This implies the model is accumulating evidence or partial guesses layer by layer. In our example, a logit lens on an early layer might produce a broad distribution over many authors, but by layer 10 it might have “Jane Austen” as a high-probability candidate, indicating the model has essentially figured out the answer internally even before the last layer. Such insights underline that intermediate hidden states carry significant meaning and often an implicit draft of the model’s eventual output.

Mechanistic Interpretability and Neuron Analysis: Beyond layer-wise aggregate behavior, researchers have also probed the internal mechanics at the level of neurons and attention heads. Each transformer layer contains numerous neurons (the dimensions of the hidden state) and multiple attention heads. Mechanistic interpretability research seeks to identify what these individual components represent and how they combine to perform computations (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). Notable findings include the discovery of specialized attention heads that perform distinct roles. For example, certain attention heads in GPT-2 were found to track long-range dependencies like quote pairings or to perform simple arithmetic in context. In BERT, some attention heads clearly specialize in coreference resolution (linking pronouns to the correct entities) (Is Attention Interpretable in Transformer-Based Large Language Models? Let’s Unpack the Hype), while others attend mostly to next or previous words to encode positional relationships. There is evidence of a division of labor: attention heads have interpretable roles, albeit as “team players” working in concert (Is Attention Interpretable in Transformer-Based Large Language Models? Let’s Unpack the Hype). At the neuron level, analysis is more challenging because transformers use distributed representations (each concept is encoded across many neurons, and each neuron participates in many concepts – a phenomenon known as superposition). Still, a few individual neurons have been identified with strong correlations to specific features. For instance, researchers found a neuron in GPT-2 that appears to track whether the text is inside a quotation (a quote-tracking neuron), and another that fires on negation words indicating a negation concept. Such neuron-level discoveries often come from exhaustive probing or even automated techniques. Recent work by OpenAI used GPT-4 to generate natural language explanations for the behavior of every neuron in GPT-2, producing a dataset of (imperfect) descriptions for what each neuron might be looking for (Language models can explain neurons in language models | OpenAI). This kind of approach treats neurons themselves as “units of thought” to be explained, and though not fully reliable yet, it represents a step toward scaling neuron interpretability. Overall, prior research in mechanistic interpretability has demonstrated that transformer internals are not inscrutable – there is structure and algorithmic behavior we can uncover. However, most such work either focuses on local interpretability (e.g. one neuron or one head at a time) or requires significant manual effort and intuition to piece together into a global explanation.

Summary: A transformer processes inputs through successive transformations, with each layer’s hidden state encoding a more refined understanding. There is a rich internal dynamics: early layers handle low-level patterns, later layers handle high-level meaning ([1905.05950] BERT Rediscovers the Classical NLP Pipeline) (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). While tools like attention maps and saliency highlight some aspects of this process, they fall short of explaining how each layer contributes to the final decision. Mechanistic studies give insight into components but are hard to apply comprehensively to each instance of model usage. This background sets the stage for our proposed solution: using layer-wise sub-models to systematically interpret each layer, effectively peeling back the transformer’s computations one layer at a time.

3. Proposed Approach: Layer-Wise Sub-Model Interpretable Transformers

We propose an approach that integrates interpretability into the very structure of transformer models by introducing layer-wise sub-models for interpretation. The key idea is to attach an auxiliary interpretable model to each transformer layer, which takes that layer’s hidden state as input and produces an explanatory output that is understandable to humans. These sub-models serve as lenses or translators that can read the state of the transformer at their respective layer and express what it represents (or what it is “thinking”) in a concise, meaningful way.

Sub-Models at Each Layer: Concretely, for each layer $\ell$ in the transformer (where $\ell = 1, 2, …, N$), we introduce a sub-model $F_\ell$ that consumes the hidden activations $h_\ell$ (the set of token vectors after layer $\ell$) and generates an explanation or interpretation $E_\ell$. The form of $E_\ell$ can vary depending on the application – it could be a natural language sentence summarizing the layer’s understanding, a set of labels or features (e.g. “Layer 3 has identified this sentence’s sentiment as positive”), or even a distribution over the final output vocabulary (as in the logit lens). The sub-model $F_\ell$ itself could be as simple as a linear classifier or as complex as a small neural network; the only requirement is that $F_\ell$ is interpretable by design or yields an output in an interpretable format. For example, one might use:

  • A linear or decision-tree model that takes the hidden state and predicts a human-interpretable attribute (such as the part-of-speech tags of each token, the subject of the sentence, etc.).
  • A small decoder network that cross-attends to $h_\ell$ and generates a natural-language description of what is encoded at layer $\ell$. This is akin to using a miniature language model to “translate” the hidden vector into words (in spirit similar to the DecoderLens method which verbalizes encoder activations) (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers).
  • A multi-label classifier that outputs which high-level concepts are active at this layer (for example, in an image transformer one might output concepts like “edges detected” at early layers and “object shape = cat” at later layers; in language, early layers might output “syntax parsed” and later “coreference resolved” or “topic: sports”).

Crucially, each sub-model focuses only on the state of one layer. This modularity means the interpretation of each layer is localized and specific, rather than entangled with the entire network at once. By designing $F_\ell$ to be much simpler than the full transformer, we aim for these sub-models to be transparent. In essence, $F_\ell$ acts as an interpreter for layer $\ell$, extracting the content of $h_\ell$ in a form we can reason about.

Extracting and Explaining Layer-Wise Meaning: With sub-models in place for all layers, we can obtain a sequence of explanations $(E_1, E_2, …, E_N)$ for a given input as it flows through the transformer. This sequence should reveal how the model’s knowledge and hypotheses are refined layer by layer. For example, consider again a question-answering transformer. The explanation $E_1$ (after the first layer) might say something like: “Identified question words; focusing on entity ‘Pride and Prejudice’.” By layer 6, $E_6$ might read: “Understands the question is asking for an author of ‘Pride and Prejudice’; likely referring to a book author.” By the penultimate layer, $E_{N-1}$ could be “Model strongly anticipates the answer is ‘Jane Austen’; high confidence in this author name.” Finally, the model outputs “Jane Austen.” In this way, the sequence of $E_\ell$ provides a narrative of the model’s reasoning, from initial comprehension to final answer. Each $E_\ell$ explains the incremental progress made by layer $\ell$. Early layers handle more generic or surface-level information, while later layers provide more detailed and task-specific information – exactly mirroring how the hidden state evolves, but now made explicit.

One can view this progressive explanation as analogous to how a human might break down their reasoning. Instead of jumping directly from question to answer, a human might think: “Okay, the question mentions Pride and Prejudice, that’s a novel. Next, the question asks who wrote it – so it’s looking for an author’s name. Pride and Prejudice was written by Jane Austen, I recall. Thus, the answer is Jane Austen.” Here we see distinct steps (identify topic, interpret question, recall fact, produce answer). Standard transformers do all this internally without explicit delineation. Our layer-wise sub-model approach aims to externalize these implicit steps, mapping them to the network’s layered processing.

Progressive Refinement and Consistency: An important aspect of our approach is that the explanations $E_\ell$ should show progressive refinement. That is, $E_{\ell+1}$ should generally build upon or update the information in $E_\ell$. If we observe drastic or contradictory changes in the explanations from one layer to the next, that might indicate an interesting phenomenon (e.g., the model revising an earlier hypothesis). In many cases, however, we expect a coherent story. This progressive property can be enforced or encouraged when designing the sub-model outputs. For instance, sub-models could be built to condition on previous explanations (though that introduces complexity and inter-dependence) or we can simply observe after the fact that $E_1, E_2, …$ align well. A concrete way to ensure consistency is to have each sub-model predict the final output as well (as an auxiliary task) – effectively each layer tries to guess the answer, as in early-exit models (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). If all layers are predicting the same final answer by the end, their intermediate explanations will likely be consistent with that answer. In our QA example, if by layer 6 the sub-model already predicts “Jane Austen” as the answer, it will shape its explanation accordingly. This doesn’t force the layers to agree on everything (they might differ in confidence or reasoning details) but it tethers them to a common end goal.

Design and Architecture Considerations: There are multiple ways to implement layer-wise interpretable sub-models, each with trade-offs:

  • Probing Approach (Post-hoc): We can train each sub-model after the transformer is trained, in a post-hoc probing fashion. Here we treat the transformer as fixed and train $F_\ell$ to map $h_\ell$ to a desired explanation output. For example, $F_\ell$ could be a simple classifier trained on a labeled dataset to predict some known property of the input or output that we believe layer $\ell$ should capture. If no labeled data for intermediate concepts is available, we might train $F_\ell$ to mimic the transformer’s own final output (as a distillation target) or to predict a proxy like the logit lens distribution (Eliciting Latent Predictions from Transformers with the Tuned Lens). Post-hoc training has the advantage that it doesn’t interfere with the original model’s training or performance. We can probe for many different types of information from the same $h_\ell$. The downside is that $F_\ell$ might extract information that is present in $h_\ell$ but not necessarily used by the model, raising questions of faithfulness. However, techniques like the Tuned Lens (which learns a small transformation to better align intermediate hidden states with final logits) show that even a simple learned mapping can faithfully reveal the model’s latent predictions (Eliciting Latent Predictions from Transformers with the Tuned Lens) (Eliciting Latent Predictions from Transformers with the Tuned Lens).
  • Integrated Training (Auxiliary Losses): Alternatively, we can integrate the sub-models into the transformer’s training process. In this setup, when training the main model, we add auxiliary objectives for each layer: the sub-model’s output $E_\ell$ should correctly predict some target (e.g. part-of-speech tags for layer 1, semantic roles for layer 6, final answer for layer $N$, etc.). This is reminiscent of “deep supervision” or early-exit networks, where intermediate layers are directly trained to produce outputs (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). By providing an explicit learning signal to each layer, we encourage the model to organize information in that layer in a human-interpretable way. For example, if layer 5 is trained to output a rough answer prediction, the model might ensure that by layer 5 it has already gathered most information needed for the answer, making layer 5’s hidden state more interpretable. Integrated training can thus bake interpretability into the model’s representations. The trade-off is additional complexity in training and potential interference with the primary task optimization – the model might sacrifice a bit of final accuracy to satisfy intermediate explanations. Careful weighting of losses and ensuring $F_\ell$ models are lightweight can mitigate this.
  • Choice of Sub-model: The simplest sub-model is a linear probe (a single linear layer mapping hidden state to some predefined labels or vocabulary). Linear probes are easy to interpret (weights highlight which dimensions correlate with which features) and cheap to train (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). They have been widely used to analyze transformers, e.g. testing which layers encode grammatical number or gender. However, linear probes might be too restrictive for rich explanations. More complex sub-models like a two-layer MLP or a mini-transformer decoder can capture nonlinear relations in $h_\ell$ and produce more flexible outputs (like sentences). We propose experimenting with a small transformer decoder as $F_\ell$ that generates a textual explanation from the hidden state. Such a decoder can be thought of as “asking the model to explain itself in English.” It could be trained on a corpus where explanations are available, or even jointly trained with a language modeling objective using the main model’s hidden state as context. Prior work (DecoderLens) has shown that using the model’s own decoder (in an encoder-decoder architecture) to interpret encoder layers yields meaningful word sequences explaining those layers’ content (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). In our approach, if using a decoder sub-model, we might feed $h_\ell$ through a cross-attention mechanism to an explanation decoder that has learned to output sentences like the ones we hypothesized (“Layer $\ell$ thinks…”). This is an exciting direction because it directly produces human-language rationales for each layer’s computations.
  • Computational Trade-offs: Adding sub-models naturally increases the computational load. If we have $N$ layers and attach even a small network to each, inference time could grow (especially if we run all sub-models for every input to get a full explanation). There are ways to manage this: we could run sub-models only when needed (for research/debugging, not every production inference), or design them to be very efficient (e.g. linear heads). Another consideration is memory – storing all intermediate activations $h_\ell$ is already common (if output_hidden_states=True in frameworks), so that’s not a big issue, but running a large decoder on each layer’s output would be heavy. A compromise is to attach sub-models to a subset of layers (e.g. every 2nd or 3rd layer) if fine granularity is not needed. We will demonstrate with code that extracting hidden states is straightforward; training and using sub-models can be scaled according to the desired interpretability depth and resources available. In Section 5, we discuss how our approach compares in cost-benefit to other interpretability techniques.

In summary, our proposed approach transforms a standard transformer into an interpretable multi-stage model. Each layer’s computation is accompanied by a transparent “explanation” via a dedicated sub-model. This yields a high-resolution view of the model’s internal decision process, addressing the black-box criticism. We expect layer-wise explanations to help answer questions like: What does the model know after this layer? What intermediate decision did it make? By making the invisible visible at every layer, we aim to greatly enhance the interpretability of transformer-based AI systems.

4. Implementation and Example Code

Implementing layer-wise interpretability involves two main steps: (1) extracting the intermediate activations from a transformer, and (2) designing/training sub-models to interpret those activations. Modern deep learning frameworks make it easy to access hidden states, and we can leverage that to build our interpretability pipeline. In this section, we provide example code snippets and explanations to illustrate how one might realize the proposed approach in practice. The examples are in Python using the Hugging Face Transformers library and PyTorch, which are common tools for transformer models.

Extracting Hidden States: We first need to obtain the hidden state outputs for each layer of a transformer. In Hugging Face’s API, many models allow this by enabling an output_hidden_states flag. For instance, to get hidden states from BERT or GPT-2:

from transformers import AutoTokenizer, AutoModel

# Load a pretrained transformer (e.g., BERT base) with hidden state output enabled
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_hidden_states=True)

# Encode an example input
sentence = "The capital of France is [MASK]."
inputs = tokenizer(sentence, return_tensors='pt')

# Forward pass to get outputs (including hidden states)
outputs = model(**inputs)
hidden_states = outputs.hidden_states  # tuple of length (num_layers + 1)

print(f"Number of layers (including embedding output): {len(hidden_states)}")
print(f"Shape of hidden state tensor for layer 0 (embeddings): {hidden_states[0].shape}")
print(f"Shape for layer 1: {hidden_states[1].shape}")

This code loads a pre-trained BERT model and runs a sample sentence through it. By specifying output_hidden_states=True, the model returns a tuple of hidden states. Typically, hidden_states[0] is the embedding output (layer 0 before any self-attention), and hidden_states[i] for i=1,...,12 correspond to the outputs of each of BERT’s 12 layers. Each hidden state is a tensor of shape (batch_size, sequence_length, hidden_dim). In our example, sequence_length would be the number of tokens in the input (after tokenization), and hidden_dim is 768 for BERT-base. The above code would print something like: 13 layers (including embeddings), and shapes such as torch.Size([1, 8, 768]) for each, meaning 1 sentence, 8 tokens, 768-dim vector per token.

With these hidden states in hand, we can begin to attach interpretation. Suppose we want to see what each layer would predict for the masked word in the example sentence (a fill-in-the-blank task). We can use a Logit Lens-style approach: take the hidden state at the mask token for each layer, and project it into the vocabulary space using BERT’s output layer. In BERT, predicting [MASK] is done via a classification head tied to the token embeddings, but for simplicity, we can use the pre-trained decoder (if available) or just a linear mapping via the embedding matrix. For demonstration, let’s assume a simpler scenario with GPT-2 (a causal language model), where the final linear layer is directly tied to token prediction. We can retrieve GPT-2’s hidden states and see the model’s top predictions at each layer for the next token:

from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load GPT-2 small with output of all hidden states
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2", output_hidden_states=True)

text = "The capital of France is"
inputs = tokenizer(text, return_tensors='pt')
outputs = model(**inputs)
hidden_states = outputs.hidden_states  # tuple of hidden states

# Get the last token's hidden state from each layer and compute logits
last_token_index = inputs['input_ids'].size(1) - 1  # index of last input token
vocab_matrix = model.transformer.wte.weight.T  # GPT-2 uses tied input embeddings as output weights

for layer_idx, h in enumerate(hidden_states[1:], start=1):  # skip hidden_states[0] which is embeddings
    # hidden_states[n] has shape (batch, seq_len, hidden_dim)
    last_token_hidden = h[0, last_token_index, :]        # vector for last token
    logits = last_token_hidden @ vocab_matrix           # project to vocab (approximate logits)
    top5_indices = logits.topk(5).indices.tolist()
    top5_tokens = tokenizer.convert_ids_to_tokens(top5_indices)
    print(f"Layer {layer_idx} top predictions: {top5_tokens}")

This snippet uses GPT-2 for simplicity (since it directly predicts next tokens). We take the hidden vector corresponding to the last token in the input sequence (“is” in the text “The capital of France is”) for each layer, and multiply by the transpose of the embedding matrix to get pseudo-logits for the next token. Then we find the top 5 predicted tokens. Running this code would output something like:

Layer 1 top predictions: [' of', ' in', ' the', ',', ' and']
Layer 2 top predictions: [' the', ' of', ' in', ',', ' The']
Layer 3 top predictions: [' Paris', ' Lyon', 'The', ' Rome', ' London']
...
Layer 12 top predictions: [' Paris', ' Lyon', ' London', ' Rome', ' Berlin']

These are illustrative results (not exact), but they demonstrate a trend: in early layers, GPT-2 hasn’t narrowed down the next word – it predicts common continuations like “ of” or “ in”. By layer 3 or 4, location names (Paris, Lyon, etc.) start appearing as likely candidates for “The capital of France is _”. By the final layer (12), “ Paris” is the top prediction (the correct completion). This is exactly the latent knowledge refinement we expect, now made explicit per layer. Such intermediate predictions are a simple form of explanation: they tell us what the model is thinking at each layer about the next word. In our framework, one could consider this a very basic sub-model at each layer: a linear map to the vocabulary. Our approach generalizes this idea to more complex and descriptive interpretations, rather than just next-token guesses.

Training Interpretable Sub-Models: Next, consider training a small model to decode some aspect of the hidden state. As a toy example, let’s train a probe to predict whether a sentence is question or statement from each layer’s CLS token in BERT. We generate some labeled data (e.g., sentences ending with question mark vs period), feed them through BERT to get hidden states, then train a logistic regression on the hidden features. Although we cannot fully execute training here, pseudocode would be:

import torch
from sklearn.linear_model import LogisticRegression

# Suppose we have data: list of sentences and labels (1 if question, 0 if statement)
sentences = ["What is your name?", "This is a cat.", ...]  
labels = [1, 0, ...]  
# Get BERT CLS hidden states for each sentence at a certain layer (say layer 8)
cls_features = []
for sent in sentences:
    inputs = tokenizer(sent, return_tensors='pt')
    outputs = model(**inputs)
    h_states = outputs.hidden_states
    # CLS token is usually at position 0 for BERT input (with [CLS] token)
    cls_vector = h_states[8][0, 0, :].detach().numpy()   # layer 8, first token
    cls_features.append(cls_vector)
# Train a simple logistic regression on these features
clf = LogisticRegression().fit(cls_features, labels)
print("Training accuracy:", clf.score(cls_features, labels))
# Inspect coefficients (optional interpretability of probe itself)
important_dims = np.argsort(np.abs(clf.coef_))[::-1][:10]
print("Top contributing hidden dimensions for question detection:", important_dims)

This code sketch shows how one might post-hoc train a probe on layer 8’s CLS token to classify sentence type. If BERT layer 8 encodes question-ness (which is plausible, as middle layers often capture sentence intent or syntax), the logistic regression should achieve good accuracy. The classifier’s weights also tell us which dimensions of the 768-dim hidden state were most useful for this task, giving some insight into the representation. In practice, one could similarly train sub-models to predict a variety of properties: sentiment, tense, topic, presence of certain keywords, etc., at different layers. Each such probe is effectively an interpretation function applied to the hidden state.

From Probing to Explaining: The ultimate goal, however, is to produce human-readable explanations of the model’s “thought process.” Instead of just predicting properties, we might want each sub-model to output text explaining what the layer is doing. One way to achieve this is to use a language-modeling approach: train a small generative model that takes the hidden state as input and outputs an explanatory sentence. This could be done by creating a synthetic dataset: run a variety of inputs through the transformer, record the hidden states, and pair them with explanations. But who writes the explanations? If we have an existing interpretability tool or heuristic, we could generate them, or we could even use a large language model (like GPT-4) to label what each layer might be doing (as OpenAI did for neurons) (Language models can explain neurons in language models | OpenAI). This is an open-ended task, but as a proof of concept, one could manually label a few cases to train a prototype.

For example, we might take a set of sentences and manually note: at layer 2, the model likely identifies proper nouns; at layer 5, it resolves pronouns; at layer 10, it integrates context to decide sentiment. Training on such data (even if small) can allow a decoder to generate similar explanations for new inputs. A pseudo-training loop for a decoder sub-model might look like:

# Pseudo-code for training an explanation decoder for a given layer
explanation_model = SmallTransformerDecoder()  # a small seq-to-seq model

for input_text, gold_explanations in training_data:
    # get hidden states from main model
    inputs = tokenizer(input_text, return_tensors='pt')
    with torch.no_grad():
        outputs = model(**inputs)
    h_state = outputs.hidden_states[layer_of_interest][0]  # take layer's full sequence hidden state

    # Use h_state as context to the explanation model
    # In practice, h_state could be fed through a cross-attention mechanism or concatenated.
    expl_pred = explanation_model.generate(h_state)  
    loss = explanation_model.compute_loss(expl_pred, gold_explanations)
    loss.backward()
    optimizer.step()

Here, gold_explanations would be a target sentence we want the explanation model to output for that layer. Over time, the explanation decoder learns to associate patterns in the hidden state with certain explanation phrases. For instance, if many examples show that whenever the hidden state strongly encodes a negative sentiment, the explanation is “Layer X has identified a negative sentiment,” the decoder can learn to output that when it sees similar activation patterns.

Training such an explanation generator is non-trivial – it requires either a labeled corpus or some automated way to get pseudo-labels. One promising approach, as mentioned, is to use an AI to explain an AI: use a large language model to annotate what each layer might be doing by analyzing activations or the input-output behavior, creating a training set for the smaller explanation model. This falls under AI-assisted interpretability and has been piloted for neuron interpretation (Language models can explain neurons in language models | OpenAI).

Practical Example Summary: As a concrete example tying everything together, consider explaining a sentiment classifier built on a transformer. We feed a movie review into the model and it predicts “Positive”. With layer-wise sub-models, we might get outputs like:

  • Layer 1 sub-model: “Notable words: ‘excellent’, ‘performances’ – likely focusing on these.”
  • Layer 4 sub-model: “Identified movie-related context and some sentiment-laden adjectives.”
  • Layer 8 sub-model: “The tone appears positive (many positive descriptors).”
  • Layer 12 sub-model: “Overall sentiment is positive.”

These could be generated by training each sub-model to recognize sentiments at different granularity. The code snippets above illustrate how to get hidden states and train simple classifiers, which are the building blocks for achieving such explanations. Integrating a fluent language-generating explanation model would build on the same extraction process with a more complex training regime.

Through these examples, we see that implementing layer-wise interpretability is feasible with standard tools. The hidden states are accessible, and we can attach additional computations to derive explanations. The specific design of the sub-model (linear probe vs. decoder) may differ, but the pipeline remains: run model → get hidden state → apply sub-model to produce explanation. In a research setting, one might iterate on the sub-model design and training strategy to find explanations that are both accurate reflections of the model’s computation and easily understood by humans. In the next section, we compare this approach with existing interpretability methods to highlight its advantages and limitations.

5. Comparison to Existing Interpretability Methods

Our layer-wise sub-model approach offers a new perspective on interpretability, but it’s important to situate it relative to established methods. In this section, we compare it with several categories of interpretability techniques, discussing how it differs and where it may offer improvements.

  • Saliency Maps / Feature Attribution: Traditional saliency methods (like gradients, Integrated Gradients, SmoothGrad, LIME, SHAP) identify which input features most influence the model’s prediction (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). They are popular because they are model-agnostic and easy to compute, but they only answer a very limited question: “Which parts of the input mattered?” They do not explain why those parts mattered or how the model processed them. Our approach, in contrast, provides a process-level explanation rather than just an importance ranking. A saliency map might tell us that in a particular movie review, the word “excellent” was important for the positive sentiment prediction. A layer-wise explanation would add: how that word’s effect is processed (e.g., layer 1 finds “excellent” is positive, layer 5 confirms many other positive clues, layer 10 aggregates them to a final sentiment). Thus, we complement saliency maps by revealing internal states corresponding to those salient features. Moreover, saliency methods have known issues with faithfulness (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers) – sometimes they highlight inputs that intuitively make sense to humans but are not actually used by the model, or they miss interactions between features. By reading the model’s actual hidden states, our method is more directly tied to what the model is representing at each step, potentially increasing faithfulness. That said, saliency maps are cheap and can be applied when one only needs a quick input attribution, whereas our method requires instrumentation of the model and possibly training sub-models, which is more involved.
  • Attention Visualization: Inspecting attention weights is a method specific to transformer architectures. It can show, for each token, the distribution of attention given to other tokens. While initially promising as an interpretability tool (Is Attention Interpretable in Transformer-Based Large Language Models? Let’s Unpack the Hype), it’s now understood that attention alone can mislead (Is Attention Interpretable in Transformer-Based Large Language Models? Let’s Unpack the Hype). Our approach is fundamentally different from raw attention analysis: we do not rely on a single mechanism like attention, but instead look at the resultant hidden state after all attention and feed-forward computations. In some sense, attention visualization addresses which inputs influence each layer, whereas our explanations address what information the layer is encoding. They are related – if a layer’s sub-model explanation says “this layer is focusing on the relationship between X and Y,” it might correspond to observing a strong attention between tokens X and Y. But attention visuals require interpretation themselves (one must infer why the model attended to those tokens). Our sub-models explicitly state the reason or content, removing that extra step. Additionally, attention weights can be distributed and not localized, making it hard to draw a clear conclusion, whereas a sub-model can be forced to output a clear categorical or text explanation. In summary, attention analysis provides low-level data, while layer-wise sub-models aim to provide a high-level summary. They are not mutually exclusive – in fact, an interesting integration could be to feed attention patterns as features into a sub-model that then explains “Layer 5 attended mostly to token X while processing token Y, likely to resolve their relationship.” Some recent work projects attention patterns into vocabulary space to find concepts (e.g., attention heads that attend to tokens representing locations or names) (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers), which aligns with the philosophy of interpreting internal components. Our approach generalizes this by not just focusing on attention, but the entire state.
  • Probing Methods: Probing involves training simple classifiers on hidden states to detect if certain information (linguistic features, etc.) is present (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). In fact, our use of sub-models has a strong flavor of probing – each $F_\ell$ can be seen as a learned probe revealing something about layer $\ell$. The difference is in purpose and scope. Traditional probing studies often ask yes/no questions like “Does layer 6 encode part-of-speech info?” by checking probe accuracy on that task. Our approach uses probes (sub-models) in a generative and explanatory way, turning them into narrative devices for individual predictions. Another difference is that probing is typically post-hoc and does not influence the model, whereas we consider possibly integrating the sub-models during training for consistent explanations. A challenge noted in probing literature is selective rigidity: a probe might find information that is linearly decodable from a layer, but that doesn’t mean the model actually uses it – it could be an artifact or unused capacity. By contrast, if we train sub-models to mimic or predict the model’s own outputs (like a tuned lens, which aligns layer representations with final logits (Eliciting Latent Predictions from Transformers with the Tuned Lens)), we improve the chance that the explanation is aligned with what the model truly computes. Nonetheless, our method must be careful about the “probe interpretability” trap: we don’t want sub-models to hallucinate interpretations that are human-plausible but not actually tied to the model. This is why designing the training objectives of $F_\ell$ to be causally linked to model behavior (e.g. predict model’s output or known intermediate labels) is important, ensuring we explain what the model is really doing rather than what we imagine it might do.
  • Causal Interventions (Ablation and Patching): Another line of interpretability research involves intervening on the model – ablating certain neurons or attention heads, or patching activations from one run into another to see how it affects outputs (Section 4: Transformer Understanding) (Section 4: Transformer Understanding). For example, activation patching (also called causal tracing) can identify which layer’s activations carry the information needed for a particular prediction by replacing them with activations from a different context and observing if the prediction changes (Section 4: Transformer Understanding). These methods are powerful for attributing causal responsibility to parts of the network. Our sub-model approach is mostly observational – we read out interpretations but do not by default tell which parts are critical. However, it could be combined with causal methods: if the sub-model explanation at layer 7 mentions a specific fact or feature, one could verify its importance by ablating that information in the hidden state and seeing if the final output changes. In terms of differences, causal methods usually don’t attempt to explain what a circuit does in human terms; they just locate it. Our method would produce the description, which could then guide where to intervene. We see these approaches as complementary: for example, one might first get a layer-wise explanation (say it says “Layer 5 has resolved that ‘he’ refers to John”), and then perform an experiment zeroing out layer-5 features related to coreference to confirm that the model’s answer changes. In summary, causal interventions give evidence of importance, whereas our sub-models give semantic meaning; both are needed for a full picture.
  • Direct Model-Transparent Approaches: Some researchers argue for building inherently interpretable models, like attention-only models with explanations, or concept bottleneck models where the model predicts human-understandable concepts as an intermediate step. Our approach can be seen as inserting a form of concept bottleneck at every layer, but not strictly bottlenecking the model – rather “reading off” concepts. We do not require the main model to only use those concepts; we just extract them. So, we maintain the full power of the black-box model while adding interpretability on the side. Approaches like concept bottlenecks can guarantee interpretability but often at a cost of performance or flexibility, and they require predefined concepts. Our method is more flexible in that sub-models can learn whichever patterns are present in hidden states, which might be complex or uninterpretable in raw form but can be translated (for example, a weird hidden dimension combination might correspond to a concept like “the sentence is in passive voice” which a trained sub-model could recognize). Compared to fully transparent models (like decision trees or sparse logical models), we do not achieve the same level of simplicity, but we aim to get the best of both worlds: high performance of transformers with an overlay of interpretability.
  • Logit Lens and Tuned Lens: A specific comparison is to the Logit Lens (Eliciting Latent Predictions from Transformers with the Tuned Lens) and its improved version the Tuned Lens (Eliciting Latent Predictions from Transformers with the Tuned Lens), since these are perhaps the closest existing ideas to layer-wise interpretation. The Logit Lens, as discussed, takes each layer’s hidden state and directly decodes it into an output (like vocabulary logits), essentially giving a rough prediction at each layer. This is a special case of our approach where the sub-model is “the final linear layer of the network (or its transpose) applied at every layer.” It provides insight into the model’s evolving prediction. The Tuned Lens augments this by actually training a small affine transformation for each layer to better predict the final logits (Eliciting Latent Predictions from Transformers with the Tuned Lens), acknowledging that each layer’s basis might differ. These techniques validate that intermediate layers contain predictive information, and they produce a kind of explanation: “what does the model think so far?” However, they don’t explain why the model is thinking that. Our layer-wise sub-models generalize the lens idea to richer explanations. Instead of just outputting the model’s implicit next-word guess or classification at layer $\ell$, we want to output why that guess is made – e.g., which features were recognized. So one could say we are adding an interpretability layer on top of the logit lens. We also allow sub-models to output things that are not the final prediction, such as intermediate decisions or features (which logit lens cannot directly do since it’s tied to final logits). In terms of performance, the tuned lens already shows that even a very lightweight sub-model (one affine layer per transformer layer) can closely match the model’s own outputs (Eliciting Latent Predictions from Transformers with the Tuned Lens), indicating that our approach is feasible without requiring a massive sub-model – the information is readily present.

Strengths of Layer-Wise Sub-Models: The primary strength of our approach is the granularity and richness of the explanations. By covering each layer, we obtain a stepwise trace of the model’s computation in human-interpretable terms. This can be invaluable for debugging and understanding failure modes. For instance, if a model answers a question incorrectly, a layer-by-layer explanation might reveal that it actually picked up the correct answer in an intermediate layer but then lost it (perhaps due to a later layer’s interference) – information that neither a saliency map nor a final explanation alone would show. It could also reveal if the model is relying on spurious features at some stage (e.g., an explanation says “layer 3 focuses on the user’s name when answering a medical question”, which might hint at bias). Compared to single-shot explanations (like explaining the final decision after the fact), our method is more faithful by construction: it uses the model’s actual internal data at each point, reducing the chance of rationalizing or ignoring what actually happens inside. Additionally, these sub-model explanations could be made model-agnostic to some extent – one could potentially train a universal explainer for a given transformer architecture, though that is speculative.

Weaknesses and Challenges: One obvious cost is the computational and implementation overhead, as discussed. Another key challenge is evaluation: how do we measure the quality of the explanations from sub-models? They could be incorrect or insufficient. If a sub-model is poorly trained or the task is too hard (say trying to generate a perfect English description of a hidden state), it might output misleading information. This is dangerous because it could give a false sense of understanding. We must ensure that sub-model outputs are validated – either by checking alignment with known attributes (if available) or via human expert review in critical cases. There is also the risk that forcing interpretability could constrain the model. If used during training, the model might overly optimize for producing something interpretable at each layer (which might degrade final performance or hide info in ways the sub-model can easily explain, a kind of Goodhart’s law effect). We should manage this by keeping sub-models “on the side” rather than in full control of representations, or by iterative refinement where we ensure they truly capture what’s needed. Another weakness is that our approach currently provides a high-level description per layer, but not a full mechanistic understanding of the computations (we aren’t explicitly identifying which neurons or weights implement that step). So, it may not directly solve the problem of understanding the network at the circuit level – it’s more of an observability tool. In comparison, methods like causal abstraction aim to map the entire model onto a human-understandable algorithm (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers), which is a stronger guarantee but far harder to do on complex models. Our method is a more pragmatic middle ground: easy to deploy, yields understandable info, but not a formal explanation of internal mechanics.

Integration with Other Frameworks: We foresee that layer-wise explanations could be integrated into existing interpretability dashboards and frameworks. For instance, Google’s Language Interpretability Tool (LIT) or other interactive visualization systems could show, for each layer slider, the output of the sub-model explanation. One could imagine an interactive interface where a user hovers over a transformer’s layers and sees a tooltip like “Layer 5: combining clause information; likely doing coreference” – all generated by our approach. This would add a new dimension to such tools, which currently might show attention weights or neuron activations without context. Furthermore, these explanations could feed into accountability mechanisms: e.g., in sensitive applications, the system could log the layer-wise reasoning for each decision, so that if a bad outcome occurs, auditors can inspect not just the input-output correlation but the internal decision trail.

In conclusion, compared to existing methods, Layer-Wise Sub-Model interpretability is a more process-oriented and fine-grained approach. It does not replace feature attribution or causal analysis, but augments them by filling in the blanks of what the model actually does at each step. Its strength lies in creating an explanatory narrative from within the model, offering potentially greater insight and trust. Its challenges lie in ensuring those narratives are accurate and efficiently obtained. The next section will reference related research that inspired this approach and provide context, before we discuss future implications.

6. References to Related Research

Our approach builds upon and intersects with several important lines of AI interpretability research. Below, we highlight key related works and how they inform or contrast with the Layer-Wise Sub-Model approach:

  • Mechanistic Interpretability and Circuits: Olah et al. (2020) and colleagues pioneered Circuits analysis, breaking down neural networks into interpretable pieces (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). They manually traced computations in vision models and small language models, identifying meaningful circuit motifs (e.g., edge detectors, curve detectors in vision, or syntax heads in language). Elhage et al. (2021) further developed a mathematical framework for reverse-engineering transformers, successfully dissecting small models (like two-layer attention-only models) into algorithmic components (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). They introduced concepts like induction heads – attention heads that help a model continue a repeated sequence, effectively implementing a copy mechanism in context (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models) (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). Our work is inspired by these findings, as they demonstrate that even black-box transformers have learnable algorithms inside. However, instead of manually finding circuits, we automate the explanation via sub-models. Mechanistic interpretability provides ground truth for what kinds of computations occur (e.g., an induction head), and a sophisticated sub-model could potentially recognize and explain “this layer is performing an induction step on a repeated sequence.” Recent works like Transformer Circuits (Olah, 2022) and others have catalogued numerous such phenomena, giving us a vocabulary of mechanisms that we would like our explanations to capture (e.g., “negative head suppressing duplicate token” (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models), “feed-forward layer storing factual memory” (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models)).
  • Neuron-Level Interpretability: Research by Bau et al. (2018) on Network Dissection showed that individual neurons in CNNs often correspond to human-recognizable visual concepts (like “tree neuron” or “dog face neuron”). In NLP, neurons are more entangled, but there has been progress: Dalvi et al. (2019) and others probed individual neurons in machine translation models, finding neurons that track specific linguistic features (like formality). Goh et al. (2021) discovered multimodal neurons in CLIP models that respond to high-level concepts (e.g., a neuron that fires for “Spider-Man” whether as text or image). OpenAI’s 2023 work took this further by using GPT-4 to generate explanations for neurons in GPT-2 (Language models can explain neurons in language models | OpenAI), as discussed earlier. These works emphasize interpreting units of the network. Our layer-wise approach operates at a higher level (the aggregate effect of a whole layer), but it could benefit from neuron-level insights. For example, if neuron $i$ in layer 5 is known to track sentiment, and our layer-5 sub-model outputs “the sentence sentiment is positive,” that’s a nice confirmation that our sub-model picked up what that neuron was doing. Conversely, if our explanation says something is happening, one could drill down and identify the neurons or heads responsible, essentially linking our high-level explanations to low-level ones. Projects like CircuitGPT or Automated Circuit Discovery aim to algorithmically find sets of neurons/heads that implement functions (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models). Those could be used to generate structured explanations (“Layer 7 contains a 2-head circuit that performs X”), which is a complementary direction – bridging our explanatory text to actual circuit diagrams.
  • Probing and Layer Analysis: A significant body of work has analyzed what different transformer layers encode. For example, Tenney et al. (2019) “BERT Rediscovers the NLP Pipeline” used a suite of probing tasks to show that BERT’s layers progressively handle syntax then semantics ([1905.05950] BERT Rediscovers the Classical NLP Pipeline). Jawahar et al. (2019) similarly found lower layers encode surface features, middle layers syntactic relations, upper layers semantic abstractions. Our approach is a natural continuation: given that each layer has a predominant role, we attempt to articulate that role for each input. There is also work on early exiting: e.g., Liu et al. (2020) and Zhou et al. (2020) trained classifiers on top of each layer so that if an early layer is already confident enough, the model can exit with a prediction to save computation (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). This demonstrates that intermediate layers can be pushed to produce outputs. Our method can be seen as early exiting for explanations – every layer “exits” to an explanation even if not to a final prediction. In fact, ideas from early-exit training (like how to avoid interfering with later layers, how to gauge confidence) are useful for us too. Another related concept is auxiliary heads used during training (e.g., InceptionNet used intermediate classifiers to help train deep networks). We similarly add heads, though for interpretability rather than purely for training support.
  • Attention Interpretation Debate: We have touched on this, but to cite explicitly: Jain & Wallace (2019) “Attention is not Explanation” argued against naive use of attention (Is Attention Interpretable in Transformer-Based Large Language Models? Let’s Unpack the Hype), while Wiegreffe & Pinter (2019) “Attention is not not Explanation” offered a rebuttal that attention can provide insight if used properly (Is Attention Interpretable in Transformer-Based Large Language Models? Let’s Unpack the Hype). Our work sidesteps this debate by not focusing solely on attention weights. However, we note research like Chefer et al. (2021) and Abnar & Zuidema (2020) that tried to aggregate attention across layers or apply gradient adjustments to get better attributions from transformers. Those methods aim to produce a single importance map that accounts for multi-hop attention paths. In contrast, our method outputs distinct information per layer, which might actually help understand those multi-hop paths in a clearer way (layer by layer rather than a diluted aggregate).
  • Interpretability in Other Modalities: While our discussion is framed around language transformers, the concept of layer-wise explanations could transfer to vision transformers or multimodal models. Research in vision interpretability, such as feature visualization (Olah et al., 2017), generates images that maximize neuron activations (27 Learned Features – Interpretable Machine Learning). This is somewhat analogous to the logit lens but for vision – it asks “what input would cause this internal neuron to activate strongly?” For a layer-wise view in vision, one might generate an image or textual description of what each layer is detecting (e.g., “Layer 3: edges and textures; Layer 6: object parts like wheels or eyes; Layer 10: full objects like cars or animals”). Some recent works have started to analyze vision transformer layers in this way, finding that early layers act like CNN filters and later ones have token mixing patterns. Our approach could potentially attach a small CNN or captioning model to each layer of a vision transformer to describe its activation patterns in words (“this layer highlights regions that look like fur”).
  • Tuned Lens and Patch:** As mentioned, the Logit Lens (Nostalgebraist, 2020) and Tuned Lens (Bau et al., 2022) provided a simple but powerful tool for interpreting layers (Eliciting Latent Predictions from Transformers with the Tuned Lens) (Eliciting Latent Predictions from Transformers with the Tuned Lens). Additionally, recent work by Belrose et al. (2023) formally introduced the Tuned Lens concept, and Halawi et al. (2022) used a logit lens to examine how few-shot examples are processed. Dar et al. (2023) projected not just hidden states but also attention patterns into the vocabulary space, finding coherent concepts (like certain attention heads consistently attending to person names) (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). Ghandeharioun et al. (2024) proposed PatchScopes, a general framework for training auxiliary models as “lenses” to interpret representations in expressive ways (DecoderLens: Layerwise Interpretation of Encoder-Decoder Transformers). PatchScopes can be seen as a sibling idea to our sub-models – they tune small models to map internal vectors to human-interpretable distributions (like vocabulary or concepts). Our approach is very much aligned with these, with the main extension being the emphasis on layer-by-layer explanatory narrative. We incorporate the spirit of Tuned Lens (alignment transformations per layer) and PatchScopes (auxiliary interpretive models) but apply it to generating natural explanations and not only predicting final outputs.
  • Causal Abstraction and Model Editing: Another relevant area is causal abstraction, where one defines a high-level model (like a symbolic model or a simplified flowchart) and tries to map the neural network onto it, verifying that certain latent variables correspond to concepts. For example, Geiger et al. (2021) attempted to align transformer computations with a human-readable algorithm by abstracting each layer’s function. Our layer-wise explanations could serve as a bridge to such high-level abstractions: the sub-model outputs might act as the “interpretable variables” that a causal abstract model would have. There’s also model editing work (like ROME, MEMIT) which locates factual associations in specific layers of transformers and intervenes to change model knowledge. Those findings often pinpoint a particular layer’s feed-forward neurons as carrying a fact (e.g., “Paris is capital of France” might be mainly stored in some mid-layer MLP) – which our explanation for that layer might surface (“Layer 10 recalls: Paris → capital of France”). So, our method could help identify where a fact is being applied, aligning with model editing insights on where facts live.

In summary, our approach is informed by a broad swath of research: from the detailed dissection of networks at the neuron/circuit level (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models) (A Practical Review of Mechanistic Interpretability for Transformer-Based Language Models), to probing studies that treat layers as linguistic feature extractors ([1905.05950] BERT Rediscovers the Classical NLP Pipeline), to recent lens methods that insert interpretability transforms into networks (Eliciting Latent Predictions from Transformers with the Tuned Lens). We contribute a unifying perspective that ties these together into a single framework of layer-wise transparent interpretation. By citing and building on these works, we stand on the shoulders of prior interpretability research while pushing toward more comprehensive and accessible explanations of transformer models.

7. Future Implications

Adopting layer-wise sub-model interpretability could have far-reaching implications for the development, deployment, and oversight of AI systems. In this final section, we explore potential impacts, extensions of the approach, and open questions that remain to be addressed.

Shaping AI Regulation and Accountability: As AI systems are used in high-stakes decisions (medical, legal, financial), regulators are increasingly interested in “Explainable AI” and even legal rights to an explanation. If every layer of a model can be interpreted and explained, we move closer to AI whose decision process can be audited in detail. For instance, in an AI-driven loan approval system using a transformer to analyze customer data, a layer-wise explanation might show: layer 3 assessed credit history, layer 5 assessed income stability, layer 8 combined these into a risk profile, final layer made a decision. Such a multi-step explanation could satisfy regulatory demands by demonstrating a logical progression (and allowing one to point out, say, bias if one layer put undue weight on a sensitive attribute). In essence, this approach could provide a traceable chain of reasoning that is missing in end-to-end deep models today. Moreover, if an AI system causes harm or makes a mistake, having intermediate explanations could clarify where the flaw occurred. Was it a misinterpretation of input early on, or an unreasonable combination of factors at the end? This could assign responsibility to specific components or phases of the model’s processing. Organizations might maintain logs of the layer-wise explanations for critical decisions, creating an audit trail for later review. Of course, one must ensure these explanations are stored and communicated properly, and that they are comprehensible to stakeholders (perhaps translating technical jargon to lay terms). But overall, accountability is bolstered when a model can effectively say “here’s what I was thinking at each step” instead of just “here’s my final answer.”

Real-Time Interpretability Tools: One exciting practical extension is the development of real-time dashboards or monitoring systems powered by layer-wise interpretability. Imagine a debugging interface for an NLP model where as you input a sentence, you see a live update of each layer’s interpretation. Such a tool would be invaluable for researchers and engineers. They could feed in examples and immediately spot if, say, layer 5 is generating an incorrect intermediate conclusion. For example, a dashboard could highlight: Layer 5 misunderstanding: it thinks “bank” refers to a river bank instead of a financial bank. This early misinterpretation could be corrected (by adjusting training data or model architecture) before it propagates to an incorrect final answer. In deployed systems, one could use simplified versions of this idea to detect anomalies. If a layer’s explanation deviates from expected patterns (perhaps a security chatbot’s layer 7 suddenly starts showing content indicating the model is interpreting a harmless query as an attack due to some quirk), it could raise a flag to a human overseer or trigger a safe intervention. This is analogous to a pilot watching an altimeter and airspeed meter – our approach provides multiple “meters” inside the model that can be monitored to ensure it’s on the right track.

Guiding Model Improvement and Safety: Layer-wise explanations not only explain models, but might also guide how we improve them. For instance, if we find that layer 4 consistently produces an explanation that is logically redundant or irrelevant to the final task, it might indicate that the model has an inefficient layer. We could try pruning that layer or merging its function with another. Or if a certain type of error often arises from a wrong intermediate decision (say in a translation model, layer 6 explanation reveals it consistently gets verb tenses wrong before final output), we can target that with additional training data or an architectural tweak at that layer (like an auxiliary loss to enforce tense correctness). From a safety perspective, if we can identify layers that handle potentially risky transformations of input (e.g., in a dialogue model, perhaps a layer that goes from understanding user query to formulating a raw response), we might focus security measures there. In other words, understanding which layer does what allows for modular risk assessment – some layers might be more critical to keep interpretable and controlled (like those dealing with factual recall to prevent misinformation) while others less so.

Shaping Future Architecture Design: If layer-wise interpretability proves valuable, architects of new models might start designing with interpretability in mind. This could mean building models with explicit intermediate interpretable representations (not unlike how traditional software has comments or logs at checkpoints). For example, future transformer variants might include dedicated “explanation vectors” that are meant to carry a summary of that layer’s state in human-interpretable form. These could be encouraged via training to align with certain concepts. Alternatively, we might see hybrid models where each layer actually outputs a short description that is fed into the next layer alongside the usual hidden state – essentially, layers communicating with each other in a human-readable language. This is speculative, but it resonates with the idea of forcing models to “think step by step” (as in chain-of-thought prompting) except internally. If a model can internally generate a chain-of-thought that is also understandable to us, it would be a breakthrough for transparency. Our sub-model technique is a step in that direction, even if currently the chain-of-thought is extracted post hoc rather than generated intrinsically.

Shifting Culture in Model Evaluation: With more interpretable models, the community might begin to expect explanations as a standard part of model output. It’s possible that leaderboards and benchmarks evolve to not only consider accuracy on tasks, but also the quality of explanations provided. If our approach becomes efficient, one could imagine requiring models to output a justification at each step for sensitive tasks (like medical diagnosis). This could be scored or at least qualitatively evaluated. It might discourage purely black-box solutions if a model that can explain its intermediate steps is seen as more trustworthy or superior, even if raw accuracy is similar. Over time, this could lead to a norm where interpretability is built-in and not just an afterthought. Our approach could serve as a template for how to integrate and evaluate such capabilities.

Open Questions: Despite the promise, there are many open research questions and challenges to be explored:

  • How to ensure faithfulness? We need robust methods to verify that sub-model explanations truly reflect the causal factors in the main model’s decisions. Techniques like causal analysis or counterfactual checking (altering input or hidden state and seeing if explanation changes accordingly) could be used to test faithfulness. Developing quantitative metrics for explanation quality is an open area – e.g., measures of alignment between explanation and model behavior, or how well humans can predict the model’s output from the explanations alone.
  • What is the optimal complexity for sub-models? If the sub-model is too simple, it might not capture the nuance of the layer’s state; if too complex, it might itself be a black box or overfit. Research is needed to find the sweet spot (e.g., is a 1-layer probe enough, or do we need a full transformer decoder?). Also, do we need one sub-model per layer, or can a single parametric model handle multiple layers’ interpretations if given the layer index as input (reducing total parameters)?
  • How to scale to very large models? Modern large language models can have hundreds of layers. It’s impractical to have an independent, heavy explanation for each. We might need techniques to select important layers to explain (perhaps layers where big changes in the representation happen). Or use sparse explanation triggers (only produce detailed explanation when something notable occurs, otherwise default to briefer info). There’s also the question of whether our approach can handle models with billions of parameters, or if there are stability issues (e.g., if different layers encode entangled concepts, can our method disentangle them well enough for explanation?).
  • Could models deceive their own sub-models? In an adversarial scenario, one might imagine a model learning to encode information in forms that the sub-model can’t easily interpret (if, say, the model “knows” it’s being watched and doesn’t want to reveal something). This is a somewhat far-fetched scenario for now, but as models get more agentic or trained with self-awareness, they might game the interpretability. Ensuring robust interpretation that can’t be tricked is an open problem (it touches on AI safety – making sure the AI can’t hide its true intentions).
  • User Understanding: Even if we produce explanations, will end-users (or even developers) always understand them? If an explanation says “Layer 7: activated a semantic induction head linking subject and object in the query,” that might confuse a non-expert. How do we translate our internal explanations into simpler terms or visualize them for different audiences? Perhaps a two-tier explanation, where a technical one is available for developers and a high-level one for users.
  • Generality: We have discussed mainly transformer encoders (like BERT) or decoders (like GPT). How would this work in other architectures? For RNNs (less used now, but conceptually), one could also apply layer-wise or time-step-wise sub-models. For non-NLP domains (reinforcement learning policies, for example), could an agent’s policy network layers be explained in terms of intermediate decisions or value estimates? If an RL agent has a vision module then a decision module, maybe those are natural layers to explain (“vision recognized an obstacle, decision layer chose to jump”). Extending these ideas to different AI systems will be important to make interpretability ubiquitous.

Realistic Optimism: In the near future, we anticipate researchers demonstrating layer-wise interpretability on more complex tasks and models, perhaps starting with medium-sized models where it’s tractable to label and evaluate the explanations. If successful, this approach can be one piece of the puzzle for trustworthy AI. It won’t solve everything – we still need good user experience design to present explanations, and rigorous validation – but it provides a technical pathway to illuminating the black box from inside out. Ultimately, the vision is that asking a transformer “What are you thinking?” will be a standard capability, and the transformer (through our sub-models) will answer by revealing its chain of thought in a way we can verify and understand. This would mark a significant paradigm shift from the opaque deep learning models of the past toward AI systems that expose their reasoning and earn our trust through transparency. We believe Layer-Wise Sub-Model interpretability is a promising step in that direction, and we encourage further exploration and refinement of this approach in the research community.