on
Influence Functions in Machine Learning
- Introduction
- Robust Statistics
- Influence Functions on Groups
- The Calculation Bottleneck
- The Problem with Influence Functions
- Libraries
- Applications
- Conclusion
- References
Introduction
With the increasing complexity of machine learning models, the generated predictions are not easily interpretable by humans and are usually treated as black-box models. To address this issue, a rising field of explainability try to understand why those models make certain predictions. In recent years, the work by (Koh & Liang, 2017) has attracted a lot of attention in many fields, using the idea of influence functions (Hampel, 1974) to identify the most responsible training points for a given prediction.
Robust Statistics
Statistical methods rely explicitly or implicitly on assumptions based on the data analysis and the problem stated. The assumption usually concerns the probability distribution of the dataset. The most widely used framework makes the assumption that the observed data have a normal (Gaussian) distribution, and this classical statistical method has been used for regression, analysis of variance and multivariate analysis. However, real-life data is noisy and contain atypical observations, called outliers. Those observations deviate from the general pattern of data, and classical estimates such as sample mean and sample variance can be highly adversely influenced. This can result in a bad fit of data. Robust statistics provide measures of robustness to provide a good fit for data containing outliers (Maronna et al., 2006).
Influence Functions
The Influence Functions (IF) was first introduced in “The Influence Curve and Its Role in Robust Estimation” (Hampel, 1974), and measures the impact of an infinitesimal perturbation on an estimator. The very interesting work by (Koh & Liang, 2017) brought this methodology into machine learning.
Influence Functions in Machine Learning
Consider an image classification task where the goal is to predict the label for a given image. We want to measure the impact of a particular training image on a testing image. A naive approach is to remove the image and retrain the model. However, this approach is prohibitively expensive. To overcome this problem, influence function upweight that particular point by an infinitesimal amount and measure the impact in the loss function without having to train the model.
Figure 1: The fish image is upweighted by an infinitesimal amount so the model try harder to fit that particular sample. Image by the author.
Change in Parameters
The empirical risk minimizer to solve an optimization problem can be defined as the following:
\[\begin{equation} \hat\theta = arg \; \underset{\theta}{min} \frac{1}{n} \sum_{i=1}^{n} \mathcal{L}(z_i, \theta) \end{equation}\]Where \(z_i\) is each training point from a training sample. First, we need to understand how the parameters \(\hat\theta\) change after perturbing a particular training point \(z\) by an infinitesimal amount \(\epsilon\), defined by \(\theta - \hat\theta\) where \(\theta\) is the original parameters for the full training data and \(\hat\theta\) is the new set of parameters after upweighting:
\[\begin{equation} \hat\theta_{\epsilon,z} = arg \; \underset{\theta}{min} \frac{1}{n}\sum_{i=1}^{n}\mathcal{L}(z_i,\theta) + \epsilon \mathcal{L}(z,\theta) \end{equation}\]As we want to measure the rate of change of the parameters after perturbing the point, the derivation made by (Cook & Weisberg, 1982) yields the following:
\[\begin{equation} I(z) = \frac{d\hat\theta_{\epsilon,z}}{d\epsilon} \bigg|_{\epsilon=0} = -H_{\hat\theta}^{-1}\nabla_{\theta} \mathcal{L}(z,\hat\theta) \end{equation}\]Where \(H_{\hat\theta}\) is the Hessian matrix and assumed to be positive definite (symmetric with all positive eigenvalues), which can be calculated by \(\frac{1}{n}\sum_{i=1}^n \nabla_{\theta}^2 \mathcal{L}(z_i,\hat\theta)\).
The equation \(3\) gives the influence of a single training point z on the parameters \(\theta\). When multiplying \(-\frac{1}{n} I(z)\) the result is similar as removing \(z\) and re-training the model.
Change in the Loss Function
As we want to measure the change in the loss function for a particular testing point, applying chain rule gives the following equation:
\[\begin{equation} I(z, z_{test}) = \frac{d L(z_{test},\hat\theta_{\epsilon, z})}{d\epsilon} \bigg|_{\epsilon=0} = -\nabla_\theta \mathcal{L}(z_{test},\hat\theta)^T H_{\hat\theta}^{-1} \nabla_\theta \mathcal{L}(z,\hat\theta) \end{equation}\]\(\frac{1}{n} I(z, z_{test})\) approximately measures the impact of \(z\) on \(z_{test}\). This is based on the assumption that the underlying loss function is strictly in the parameters \(\theta\). Some loss functions are not differentiable (), so in this case, one of the contributions of Koh’s work is to approximate to a differentiable region right at the margin.
Influence Functions on Groups
As previously seen, the influence functions measure the impact of a training point in a single testing point. They are based on first-order , which is fairly accurate for small changes. In order to study the effect of a large group of training points, (Koh et al., 2019) analyze this phenomenon where influence functions can be used for some particular cases. It can be written as the sum of the influences of individual points in a group:
\[\sum_{i=1}^n I(z_i, z_{test}) = -\nabla_\theta \mathcal{L}(z_{test},\hat\theta)^T H_{\hat\theta}^{-1} \sum_{i=1}^n \nabla_\theta \mathcal{L}(z,\hat\theta)\]Given a group \(\mathcal{U}\) and \(I(\mathcal{U})^{(1)}\) the first-order group influence, (Basu et al., 2020) proposes second-order group influence function to capture informative cross-dependencies among samples:
\[I(\mathcal{U})^{2} = I(\mathcal{U})^{(1)} + I(\mathcal{U})^{'}\]Hence, first-order group influence function \(I(\mathcal{U})^{(1)}\) can be defined as:
\[I(\mathcal{U})^{(1)} = \frac{\partial \theta_{\mathcal{U}}^{\epsilon}}{\partial \epsilon} \bigg|_{\epsilon=0}\]And the second-order group influence \(I(\mathcal{U})^{'}\) as:
\[I(\mathcal{U})^{(1)} = \frac{\partial^2 \theta_{\mathcal{U}}^{\epsilon}}{\partial \epsilon^2} \bigg|_{\epsilon=0}\]This technique was empirically proven that can be used to improve the selection of the most influential group for a test sample across different group sizes and types. The idea is to capture more information when the changes to the underlying model are relatively large.
The Calculation Bottleneck
Computing the inverse hessian is quite expensive and infeasible for a network with
lots of parameters. In numpy, it can be calculated using numpy.linalg.inv
.
As a side note, numpy is mostly written in c and the high-level functions are
python bindings. Nevertheless, it is still an expensive function. In
PyTorch framework, you can compute the Hessians using torch.autograd.functional.hessian
and then inversing it with torch.linalg.inv
. I’m going to expand a little bit
here using examples because this is a bit tricky. The module nn.torch
contains different classes that provides useful methods for models that inherit
nn.Module
.
funcional modules takes NN modules and turn them in purely functional stateless so you can explicitely pass parameters to a function.
torch.autograd.functional
requires to pass the paramenter to a
function (see the long discussion here).
Conjugate Gradients
Conjugate gradient (Shewchuk, 1994) is an iterative method for solving large systems of linear equations, and it is effective to solve systems in the form of \(Ax = b\). In (Martens, 2010), the hessian is calculated by approximation using second-order optimization technique. This method does not invert the hessian directly but calculate the inverse hessian product:
\[H^{-1} v = arg min_{t}(t^T Ht - v^Tt)\]Linear Time Stochastic Second-Order Algorithm (LiSSA)
The main idea of LiSSA (Agarwal et al., 2017) is to use Taylor expansion (Neumann series) to construct a natural estimator of the inverse Hessian:
\[H^{-1} = \sum^{\infty}_{i=0} (I - H)^i\]Rewriting this equation recursively, as \(\lim_{j \to \infty} H_{j}^{-1} = H^{-1}\), we have the following:
\[H_{j}^{-1} = \sum^{j}_{i=0} (I - H)^i = I + (I - H) H^{-1}_{j-1}\]FastIF
In order to improve the scalability and computational cost, FastIF (Guo et al., 2021) present a set of modifications to improve the runtime. The work uses k-neareast neighbours to narrow the search space down, which can be inexpensive for this context since i k-nn is a ) algorithm.
The Problem with Influence Functions
Influence functions are an approximation and do not always produce correct values. In some particular settings, influence functions can have a significant loss in information quality. It is known to work with convex loss functions, but for non-convex setups, the estimations can not work as expected. The work ‘Influence Functions in Deep Learning are Fragile’ (Basu et al., 2021) examines the conditions where influence estimation can be applied to deep networks through vast experimentation. In short, there are a few obstacles:
- The estimation in deeper architectures is erroneous, possibly due to poor inverse hessian estimation. Weight-decay regularization can help.
- Wide networks perform poorly. When increasing the width of a network, the correlation between the true difference in the loss and the influence function decreases substantially.
- Scale influence functions is challenging. ImageNet contains 1.2 million images in the training set, being difficult to evaluate if influence functions are effective since it is computationally prohibitive to re-train the model multiple times, leaving each training point out of the training.
Libraries
There are several implementations available in Python with PyTorch and TensorFlow. A few others are built on R and Matlab.
Influence Functions
The official version of (Koh & Liang, 2017) built on TensorFlow.
Influence Functions for PyTorch
PyTorch implementation. It uses stochastic estimation to calculate the
influence.
Torch Influence
A recent implementation (Jul/2022) of influence functions on PyTorch, providing
three different ways to calculate the inverse hessian: direct computation and
inversion with torch.autograd, truncated conjugate gradients and LiSSA.
Fast Influence Functions
A modified influence function computation using k-Nearest Neighbors (kNN),
implemented in PyTorch.
Other implementations
Influence Function with LiSSA
A simple implementation with LiSSA on TensorFlow.
Influence Pytorch One-file code with the implementation for a random classification problem.
IF notebook
Python notebook with IF applied to other algorithms (Trees, ).
Influence Functions Pytorch
Another implementation of influence functions.
Applications
- Explainability: This is the most common use we explored so far, measuring the impact of a training point to explain the impact in a given testing point.
- Adversarial Attacks: Real-world data is noisy, and it can be problematic for machine learning. Adversarial machine learning methods are methods used to feed a model with deceptive input, changing the predictions of a classifier. Influence functions can help by identifying how to modify a training point to increase the loss in a target point.
- Label mismatch: Toy datasets are pretty good for experimentation, but real data might contain many mislabeled examples. The idea is to calculate the influence of a particular training point \(I(z_{i}, z_{i})\) if that point was removed. Email spam is a good example since it usually uses the user’s input in classifying whether an email is spam or not.
Conclusion
The very interesting work from (Koh & Liang, 2017) brought influence functions to the context of machine learning. In principle, this technique was introduced more than 40 years ago by (Hampel, 1974). One of the main contributions is how to apply to non-differentiable loss functions (i.e. hinge loss). In addition to that, the paper uses other existing ideas to overcome the computation issue, such as conjugate gradients and LiSSA algorithm. Subsequent work studied influence functions on groups (Koh et al., 2019), (Basu et al., 2020). The last used second-order influence functions to capture hidden information when the group size is relatively large. I believe this is a powerful technique that will continue to derive new ideas in many different areas. One example is in pruning, where a single-shot pruning technique was based on sensitivity connections (Lee et al., 2019), exploring the idea of perturbing weights in a network. Another idea is in the area of graphs, a popular framework JK Networks (Xu et al., 2018) uses perturbation analysis to measure what is the impact of a change in one node embedding in another node embedding.
References
- Koh, P. W., & Liang, P. (2017). Understanding Black-box Predictions via Influence Functions. In D. Precup & Y. W. Teh (Eds.), Proceedings of the 34th International Conference on Machine Learning (Vol. 70, pp. 1885–1894). PMLR.
- Hampel, F. R. (1974). The Influence Curve and Its Role in Robust Estimation. Journal of the American Statistical Association, 69(346), 383–393.
- Maronna, R. A., Martin, D. R., & Yohai, V. J. (2006). Robust Statistics: Theory and Methods. Wiley.
- Cook, R. D., & Weisberg, S. (1982). Residuals and Influence in Regression . New York: Chapman and Hall.
- Koh, P. W. W., Ang, K.-S., Teo, H., & Liang, P. S. (2019). On the Accuracy of Influence Functions for Measuring Group Effects. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d’ Alché-Buc, E. Fox, & R. Garnett (Eds.), Advances in Neural Information Processing Systems (Vol. 32). Curran Associates, Inc.
- Basu, S., You, X., & Feizi, S. (2020). On Second-Order Group Influence Functions for Black-Box Predictions. In H. D. III & A. Singh (Eds.), Proceedings of the 37th International Conference on Machine Learning (Vol. 119, pp. 715–724). PMLR.
- Shewchuk, J. R. (1994). An Introduction to the Conjugate Gradient Method Without the Agonizing Pain.
- Martens, J. (2010). Deep Learning via Hessian-Free Optimization. Proceedings of the 27th International Conference on International Conference on Machine Learning, 735–742.
- Agarwal, N., Bullins, B., & Hazan, E. (2017). Second-Order Stochastic Optimization for Machine Learning in Linear Time. Journal of Machine Learning Research, 18(116), 1–40.
- Guo, H., Rajani, N., Hase, P., Bansal, M., & Xiong, C. (2021). FastIF: Scalable Influence Functions for Efficient Model Interpretation and Debugging. Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, 10333–10350.
- Basu, S., Pope, P., & Feizi, S. (2021). Influence Functions in Deep Learning Are Fragile. International Conference on Learning Representations.
- Lee, N., Ajanthan, T., & Torr, P. (2019). SNIP: SINGLE-SHOT NETWORK PRUNING BASED ON CONNECTION SENSITIVITY. International Conference on Learning Representations.
- Xu, K., Li, C., Tian, Y., Sonobe, T., Kawarabayashi, K.-ichi, & Jegelka, S. (2018). Representation Learning on Graphs with Jumping Knowledge Networks. ICML, 5449–5458.