This repository has been archived by the owner on Oct 19, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 244
/
Copy pathimage_recovery.py
55 lines (38 loc) · 2 KB
/
image_recovery.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#!/usr/bin/env python
# coding: utf-8
# ## 프로젝트 1. 경사 하강법으로 이미지 복원하기
# ### 프로젝트 개요와 목표
# 이번 프로젝트에서 우리가 풀 문제는 다음과 같습니다.
# 이미지 처리를 위해 만들어 두었던 weird_function() 함수에 실수로 버그가 들어가 100×100 픽셀의 오염된 미미지가 만들어졌습니다. 이 오염된 이미지와 오염되기 전 원본 이미지를 동시에 파일로 저장하려고 했으나, 모종의 이유로 원본 이미지 파일은 삭제된 상황입니다. 다행히도 weird_function()의 소스코드는 남아 있습니다. 오염된 이미지와 weird_function()을 활용해 원본 이미지를 복원해봅시다.
# *참고자료: https://github.com/jcjohnson/pytorch-examples, NYU Intro2ML*
import torch
import pickle
import matplotlib.pyplot as plt
shp_original_img = (100, 100)
broken_image = torch.FloatTensor( pickle.load(open('./broken_image_t.p', 'rb'),encoding='latin1' ) )
plt.imshow(broken_image.view(100,100))
def weird_function(x, n_iter=5):
h = x
filt = torch.tensor([-1./3, 1./3, -1./3])
for i in range(n_iter):
zero_tensor = torch.tensor([1.0*0])
h_l = torch.cat( (zero_tensor, h[:-1]), 0)
h_r = torch.cat((h[1:], zero_tensor), 0 )
h = filt[0] * h + filt[2] * h_l + filt[1] * h_r
if i % 2 == 0:
h = torch.cat( (h[h.shape[0]//2:],h[:h.shape[0]//2]), 0 )
return h
def distance_loss(hypothesis, broken_image):
return torch.dist(hypothesis, broken_image)
random_tensor = torch.randn(10000, dtype = torch.float)
lr = 0.8
for i in range(0,20000):
random_tensor.requires_grad_(True)
hypothesis = weird_function(random_tensor)
loss = distance_loss(hypothesis, broken_image)
loss.backward()
with torch.no_grad():
random_tensor = random_tensor - lr*random_tensor.grad
if i % 1000 == 0:
print('Loss at {} = {}'.format(i, loss.item()))
plt.imshow(random_tensor.view(100,100).data)