Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade to genjax 0.2.0; incorporate Nishad's fixes for b3d tutorial #55

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bayes3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .transforms_3d import *
from .viz import *

RENDERER = None
RENDERER: "Renderer" = None

__version__ = metadata.version("bayes3d")

Expand Down
13 changes: 7 additions & 6 deletions bayes3d/genjax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from genjax.incremental import Diff, NoChange, UnknownChange

import bayes3d as b
import bayes3d.scene_graph

from .genjax_distributions import (
contact_params_uniform,
Expand Down Expand Up @@ -127,14 +128,14 @@ def get_far_plane(trace):


def add_object(trace, key, obj_id, parent, face_parent, face_child):
N = b.get_indices(trace).shape[0] + 1
N = get_indices(trace).shape[0] + 1
choices = trace.get_choices()
choices[f"parent_{N-1}"] = parent
choices[f"id_{N-1}"] = obj_id
choices[f"face_parent_{N-1}"] = face_parent
choices[f"face_child_{N-1}"] = face_child
choices[f"contact_params_{N-1}"] = jnp.zeros(3)
return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[1]
return model.importance(key, choices, (jnp.arange(N), *trace.get_args()[1:]))[0]


add_object_jit = jax.jit(add_object)
Expand All @@ -151,7 +152,7 @@ def print_trace(trace):


def viz_trace_meshcat(trace, colors=None):
b.clear()
b.clear_visualizer()
b.show_cloud(
"1", b.apply_transform_jit(trace["image"].reshape(-1, 3), trace["camera_pose"])
)
Expand Down Expand Up @@ -223,14 +224,14 @@ def enumerator(trace, key, *args):
key,
chm_builder(addresses, args, chm_args),
argdiff_f(trace),
)[2]
)[0]

def enumerator_with_weight(trace, key, *args):
return trace.update(
key,
chm_builder(addresses, args, chm_args),
argdiff_f(trace),
)[1:3]
)[0:2]

def enumerator_score(trace, key, *args):
return enumerator(trace, key, *args).get_score()
Expand Down Expand Up @@ -301,4 +302,4 @@ def update_address(trace, key, address, value):
key,
genjax.choice_map({address: value}),
tuple(map(lambda v: Diff(v, UnknownChange), trace.args)),
)[2]
)[0]
12 changes: 6 additions & 6 deletions bayes3d/scene_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,16 +211,16 @@ def relative_pose_from_edge(
face_child,
dims_child,
):
x, y, angle = contact_params
contact_transform = t3d.transform_from_pos(jnp.array([x, y, 0.0])).dot(
t3d.transform_from_axis_angle(jnp.array([1.0, 1.0, 0.0]), jnp.pi).dot(
t3d.transform_from_axis_angle(jnp.array([0.0, 0.0, 1.0]), angle)
)
)
contact_transform = contact_params_to_pose(contact_params)
child_plane = get_contact_planes(dims_child)[face_child]
return contact_transform.dot(jnp.linalg.inv(child_plane))


def relative_pose_from_edge_pose(contact_pose, face_child, dims_child):
child_plane = get_contact_planes(dims_child)[face_child]
return contact_pose.dot(jnp.linalg.inv(child_plane))


relative_pose_from_edge_jit = jax.jit(relative_pose_from_edge)
relative_pose_from_edge_parallel_jit = jax.jit(
jax.vmap(
Expand Down
20 changes: 11 additions & 9 deletions bayes3d/viz/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ def get_depth_image(image, min_val=None, max_val=None, remove_max=True):
depth = np.array(image)

if max_val is None:
max_val = depth.max()
max_val = depth[depth < depth.max()].max(initial=0.0)
if not remove_max:
max_val += 1
if min_val is None:
min_val = depth.min()
min_val = depth[depth > depth.min()].min(initial=0.0)

mask = (depth < max_val) * (depth > min_val)
depth[np.logical_not(mask)] = np.nan
Expand Down Expand Up @@ -208,7 +208,7 @@ def scale_image(img, factor):
return img.resize((int(w * factor), int(h * factor)))


def vstack_images(images, border=10):
def vstack_images(images, border=10, bg_color=(255, 255, 255)):
"""Stack images vertically.

Args:
Expand All @@ -224,7 +224,7 @@ def vstack_images(images, border=10):
max_w = max(max_w, w)
sum_h += h

full_image = Image.new("RGB", (max_w, sum_h), (255, 255, 255))
full_image = Image.new("RGB", (max_w, sum_h), bg_color)
running_h = 0
for img in images:
w, h = img.size
Expand All @@ -233,7 +233,7 @@ def vstack_images(images, border=10):
return full_image


def hstack_images(images, border=10):
def hstack_images(images, border=10, bg_color=(255, 255, 255)):
"""Stack images horizontally.

Args:
Expand All @@ -249,7 +249,7 @@ def hstack_images(images, border=10):
max_h = max(max_h, h)
sum_w += w

full_image = Image.new("RGB", (sum_w, max_h), (255, 255, 255))
full_image = Image.new("RGB", (sum_w, max_h), bg_color)
running_w = 0
for img in images:
w, h = img.size
Expand All @@ -258,7 +258,7 @@ def hstack_images(images, border=10):
return full_image


def hvstack_images(images, h, w, border=10):
def hvstack_images(images, h, w, border=10, bg_color=(255, 255, 255)):
"""Stack images in a grid.

Args:
Expand All @@ -274,10 +274,12 @@ def hvstack_images(images, h, w, border=10):
images_to_vstack = []

for row_idx in range(h):
hstacked_row = hstack_images(images[row_idx * w : (row_idx + 1) * w])
hstacked_row = hstack_images(
images[row_idx * w : (row_idx + 1) * w], border=border, bg_color=bg_color
)
images_to_vstack.append(hstacked_row)

return vstack_images(images_to_vstack)
return vstack_images(images_to_vstack, border=border, bg_color=bg_color)


def multi_panel(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
]
dependencies = [
"distinctipy",
"genjax==0.1.1",
"genjax>=0.2.0",
"graphviz",
"imageio",
"matplotlib",
Expand Down
Loading