-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimage_blurring.py
121 lines (103 loc) · 5.04 KB
/
image_blurring.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import torch
from transformers import pipeline
from ultralytics import SAM
from PIL import Image, ImageFilter
from torchvision import transforms
class BlurImage(object):
def __init__(self,
device="mps",
owlvit_ckpt="google/owlv2-base-patch16-ensemble",
owlvit_task="zero-shot-object-detection",
sam_ckpt="mobile_sam.pt",
):
self.module_dir = os.path.dirname(__file__)
self.device = self.initialize_device(device)
self.owlvit = pipeline(model=owlvit_ckpt, task=owlvit_task, device=self.device)
self.sam = SAM(model=sam_ckpt)
self.create_dirs(self.module_dir)
def blur(self, image, text_prompts, blur_intensity=50, labels=None, save=True, size=None):
"""Returns blurred image based on given text prompt"""
if type(image) == str:
image_name = image
image = self.read_image(self.module_dir, "images-to-blur", image_name, size)
else:
image_name = "noname"
if size: image = self.resize_image(image, size)
aggregated_mask = self.get_aggregated_mask(image, text_prompts, labels)
blurred_image = self.blur_entire_image(image, radius=blur_intensity)
blurred_image[:, ~aggregated_mask] = transforms.functional.pil_to_tensor(image)[:, ~aggregated_mask]
blurred_image = transforms.functional.to_pil_image(blurred_image)
if save:
blurred_image.save(os.path.join(self.module_dir, "blurred-images", os.path.splitext(image_name)[0] + "_blurred_image.jpg"))
return blurred_image
def get_aggregated_mask(self, image, text_prompts, labels=None):
"""Returns aggregated mask for given image, text prompts, and labels"""
aggregated_mask = 0
for annotation in self.get_annotations(image, text_prompts, labels):
aggregated_mask += annotation["segmentation"].float()
return aggregated_mask > 0
def get_annotations(self, image, text_prompts, labels=None):
"""Returns annotations predicted by SAM"""
return self.annotate_sam_inference(*self.predict_masks(image,
self.predict_boxes(image, text_prompts),
labels))
def predict_masks(self, image, box_prompts, labels=None):
"""Returns predicted masks by SAM"""
if labels is None:
labels = [1]*len(box_prompts)
return self.sam(image, bboxes=box_prompts, labels=labels)
def predict_boxes(self, image, prompts):
"""Returns bounding boxes for given image and prompts"""
return [list(pred["box"].values()) for pred in self.owlvit(image, prompts)]
def initialize_device(self, device):
"""Initializes device based on availability"""
if device is None:
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
return torch.device(device)
def blur_entire_image(self, image, radius=50):
"""Returns Gaussian-blurred image"""
return transforms.functional.pil_to_tensor(image.filter(ImageFilter.GaussianBlur(radius=radius)))
def annotate_sam_inference(self, inference, area_threshold=0):
"""Returns list of annotation dicts
Args:
inference (ultralytics.engine.results.Results): Output of the model
area_threshold (int): Threshold for the segmentation area
"""
annotations = []
for i in range(len(inference.masks.data)):
annotation = {}
annotation["id"] = i
annotation["segmentation"] = inference.masks.data[i].cpu()==1
annotation["area"] = annotation["segmentation"].sum()
if annotation["area"] >= area_threshold:
annotations += [annotation]
return annotations
def read_image(self, root, base_folder, image_name, size=None):
"""Returns the openned for given image name base folder, and root"""
image = Image.open(os.path.join(root, base_folder, image_name))
if size:
image = self.resize_image(image, size)
return image
def resize_image(self, image, size):
"""Returns resized image"""
return image.resize(size)
def create_dirs(self, root):
"""Creates required directories under the given root"""
dir_names = ["images-to-blur", "blurred-images"]
for dir_name in dir_names:
os.makedirs(os.path.join(root, dir_name), exist_ok=True)
if __name__ == "__main__":
image = Image.open("images-to-blur/dogs.jpg")
image = image.resize((1024, 1024))
blur_image = BlurImage(device="cpu")
#annotations = blur_image.get_annotations(image, ["small nose"])
#print(annotations)
#print(blur_image.get_aggregated_mask(image, ["small nose"]).shape)
image = "dogs.jpg"
blur_image.blur(image, ["jacket"])