Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Viz rotation #53

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions bayes3d/genjax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from genjax.incremental import Diff, NoChange, UnknownChange

import bayes3d as b
import bayes3d.scene_graph

from .genjax_distributions import (
contact_params_uniform,
Expand Down Expand Up @@ -127,14 +128,14 @@ def get_far_plane(trace):


def add_object(trace, key, obj_id, parent, face_parent, face_child):
N = b.get_indices(trace).shape[0] + 1
N = get_indices(trace).shape[0] + 1
choices = trace.get_choices()
choices[f"parent_{N-1}"] = parent
choices[f"id_{N-1}"] = obj_id
choices[f"face_parent_{N-1}"] = face_parent
choices[f"face_child_{N-1}"] = face_child
choices[f"contact_params_{N-1}"] = jnp.zeros(3)
return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[1]
return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[0]


add_object_jit = jax.jit(add_object)
Expand All @@ -151,7 +152,7 @@ def print_trace(trace):


def viz_trace_meshcat(trace, colors=None):
b.clear()
b.clear_visualizer()
b.show_cloud(
"1", b.apply_transform_jit(trace["image"].reshape(-1, 3), trace["camera_pose"])
)
Expand Down Expand Up @@ -223,14 +224,14 @@ def enumerator(trace, key, *args):
key,
chm_builder(addresses, args, chm_args),
argdiff_f(trace),
)[2]
)[0]

def enumerator_with_weight(trace, key, *args):
return trace.update(
key,
chm_builder(addresses, args, chm_args),
argdiff_f(trace),
)[1:3]
)[0:2]

def enumerator_score(trace, key, *args):
return enumerator(trace, key, *args).get_score()
Expand Down Expand Up @@ -301,4 +302,4 @@ def update_address(trace, key, address, value):
key,
genjax.choice_map({address: value}),
tuple(map(lambda v: Diff(v, UnknownChange), trace.args)),
)[2]
)[0]
58 changes: 58 additions & 0 deletions bayes3d/viser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import viser
import random
import time

import imageio.v3 as iio
import numpy as onp

server.add_frame(
"/tree",
wxyz=(1.0, 0.0, 0.0, 0.0),
position=(random.random() * 2.0, 2.0, 0.2),
)
server.add_frame(
"/tree/branch",
wxyz=(1.0, 0.0, 0.0, 0.0),
position=(random.random() * 2.0, 2.0, 0.2),
)

client_handle = list(server.get_clients().values())[0]

p,q = client_handle.camera.position, client_handle.camera.wxyz

client_handle.camera.position = p
client_handle.camera.wxyz = q

img = client_handle.camera.get_render(100,100)



server = viser.ViserServer()

import os
import trimesh
i = 9
model_dir = os.path.join(b.utils.get_assets_dir(), "ycb_video_models/models")
mesh_path = os.path.join(model_dir, b.utils.ycb_loader.MODEL_NAMES[i],"textured.obj")
mesh = trimesh.load(mesh_path)

server.add_mesh_trimesh(
name="/trimesh",
mesh=mesh,
)

server.reset_scene()


server.add_mesh(
name="/trimesh",
vertices=mesh.vertices,
faces=mesh.faces,
)

sphere = trimesh.creation.uv_sphere(0.1, (10,10,))
server.add_mesh(
name="/trimesh2",
vertices=sphere.vertices * np.array([1.0, 2.0, 3.0]),
faces=sphere.faces,
)
86 changes: 68 additions & 18 deletions bayes3d/viz/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
from PIL import Image, ImageDraw, ImageFont

import bayes3d.utils
Expand Down Expand Up @@ -45,13 +46,12 @@ def preprocess_for_viz(img):
return depth_np


cmap = copy.copy(plt.get_cmap("turbo"))
cmap = copy.copy(plt.get_cmap('turbo'))
cmap.set_bad(color=(1.0, 1.0, 1.0, 1.0))


def get_depth_image(image, min_val=None, max_val=None, remove_max=True):
def get_depth_image(image, max=None):
"""Convert a depth image to a PIL image.

Args:
image (np.ndarray): Depth image. Shape (H, W).
min (float): Minimum depth value for colormap.
Expand All @@ -60,28 +60,22 @@ def get_depth_image(image, min_val=None, max_val=None, remove_max=True):
Returns:
PIL.Image: Depth image visualized as a PIL image.
"""
if len(image.shape) > 2:
depth = np.array(image[:, :, -1])
depth = np.array(image)
if max is None:
maxim = depth.max()
else:
depth = np.array(image)

if max_val is None:
max_val = depth.max()
if not remove_max:
max_val += 1
if min_val is None:
min_val = depth.min()

mask = (depth < max_val) * (depth > min_val)
maxim = max
mask = depth < maxim
depth[np.logical_not(mask)] = np.nan
depth = (depth - min_val) / (max_val - min_val + 1e-10)
vmin = depth[mask].min()
vmax = depth[mask].max()
depth = (depth - vmin) / (vmax - vmin)

img = Image.fromarray(
np.rint(cmap(depth) * 255.0).astype(np.int8), mode="RGBA"
).convert("RGB")
return img


def get_rgb_image(image, max=255.0):
"""Convert an RGB image to a PIL image.

Expand Down Expand Up @@ -465,3 +459,59 @@ def viz_graph(num_nodes, edges, filename, node_names=None):
)
filename_prefix, filetype = filename.split(".")
g_out.render(filename_prefix, format=filetype)




def visualize_rotation_headings(rot_matrices, init_vector = None):
"""
rot_matrices: rotation matrics of shape (N,3,3)
init_vector: rotation about init vector of shape (3,)
if not specified, it will use the first rotation matrix about the vector [1,0,0]
"""

def plot_vector(fig, start, end, name, color, width = 5):
fig.add_trace(go.Scatter3d(x=[start[0], end[0]], y=[start[1], end[1]], z=[start[2], end[2]],
mode='lines+text',
line=dict(width=width, color=color)))

# Sphere
u, v = np.mgrid[0:2*np.pi:20j, 0:np.pi:10j]
x = np.cos(u) * np.sin(v)
y = np.sin(u) * np.sin(v)
z = np.cos(v)

# Initialize figure
fig = go.Figure(data=[go.Surface(z=z, x=x, y=y, colorscale='Greys', opacity=0.3, showscale = False)])

# Original unit vector
if init_vector is None:
unit_vector = np.array([1, 0, 0])
origin_vector = rot_matrices[0] @ unit_vector

else:
unit_vector = init_vector/np.linalg.norm(init_vector)
origin_vector = unit_vector

plot_vector(fig, [0, 0, 0], origin_vector, "Original", "blue", width = 5)

# Apply rotation matrices to the unit vector
for i, R in enumerate(rot_matrices):
if i == 0 and init_vector is None:
continue
transformed_vector = R @ unit_vector
plot_vector(fig, [0, 0, 0], transformed_vector, f"Transformed {i+1}", "red")

# Update layout for a better view
fig.update_layout(scene=dict(xaxis_title='X axis',
yaxis_title='Y axis',
zaxis_title='Z axis',
xaxis=dict(range=[-1,1], autorange=False),
yaxis=dict(range=[-1,1], autorange=False),
zaxis=dict(range=[-1,1], autorange=False),
aspectratio=dict(x=1, y=1, z=1)),
showlegend = False,

margin=dict(l=0, r=0, b=0, t=0))

fig.show()
373 changes: 373 additions & 0 deletions demo_c2f.ipynb

Large diffs are not rendered by default.

Loading
Loading