Skip to content

Commit

Permalink
Improved visual models' output
Browse files Browse the repository at this point in the history
- All of them = resized input images using PIL
- BlendGAN = resized input using fa (face detection model)
- ArcaneGAN = added newer model
  • Loading branch information
ma7dev committed Dec 13, 2021
1 parent 991866e commit f4c7847
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
# ASC_ROBOT
# OnlySudo
31 changes: 15 additions & 16 deletions src/ai/ArcaneGAN/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from tqdm.notebook import tqdm
from glob import glob
from PIL import Image
from PIL import Image, ImageOps
import requests
from io import BytesIO
import numpy as np
Expand Down Expand Up @@ -118,37 +118,36 @@ def process(output_path,mtcnn,model,img):

project_path = "/home/alotaima/Projects/side/onlysudo/src/ai/ArcaneGAN"
args.outdir = '/src/api/public/ai/arcane'

args.size = 1024

if args.url == '':
print('Need url or streamer name!')
exit()
size = 256
mtcnn = MTCNN(image_size=256, margin=80)
mtcnn = MTCNN(image_size=args.size, margin=80)

means = [0.485, 0.456, 0.406]
stds = [0.229, 0.224, 0.225]

t_stds = torch.tensor(stds).cuda().half()[:,None,None]
t_means = torch.tensor(means).cuda().half()[:,None,None]
t_stds = torch.tensor(stds).cuda()[:,None,None]
t_means = torch.tensor(means).cuda()[:,None,None]

img_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(means,stds)])

version = '0.2' #@param ['0.1','0.2']
version = '0.3' #@param ['0.1','0.2']
model_path = f'{project_path}/ArcaneGANv{version}.jit'

root_path = '/'.join(os.path.abspath(os.getcwd()).split('/')[:-2])
output_path = f"{root_path}{args.outdir}/{args.filename}"

model = torch.jit.load(model_path).eval().cuda().half()
model = torch.jit.load(model_path).eval().cuda()

response = requests.get(args.url, stream = True)
img = Image.open(BytesIO(response.content))
width, height = img.size
max_ = max(width,height)
if max_ > 1080:
ratio = max_ / 1080
img = img.resize((int(width*ratio),int(height*ratio)))

process(output_path,mtcnn,model,img)
img = Image.open(BytesIO(response.content)).convert('RGB')
img = ImageOps.fit(img, (args.size, args.size), centering=(0.5, 0.5))

process(output_path,mtcnn,model,img)

print("Done!")
50 changes: 33 additions & 17 deletions src/ai/BlendGAN/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
import random
from urllib.request import urlopen
import time
from PIL import Image, ImageOps
import requests
from io import BytesIO

from ffhq_dataset.gen_aligned_image import FaceAlign

seed = 0

Expand All @@ -29,6 +34,7 @@
parser.add_argument('--style_selected', type=int, default=0)
parser.add_argument('--filename', type=str, default=f"{time.time()}.jpg")
args = parser.parse_args()

project_path = "/home/alotaima/Projects/side/onlysudo/src/ai/BlendGAN"
args.size = 1024
args.ckpt = f'{project_path}/pretrained_models/blendgan.pt'
Expand All @@ -39,20 +45,19 @@
args.channel_multiplier = 2
args.latent = 512
args.n_mlp = 8

if args.url == '':
print('Need url or a streamer name!')
exit()

outdir = args.outdir
root_path = '/'.join(os.path.abspath(os.getcwd()).split('/')[:-2])
output_path = f"{root_path}{args.outdir}/{args.filename}"
print(f"/ai/style_transfer/{args.filename}")
print(args.url, args.style_selected, args.filename)
print(args.style_img_path)
# print(os.path.join(args.style_img_path, '*.*'))
style_img_paths = sorted(glob.glob(os.path.join(args.style_img_path, '*.*')))[:]
# print(style_img_paths)
# print(args)
# exit()

print('start')
fa = FaceAlign()

checkpoint = torch.load(args.ckpt)
model_dict = checkpoint['g_ema']

Expand All @@ -64,34 +69,45 @@

psp_encoder = PSPEncoder(args.psp_encoder_ckpt, output_size=args.size).to(device)
psp_encoder.eval()

response = requests.get(args.url, stream = True)
img_in = Image.open(BytesIO(response.content)).convert('RGB')
img_in = np.array(img_in)
img_in = img_in[:, :, ::-1].copy()

def url_to_image(url):
resp = urlopen(url)
image = np.asarray(bytearray(resp.read()), dtype="uint8")
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
return image
img_in = fa.get_crop_image(img_in)

# name_in = os.path.splitext(os.path.basename(input_img_path))[0]
# img_in = cv2.imread(input_img_path, 1)
img_in = url_to_image(args.url)
img_in_ten = cv2ten(img_in, device)
img_in = cv2.resize(img_in, (args.size, args.size))

# img_in = cv2.cvtColor(img_in, cv2.COLOR_BGR2RGB)
# img_in = Image.fromarray(img_in)
# img_in = ImageOps.fit(img_in, (args.size, args.size), centering=(0.5, 0.5))
# img_in = np.array(img_in)
# img_in = img_in[:, :, ::-1].copy()

style_img_paths = sorted(glob.glob(os.path.join(args.style_img_path, '*.*')))[:]
style_img_path = style_img_paths[args.style_selected]

name_style = os.path.splitext(os.path.basename(style_img_path))[0]
img_style = cv2.imread(style_img_path, 1)

img_style_ten = cv2ten(img_style, device)
img_style = cv2.resize(img_style, (args.size, args.size))

# img_style = cv2.cvtColor(img_style, cv2.COLOR_BGR2RGB)
# img_style = Image.fromarray(img_style)
# img_style = ImageOps.fit(img_style, (args.size, args.size), centering=(0.5, 0.5))
# img_style = np.array(img_style)
# img_style = img_style[:, :, ::-1].copy()

with torch.no_grad():
sample_style = g_ema.get_z_embed(img_style_ten)
sample_in = psp_encoder(img_in_ten)
img_out_ten, _ = g_ema([sample_in], z_embed=sample_style, add_weight_index=args.add_weight_index,
input_is_latent=True, return_latents=False, randomize_noise=False)
img_out = ten2cv(img_out_ten)
print(style_img_path)
print(output_path)

out = np.concatenate([img_in, img_out], axis=1)
cv2.imwrite(output_path, out)

Expand Down
20 changes: 9 additions & 11 deletions src/ai/anime/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import time
import torch

from PIL import Image
from PIL import Image, ImageOps
from torchvision import transforms
import requests
from io import BytesIO
import numpy as np
Expand All @@ -25,23 +26,20 @@
print('Need url or streamer name!')
exit()

size = 512
args.size = 1024

model = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator").eval()
face2paint = torch.hub.load("bryandlee/animegan2-pytorch:main", "face2paint", size=size)
print('start')

model = torch.hub.load("bryandlee/animegan2-pytorch:main", "generator", device="cuda").eval()
face2paint = torch.hub.load("bryandlee/animegan2-pytorch:main", "face2paint", device="cuda", size=args.size)

response = requests.get(args.url, stream = True)
img = Image.open(BytesIO(response.content)).convert("RGB")
width, height = img.size
max_ = max(width,height)
if max_ > 1080:
ratio = max_ / 1080
img = img.resize((int(width*ratio),int(height*ratio)))
img = ImageOps.fit(img, (args.size, args.size), centering=(0.5, 0.5))

out = face2paint(model, img)

img = img.resize((size,size))

out = np.concatenate([img, out], axis=1)
cv2.imwrite(output_path, cv2.cvtColor(out, cv2.COLOR_BGR2RGB))

print("Done!")
24 changes: 14 additions & 10 deletions src/ai/pixel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import numpy as np
from skimage import io
from PIL import Image
from PIL import Image, ImageOps
import requests
from io import BytesIO
import numpy as np
Expand All @@ -29,22 +29,26 @@
print('Need url or streamer name!')
exit()

scale = 14
args.size = 1024
args.factor = 2

print('start')

response = requests.get(args.url, stream = True)
image = np.asarray(Image.open(BytesIO(response.content)).convert("RGB"))
img = Image.open(BytesIO(response.content)).convert("RGB")
img = ImageOps.fit(img, (args.size, args.size), centering=(0.5, 0.5))
image = np.asarray(img)

out = Pyx(
factor=scale,
factor=args.factor,
palette=args.palette,
upscale = scale,
depth=2,
# dither="none",
# alpha=.6,
# boost=True
upscale = args.factor,
depth=2
).fit_transform(image)

out = cv2.resize(out, (image.shape[1],image.shape[0]), interpolation=cv2.INTER_NEAREST)
out = np.concatenate([image, out], axis=1)

cv2.imwrite(output_path, cv2.cvtColor(out, cv2.COLOR_BGR2RGB))
cv2.imwrite(output_path, cv2.cvtColor(out, cv2.COLOR_BGR2RGB))

print('Done!')
2 changes: 1 addition & 1 deletion src/twitch_bot/commands/pixel.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ module.exports = {
const filename = `${makeid(5)}.jpg`,
streamer = channel.replace("#", "");

let palette = 8;
let palette = 10;

let url = `https://static-cdn.jtvnw.net/previews-ttv/live_user_${streamer}.jpg`;
if (args.length > 0) {
Expand Down

0 comments on commit f4c7847

Please sign in to comment.