Skip to content

Commit

Permalink
face xray and fwa update
Browse files Browse the repository at this point in the history
  • Loading branch information
YZY-stack committed Dec 1, 2023
1 parent 9cc2866 commit 71e29cf
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 90 deletions.
10 changes: 6 additions & 4 deletions training/dataset/ff_blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,15 @@ def randomdownscale(self,img):
class FFBlendDataset(data.Dataset):
def __init__(self, config=None):
# Check if the dictionary has already been created
if os.path.exists('nearest_face_info_new.pkl'):
with open('nearest_face_info_new.pkl', 'rb') as f:
if os.path.exists('training/lib/nearest_face_info.pkl'):
with open('training/lib/nearest_face_info.pkl', 'rb') as f:
face_info = pickle.load(f)
else:
raise ValueError(f"Need to run the dataset/generate_xray_nearest.py before training the face xray.")
self.face_info = face_info
# Check if the dictionary has already been created
if os.path.exists('landmark_dict_ffall.pkl'):
with open('landmark_dict_ffall.pkl', 'rb') as f:
if os.path.exists('training/lib/landmark_dict_ffall.pkl'):
with open('training/lib/landmark_dict_ffall.pkl', 'rb') as f:
landmark_dict = pickle.load(f)
self.landmark_dict = landmark_dict
self.imid_list = self.get_training_imglist()
Expand Down
150 changes: 66 additions & 84 deletions training/dataset/fwa_blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from skimage.draw import polygon
from scipy import linalg
import heapq as hq
import albumentations as A

import torch
from torch.autograd import Variable
Expand All @@ -40,12 +41,12 @@
from scipy.ndimage.filters import gaussian_filter
from skimage.transform import AffineTransform, warp

from dataset.ff_blend import FFBlendDataset
from dataset.abstract_dataset import DeepfakeAbstractBaseDataset


# Define face detector and predictor models
face_detector = dlib.get_frontal_face_detector()
predictor_path = '../preprocessing/dlib_tools/shape_predictor_81_face_landmarks.dat'
predictor_path = './preprocessing/dlib_tools/shape_predictor_81_face_landmarks.dat'
face_predictor = dlib.shape_predictor(predictor_path)


Expand All @@ -70,6 +71,22 @@
landmarks_2D = np.stack([mean_face_x, mean_face_y], axis=1)


class RandomDownScale(A.core.transforms_interface.ImageOnlyTransform):
def apply(self,img,**params):
return self.randomdownscale(img)

def randomdownscale(self,img):
keep_ratio=True
keep_input_shape=True
H,W,C=img.shape
ratio_list=[2,4]
r=ratio_list[np.random.randint(len(ratio_list))]
img_ds=cv2.resize(img,(int(W/r),int(H/r)),interpolation=cv2.INTER_NEAREST)
if keep_input_shape:
img_ds=cv2.resize(img_ds,(W,H),interpolation=cv2.INTER_LINEAR)
return img_ds


def umeyama( src, dst, estimate_scale ):
"""Estimate N-D similarity transformation with or without scaling.
Parameters
Expand Down Expand Up @@ -302,49 +319,51 @@ def align(im, face_detector, lmark_predictor, scale=0):
return face_list


class FWABlendDataset(FFBlendDataset):
class FWABlendDataset(DeepfakeAbstractBaseDataset):
def __init__(self, config=None):
# Check if the dictionary has already been created
if os.path.exists('nearest_face_info_new.pkl'):
with open('nearest_face_info_new.pkl', 'rb') as f:
face_info = pickle.load(f)
self.face_info = face_info
# Check if the dictionary has already been created
if os.path.exists('landmark_dict_ffall.pkl'):
with open('landmark_dict_ffall.pkl', 'rb') as f:
landmark_dict = pickle.load(f)
self.landmark_dict = landmark_dict
self.imid_list = self.get_training_imglist()
super().__init__(config, mode='train')
self.transforms = T.Compose([
# T.RandomHorizontalFlip(),
# T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
# T.ColorJitter(brightness=0.1, contrast=0.1,
# saturation=0.1, hue=0.1),
T.ToTensor(),
T.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
T.Normalize(mean=config['mean'],
std=config['std'])
])
self.data_dict = {
'img_list': self.imid_list
}
self.resolution = config['resolution']


def get_training_imglist(self):
"""
Get the list of training images.
"""
random.seed(1024) # Fix the random seed for reproducibility
imid_list = list(self.landmark_dict.keys())
# imid_list = [imid.replace('landmarks', 'frames').replace('npy', 'png') for imid in imid_list]
random.shuffle(imid_list)
return imid_list

def blended_aug(self, im):
transform = A.Compose([
A.RGBShift((-20,20),(-20,20),(-20,20),p=0.3),
A.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=0.3),
A.RandomBrightnessContrast(brightness_limit=(-0.3,0.3), contrast_limit=(-0.3,0.3), p=0.3),
A.ImageCompression(quality_lower=40, quality_upper=100,p=0.5)
])
# Apply transformations
im_aug = transform(image=im)
return im_aug['image']


def preprocess_images(self, imid_fg, imid_bg):
def data_aug(self, im):
"""
Load foreground and background images and face shapes, and apply transformations.
Apply data augmentation on the input image using albumentations.
"""
im = cv2.imread(imid_fg.replace('landmarks', 'frames').replace('npy', 'png'))
transform = A.Compose([
A.Compose([
A.RGBShift((-20,20),(-20,20),(-20,20),p=0.3),
A.HueSaturationValue(hue_shift_limit=(-0.3,0.3), sat_shift_limit=(-0.3,0.3), val_shift_limit=(-0.3,0.3), p=1),
A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1,0.1), p=1),
],p=1),
A.OneOf([
RandomDownScale(p=1),
A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=1),
],p=1),
], p=1.)
# Apply transformations
im_aug = transform(image=im)
return im_aug['image']


def blend_images(self, img_path):
im = cv2.imread(img_path)

# Get the alignment of the head
face_cache = align(im, face_detector, face_predictor)
Expand Down Expand Up @@ -377,35 +396,35 @@ def preprocess_images(self, imid_fg, imid_bg):
im = np.array(self.blended_aug(im))

# Warp the face back to the original image
im, masked_face = face_warp(im, masked_face, face_cache[0][0], 256, [0, 0])
shape = get_2d_aligned_landmarks(face_cache[0], 256, [0, 0])
im, masked_face = face_warp(im, masked_face, face_cache[0][0], self.resolution, [0, 0])
shape = get_2d_aligned_landmarks(face_cache[0], self.resolution, [0, 0])
return im, masked_face


def process_images(self, imid_fg, imid_bg, index):
def process_images(self, img_path, index):
"""
Process an image following the data generation pipeline.
"""
im, mask = self.preprocess_images(imid_fg, imid_bg)
blended_im, mask = self.blend_images(img_path)

# Prepare images and titles for the combined image
imid_fg = cv2.imread(imid_fg.replace('landmarks', 'frames').replace('npy', 'png'))
imid_fg = cv2.imread(img_path)
imid_fg = np.array(self.data_aug(imid_fg))

if im is None or mask is None:
if blended_im is None or mask is None:
return imid_fg, None

# images = [
# imid_fg,
# np.where(mask.astype(np.uint8)>0, 255, 0),
# im,
# blended_im,
# ]
# titles = ["Image", "Mask", "Blended Image"]

# # Save the combined image
# os.makedirs('fwa_examples_2', exist_ok=True)
# self.save_combined_image(images, titles, index, f'fwa_examples_2/combined_image_{index}.png')
return imid_fg, im
return imid_fg, blended_im


def post_proc(self, img):
Expand Down Expand Up @@ -468,10 +487,10 @@ def __getitem__(self, index):
"""
Get an item from the dataset by index.
"""
one_lmk_path = self.imid_list[index]
label = 1 if one_lmk_path.split('/')[6]=='manipulated_sequences' else 0
one_img_path = self.data_dict['image'][index]
label = 1 if one_img_path.split('/')[6]=='manipulated_sequences' else 0
blend_label = 1
imid, manipulate_img = self.process_images(one_lmk_path, one_lmk_path, index)
imid, manipulate_img = self.process_images(one_img_path, index)

if manipulate_img is None:
manipulate_img = deepcopy(imid)
Expand Down Expand Up @@ -521,40 +540,3 @@ def collate_fn(batch):
}

return data_dict


def __len__(self):
"""
Get the length of the dataset.
"""
return len(self.imid_list)


if __name__ == "__main__":
dataset = FFBlendDataset()
print('dataset lenth: ', len(dataset))

def tensor2bgr(im):
img = im.squeeze().cpu().numpy().transpose(1, 2, 0)
img = (img + 1)/2 * 255
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
return img

def tensor2gray(im):
img = im.squeeze().cpu().numpy()
img = img * 255
return img

for i, data_dict in enumerate(dataset):
if i > 20:
break
if label == 1:
if not use_mouth:
img, boudary = im
cv2.imwrite('{}_whole.png'.format(i), tensor2bgr(img))
cv2.imwrite('{}_boudnary.png'.format(i), tensor2gray(boudary))
else:
img, mouth, boudary = im
cv2.imwrite('{}_whole.png'.format(i), tensor2bgr(img))
cv2.imwrite('{}_mouth.png'.format(i), tensor2bgr(mouth))
cv2.imwrite('{}_boudnary.png'.format(i), tensor2gray(boudary))
4 changes: 2 additions & 2 deletions training/detectors/facexray_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ def __init__(self, config):
self.correct, self.total = 0, 0

def build_backbone(self, config):
cfg_path = '/home/zhiyuanyan/disfin/deepfake_benchmark/training/config/backbone/cls_hrnet_w48.yaml'
cfg_path = './training/config/backbone/cls_hrnet_w48.yaml'
# parse options and load config
with open(cfg_path, 'r') as f:
cfg_config = yaml.safe_load(f)
convnet = get_cls_net(cfg_config)
saved = torch.load('./pretrained/hrnetv2_w48_imagenet_pretrained.pth', map_location='cpu')
saved = torch.load('./training/pretrained/hrnetv2_w48_imagenet_pretrained.pth', map_location='cpu')
convnet.load_state_dict(saved, False)
print('Load HRnet')
return convnet
Expand Down

0 comments on commit 71e29cf

Please sign in to comment.