-
-
Notifications
You must be signed in to change notification settings - Fork 98
/
Copy pathhelpers.py
60 lines (50 loc) · 2.31 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
import numpy as np
import tensorflow as tf
def vector_plot(vecs, xlim, ylim, cols=["#1190FF", "#FF9A13"], alpha=1):
plt.rc_context({'axes.edgecolor':'orange', 'xtick.color':'red', 'ytick.color':'red'})
plt.axvline(x=0, color='k', zorder=0)
plt.axhline(y=0, color='k', zorder=0)
for i in range(len(vecs)):
if (isinstance(alpha, list)):
alpha_i = alpha[i]
else:
alpha_i = alpha
x = np.concatenate([[0,0],vecs[i]])
plt.quiver([x[0]],
[x[1]],
[x[2]],
[x[3]],
angles='xy', scale_units='xy', scale=1, color=cols[i],
alpha=alpha_i)
plt.ylim(-xlim, xlim)
plt.xlim(-ylim, ylim)
plt.grid()
def plot_vector2d(vector2d, origin=[0, 0], **options):
return plt.arrow(origin[0], origin[1], vector2d[0], vector2d[1],
head_width=0.2, head_length=0.3, length_includes_head=True,
**options)
def plot_transform(P_before, P_after, text_before, text_after, name, color=['#FF9A13', '#1190FF'], axis = [0, 5, 0, 4], arrows=False):
if arrows:
for vector_before, vector_after in zip(tf.transpose(P_before), tf.transpose(P_after)):
plot_vector2d(vector_before, color="#FF9A13", linestyle="--")
plot_vector2d(vector_after, color="#1190FF", linestyle="-")
plt.rc_context({'axes.edgecolor':'orange', 'xtick.color':'red', 'ytick.color':'red'})
plt.gca().add_artist(Polygon(tf.transpose(P_before), alpha=0.2))
plt.gca().add_artist(Polygon(tf.transpose(P_after), alpha=0.3, color="#FF9A13"))
plt.text(-.25, 1, text_before, size=18, color=color[1])
plt.text(1.5, 0, text_after, size=18, color=color[0])
plt.title(name, color='w')
plt.axis(axis)
plt.grid()
def evaluate(tensors):
"""Evaluates Tensor or EagerTensor to Numpy `ndarray`s.
Args:
tensors: Object of `Tensor` or EagerTensor`s; can be `list`, `tuple`,
`namedtuple` or combinations thereof.
Returns:
ndarrays: Object with same structure as `tensors` except with `Tensor` or
`EagerTensor`s replaced by Numpy `ndarray`s.
"""
return tf.nest.pack_sequence_as(tensors,[t.numpy() if tf.is_tensor(t) else t for t in tf.nest.flatten(tensors)])