Skip to content

Commit

Permalink
upload source codes
Browse files Browse the repository at this point in the history
  • Loading branch information
柏灌 committed May 31, 2021
1 parent 7f8c666 commit 23b8820
Show file tree
Hide file tree
Showing 65 changed files with 6,031 additions and 8 deletions.
34 changes: 26 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,37 @@
# GAN Prior Embedded Network for Blind Face Restoration in the Wild

[Paper](<https://arxiv.org/abs/2105.06070>) | [Demo](<https://vision.aliyun.com/experience/detail?spm=a211p3.14020179.J_7524944390.17.66cd4850wVDkUQ&tagName=facebody&children=EnhanceFace/>)
[Paper](https://arxiv.org/abs/2105.06070) | [Supplementary](https://www4.comp.polyu.edu.hk/~cslzhang/paper/GPEN-cvpr21-supp.pdf) | [Demo](https://vision.aliyun.com/experience/detail?spm=a211p3.14020179.J_7524944390.17.66cd4850wVDkUQ&tagName=facebody&children=EnhanceFace)

[Tao Yang](https://cg.cs.tsinghua.edu.cn/people/~tyang/), Peiran Ren, Xuansong Xie, [Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang/)
_[DAMO Academy, Alibaba Group](https://damo.alibaba.com/), Hangzhou, China_
_[Department of Computing, The Hong Kong Polytechnic University](http://www.comp.polyu.edu.hk/), Hong Kong, China_
[Tao Yang](https://cg.cs.tsinghua.edu.cn/people/~tyang), Peiran Ren, Xuansong Xie, [Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang)
_[DAMO Academy, Alibaba Group](https://damo.alibaba.com), Hangzhou, China_
_[Department of Computing, The Hong Kong Polytechnic University](http://www.comp.polyu.edu.hk), Hong Kong, China_

<img src="figs/real_00.png" width="390px"/> <img src="figs/real_01.png" width="390px"/>
<img src="figs/real_02.png" width="390px"/> <img src="figs/real_03.png" width="390px"/>

<img src="figs/Solvay_conference_1927_comp.jpg" width="784px"/>

## Usage

![python](https://img.shields.io/badge/python-v3.7.4-green.svg?style=plastic)
![pytorch](https://img.shields.io/badge/pytorch-v1.7.1-green.svg?style=plastic)
![cuda](https://img.shields.io/badge/cuda-v10.2.89-green.svg?style=plastic)
![driver](https://img.shields.io/badge/driver-v460.73.01-green.svg?style=plastic)
![gcc](https://img.shields.io/badge/gcc-v7.5.0-green.svg?style=plastic)

<img src="figs/real_00.png" width="390px"/> <img src="figs/real_01.png" width="390px"/>
<img src="figs/real_02.png" width="390px"/> <img src="figs/real_03.png" width="390px"/>
- Clone this repository:
```bash
git clone https://github.com/yangxy/GPEN.git
cd GPEN
```
- Download RetinaFace model and our pre-trained model (not our best model due to commercial issues) and put them into ``weights/``.

<img src="figs/Solvay_conference_1927_comp.jpg" width="784px"/>
[RetinaFace-R50](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/RetinaFace-R50.pth) | [GPEN-512](https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-512.pth)

Coming soon, stay tuned.
- Restore faces:
```bash
python face_enhancement.py
```

## Main idea
<img src="figs/architecture.png" width="784px"/>
Expand All @@ -34,5 +49,8 @@ If our work is useful for your research, please consider citing:
## License
© Alibaba, 2021. For academic and non-commercial use only.

## Acknowledgments
We borrow some codes from [Pytorch_Retinaface](https://github.com/biubug6/Pytorch_Retinaface) and [stylegan2-pytorch](https://github.com/rosinality/stylegan2-pytorch).

## Contact
If you have any questions or suggestions about this paper, feel free to reach me at yangtao9009@gmail.com
18 changes: 18 additions & 0 deletions __init_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
'''
@paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
@author: yangxy (yangtao9009@gmail.com)
'''
import os.path as osp
import sys

def add_path(path):
if path not in sys.path:
sys.path.insert(0, path)

this_dir = osp.dirname(__file__)

path = osp.join(this_dir, 'retinaface')
add_path(path)

path = osp.join(this_dir, 'face_model')
add_path(path)
263 changes: 263 additions & 0 deletions align_faces.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 24 15:43:29 2017
@author: zhaoy
"""
import cv2
import numpy as np
from skimage import transform as trans

# reference facial points, a list of coordinates (x,y)
REFERENCE_FACIAL_POINTS = [
[30.29459953, 51.69630051],
[65.53179932, 51.50139999],
[48.02519989, 71.73660278],
[33.54930115, 92.3655014],
[62.72990036, 92.20410156]
]

DEFAULT_CROP_SIZE = (96, 112)


def _umeyama(src, dst, estimate_scale=True, scale=1.0):
"""Estimate N-D similarity transformation with or without scaling.
Parameters
----------
src : (M, N) array
Source coordinates.
dst : (M, N) array
Destination coordinates.
estimate_scale : bool
Whether to estimate scaling factor.
Returns
-------
T : (N + 1, N + 1)
The homogeneous similarity transformation matrix. The matrix contains
NaN values only if the problem is not well-conditioned.
References
----------
.. [1] "Least-squares estimation of transformation parameters between two
point patterns", Shinji Umeyama, PAMI 1991, :DOI:`10.1109/34.88573`
"""

num = src.shape[0]
dim = src.shape[1]

# Compute mean of src and dst.
src_mean = src.mean(axis=0)
dst_mean = dst.mean(axis=0)

# Subtract mean from src and dst.
src_demean = src - src_mean
dst_demean = dst - dst_mean

# Eq. (38).
A = dst_demean.T @ src_demean / num

# Eq. (39).
d = np.ones((dim,), dtype=np.double)
if np.linalg.det(A) < 0:
d[dim - 1] = -1

T = np.eye(dim + 1, dtype=np.double)

U, S, V = np.linalg.svd(A)

# Eq. (40) and (43).
rank = np.linalg.matrix_rank(A)
if rank == 0:
return np.nan * T
elif rank == dim - 1:
if np.linalg.det(U) * np.linalg.det(V) > 0:
T[:dim, :dim] = U @ V
else:
s = d[dim - 1]
d[dim - 1] = -1
T[:dim, :dim] = U @ np.diag(d) @ V
d[dim - 1] = s
else:
T[:dim, :dim] = U @ np.diag(d) @ V

if estimate_scale:
# Eq. (41) and (42).
scale = 1.0 / src_demean.var(axis=0).sum() * (S @ d)
else:
scale = scale

T[:dim, dim] = dst_mean - scale * (T[:dim, :dim] @ src_mean.T)
T[:dim, :dim] *= scale

return T, scale


class FaceWarpException(Exception):
def __str__(self):
return 'In File {}:{}'.format(
__file__, super.__str__(self))


def get_reference_facial_points(output_size=None,
inner_padding_factor=0.0,
outer_padding=(0, 0),
default_square=False):
tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
tmp_crop_size = np.array(DEFAULT_CROP_SIZE)

# 0) make the inner region a square
if default_square:
size_diff = max(tmp_crop_size) - tmp_crop_size
tmp_5pts += size_diff / 2
tmp_crop_size += size_diff

if (output_size and
output_size[0] == tmp_crop_size[0] and
output_size[1] == tmp_crop_size[1]):
print('output_size == DEFAULT_CROP_SIZE {}: return default reference points'.format(tmp_crop_size))
return tmp_5pts

if (inner_padding_factor == 0 and
outer_padding == (0, 0)):
if output_size is None:
print('No paddings to do: return default reference points')
return tmp_5pts
else:
raise FaceWarpException(
'No paddings to do, output_size must be None or {}'.format(tmp_crop_size))

# check output size
if not (0 <= inner_padding_factor <= 1.0):
raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')

if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0)
and output_size is None):
output_size = tmp_crop_size * \
(1 + inner_padding_factor * 2).astype(np.int32)
output_size += np.array(outer_padding)
print(' deduced from paddings, output_size = ', output_size)

if not (outer_padding[0] < output_size[0]
and outer_padding[1] < output_size[1]):
raise FaceWarpException('Not (outer_padding[0] < output_size[0]'
'and outer_padding[1] < output_size[1])')

# 1) pad the inner region according inner_padding_factor
# print('---> STEP1: pad the inner region according inner_padding_factor')
if inner_padding_factor > 0:
size_diff = tmp_crop_size * inner_padding_factor * 2
tmp_5pts += size_diff / 2
tmp_crop_size += np.round(size_diff).astype(np.int32)

# print(' crop_size = ', tmp_crop_size)
# print(' reference_5pts = ', tmp_5pts)

# 2) resize the padded inner region
# print('---> STEP2: resize the padded inner region')
size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
# print(' crop_size = ', tmp_crop_size)
# print(' size_bf_outer_pad = ', size_bf_outer_pad)

if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
raise FaceWarpException('Must have (output_size - outer_padding)'
'= some_scale * (crop_size * (1.0 + inner_padding_factor)')

scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
# print(' resize scale_factor = ', scale_factor)
tmp_5pts = tmp_5pts * scale_factor
# size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
# tmp_5pts = tmp_5pts + size_diff / 2
tmp_crop_size = size_bf_outer_pad
# print(' crop_size = ', tmp_crop_size)
# print(' reference_5pts = ', tmp_5pts)

# 3) add outer_padding to make output_size
reference_5point = tmp_5pts + np.array(outer_padding)
tmp_crop_size = output_size
# print('---> STEP3: add outer_padding to make output_size')
# print(' crop_size = ', tmp_crop_size)
# print(' reference_5pts = ', tmp_5pts)
#
# print('===> end get_reference_facial_points\n')

return reference_5point


def get_affine_transform_matrix(src_pts, dst_pts):
tfm = np.float32([[1, 0, 0], [0, 1, 0]])
n_pts = src_pts.shape[0]
ones = np.ones((n_pts, 1), src_pts.dtype)
src_pts_ = np.hstack([src_pts, ones])
dst_pts_ = np.hstack([dst_pts, ones])

A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)

if rank == 3:
tfm = np.float32([
[A[0, 0], A[1, 0], A[2, 0]],
[A[0, 1], A[1, 1], A[2, 1]]
])
elif rank == 2:
tfm = np.float32([
[A[0, 0], A[1, 0], 0],
[A[0, 1], A[1, 1], 0]
])

return tfm


def warp_and_crop_face(src_img,
facial_pts,
reference_pts=None,
crop_size=(96, 112),
align_type='smilarity'): #smilarity cv2_affine affine
if reference_pts is None:
if crop_size[0] == 96 and crop_size[1] == 112:
reference_pts = REFERENCE_FACIAL_POINTS
else:
default_square = False
inner_padding_factor = 0
outer_padding = (0, 0)
output_size = crop_size

reference_pts = get_reference_facial_points(output_size,
inner_padding_factor,
outer_padding,
default_square)

ref_pts = np.float32(reference_pts)
ref_pts_shp = ref_pts.shape
if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
raise FaceWarpException(
'reference_pts.shape must be (K,2) or (2,K) and K>2')

if ref_pts_shp[0] == 2:
ref_pts = ref_pts.T

src_pts = np.float32(facial_pts)
src_pts_shp = src_pts.shape
if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
raise FaceWarpException(
'facial_pts.shape must be (K,2) or (2,K) and K>2')

if src_pts_shp[0] == 2:
src_pts = src_pts.T

if src_pts.shape != ref_pts.shape:
raise FaceWarpException(
'facial_pts and reference_pts must have the same shape')

if align_type is 'cv2_affine':
tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
tfm_inv = cv2.getAffineTransform(ref_pts[0:3], src_pts[0:3])
elif align_type is 'affine':
tfm = get_affine_transform_matrix(src_pts, ref_pts)
tfm_inv = get_affine_transform_matrix(ref_pts, src_pts)
else:
params, scale = _umeyama(src_pts, ref_pts)
tfm = params[:2, :]

params, _ = _umeyama(ref_pts, src_pts, False, scale=1.0/scale)
tfm_inv = params[:2, :]

face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]), flags=3)

return face_img, tfm_inv
Binary file added examples/imgs/Solvay_conference_1927.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_COMP.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_GPEN.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face00.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face01.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face02.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face03.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face04.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face05.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face06.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face07.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face08.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face09.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face10.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face11.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face12.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face13.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face14.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face15.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face16.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face17.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face18.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face19.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face20.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face21.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/outs/Solvay_conference_1927_face22.jpg
Binary file added examples/outs/Solvay_conference_1927_face24.jpg
Binary file added examples/outs/Solvay_conference_1927_face25.jpg
Binary file added examples/outs/Solvay_conference_1927_face26.jpg
Binary file added examples/outs/Solvay_conference_1927_face27.jpg
Binary file added examples/outs/Solvay_conference_1927_face28.jpg
Loading

0 comments on commit 23b8820

Please sign in to comment.