Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Replicate demo and API #21

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Scaling Up to Excellence: Practicing Model Scaling for Photo-Realistic Image Restoration In the Wild

> [[Paper](https://arxiv.org/abs/2401.13627)] &emsp; [[Project Page](http://supir.xpixel.group/)] &emsp; [Online Demo (Coming soon)] <br>
> [[Paper](https://arxiv.org/abs/2401.13627)] &emsp; [[Project Page](http://supir.xpixel.group/)] &emsp; [[Replicate Demo](https://replicate.com/cjwbw/supir)] <br>
> Fanghua, Yu, [Jinjin Gu](https://www.jasongt.com/), Zheyuan Li, Jinfan Hu, Xiangtao Kong, [Xintao Wang](https://xinntao.github.io/), [Jingwen He](https://scholar.google.com.hk/citations?user=GUxrycUAAAAJ), [Yu Qiao](https://scholar.google.com.hk/citations?user=gFtI-8QAAAAJ), [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ) <br>
> Shenzhen Institute of Advanced Technology; Shanghai AI Laboratory; University of Sydney; The Hong Kong Polytechnic University; ARC Lab, Tencent PCG; The Chinese University of Hong Kong <br>

Expand Down
41 changes: 41 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Configuration for Cog ⚙️
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md

build:
gpu: true
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_version: "3.11"
python_packages:
- sentencepiece==0.1.98
- tokenizers==0.13.3
- torch>=2.1.0
- torchvision>=0.16.0
- uvicorn==0.21.1
- transformers==4.28.1
- accelerate==0.18.0
- scikit-learn==1.2.2
- sentencepiece==0.1.98
- einops==0.7.0
- einops-exts==0.0.4
- timm==0.9.8
- openai-clip==1.0.1
- kornia==0.6.9
- matplotlib==3.7.1
- ninja==1.11.1
- omegaconf==2.3.0
- open-clip-torch==2.17.1
- opencv-python==4.7.0.72
- pandas==2.0.1
- Pillow==9.4.0
- pytorch-lightning==2.1.2
- PyYAML==6.0
- scipy==1.12.0
- tqdm==4.65.0
- triton==2.1.0
- webdataset==0.2.48
- xformers>=0.0.20
run:
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.6.0/pget_linux_x86_64" && chmod +x /usr/local/bin/pget
predict: "predict.py:Predictor"
213 changes: 213 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Prediction interface for Cog ⚙️
# https://github.com/replicate/cog/blob/main/docs/python.md

import os
import subprocess
import time
from omegaconf import OmegaConf
from PIL import Image
from cog import BasePredictor, Input, Path

from SUPIR.util import (
create_SUPIR_model,
PIL2Tensor,
Tensor2PIL,
convert_dtype,
)
from llava.llava_agent import LLavaAgent
import CKPT_PTH

SUPIR_v0Q_URL = "https://weights.replicate.delivery/default/SUPIR-v0Q.ckpt"
SUPIR_v0F_URL = "https://weights.replicate.delivery/default/SUPIR-v0F.ckpt"
LLAVA_URL = "https://weights.replicate.delivery/default/llava-v1.5-13b.tar"
LLAVA_CLIP_URL = (
"https://weights.replicate.delivery/default/clip-vit-large-patch14-336.tar"
)
SDXL_URL = "https://weights.replicate.delivery/default/stable-diffusion-xl-base-1.0/sd_xl_base_1.0_0.9vae.safetensors"
SDXL_CLIP1_URL = "https://weights.replicate.delivery/default/clip-vit-large-patch14.tar"
SDXL_CLIP2_URL = (
"https://weights.replicate.delivery/default/CLIP-ViT-bigG-14-laion2B-39B-b160k.tar"
)

MODEL_CACHE = "/opt/data/private/AIGC_pretrain/" # Follow the default in CKPT_PTH.py
LLAVA_CLIP_PATH = CKPT_PTH.LLAVA_CLIP_PATH
LLAVA_MODEL_PATH = CKPT_PTH.LLAVA_MODEL_PATH
SDXL_CLIP1_PATH = CKPT_PTH.SDXL_CLIP1_PATH
SDXL_CLIP2_CACHE = f"{MODEL_CACHE}/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k"
SDXL_CKPT = f"{MODEL_CACHE}/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors"
SUPIR_CKPT_F = f"{MODEL_CACHE}/SUPIR_cache/SUPIR-v0F.ckpt"
SUPIR_CKPT_Q = f"{MODEL_CACHE}/SUPIR_cache/SUPIR-v0Q.ckpt"


def download_weights(url, dest, extract=True):
start = time.time()
print("downloading url: ", url)
print("downloading to: ", dest)
args = ["pget"]
if extract:
args.append("-x")
subprocess.check_call(args + [url, dest], close_fds=False)
print("downloading took: ", time.time() - start)


class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
for model_dir in [
MODEL_CACHE,
f"{MODEL_CACHE}/SUPIR_cache",
f"{MODEL_CACHE}/SDXL_cache",
]:
if not os.path.exists(model_dir):
os.makedirs(model_dir)
if not os.path.exists(SUPIR_CKPT_Q):
download_weights(SUPIR_v0Q_URL, SUPIR_CKPT_Q, extract=False)
if not os.path.exists(SUPIR_CKPT_F):
download_weights(SUPIR_v0F_URL, SUPIR_CKPT_F, extract=False)
if not os.path.exists(LLAVA_MODEL_PATH):
download_weights(LLAVA_URL, LLAVA_MODEL_PATH)
if not os.path.exists(LLAVA_CLIP_PATH):
download_weights(LLAVA_CLIP_URL, LLAVA_CLIP_PATH)
if not os.path.exists(SDXL_CLIP1_PATH):
download_weights(SDXL_CLIP1_URL, SDXL_CLIP1_PATH)
if not os.path.exists(SDXL_CKPT):
download_weights(SDXL_URL, SDXL_CKPT, extract=False)
if not os.path.exists(SDXL_CKPT):
download_weights(SDXL_CLIP2_URL, SDXL_CKPT)

self.supir_device = "cuda:0"
self.llava_device = "cuda:0"
ae_dtype = "bf16" # Inference data type of AutoEncoder
diff_dtype = "bf16" # Inference data type of Diffusion

self.models = {
k: create_SUPIR_model("options/SUPIR_v0.yaml", SUPIR_sign=k).to(
self.supir_device
)
for k in ["Q", "F"]
}

for k in ["Q", "F"]:
self.models[k].ae_dtype = convert_dtype(ae_dtype)
self.models[k].model.dtype = convert_dtype(diff_dtype)

# load LLaVA
self.llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=self.llava_device)

def predict(
self,
model_name: str = Input(
description="Choose a model. SUPIR-v0Q is the default training settings with paper. SUPIR-v0F is high generalization and high image quality in most cases. Training with light degradation settings. Stage1 encoder of SUPIR-v0F remains more details when facing light degradations.",
choices=["SUPIR-v0Q", "SUPIR-v0F"],
default="SUPIR-v0Q",
),
image: Path = Input(description="Low quality input image."),
upscale: int = Input(
description="Upsampling ratio of given inputs.", default=1
),
min_size: float = Input(
description="Minimum resolution of output images.", default=1024
),
edm_steps: int = Input(
description="Number of steps for EDM Sampling Schedule.",
ge=1,
le=500,
default=50,
),
use_llava: bool = Input(
description="Use LLaVA model to get captions.", default=True
),
a_prompt: str = Input(
description="Additive positive prompt for the inputs.",
default="Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, skin pore detailing, hyper sharpness, perfect without deformations.",
),
n_prompt: str = Input(
description="Negative prompt for the inputs.",
default="painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature, jpeg artifacts, deformed, lowres, over-smooth",
),
color_fix_type: str = Input(
description="Color Fixing Type..",
choices=["None", "AdaIn", "Wavelet"],
default="Wavelet",
),
s_stage1: int = Input(
description="Control Strength of Stage1 (negative means invalid).",
default=-1,
),
s_churn: float = Input(
description="Original churn hy-param of EDM.", default=5
),
s_noise: float = Input(
description="Original noise hy-param of EDM.", default=1.003
),
s_cfg: float = Input(
description=" Classifier-free guidance scale for prompts.",
ge=1,
le=20,
default=7.5,
),
s_stage2: float = Input(description="Control Strength of Stage2.", default=1.0),
linear_CFG: bool = Input(
description="Linearly (with sigma) increase CFG from 'spt_linear_CFG' to s_cfg.",
default=False,
),
linear_s_stage2: bool = Input(
description="Linearly (with sigma) increase s_stage2 from 'spt_linear_s_stage2' to s_stage2.",
default=False,
),
spt_linear_CFG: float = Input(
description="Start point of linearly increasing CFG.", default=1.0
),
spt_linear_s_stage2: float = Input(
description="Start point of linearly increasing s_stage2.", default=0.0
),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> Path:
"""Run a single prediction on the model"""

if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")

model = self.models["Q"] if model_name == "SUPIR-v0Q" else self.models["F"]

lq_img = Image.open(str(image))
lq_img, h0, w0 = PIL2Tensor(lq_img, upsacle=upscale, min_size=min_size)
lq_img = lq_img.unsqueeze(0).to(self.supir_device)[:, :3, :, :]

# step 1: Pre-denoise for LLaVA)
clean_imgs = model.batchify_denoise(lq_img)
clean_PIL_img = Tensor2PIL(clean_imgs[0], h0, w0)

# step 2: LLaVA
captions = [""]
if use_llava:
captions = self.llava_agent.gen_image_caption([clean_PIL_img])
print(f"Captions from LLaVA: {captions}")

# step 3: Diffusion Process
samples = model.batchify_sample(
lq_img,
captions,
num_steps=edm_steps,
restoration_scale=s_stage1,
s_churn=s_churn,
s_noise=s_noise,
cfg_scale=s_cfg,
control_scale=s_stage2,
seed=seed,
num_samples=1,
p_p=a_prompt,
n_p=n_prompt,
color_fix_type=color_fix_type,
use_linear_CFG=linear_CFG,
use_linear_control_scale=linear_s_stage2,
cfg_scale_start=spt_linear_CFG,
control_scale_start=spt_linear_s_stage2,
)

out_path = "/tmp/out.png"
Tensor2PIL(samples[0], h0, w0).save(out_path)
return Path(out_path)