Ablation Based Counterfactuals

Zheng Dai1, David K Gifford1

1Computer Science and Artificial Intelligence Laboratory, Massachusetts Institute of Technology
Contact: *@mit.edu (replace "*" with "zhengdai")

Abstract

Motivation

We tackle the challenge of analyzing the relationship between diffusion models and their training data. Diffusion models are a class of generative models that generate high-quality samples, but at present it is difficult to characterize how they depend upon their training data due to the complexity of these models and their sampling process. The analysis of this dependence has both scientific and regulatory applications due to the widespread adoption of these models.

Our Contributions

  • We introduce Ablation Based Counterfactuals (ABC), a method of performing counterfactual analysis that relies on model ablation rather than model retraining, enabling us to analyze the extremely complex relationship between diffusion models and their training data.
  • We show that we can train a model that can be ablated without loss of function by training an ensemble of diffusion models.
  • We use ablation to study the limits of training data attribution by enumerating full counterfactual landscapes, and show that single source attributability diminishes with increasing training data size, culminating in unattributable samples, whose existence we demonstrate.

Technical Details

Counterfactual Analysis

To discover what the effects of a piece of training data was on a generated sample, we consider the counterfactual question: "what would have been generated if the piece of data in question was missing?". This can be computed by going back to the training set, removing the piece of data in question, and training a new model on incomplete dataset. We can regenerate the sample using the newly trained model to produce a counterfactual sample, a sample that would have been generated if the model had been trained on the incomplete dataset. The counterfactual can then be compared with the factual sample generated by the original model trained on the full dataset to assess the effect of that piece of training data. We describe this type of analysis as counterfactual analysis.

Performing counterfactual analysis by retraining models on incomplete datasets is extremely computationally expensive. Explore an entire counterfactual landscape (i.e. all possible counterfactual samples that can be produced by leaving out one piece of training data), which we need to do in order to fully understand the contributions of all parts of the training dataset, is effectively impossible.

Figure 1. Cartoon illustration of the causal relationship between the training data and the final generated sample. The training data (left) is used to train a model (middle), which is then used to generate a sample (right). Hover over the training set to samples from the training data. Removing samples leads to changes in the final generated sample. The changes in the final generated sample can be used to infer what the model learned from the training data.

Computing Counterfactual Samples Efficiently via Ablation

Instead of training a single model on the entire training dataset, we train sets of model parameters independently on different subsets of the training dataset. These independently trained parameters are then combined into a single model. If we then use the overall model to generate a sample, we can remove the causal influence of any given piece of training data by removing all parameters of the model that were trained on it. Since this ablated model has effectively never seen the training data that we removed, we can use it to generate counterfactual samples for counterfactual analysis.

A way of creating an ablatable model is to train a set of models and combine them into an ensemble, where the ensemble output is the arithmetic mean of the outputs of its members.

Figure 2. Using ablation to compute counterfactuals with a digit generator trained on 384 digits (only 10 training digits are shown on the left, click here to show all 384). Hover over the training set on the left to highlight in red the flow of the causal influence from the training data (left) to the independently trained model parameters (small circles) to the combined model (large circle) to the generated sample (right). Click a member of the training set on the left to remove it via albation, and click it again to unablate. Note how the flow of causal influence is broken, and how it is not broken for any other member of the training dataset. Click here to generate a different image.

Results

Ensembles of Diffusion Models are Viable Image Generators

In order to use ablation, we must construct a model with sufficient redundancy where groups of paramters can be ablated without loss of function. An ensemble of identical model architectures fulfills this requirement, where the output of the ensemble is the arithmetic average of the model outputs. We find that such ensembles of diffusion models are viable image generators.

Figure 3. Examples of images that were sampled by diffusion model ensembles. Each image was generated by a different ensemble. Click here to generate different images.

Attributability Diminishes with Training Dataset Size

If the counterfactual landscape consists of samples that are all similar to the generated sample, then we must conclude that no training sample contributed significantly to the generation of that sample. We operationalize this notion of diminished attributability by defining the counterfactual radius: the largest possible distance between a generated sample and a counterfactual sample. Another way of understanding the counterfactual radius is that there exists no counterfactual sample that lies outside the ball surrounding the original sample whose radius is the counterfactual radius. A low counterfactual radius corresponds with low attributability.

We compute the counterfactual radius of generated samples for 23 different models that were trained on varying datasets, and find that there is a strong negative relationship between the training set size and the counterfactual radius.

100
10
100
1000
10000
100000
Dataset Size
Counterfactual Radius

Figure 4. The counterfactual radius is negatively correlated with the training data size. Given a model, we plot the training set size on the x-axis. We plot the geometric mean of the counterfactual radii of the samples generated by the model on the y-axis. Hover over a point to expand it into a box and strip plot of the individual samples before their geomteric means were taken (note that the x coordinate of the strip plot points are meaningless) and to see the size of the samples' counterfactual landscape and the dataset that was used to train it. The size of the counterfactual landscape does not always match the training set size, since for some datasets we remove multiple samples at once (for example, for CelebA datasets we remove all samples that correspond to a single individual at once). Counterfactual radii are measured in Euclidean distance, and all images were scaled to 256-by-256 and 3 channels to ensure consistency between datasets.

Samples can be Unattributable

It is possible for the counterfactual radius to be zero. This can occur when we sample from discrete spaces, such as handwritten digits where each pixel value is either black or white. We refer to a sample with a zero counterfactual radius as unattributable, since for every training sample, we can produce a counterfactual sample that certifies that the outcome of the sampling process would be entirely unchanged if that training sample had not existed.

Figure 5. Examples of unattributable samples generated from a model that generates binary images of handwritten digits.

Check out the preprint for more details
Scripts for running the experiments

Citation

"Z. Dai and D. Gifford. "Ablation Based Counterfactuals.". In: ArXiv. 2024. url: https://arxiv.org/abs/2406.07908."