-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbnn_adversary.py
77 lines (57 loc) · 2.19 KB
/
bnn_adversary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from typing import List, Tuple
from pathlib import Path
import pandas
import torch
import pyro
from networks import BNNWrapper
from helper.data_loader import get_test_loader
from helper.adversary import fgsm_attack
from helper.config import Configuration
def run_attack(
bnn: BNNWrapper,
loss_fn: pyro.infer.Trace_ELBO,
x: torch.Tensor,
y: torch.Tensor,
epsilons: List[float],
batch_id: int,
) -> Tuple[pandas.DataFrame, List[torch.Tensor], List[torch.Tensor]]:
x = x.to(bnn.device)
y = y.to(bnn.device)
x.requires_grad = True
loss = loss_fn(bnn.model, bnn.guide, x_data=x.view(-1, 28 * 28), y_data=y)
loss.backward()
data_grad = x.grad.data
tmp_dict = {"id": [], "epsilon": [], "y": [], "y_": [], "std": []}
pertubed_images = []
pertubation = []
for epsilon in epsilons:
pertubed_image, pert = fgsm_attack(x, epsilon, data_grad)
pertubed_images.append(pertubed_image)
pertubation.append(pert)
mean, std = bnn.predict(pertubed_image.view(-1, 28 * 28))
y_ = mean.max(1).indices.item()
std_ = std[0][y_].item()
tmp_dict["id"].append(batch_id)
tmp_dict["epsilon"].append(epsilon)
tmp_dict["y"].append(y.item())
tmp_dict["y_"].append(y_)
tmp_dict["std"].append(std_)
return pandas.DataFrame.from_dict(tmp_dict), pertubed_images, pertubation
if __name__ == "__main__":
config = Configuration()
bnn = BNNWrapper()
bnn.load_model()
loss_fn = pyro.infer.Trace_ELBO(
num_particles=config.bnn_adversary_samples
).differentiable_loss
test_loader = get_test_loader(batch_size=1, shuffle=False)
result = [] # type: List[pandas.DataFrame]
for batch_id, (x, y) in enumerate(test_loader):
result.append(run_attack(bnn, loss_fn, x, y, config.epsilons, batch_id))
if batch_id % 100 == 0:
print(f"Step {batch_id}/{len(test_loader.dataset)}")
result_df = pandas.concat(result) # type: pandas.DataFrame
result_df.reset_index(inplace=True, drop=True)
result_path = Path("data/")
result_path.mkdir(exist_ok=True, parents=False)
result_df.to_csv(result_path.joinpath(f"{config.id:02}_bnn_result.csv"))