| | |
| | import torch |
| | import torch.nn as nn |
| | import torchvision.transforms as transforms |
| | from torch.autograd import Variable |
| | from util.feature_extraction_utils import warp_image, normalize_batch |
| | from util.prepare_utils import get_ensemble, extract_features |
| | from lpips_pytorch import LPIPS |
| | from tqdm import trange |
| |
|
| | tensor_transform = transforms.ToTensor() |
| | pil_transform = transforms.ToPILImage() |
| |
|
| |
|
| | class Attack(nn.Module): |
| | def __init__( |
| | self, |
| | models, |
| | dim, |
| | attack_type, |
| | eps, |
| | c_sim=0.5, |
| | net_type="alex", |
| | lr=0.05, |
| | n_iters=100, |
| | noise_size=0.001, |
| | n_starts=10, |
| | c_tv=None, |
| | sigma_gf=None, |
| | kernel_size_gf=None, |
| | combination=False, |
| | warp=False, |
| | theta_warp=None, |
| | V_reduction=None, |
| | ): |
| | super(Attack, self).__init__() |
| | self.extractor_ens = get_ensemble( |
| | models, sigma_gf, kernel_size_gf, combination, V_reduction, warp, theta_warp |
| | ) |
| | |
| | self.dim = dim |
| | self.eps = eps |
| | self.c_sim = c_sim |
| | self.net_type = net_type |
| | self.lr = lr |
| | self.n_iters = n_iters |
| | self.noise_size = noise_size |
| | self.n_starts = n_starts |
| | self.c_tv = None |
| | self.attack_type = attack_type |
| | self.warp = warp |
| | self.theta_warp = theta_warp |
| | if self.attack_type == "lpips": |
| | self.lpips_loss = LPIPS(self.net_type) |
| |
|
| | def execute(self, images, dir_vec, direction): |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print("Device in Excute:", device) |
| | self.lpips_loss.to(device) |
| | images = Variable(images).to(device) |
| | dir_vec = dir_vec.to(device) |
| | |
| | dir_vec_norm = dir_vec.norm(dim=2).unsqueeze(2).to(device) |
| | dist = torch.zeros(images.shape[0]).to(device) |
| | adv_images = images.detach().clone() |
| |
|
| | if self.warp: |
| | self.face_img = warp_image(images, self.theta_warp) |
| |
|
| | for start in range(self.n_starts): |
| | |
| | adv_images_old = adv_images.detach().clone() |
| | dist_old = dist.clone() |
| | |
| | noise_uniform = Variable( |
| | 2 * self.noise_size * torch.rand(images.size()) - self.noise_size |
| | ).to(device) |
| | adv_images = Variable( |
| | images.detach().clone() + noise_uniform, requires_grad=True |
| | ).to(device) |
| |
|
| | for i in trange(self.n_iters): |
| | adv_features = extract_features( |
| | adv_images, self.extractor_ens, self.dim |
| | ).to(device) |
| | |
| | loss = direction * torch.mean( |
| | (adv_features - dir_vec) ** 2 / dir_vec_norm |
| | ) |
| |
|
| | if self.c_tv is not None: |
| | tv_out = self.total_var_reg(images, adv_images) |
| | loss -= self.c_tv * tv_out |
| |
|
| | if self.attack_type == "lpips": |
| | lpips_out = self.lpips_reg(images, adv_images) |
| | loss -= self.c_sim * lpips_out |
| |
|
| | grad = torch.autograd.grad(loss, [adv_images]) |
| | adv_images = adv_images + self.lr * grad[0].sign() |
| | perturbation = adv_images - images |
| |
|
| | if self.attack_type == "sgd": |
| | perturbation = torch.clamp( |
| | perturbation, min=-self.eps, max=self.eps |
| | ) |
| | adv_images = images + perturbation |
| |
|
| | adv_images = torch.clamp(adv_images, min=0, max=1) |
| | adv_features = extract_features( |
| | adv_images, self.extractor_ens, self.dim |
| | ).to(device) |
| | dist = torch.mean((adv_features - dir_vec) ** 2 / dir_vec_norm, dim=[1, 2]) |
| |
|
| | if direction == 1: |
| | adv_images[dist < dist_old] = adv_images_old[dist < dist_old] |
| | dist[dist < dist_old] = dist_old[dist < dist_old] |
| | else: |
| | adv_images[dist > dist_old] = adv_images_old[dist > dist_old] |
| | dist[dist > dist_old] = dist_old[dist > dist_old] |
| | return adv_images.detach().cpu() |
| |
|
| | def lpips_reg(self, images, adv_images): |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | if self.warp: |
| | face_adv = warp_image(adv_images, self.theta_warp) |
| | lpips_out = self.lpips_loss( |
| | normalize_batch(self.face_img).to(device), |
| | normalize_batch(face_adv).to(device), |
| | )[0][0][0][0] / (2 * adv_images.shape[0]) |
| | lpips_out += self.lpips_loss( |
| | normalize_batch(images).to(device), |
| | normalize_batch(adv_images).to(device), |
| | )[0][0][0][0] / (2 * adv_images.shape[0]) |
| |
|
| | else: |
| | lpips_out = ( |
| | self.lpips_loss( |
| | normalize_batch(images).to(device), |
| | normalize_batch(adv_images).to(device), |
| | )[0][0][0][0] |
| | / adv_images.shape[0] |
| | ) |
| |
|
| | return lpips_out |
| |
|
| | def total_var_reg(images, adv_images): |
| | perturbation = adv_images - images |
| | tv = torch.mean( |
| | torch.abs(perturbation[:, :, :, :-1] - perturbation[:, :, :, 1:]) |
| | ) + torch.mean( |
| | torch.abs(perturbation[:, :, :-1, :] - perturbation[:, :, 1:, :]) |
| | ) |
| |
|
| | return tv |
| |
|