Download notes as jupyter notebook

Chapter 2: Linear models

Before we dive into the discussion of adversarial attacks and defenses on deep networks, it is worthwhile considering the situation that arises when the hypothesis class is linear. That is, for the multi-class setting $h_\theta : \mathbb{R}^n \rightarrow \mathbb{R}^k$, we consider a classifier of the form

where $\theta = \{W \in \mathbb{R}^{k \times n}, b \in \mathbb{R}^k\}$. We will also shortly consider a binary classifier of a slightly different form, as many of the ideas are a bit easer to describe in this setting, before returning back to the multi-class case.

Substituting this hypothesis back into our robust optimization framework, and also focusing on the case where the pertrubation set $\Delta$ is a norm ball $\Delta = \{\delta : \|\delta\| \leq \epsilon\}$, where we don’t actually specify the type of norm, so this could be $\ell_\infty$, $\ell_2$, $\ell_1$, etc, we arrive at the mix-max problem

The key point we will emphasize in this section, is that under this formulation, we can solve the inner maximization \emph{exactly} for the case of binary optimization, and provide a relatively tight upper bound for the case of multi-class classification. Futhermore, because the resulting minimization problem is still convex in $\theta$ (we will see shortly that it remains convex even after maximizing over $\delta$, the resulting robust training procedure can also be solved optimally, and thus we can achieve the globally optimal robust classifier (at least for the case of binary classification). This is in stark constrast to the deep network case, where neither the inner maximization problem nor the outer minmimization problem can be solved globally (in the case of the outer minimization, this holds even if we assume exact solutions of the inner problem, due to the non-convexity of the network itself).

However, understanding the linear case provides important insights into the theory and practice of adversarial robustness, and also provides connections to more commonly-studied methods in machine learning such as support vector machines.

Binary classification

Let’s begin first by considering the case of binary classification, i.e., k=2 in the multi-class setting we desribe above. In this case, rather than use multi-class cross entropy loss, we’ll be adopting the more common approach and using the binary cross entropy, or logistic loss. In this setting, we have our hypothesis function

for $\theta = \{w \in \mathbb{R}^n, b \in \mathbb{R}\}$, class label $y \in \{+1,-1\}$, and loss function

where for convience below we define the function $L(z) = \log(1+\exp(-z))$ which we will use below when discussing how to solve the optimization problems involving this loss. The semantics of this setup are that for a data point $x$, the classifier predicts class $+1$ with probability

Aside: Again, for those who may be unfamiliar with how this setting relates to the multiclass case we saw before, note that if we use the traditional mutlticlass cross entropy loss with two classes, of class 1 would be given by

and similarly the probaiblity of predicting class 2

We can thus define a single scalar-valued hypothesis

with the associated probabilities

for $y$ defined as $y \in \{+1,-1\}$ as written above. Taking the negative log of this quantity gives

which is exactly the logistic loss we define above.

Solving the inner maximization problem

Now let’s return to the robust optimization problem, and consider the inner maximization problem, which in this case takes the form

The key point we need to make here is that in this setting, it is actually possible to solve this inner maximization problem exactly. To show this, first note the $L$ as we described it earlier is a scalar function that is monotonically decreasing, and looks like the following:

x = np.linspace(-4,4)
plt.plot(x, np.log(1+np.exp(-x)))

Because the function is monotoic decreasing, if we want to maximize this function applied to a scalar, that is equivalent to just minimizing the scalar quantity. That is

where we get the second line by just distributing out the linear terms.

So we need to consider how to solve the problem

To get the intuition here, let’s just consider the case that $y = +1$, and consider an $\ell_\infty$ norm constraint $\|\delta\|_\infty \leq \epsilon$. Since the $\ell_\infty$ norm says that each element in $\delta$ must have magnitude less than or equal $\epsilon$, we clearly minimize this quantity when we set $\delta_i = -\epsilon$ for $w_i \geq 0$ and $\delta_i = \epsilon$ for $w_i < 0$. For $y = -1$, we would just flip these quantities. That is, the optimal solution to the above optimization problem for the $\ell_\infty$ norm is given by

Furthermore, we can also determine the function valued achieved by this solution,

Thus, we can actually analytically compute the solution of the inner maximization problem, which just has the form

Therefore, instead of solving the robust min-max problem as an actual min-max problem, we have been able to convert it to a pure minimization problem, given by

This problem is still convex in $w,b$, so can be solved exactly, or e.g., SGD will also approach the globally optimal solution. A little more generally, it turns out that in general the optimization problem

where $\|\cdot\|_*$ denotes the the dual norm of our original norm bound on $\theta$ ($\|\cdot\|_p$ and $\|\cdot\|_q$ are dual norms for $1/p + 1/q = 1$). So regardless of our norm constraint, we can actually solve the robust optimization problem via a single minimization problem (and find the analytical solution to the worse-case adversarial attack), without the need to explicitly solve a min-max problem.

Note that the final robust optimization problem (now adopting the general form),

looks an awful lot like the typical norm-regularized objective we commonly consider in machine learning esdfd

with the except that the regularization term is inside the loss function. Intuitively, this means that in the robust optimization case, if a point is far from the decision boundary, we don’t penalize the norm of the parameters, but we do penalize the norm of the parameters (transformed by the loss function) for a point where we close to the decision boundary. The connections between such formulations and e.g. support vector machines, has been studied extensively.

Illustration of binary classification setting

Let’s see what this looks like for an actual linear classifier. In doing so, we can also get a sense of how well traditional linear models might work to also prevent adversarial examples (spoiler: not very well, unless you do regularize). To do so, we’re going to consider the MNIST data set, which will actually serve as a running example for the vast majority of the rest of this tutorial. MNIST is actually a fairly poor choice of problem for many reasons: in addition to being very small for modern ML, it also has the property that it can easily be “binarized”, i.e., because the pixel values are essentially just black and white, we can remove more $\ell_\infty$ noise by just rounding to 0 or 1, and the classifying the resulting iamge. But presuming we don’t use such strategies, it is still a reasonable choice for initial experiments, and small enough that some of the more complex methods we discuss in further sections still can be run in a reasonable amount of time.

Since we’re in the binary classification setting for now, let’s focus on the even easier problem of just classifying between 0s and 1s in the MNIST data (we’ll return back to the multi-class setting for linear models shortly). Let’s first load the data using the PyTorch library and build a simple linear classifier using gradient descent. Note that we’re going to do this a bit more explicitly to replicate the logic above (i.e., using labels of +1/-1, using the direct computation of the $L$ function, etc) instead of reverse-engineering it from the typical PyTorch functions.

Let’s first load the MNIST data reduced to the 0/1 examples.

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

mnist_train = datasets.MNIST("./data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("./data", train=False, download=True, transform=transforms.ToTensor())

train_idx = mnist_train.train_labels <= 1
mnist_train.train_data = mnist_train.train_data[train_idx]
mnist_train.train_labels = mnist_train.train_labels[train_idx]

test_idx = mnist_test.test_labels <= 1
mnist_test.test_data = mnist_test.test_data[test_idx]
mnist_test.test_labels = mnist_test.test_labels[test_idx]

train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)

Now let’s build a simple linear classifier (the nn.Linear module does this, containing the weights in the .weight object and the bias in the .bias object). The nn.Softplus function implement $L$ function above (though without negating the input), and does so in a more numerically stable way than using the exp or log functions directly.

import torch
import torch.nn as nn
import torch.optim as optim

# do a single pass over the data
def epoch(loader, model, opt=None):
    total_loss, total_err = 0.,0.
    for X,y in loader:
        yp = model(X.view(X.shape[0], -1))[:,0]
        loss = nn.BCEWithLogitsLoss()(yp, y.float())
        if opt:
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        total_err += ((yp > 0) * (y==0) + (yp < 0) * (y==1)).sum().item()
        total_loss += loss.item() * X.shape[0]
    return total_err / len(loader.dataset), total_loss / len(loader.dataset)

We’ll train the classifier for 10 epochs, though note that MNIST 0/1 binary classification is a very easy problem, and after one epoch we basically have converged to the final test error (though test loss still decreases). It eventually reaches as error of 0.0004, which in this case actually just making one mistake on the test set.

model = nn.Linear(784, 1)
opt = optim.SGD(model.parameters(), lr=1.)
print("Train Err", "Train Loss", "Test Err", "Test Loss", sep="\t")
for i in range(10):
    train_err, train_loss = epoch(train_loader, model, opt)
    test_err, test_loss = epoch(test_loader, model)
    print(*("{:.6f}".format(i) for i in (train_err, train_loss, test_err, test_loss)), sep="\t")
Train Err	Train Loss	Test Err	Test Loss
0.007501	0.015405	0.000946	0.003278
0.001342	0.005392	0.000946	0.002892
0.001342	0.004438	0.000473	0.002560
0.001105	0.003788	0.000946	0.002495
0.000947	0.003478	0.000946	0.002297
0.000947	0.003251	0.000946	0.002161
0.000711	0.002940	0.000473	0.002159
0.000790	0.002793	0.000946	0.002109
0.000711	0.002650	0.000946	0.002107
0.000790	0.002529	0.000946	0.001997

In case you’re curious, we can actually look at the one test example that the classifier makes a mistake on, which indeed seems a bit odd relative to most 0s and 1s.

X_test = (test_loader.dataset.test_data.float()/255).view(len(test_loader.dataset),-1)
y_test = test_loader.dataset.test_labels
yp = model(X_test)[:,0]
idx = (yp > 0) * (y_test == 0) + (yp < 0) * (y_test == 1)
plt.imshow(1-X_test[idx][0].view(28,28).numpy(), cmap="gray")
plt.title("True Label: {}".format(y_test[idx].item()))

Hopefully you’ve already noticed something else about the adversarial examples we generate in the linear case: because the optimal perturbation is equal to

which doesn’t depend on $x$, this means that the best perturbation to apply is the same across all examples. Note however, that to get the best valid perturbation, here we should really be constraining $x + \delta$ to be in $[0,1]$, which doesn’t hold for this case. For simplicity, we’ll ignore this for now, and go ahead and add this same $\delta$ anyway (even if it gives us a technically invalid image). After all, for the classifier, the inputs are just numerical value, so we can always have values greater than one or less than zero; the performance we’ll see also applies to the case where we clip values, it just adds a bit of uncecessary hassle.

Let’s look at the actual perturvation, to try to get a dense of it.

epsilon = 0.2
delta = epsilon * model.weight.detach().sign().view(28,28)
plt.imshow(1-delta.numpy(), cmap="gray")

It’s perhaps not all that obvious, but if you squint you can see that maybe there is a vertical line (like a 1) in black pixels, and a cirlce (like a 0) in in white. The intuition here is that moving in the black direction, we make the classifier think the image is more like a 1, while moving in the white direction, more like a 0. But the picture here is not perfect, and it you didn’t know to look for this, it may not be obviously. We’ll shortly

Let’s next see what happens when we evaluate the test accuracy when we make this (optimal) adverarial attack on the images in the test set.

def epoch_adv(loader, model, delta):
    total_loss, total_err = 0.,0.
    for X,y in loader:
        yp = model((X-(2*y.float()[:,None,None,None]-1)*delta).view(X.shape[0], -1))[:,0]
        loss = nn.BCEWithLogitsLoss()(yp, y.float())
        total_err += ((yp > 0) * (y==0) + (yp < 0) * (y==1)).sum().item()
        total_loss += loss.item() * X.shape[0]
    return total_err / len(loader.dataset), total_loss / len(loader.dataset)
print(epoch_adv(test_loader, model, delta[None,None,:,:]))
(0.8458628841607565, 3.4517438034075654)

So allowing perturbations within the $\ell_\infty$ ball of size $\epsilon=0.2$, the classifier we go from essentially zero error to 84.5% error. Unlike the ImageNet case, the perturbed images here, are recognizably different (we just overlay the noise you saw above), but this would definitely note be sufficient to fool most humans in recognizing the image.

f,ax = plt.subplots(5,5, sharey=True)
for i in range(25):
    ax[i%5][i//5].imshow(1-(X_test[i].view(28,28) - (2*y_test[i]-1)*delta).numpy(), cmap="gray")
    ax

Training robust linear models

We’ve now seen that a standard linear model suffers from a lot of the same problems as deep models (though it should be said, they are still slightly more resilient than standard training for deep networks, for which an $\ell_\infty$ ball with $\epsilon=0.2$ could easily create 100% error). But we also know that we can easily perform exact robust optimization (i.e., solving the equivalent of the min-max problem) by simply incorporating the $\ell_1$ norm into the objective. Putting this into the standard binary cross entropy loss that PyTorch implements (which uses labels of 0/1 by default, not -1/+1), takes a bit of munging, but the training procedure is still quite simple: we just subtract $\epsilon(2y-1)\|w\|_1$ from the predictions (the $2y-1$ scales the 0/1 entries to -1/+1).

# do a single pass over the data
def epoch_robust(loader, model, epsilon, opt=None):
    total_loss, total_err = 0.,0.
    for X,y in loader:
        yp = model(X.view(X.shape[0], -1))[:,0] - epsilon*(2*y.float()-1)*model.weight.norm(1)
        loss = nn.BCEWithLogitsLoss()(yp, y.float())
        if opt:
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        total_err += ((yp > 0) * (y==0) + (yp < 0) * (y==1)).sum().item()
        total_loss += loss.item() * X.shape[0]
    return total_err / len(loader.dataset), total_loss / len(loader.dataset)
model = nn.Linear(784, 1)
opt = optim.SGD(model.parameters(), lr=1e-1)
epsilon = 0.2
print("Rob. Train Err", "Rob. Train Loss", "Rob. Test Err", "Rob. Test Loss", sep="\t")
for i in range(20):
    train_err, train_loss = epoch_robust(train_loader, model, epsilon, opt)
    test_err, test_loss = epoch_robust(test_loader, model, epsilon)
    print(*("{:.6f}".format(i) for i in (train_err, train_loss, test_err, test_loss)), sep="\t")
Rob. Train Err	Rob. Train Loss	Rob. Test Err	Rob. Test Loss
0.147414	0.376791	0.073759	0.228654
0.073352	0.223381	0.053901	0.176481
0.062929	0.197301	0.043026	0.154818
0.057008	0.183879	0.038298	0.139773
0.052981	0.174964	0.040662	0.143639
0.050059	0.167973	0.037352	0.132365
0.048164	0.162836	0.032624	0.119755
0.046190	0.158340	0.033570	0.123211
0.044769	0.154719	0.029787	0.118066
0.043979	0.152048	0.027423	0.118974
0.041058	0.149381	0.026478	0.110074
0.040268	0.147034	0.027423	0.114998
0.039874	0.145070	0.026950	0.109395
0.038452	0.143232	0.026950	0.109015
0.037663	0.141919	0.027896	0.113093
0.036715	0.140546	0.026478	0.103066
0.036321	0.139162	0.026478	0.107028
0.035610	0.138088	0.025059	0.104717
0.035215	0.137290	0.025059	0.104803
0.034741	0.136175	0.025059	0.106629

We stay it above, but we should emphasize that all the numbers reported above are the robust (i.e., worst case adversarial) errors and losses. So by training with the robust optimization problem, we’re able to train a model such that for $\epsilon=0.2$, no adversarial attack can lead to more then 2.5% error on the test set. Quite an improvement from the ~85% that the standard training had. But how well does it do on the non-adversarial training set?

train_err, train_loss = epoch(train_loader, model)
test_err, test_loss = epoch(test_loader, model)
print("Train Err", "Train Loss", "Test Err", "Test Loss", sep="\t")
print(*("{:.6f}".format(i) for i in (train_err, train_loss, test_err, test_loss)), sep="\t")
Train Err	Train Loss	Test Err	Test Loss
0.006080	0.015129	0.003783	0.008186

We’re getting 0.3% error on the test set. This is good, but not as good as we were doing with standard training; we’re now making 8 mistakes on the test set, instead of the 1 that we were making before. And this is not just a random effect of this particular problem, or the fact that it is relatively easy. Rather, perhaps somewhat surprisingly, there is a fundamental tradeoff between clean accuracy and robust accuracy, and doing better on the robust error leads to higher clean error. We will return to this point in much more detail later.

Finally, let’s look at the image of the optimal perturbation for this robust model.

delta = epsilon * model.weight.detach().sign().view(28,28)
plt.imshow(1-delta.numpy(), cmap="gray")

That looks substantially more like a zero than what we saw before. Thus, we have some (admittedly, at this point, fairly weak) evidence that robsut training may also lead to “adversarial directions” that are inherrently more meaningful. Rather than fooling the classifier by just adding “random noise” we actually need to start moving the image in the direction of an actual new image (and even doing so, at least with this size epsilon, we aren’t very successful at fooling the classifier). This idea will also come up later.

Multi-class classification [coming soon]

Before moving on to the deep learning setting, let’s briefly consider the multi-class extension of what we presented above. After all, most deep classifiers that we care about are actually multi-class classifiers, using the cross entropy loss or something similar. Recalling what we defined before, this means we are considering the linear hypothesis function

which results in an inner maximization problem of the form

Unforutnately in the binary case, it turns out that it is no longer possible to optimally solve the inner maximization problem. Specifcally, if we consider the cross entropy loss plugged into the above expression

Here, unlike the binary case, we cannot push the max over $\delta$ inside the nonlinear function (the log-sum-exp function is convex, so maximizing over it is difficult in general).